Optimizing RAG: Should you Fine-Tune?
Watch the video!
Let’s dive into fine-tuning large language models (LLMs) for Retrieval-Augmented Generation or RAG.
What is Fine-Tuning?
So, you’ve got this massive, pre-trained language model like the recent Llama 405B or GPT-4. Think of it as a brainy student who’s read every book in the library. Fine-tuning is like giving this student some extra tutoring in a specific subject like linear algebra. We take this pre-trained model and train it a bit more on a specialized dataset containing only math problems to make it really good at it. This makes the model much smarter about the particular topic you’re interested in.
Fine-tuning is not a competitor to RAG but can complement it. By training the model on specific data, the retriever also finds more relevant documents, and the generator can give more accurate answers since both understand the topic better. This process cuts down on errors, making responses clearer and more precise. It helps RAG perform better in areas like customer support or technical advice, and it also improves how the model handles complex or rare questions. Plus, fine-tuning ensures the retriever and generator work together smoothly, leading to high-quality, relevant answers. Overall, fine-tuning makes RAG models much more effective and efficient at providing the right information. Still, there’s a big downside: fine-tuning is complex to do and costly. You need to build the right dataset, understand the fine-tuning process and iterate to have it right with the right combination of parameters, which all lead to lots of extra costs.
Techniques for Optimizing Fine-Tuning
Fortunately for us, there are various techniques to fine-tune your model more optimally. Let’s look at the first technique called quantization. This method makes our models run faster and use less memory. And there’s nothing new here. It’s just all linked to how computers and bits work. To understand this, we need to understand what precision is. Precision is basically how many bits we use to represent numbers in our models. Floating-point precision like FP32 and FP16 uses respectively 32 and 16 bits to represent each number. It is very accurate but also heavy on memory and computation. Integer precision like INT8 and INT4 use both 8 and 4 bits respectively, which is obviously lighter and faster but will be less precise.
Quantization transforms these high-precision floating-point numbers like fp32 into lower-precision integers like int4. Why? Because it reduces memory usage, increases computational speed, and improves energy efficiency. To put it simply, weights are like fixed rules that tell which word should focus on which part of the sentence. For example, in the sentence “Louis has a YouTube channel, and you should subscribe to it to learn more about AI”, we definitely want the word “subscribe” to focus on Louis to have the information that we are not talking about another YouTube channel.
These rules, known as weights, are like fixed guidelines that help the model understand which parts of the input to focus on. For example, in a sentence, weights help the model determine which words are most important in relation to each other. These weights are relatively stable and don’t change much during the model’s operation.
On the other hand, activations are the dynamic outputs generated by the model as it processes information through its layers. When the model applies the weights to the input data, it produces activations, which represent the model’s current understanding of the input at each step. Since the model is constantly updating its understanding as it moves through each layer, activations change frequently and are more sensitive to precision.
In a typical model, there are more activations than weights, often with a ratio of about 3:1. Reducing the precision of weights (i.e., using fewer bits to represent them) has a smaller impact on the model’s accuracy because weights are relatively stable and don’t change much. However, reducing the precision of activations can have a bigger impact on accuracy because activations change with every step in the processing, and any loss in precision can add up across layers.
So by converting both weights and activations to lower precision, quantization allows us to deploy large models on smaller devices while still maintaining high performance.
The only issue with quantization is that you cannot perfectly represent large, precise numbers with smaller, less precise ones. They use float 32 for a reason. For example, imagine you have the number 3.14159 (a precise value for pi) in FP32 format. If you convert this to a smaller format like Int8, it might just become 3, losing all the decimal detail. This loss of detail means your calculations can become less accurate. In machine learning, this could mean your model makes more mistakes because it can’t capture the small differences that matter. So, while quantization makes computations faster and uses less memory, it can also hurt the performance and accuracy of your results. Fortunately, there are so many weights in a model that it usually still works out well. We also use many tricks to make this process better and find the best combination for speed and accuracy.
One popular library that does this is Bits and Bytes. It uses two main methods: 8-bit quantization and 4-bit quantization. Imagine you have a lot of numbers, some big and some small. B&B keeps the big numbers (outliers) in a more detailed format like FP16 but stores the rest in a simpler 8-bit format like INT8. This means the model uses less memory and can process data faster without losing too much accuracy because it’s still careful with the big numbers.
Bits and Bytes can also store the model’s weights in an even simpler 4-bit format. When the model needs to use these weights, it temporarily converts them to a more detailed format to ensure accuracy. This keeps memory usage very low but still performs well when needed.
But it’s not just about choosing between 4-bit or 8-bit. When fine-tuning a model, there’s a cool trick called Parameter-Efficient Fine-Tuning (PEFT). Instead of tweaking every single part of the model, PEFT just adjusts a small bit to get it working better for your specific task. There are a few ways to do this. The first is by using adapters to add a new layer that’s the only thing you fine-tune. Techniques like Low-Rank Adaptation, or LoRA use some smart math, like matrix multiplication, to change just a part of the model without touching the rest. QLoRA goes a step further by shrinking the model weights down to 4-bit first, with the Q standing for quantized and then applying LoRA. So, with QLoRA, you save a ton of space with 4-bit weights, but still do the heavy lifting in 16-bit to keep things accurate.
So far, we’ve talked about what you do during fine-tuning, but what about after that? How do you deploy these models in a way that’s even better? To get around the performance issues that come from simplifying numbers, there are two solid methods: SmoothQuant and Gradient Post-Training Quantization.
Unlike QLoRA, SmoothQuant takes a different route by quantizing both weights and activations to 8-bit precision. It balances the load between weights and activations, which helps speed things up, saves memory, and still keeps accuracy high. SmoothQuant is built into popular tools like FasterTransformer, Amazon SageMaker, Nvidia TensorRT, and ONNX, so you can make big models run up to 1.56 times faster while using only half the memory.
By the way, ONNX makes it easy to move models between different tools and frameworks. This is especially useful if you’ve trained a model in one environment but need to deploy or fine-tune it in another. If you ever need to switch frameworks, ONNX can make that process a lot smoother.
Gradient Post-Training Quantization, or (GPTQ) also comes in after the model is fully trained, focusing on fine-tuning the quantized model to reduce any accuracy loss.. What sets GPTQ apart is its flexibility: it can be applied to both weights and, in some cases, activations. GPTQ works with different levels of quantization, including highly compact formats like 4-bit, allowing for a smaller, faster model that’s still accurate. This makes GPTQ an excellent choice for deployment when you need to balance performance, speed, and memory usage across various platforms.
But quantization isn’t the only trick for fine-tuning models more efficiently. There are other methods that also help, like model pruning.
Pruning is like trimming a tree. By cutting away the least useful parts, you make the model smaller and faster without losing its strength. This means it requires less computation but still performs just as well. Tools like TensorFlow Model Optimization Toolkit and TorchPrune help with this by figuring out which weights in the network aren’t pulling their weight and safely removing them so your model stays lean and efficient. We often use importance estimation to figure our which weight to prune, which you can learn more about in the Minitron paper by NVIDIA if you are curious.
Another smart approach is Mixed Precision Training. It uses lower precision, like FP16, for most of the model to save time and memory, and higher precision, like FP32, where it really counts. Tools like NVIDIA’s Apex and TensorFlow’s Mixed Precision API handle this by automatically deciding which parts of the model can get by with lower precision and which need that extra detail, making the whole process faster and more memory-efficient.
Then there’s Knowledge Distillation, which Meta is really banking on with their Llama 3.1 models. They’re planning to use the massive 405B model to distill smaller ones. Basically, you train a smaller “student” model to mimic a larger “teacher” model instead of training the smaller ones with a smaller dataset. The student model learns to be almost as good as the teacher, but it’s way lighter and doesn’t require ground truth data as the teacher model gives it, which makes it perfect for running on devices with limited power or memory AND having a great model with much less data preparation.
Lastly, we have graph optimization techniques, like layer fusion, which streamlines a process to cut out unnecessary steps. It takes multiple operations that the model would normally do one by one and fuses them into a single, more efficient operation. This reduces the number of times the model has to pass data between different layers, cutting down on the time and memory needed to get the job done. Tools like ONNX Runtime and NVIDIA TensorRT do this by merging similar operations into one, reducing the number of calculations the model needs to make. This speeds up how fast the model can run and uses less memory, which is a big win when deploying models on platforms with limited resources.
These are powerful techniques that make models run faster, use less memory, and stay accurate. They are especially useful for smaller companies and for deployment. These methods help deploy large models even on devices with limited resources, making AI more accessible and efficient.
Thus, optimizing LLMs involves a mix of hyperparameter tuning, advanced training techniques, and various forms of quantization that can happen during one or all the usual steps: pre-training, fine-tuning or production. Combined with other optimization strategies like pruning, mixed precision training, knowledge distillation, and graph optimization, these methods ensure that the model is efficient, powerful, and ready for deployment!
We discuss these methods in much more depth with applied examples in our RAG course.
Thank you for reading throughout the end, and I hope you found this article useful!