Reverse Engineering Retrieval Augmented Generation
With the rise of Large Language Models (LLMs), particularly since the release of ChatGPT, a new search technique has gained immense popularity. Instead of relying on traditional keyword-based searches, users can now ask questions as if they were speaking to a real person, receiving grammatically coherent answers that synthesize and summarize the top search results.
Here are a few scenarios where this approach can save time:
- Technical documentation: Quickly determine if a feature or functionality is supported. “Can we integrate OAuth2 authentication with our API using this framework?”
- Legal documents: Check whether a specific action complies with regulations. “Is it legal to use customer data for marketing under GDPR?”
- Customer support: Provide concise, product-related answers. “How do I reset my router password?”
This technique is called Retrieval-Augmented Generation (RAG). In this article, I’ll break down the RAG process, step by step, so you can implement it using the tools of your choice. We’ll start by examining the fully abstracted end result and then simplify the process until it all becomes clear.
Data Extraction With a Large Language Model
I like to think of an LLM as a virtual expert that you can interact with via a REST API. You send it a question, and it responds with an answer — just like sending a Slack message to a knowledgeable colleague.
LLMs are relatively easy to deploy thanks to Ollama. With Ollama you can run an “LLM service” in a Docker container, allowing you to pull language models like Llama 3.2 and expose them through a REST API. You can get started by visiting Ollama’s download page.
curl http://localhost:11434/api/generate -d '{
"model": "llama3.2",
"prompt": "Why is the sky blue?",
"stream": false
}'
{
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"response": "The sky is blue because it is the color of the sky.",
"done": true,
"context": [1, 2, 3],
"total_duration": 5043500667,
"load_duration": 5025959,
"prompt_eval_count": 26,
"prompt_eval_duration": 325953000,
"eval_count": 290,
"eval_duration": 4709213000
}
Here’s an example of a message exchange between a developer at a tech company and an LLM, demonstrating how RAG works in practice:
Prompt:
Can you give me an overview of our authentication system?
Response:
Our authentication system is built on OAuth 2.0 and supports multi-factor authentication (MFA) for enhanced security. It integrates with our internal identity provider and external providers like Google and Microsoft. The system includes role-based access control (RBAC), allowing fine-grained permissions for various services. We use a token-based approach to authenticate users, with tokens expiring after a set period to reduce risk. For detailed implementation, refer to the ‘Authentication’ section in our developer documentation.
If you look closely at the exchange above, you’ll notice that the answer from the model contains some context. The context of a conversation can change depending on the use case, so it’s not inherently built into an LLM — we have to define it. In other words, if the context of the conversation were not hidden from us, the user message would probably look like this:
Prompt:
Context:
- Our authentication system uses OAuth 2.0 with JSON Web Tokens (JWT) for secure access. It integrates with external identity providers (Google, Microsoft, Okta) for Single Sign-On (SSO) and uses LDAP internally for user storage and validation.
- Multi-Factor Authentication (MFA) is required for sensitive roles, offering options like SMS, authenticator apps, and hardware tokens, configurable by user group.
- Role-Based Access Control (RBAC) provides granular permissions via roles like “Admin” or “Viewer,” managed through the Admin Console.
- JWTs include claims for roles, permissions, and expiration time. Tokens expire after a set period and require re-authentication to reduce unauthorized access.
- OAuth 2.0 authentication flow redirects users to IdP login pages for credential validation. IdPs issue tokens validated by the system before granting access.
- Detailed logs capture login attempts with timestamp, IP, and method. Anomaly detection flags suspicious activity and alerts administrators.
- A RESTful API supports user management, role assignment, and token validation with clear documentation and examples for integration.
Question:
Can you provide an overview of our authentication system?
The context section can be programmatically added by the server to a user message before it’s sent to the LLM. However, it’s important to ensure that this process is secure. Users might attempt to format their queries to manipulate or override the provided context, potentially gaining access to unauthorized information. This exploit is commonly referred to as a “jailbreak,” where users bypass the model’s intended restrictions.
The question then becomes: how can we retrieve the relevant parts of the documentation? Hypothetically, this could be achieved with a database designed to store frequently asked questions alongside their answers. Given an input question, the system could query this database to retrieve the top K similar questions, which would then be mapped to their corresponding answers.
With this in mind, we can design an LLM-based program to automatically extract a set of questions and answers from documents as they are submitted to the company’s archive. Here’s an example of how this extraction process might work:
Prompt:
Given the following document:
{{document}}
Generate questions and answers based this JSON Schema format:
{{jsonSchema}}
Ensure that the questions reference answers using unique IDs, and answers are grouped separately.
Response:
{
"questions": [
{
"id": "q1",
"question": "What authentication framework does our system use?",
"answer_id": "a1"
},
{
"id": "q2",
"question": "How is our authentication system architected?",
"answer_id": "a1"
},
{
"id": "q3",
"question": "What security features does our authentication system provide?",
"answer_id": "a2"
},
{
"id": "q4",
"question": "How does multi-factor authentication work in our system?",
"answer_id": "a2"
}
],
"answers": [
{
"id": "a1",
"answer": "Our authentication system is built on the OAuth 2.0 framework and implements JSON Web Tokens (JWT) for secure access management. It supports integration with external identity providers (IdPs) such as Google, Microsoft, and Okta, enabling Single Sign-On (SSO). Internally, it uses an LDAP directory for user information storage and validation, with support for custom IdPs within the enterprise network."
},
{
"id": "a2",
"answer": "The system mandates Multi-Factor Authentication (MFA) for sensitive roles and offers options like SMS verification, authenticator apps, and hardware tokens, configurable by user group."
}
]
}
Since an LLM response can have any shape or form, it can be critical to specify the output format using something like a JSON Schema. By defining a clear structure, we ensure that the response is predictable, and easier to process programmatically.
Data Processing With an Embedding Model
Everything looks good in theory, but there’s an important assumption we need to address: the ability of our database to store and query frequently asked questions. While we could compare a user’s input to every question stored in the database using an LLM, this approach would be computationally expensive, quickly becoming impractical. With potentially millions of questions in the database, performing such comparisons would require significant processing power. To handle this at scale, we need a much faster, more efficient querying method.
There’s an important criterion to consider before picking such a method: two questions can appear different but have nearly identical semantic meanings. For example:
- User input: How can I cancel my subscription?
- Stored question: How do I stop my membership?
As humans, we can easily recognize that these two questions likely have the same answer, but a machine can’t make that connection. To help the machine understand this, we need to transform the questions into a format that makes it easier to compare them using a mathematical equation, such as a vector of numbers. This can be done by using an embedding model, which is another type of AI model we’ll leverage alongside the LLM.
An embedding model converts text into a vector of numbers, where similar questions produce similar vectors. Just like an LLM, an embedding model can be pulled and hosted over a REST API using Ollama. A popular model for this purpose is called nomic-embed-text, but any embedding model will work.
curl http://localhost:11434/api/embed -d '{
"model": "nomic-embed-text",
"input": "How can I cancel my subscription?"
}'
{
"model": "all-minilm",
"embeddings": [[
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348,
... # length: 768
]],
"total_duration": 14143917,
"load_duration": 1019500,
"prompt_eval_count": 8
}
To measure the similarity between two questions, we can embed them and apply the Cosine Similarity formula to the resulting embeddings:
This formula essentially compares the angle between the vectors. A similarity of 1 means the vectors point in the same direction, 0 means they are perpendicular (no similarity), and -1 means they are pointing in opposite directions.
To clarify things further, I’ve put together a program that demonstrates how to check if two questions are similar using the concepts we just discussed:
def are_questions_similar(q1: str, q2: str, threshold: float = 0.8):
e1 = embed(q1)
e2 = embed(q2)
return cosine_similarity(e1, e2) > threshold
Querying With Hierarchical Navigable Small World
While Cosine Similarity significantly improves the speed of querying a user’s input, it’s still not fast enough, especially when scanning through large numbers of stored records. To speed this up further, we need a method to index vectors based on their similarity. This is where the Hierarchical Navigable Small World (HNSW) algorithm comes into play.
HNSW can seem complex at first, but here’s a simplified explanation of its core concept: it links pieces of data based on their similarity, as measured by a specific mathematical metric (Cosine Similarity in our case). To perform a search, the algorithm compares the input to one of the stored data nodes. If the similarity surpasses a predefined threshold, the current node is returned. Otherwise, the algorithm moves to the next most similar neighbor and repeats the process.
This explanation focuses on the main search mechanism while omitting some of the more advanced details, such as how the algorithm maintains multiple versions of the data graph to further improve the searching process. For a deeper dive into how the algorithm operates, I recommend this detailed article by Pinecone. In addition, you can check out this excellent visualization tool called Feder, to better understand how HNSW works.
Any database that can store vectors and supports HNSW-based indexing can be used to build a RAG application, such as:
- Postgres with pgvector extension
- Elasticsearch
- Solr
With that in mind, we can write a program that takes advantage of the database’s internal mechanisms to insert and query questions efficiently. We just need to ensure that questions are embedded before sending them to the database:
def insert_question(db: Database, q: str):
e = embed(q)
db.questions.insert(e, q)
def query_questions(db: Database, q: str, top_k: int = 5):
e = embed(q)
return db.questions.query(e, top_k)
Since HNSW uses a threshold rather than seeking the absolute best match, it can perform very fast searches, but at the cost of accuracy. Because of this, it’s important to experiment with the algorithm’s parameters to achieve an optimal balance between performance and accuracy.
Recap
Now that we understand the components of a RAG application, let’s put all the pieces together and outline the data pipeline from creation to user delivery:
- When someone writes a new document, they upload it to the RAG system.
- A set of potential frequently asked questions (FAQs) and their corresponding answers are extracted from the document using an LLM.
- The extracted questions are converted into vectors using an embedding model.
- These vectors are indexed with HNSW and stored in a vector database alongside the answers.
- When a user submits a question, it is transformed into a vector using the same embedding model.
- The input vector is compared against stored vectors using HNSW, returning a set of similar questions.
- The database is queried for the answers corresponding to these similar questions.
- The answers are summarized with an LLM.
- The summary is returned to the user.
It’s important to note that while this is one way to implement RAG, there are many variations. The key is understanding the core principles and adapting them to suit your specific use case.