TODAY'S ISSUE
TODAYβS DAILY DOSE OF DATA SCIENCE
Extend the context length of LLMs
Consider this:
- GPT-3.5-turbo had a context window of 4,096 tokens.
- Later, GPT-4 took that to 8,192 tokens.
- Claude 2 reached 100,000 tokens.
- Llama 3.1 β 128,000 tokens.
- Gemini β 1M+ tokens.
We have been making great progress in extending the context window of LLMs.
Today, let's understand some techniques that help us unlock this.
What's the challenge?
In a traditional transformer, a model processing 4,096 tokens requires 64 times more computation (quadratic growth) than one handling 512 tokens due to the attention mechanism.
Thus, having a longer context window isn't just as easy as increasing the size of the matrices, if you will.
But at least we have narrowed down the bottleneck.
If we can optimize this quadratic complexity, we have optimized the network.
A quick note: This bottleneck was already known way back in 2017 when transformers were introduced. Since GPT-3, most LLMs have been utilizing non-quadratic approaches for attention computation.
1) Approximate attention using Sparse Attention
Instead of computing attention scores between all pairs of tokens, sparse attention limits that to a subset of tokens, which reduces the computations.
There are two common ways:
- Use local attention, where tokens attend only to their neighbors.
- Let the model learn which tokens to focus on.
As you may have guessed, there's a trade-off between computational complexity and performance.
2) Flash Attention
This is a fast and memory-efficient method that retains the exactness of the traditional attention mechanism.
The whole idea revolves around optimizing the data movement within GPU memory. Here are some background details and how it works.
In a GPU:
- A thread is the smallest unit of execution.
- A group of threads is called a block.
Also:
- A block executes the same kernel (function, to simplify), and its threads cooperate by sharing a fast memory block called SRAM.
- All blocks together can access a shared global memory block in the GPU called HBM.
A note about SRAM and HBM:
- SRAM is scarce but extremely fast.
- HBM is much more abundant but slow (typically 8-15x slower).
The quadratic attention and typical optimizations involve plenty of movement of large matrices between SRAM and HBM:
- First, the product of query (Q) and key (K) is distributed to threads, computed, and brought back to HBM.
- Next, the above result is again distributed to threads to compute the softmax of the product and again brought back to HBM once it is done.
Flash attention reduces the repeated movements by utilizing SRAM to cache the intermediate results.
This way, it reduces the redundant movements, and typically, this offers a speed up of ~8x over standard attention methods.
Also, it scales linearly with sequence length, which is great.
What else?
While reducing the computational complexity is crucial, this is not sufficient.
See, using the above optimization, we have made it practically feasible to pass longer contexts without drastically increasing the computation cost.
However, the model must know how to comprehend longer contexts and the relative position of tokens.
That is why the selection of positional embeddings is crucial.
Rotary positional embeddings (RoPE) usually work the best since they preserve both the relative position and the relation.
If you want to learn more about RoPE, let us know. We can cover it in another issue.
In the meantime, if you want to get into the internals of CUDA GPU programming and understand the internals of GPU, how it works, and learn how CUDA programs are built, we covered it here: Implementing (Massively) Parallelized CUDA Programs From Scratch Using CUDA Programming.
π Over to you: What are some other ways to extend the context length of LLMs?
PRIVACY PRESERVING ML
Train models on private data with federated learning
Thereβs so much data on your mobile phone right now β images, text messages, etc.
And this is just about one user β you.
But applications can have millions of users. The amount of data we can train ML models on is unfathomable.
The problem?
This data is private.
So consolidating this data into a single place to train a model.
The solution?
βFederated learning is a smart way to address this challenge.
The core idea is to ship models to devices, train the model on the device, and retrieve the updates:
But this isn't as simple as it sounds.
1) Since the model is trained on the client side, how to reduce its size?
2) How do we aggregate different models received from the client side?
3) [IMPORTANT] Privacy-sensitive datasets are always biased with personal likings and beliefs. For instance, in an image-related task:
- Some devices may only have pet images.
- Some devices may only have car images.
- Some people may love to travel, and may primarily have travel-related images.
- How to handle such skewness in data distribution?
βLearn how to build federated learning systems (beginner-friendly and with implementation) β
PRODUCTION ML (WITH IMPLEMENTATION)
5 techniques to test ML models in production
Despite rigorously testing an ML model locally (on validation and test sets), it could be a terrible idea to instantly replace the previous model with the new model.
A more reliable strategy is to test the model in production (yes, on real-world incoming data).
While this might sound risky, ML teams do it all the time, and it isnβt that complicated.
There are many ways to do this.
THAT'S A WRAP
No-Fluff Industry ML resources to
Succeed in DS/ML roles
At the end of the day, all businesses care about impact. Thatβs it!
- Can you reduce costs?
- Drive revenue?
- Can you scale ML models?
- Predict trends before they happen?
We have discussed several other topics (with implementations) in the past that align with such topics.
Here are some of them:
- Learn sophisticated graph architectures and how to train them on graph data in this crash course.
- So many real-world NLP systems rely on pairwise context scoring. Learn scalable approaches here.
- Run large models on small devices using Quantization techniques.
- Learn how to generate prediction intervals or sets with strong statistical guarantees for increasing trust using Conformal Predictions.
- Learn how to identify causal relationships and answer business questions using causal inference in this crash course.
- Learn how to scale and implement ML model training in this practical guide.
- Learn 5 techniques with implementation to reliably test ML models in production.
- Learn how to build and implement privacy-first ML systems using Federated Learning.
- Learn 6 techniques with implementation to compress ML models.
All these resources will help you cultivate key skills that businesses and companies care about the most.
SPONSOR US
Advertise to 450k+ data professionals
Our newsletter puts your products and services directly in front of an audience that matters β thousands of leaders, senior data scientists, machine learning engineers, data analysts, etc., around the world.