In the last few days, you may likely have at least heard of the Kolmogorov Arnold Networks (KAN). It's okay if you don't know what they are or how they work; this article is precisely intended for that.

However, KANs are gaining a lot of traction because they challenge traditional neural network design and offer an exciting new paradigm to design and train them.

In a gist, the idea is that instead of just stacking many layers on top of each other (which are always linear transformations, and we deliberately introduce non-linearity with activation functions), KANs recommend an alternative approach.

The authors used complex learnable functions (called B-Splines) that help us directly represent non-linear input transformations, with relatively fewer parameters than a traditional neural network.

In the above image, if you notice closely:

- In a neural network, activation functions are fixed across the entire network, and they exist on each
**node**(shown below).

- In KAN, however, activation functions exist on edges (which correspond to the weights of a traditional neural network), and every edge has a different activation function.

Thus, it must be obvious to expect that KAN is much more accurate and interpretable with fewer parameters.

Going ahead, we shall dive into the technical details of how they work.

Let’s begin!

The universal approximation theorem (UAT) is the theoretical foundation behind the neural networks we use.

In simple words, it states that a neural network with just one hidden layer containing a finite number of neurons can approximate **ANY** continuous function to a reasonable accuracy on a compact subset of $\mathbb{R}^n$, given suitable activation functions.

The key points are:

**Network Structure**: Single hidden layer.**Function Approximation**: Any continuous function can be approximated.**Activation Function**: Must be non-linear and suitable (e.g., sigmoid, ReLU).

Mathematically speaking, for any continuous function $f$ and $\epsilon > 0$, there always exists a neural network $\hat{f}$ such that:

It was only proved for sigmoid activation when it was first proposed in 1989. However, in 1991, it was proven to be applicable to all activation functions.

The fundamental thought process involved here was that a piecewise function can approximate any continuous function on a **bounded set**.

For instance, the cosine curve between $-\frac{\pi}{2}$ and $\frac{\pi}{2}$ can be approximated with several steps fitted to this curve.

Naturally, the more the number of steps, the better the approximation. To generalize, neural networks with a single hidden layer can map each of its neurons to one piecewise function.

By using weights and biases as gates, each neuron can determine if an input falls within its designated region.

For inputs falling in a neuron’s designated section, a large positive weight will push the neuron’s output towards $1$ when using a sigmoid activation function. Conversely, a large negative weight will push the output towards $0$ if the input does not belong to that section.

While today’s neural networks are much more complex, and large and do not just estimate step functions, there is still some element of “gating” involved, wherein certain neurons activate based on specific input patterns.

Moreover, as more hidden layers are added, the potential for universal approximation grows exponentially — second-layer neurons form patterns based on patterns of first-layer, and so on.

This is another theorem that’s based on approximating/representing continuous functions.

More formally, the Kolmogorov-Arnold representation theorem asserts that any multivariate continuous function can be represented as the composition of a **finite** number of continuous functions of a single variable.

Let’s break that down a bit:

- "multivariate continuous function" is a function that accepts multiple parameters:

- "a finite number of continuous functions of a single variable" means the following:

We don't stop here though. The above sum is passed through one more function $psi$:

Finally, we also apply a composition to the last step and make this more general as follows:

So, to summarize, the Kolmogorov-Arnold representation theorem asserts that any multivariate continuous function can be represented as the composition of continuous functions of a single variable.

If we expand the sum terms, we get the following:

Here's a simple toy example:

Now, if you go back to the KAN network image shown earlier, this is precisely what's going on in the proposed network.

At first:

- The input $x_1$ is passed through univariate functions ($\phi_{11}, \phi_{12}, \cdots, \phi_{15}$) to get ($\phi_{11}(x_1), \phi_{12}(x_1), \cdots, \phi_{15}(x_1)$).
- The input $x_2$ is passed through univariate functions ($\phi_{21}, \phi_{22}, \cdots, \phi_{25}$) to get ($\phi_{21}(x_2), \phi_{22}(x_2), \cdots, \phi_{25}(x_2)$).

Next, the two corresponding outputs are aggregated (summed), so the $\psi$ function in this case is just the identity operation ($\psi(z) = z$):

This forms one KAN layer.

Next, we pass the above output through one more KAN layer.

So the above output is first passed through the $\phi$ function of the next layer and summed to get the final output:

In terms of the scope, the Universal Approximation Theorem applies explicitly to neural networks, while the Kolmogorov-Arnold theorem is a broader mathematical result, which the authors appear to extend to neural networks that supposedly estimate some unknown continuous function.

Talking about the approach, the Universal Approximation Theorem deals with **approximating** functions using neural networks. In contrast, the Kolmogorov-Arnold theorem provides a way to **exactly represent** continuous functions through sums of univariate functions, which is expected to make them more accurate.

Finally, from the application perspective, the Universal Approximation Theorem is widely used in machine learning to justify the power of neural networks, while the Kolmogorov-Arnold theorem is aligned towards a theoretical insight that is now being extended to machine learning through architectures like Kolmogorov Arnold Networks.

Thus, since the Kolmogorov-Arnold theorem is more aligned towards exactly representing the multivariate function, embedding it in neural networks is expected to be much more efficient and precise.

In fact, this exact representation could potentially lead to models that can interpreted well and may require fewer resources for training, as the architecture would inherently incorporate the mathematical properties of continuous functions being learned.

Okay, so far, I hope everything is clear.

Just to reiterate...

All we are doing in a single KAN layer is taking the input $(x_1, x_2, \cdots, x_n)$ and applying a transformation $\phi$ to it.

Thus, the transformation matrix $\phi^1$ (corresponding to the first layer) can be represented as follows:

In the above matrix:

- $n$ denotes the number of inputs.
- $m$ denotes the number of output nodes in that layer.
**[IMPORTANT] The individual entries are not numbers, they are univariate functions.**For instance:- $\phi_{11}$ could be $2x^2 - 3x + 4$.
- $\phi_{12}$ could be $4x^3 + 5x^2 + x - 2$.
- and so on...

So to generate a transformation, all we have to do is take the input vector and pass it through the corresponding functions in the above transformation matrix:

This will result in the following vector:

The above is the output of the first layer, which is then passed through the next layer for another function transformation.

Thus, the entire KAN network can be condensed into one formula as follows:

Where:

- $x$ denotes the input vector
- $\phi^k$ denotes the function transformation matrix of layer $k$.
- $KAN(x)$ is the output of the KAN network.

The above formulation can appear quite similar to what we do in neural networks:

The only difference is that the parameters $\theta^i$ are linear transformations, and $\sigma$ denotes the activation function used for non-linearity, and it is the same activation function across all layers.

In the case of KANs, the matrices $\phi^k$ themselves are non-linear transformation matrices, and each univariate function can be quite different.

But you might be wondering how we train/generate the activation functions that are present on every edge.

More specifically, how do we estimate the univariate functions in each of the transformation matrices?

To understand this, we must first understand the concept of Bezier curves and B-Splines.

Imagine a gaming character that must pass through the following four points:

The most obvious way is to traverse them as follows:

But this movement does not appear natural, does it?

In computer graphics, we would desire a smooth traversal that looks something like this:

One way to do this is by estimating the coefficients of a higher-degree polynomial by solving a system of linear equations:

More specifically, we know the location of points $(A, B, C, D)$, so we can substitute them into the above function and determine the values of the coefficients $(a, b, c, d)$.

However, what if we have hundreds of data points?

To formalize a bit, if we have N data points, we must solve $N$ equations to determine the parameters. Solving such a system of linear equations will be computationally expensive and almost infeasible for almost real-time computer graphics applications.

Bezier curves solve this problem efficiently. They provide a way to represent a smooth curve that passes near a set of control points without needing to solve a large system of equations.

For instance, consider we have these two points, and we need a traversal path:

👉

Points $P_1$ and $P_2$ are called control points.

We can determine the curve as follows:

- When $t=0$, the position $P$ will be $P_1$.
- When $t=1$, the position $P$ will be $P_2$.
- We can vary the parameter $t$ from [0,1] and obtain the trajectory.

What if we have three points?

This time, we can extend the above case of 2 points to three points as follows:

- First, we determine a point between $P_1$ and $P_2$ as follows:

- Next, we can determine another point between $P_2$ and $P_3$ as follows:

- Finally, between points $Q_1$ and $Q_2$, we can create a linear interpolation:

The final curve looks like this:

- We vary $t$ from $[0,1]$ and obtain the red curve.

The final equation for the position of point $P$ is given by:

Moving on, the same idea can be extended to 4 points as well to get the following curve:

- We begin by interpolating between
- $P_0$ and $P_1$ to get $Q_1$.
- $P_1$ and $P_2$ to get $Q_2$.
- $P_2$ and $P_3$ to get $Q_3$.

- Next, we interpolate between:
- $Q_1$ and $Q_2$ to get $R_1$.
- $Q_2$ and $Q_3$ to get $R_2$.

- Finally, we interpolate between:
- $R_1$ and $R_2$ to get the final point $P$ (red curve above).

The final formula for the trajectory is given by:

To summarize, these are the trajectory formulas we have obtained so far for 2 points, 3 points, and 4 points:

If you notice closely, the coefficients in these formulas match with the binomial coefficients of $(1+x)^n$, so we can extend the above formula for traversal curve to a higher dimension as follows:

In the above formula, the $c_{i,n}$ is a function of $t$, and at every time stamp, it denotes the contribution of point $P_i$ to the final curve.

Let’s go back to the case where we had 4 points:

If we plot the coefficient terms (marked in yellow boxes above), we get the following plot:

In the above plot, at every timestamp $t$, the plot denotes the contribution of every point to the final curve. For instance:

- At $t=0$, except $P_0$, coefficients of all points are zero, which means the curve starts from $P_0$.
- From $t \in (0.25, 0.5)$, the coefficient of point $P_1$ is max, which means that the curve is closest to $P_1$ in that duration.
- From $t \in (0.5, 0.75)$, the coefficient of point $P_2$ is max, which means that the curve is closest to $P_2$ in that duration.
- At $t=1$, except $P_3$, coefficients of all points are zero, which means the curve ends at $P_3$.

Everything looks good so far.

We have a standard formula for bezier curves that can be extended to any number of points:

However, the problem is still the same as before.

Having 100 data points will result in a polynomial of degree 99, which will be computationally expensive. Moreover, the factorial terms are going to be equally hard to compute as well.

We need to find a better way.

B-Splines provide a more efficient way to represent curves, especially when dealing with a large number of data points.

Unlike high-degree polynomials, B-Splines use a series of lower-degree polynomial segments, which are connected smoothly.

In other words, instead of extending Bezier curves to tens of hundreds of data points, which leads to an equally high degree of the polynomial, we use multiple lower-degree polynomials and connect them together to form a smooth curve.

To exemplify, consider the set of data points in the image below:

Given that we have 6 points, we can generate a bezier curve of degree 5. That's always an option. However, as discussed above, this is still computationally expensive and not desired.

Instead, we can create curves of smaller degrees (say, 3), and then connect them.

For instance, a full B-spline can be created as follows:

- Some part of it can come from a curve of degree 3 from points $(P_1, P_2, P_3, P_4)$
- Some part of it can come from a curve of degree 3 from points $(P_2, P_3, P_4, P_5)$
- Some part of it can come from a curve of degree 3 from points $(P_3, P_4, P_5, P_6)$

These individual curves are represented below:

When we have $n$ control pints (6 in the diagram above), and we create $k$ degree polynomial Bezier curves, we get $(n-k)$ Bezier curves in the final Bsplines.

While the actual mathematics behind B-splines is beyond this deep dive, this is a great lecture to understand them in detail:

In a gist, the core idea is to ensure certain continuity conditions at the points where the curves meet.

**Position Continuity ($C^0$ Continuity)**:- Ensures that the curves meet at shared points.
- This means ensuring that the end point of the first segment is the same as the starting point of the second segment, and so forth.

**Tangent Continuity ($C^1$ Continuity)**:- Ensures that the direction of the curves is consistent at the join points.
- This is achieved by ensuring that the first derivative is continuous at the join points. It is not necessary to ensure this through the curve because the individual Bezier curves are polynomials, so they are always continuous.

**Curvature Continuity ($C^2$ Continuity)**:- Ensures that the curvature of the curve is smooth across the joins.
- This involves higher-order derivatives and is more complex but provides even smoother transitions.

Similar to what we saw in Bezier curves, the final B-spline curve is represented as a linear combination of the points $P_i$:

- $P_i$: Control points that define the shape of the curve.
- $𝑁_{𝑖,𝑘}(𝑡)$: B-spline basis functions of degree $k$ associated with each control point $P_i$, and they are similar to what we saw earlier in the case of Bezier curves and are
**fixed**.

With those details in mind, we are ready to understand how B-splines are used in KAN to train those higher-degree activation functions.

Recall that whenever we create B-splines, we already have our control points $(P_1, P_2, \dots, P_n)$, and the underlying basis functions based on the degree $k$.

And by varying the position of the control points, we tend to get different curves, as depicted in the video below:

**So here's the core idea of KANs:**

Let's make the positions of control points learnable in the activation function so that the model is free to learn any arbitrary shape activation function that fits the data best.

That's it.

By adjusting the positions of the control points during training, the KAN model can dynamically shape the activation functions that best fit the data.

- Start with the initial positions for the control points, just like we do with weights.
- During the training process, update the positions of these control points through backpropagation, similar to how weights in a neural network are updated.
- Use an optimization algorithm (e.g., gradient descent) to adjust the control points so that the model can minimize the loss function.

Mathematically speaking, in KANs, every activation function $\phi(x)$ is defined as follows:

Within this:

- The computation involves a basis function $b(x)$ (similar to residual connections).
- $spline(x)$ is learnable, specifically the parameters $c_i$, which denotes the position of the control points.
- There's another parameter $w$. The authors say that, in principle, $w$ is redundant since it can be absorbed into $b(x)$ and $spline(x)$. Yet, they still included this $w$ factor to better control the overall magnitude of the activation function.

**For initialization purposes:**

- Each activation function is initialized with $spline(x) \approx 0$. This is done by drawing B-spline coefficients $c_i ∼ \mathcal{N} (0, \sigma^2)$ with a small $σ$ around $0.1$.

- Moreover, $w$ is initialized according to the Xavier initialization:

Done!

Once this has been defined, the network can be trained just like any other neural network.

More specifically:

- We initialize each of the $phi$ matrices as follows:

- Run the forward pass:

- Calculate the loss and run backpropagation.

Done!

Let's compare the number of parameters of a KAN and MLP, both with $L$ layers and an equal number of neurons in each layer ($N$):

- The number of edges from one layer to another is the same in both cases – $N^2$.
- While the edges of MLP just hold one weight, the edges of KAN hold more parameters because of splines.

- For a B-spline with $G$ control points and degree $k$, the number of basis functions are $G+k-1$. In KAN, as each basis function is associated with one parameter, the number of parameters is also the same – $G+k-1$.

- So, every edge in a KAN holds $G+k-1$ parameters.
- The final parameter count comes out to be as follows (including all layers):

While MLPs appear to be more efficient than KANs, a point to note is that based on their experiments, KANs usually don't require as much large $N$ as MLPs do. This saves parameters while also achieving better generalization.

The authors have presented many performance-related figures that compare the performance of KANs with MLP on various dummy/toy datasets.

Consider the following image:

- In both plots,
- KANs consistently outperform MLPs, achieving significantly lower test loss across a range of parameter, and at much lower network depth (number of layers).
- KANs demonstrate superior efficiency, with steeper declines in loss, particularly noticeable with fewer parameters.
- MLP's performance almost stagnates with increasing the number of parameters.

- The theoretical lines, $𝑁^{−4}$ for KAN and $𝑁^{−2}$ for ideal models (ID), show that KANs closely follow their expected theoretical performance.

Here are the results for two more functions:

Yet again, in both plots, KANs considerably outperform MLPs by achieving significantly lower test loss across a range of parameter, and at much lower network depth (number of layers).

These are some more results across various data shapes, and yet again, KANs' test RMSE is lower than that of MLP.

Moving on, another incredible thing that KANs possess is continual learning.

For more context, it is commonly observed that when a neural network is trained on a particular task and then shifted to being trained on task 2, the network will soon forget about how to perform task 1.

This happens because there is no notion of locality-storage of knowledge in MLPs, which is found in humans.

Quoting from the paper – “human brains have functionally distinct modules placed locally in space. When a new task is learned, structure re-organization only occurs in local regions responsible for relevant skills, leaving other regions intact. MLPs do not have this notion of locality, which is probably the reason for catastrophic forgetting.”

This is evident from the image below:

As depicted above, as new data is added, an MLP's learned fit drastically changes.

However, now look at the performance of KANs in the figure below:

As depicted above, as new data is added, KAN retains the previously learned fit and can also adapt to the new changes.

This happens due to the locality property of B-Splines. To understand better, consider the B-Spline below and notice what happens when I move this point in the animation below this image:

The left part of the B-splines is unaffected by any movement of the above control point.

Since spline bases are local, a sample will only affect a few nearby spline coefficients, leaving far-away coefficients intact (which is desirable since far-away regions may have already stored information that we want to preserve).

However, since MLPs usually use global activations, e.g., ReLU/Tanh/SiLU, etc., any local change may propagate uncontrollably to regions far away, destroying the information being stored there.

This makes intuitive sense as well.

That said, the authors do mention that they are unclear whether this can be generalized to more realistic setups, and they have left investigating this as future work.

Since KANs learn univariate functions at all levels, which can be inspected if needed, it is pretty easy to determine the structure learned by the network using those formulas.

For instance, consider the KAN network below, which learns $f(x, y) = xy$.

Inspecting the B-splines learned by KAN, we notice the following:

- Both inputs $x$ and $y$ get transformed to their squares in the first KAN later.
- Activation #1 $\rightarrow$ maps $x$ to $x$.
- Activation #2 $\rightarrow$ maps $x$ to $x^2$.
- Activation #3 $\rightarrow$ maps $y$ to $y$.
- Activation #4 $\rightarrow$ maps $y$ to $y^2$.

- In the second layer:
- The sum of Activation #1 and #3, which is $(x+y)$, gets squared by activation #5. This outputs $(x+y)^2$.
- The sum of Activation #2 and #4, which is $(x^2+y^2)$, gets negated by activation #6. This outputs $-(x^2+y^2)$.

- At the output node, we get the final output, which is the sum of activation #5 and activation #6:

The additional "2" here is just a constant factor and can be adjusted.

For instance, it's possible that activation #5 and #6 learned the following:

- Activation #5 $\rightarrow$ maps input $a$ to $\frac{a^2}{2}$.
- Activation #6 $\rightarrow$ maps input $a$ to $\frac{-a}{2}$.

The scaling factors are not visible to us in the above figure but I hope you get the point.

By inspecting the nodes, we can determine the function learned by the KAN network.

The authors of KAN have been honest about one of its biggest bottlenecks, which is its slow training.

They noticed that KANs are usually 10x slower than MLPs when both have the same number of parameters.

However, the overall performance difference could be slight because, as discussed earlier, KANs usually don't require as many parameters as MLPs.

The authors mention that they did not try much to optimize KANs’ efficiency, and they treat bottlenecks more like an engineering problem, which can be improved in future iterations of KAN released by them or the community.

And this did come true.

Within a week's time, someone released an efficient implementation of KAN, called efficient-KAN, which can be found in the GitHub repo below.

As an exercise, by reading the README of the above repo, can you identify the major bottleneck of the original implementation of KAN, and how this implementation improved it?

With this, we come to an end of this deep dive on KANs.

If you have read through so far, you may have understood that KANs are not entirely based on any novel idea but instead based on a common intuition of how neural networks can model data more efficiently, which is quite comprehensible.

While writing this article, I went through a bunch of videos, which I am including here for further reference:

- A quick introduction to KANs:

- For those who understand a bit about KANs, this video distills some of the ongoing discussions since it was released.

- Finally, this is a comprehensive lecture on KANs:

That said, I know we did not discuss much about the code in this article, but I do intend to cover that pretty soon, possibly in an article dedicated to implementing KANs from scratch using PyTorch.

Also, while writing this article, I found this repository, which is a curated list of awesome libraries, projects, tutorials, papers, and other resources related to the Kolmogorov-Arnold Network (KAN).

As always, thanks for reading!

Any questions?

Feel free to post them in the comments.

Or

If you wish to connect privately, feel free to initiate a chat here:

]]>Due to advancements in open-source frameworks like PyTorch, utilizing GPUs to accelerate the training of deep learning models just takes one simple step, as demonstrated below:

While this encapsulation is as easy as it can get, the underlying **implementation** of how GPUs accelerate computing tasks is still unknown to many.

More specifically, what happens under the hood when we do a `.cuda()`

call?

In this article, let’s understand the mechanics of GPU programming.

More specifically, we shall understand how CUDA, a programming interface developed by NVIDIA, allows developers to run processes on their GPU devices and how the **underlying implementations **work.

Thus, we shall do a hands-on demo on CUDA programming and implement parallelized implementations of various operations we typically perform in deep learning.

Let’s begin!

Simply put, CUDA, also known as *Compute Unified Device Architecture*, is a parallel computing platform and application programming interface (API) model created by NVIDIA (as mentioned above).

It allows developers to use a CUDA-enabled graphics processing unit (GPU) for general-purpose processing.

💡

For more context, general-purpose processing means leveraging computational capabilities for a broad range of tasks beyond its specialized functions. In the context of GPUs, which were originally designed for handling graphics rendering in games, general-purpose processing allows us to adapt them to perform computation-intensive tasks that are not necessarily related to graphics.

In order to leverage the processing power of GPU, CUDA provides an interface implemented in C/C++. This allows us to access the GPU’s memory and run compute operations.

In the context of deep learning, these are typical mathematical operations and general operations like:

- Adding matrics/vectors
- Multiplying matrices
- Transforming a matrix by applying a function, such as an activation function, dropout, etc.
- Moving data from CPU to GPU, and then back to CPU.
- And more.

To put it another way, just like Pandas allows developers to interact seamlessly with tabular datasets through high-level data structures and operations, CUDA enables a similar level of abstraction but for processing on NVIDIA GPUs.

By abstracting the underlying complexities associated with GPU, such as memory management, thread handling, and handling blocks, developers get to focus more on solving the computational problems at hand rather than the intricacies of the hardware they are running on.

Of course, as mentioned above, the CUDA provides a C/C++ interface, and deep learning frameworks like PyTorch made this simpler by building a Python-based wrapper around it:

So when we develop deep learning models, we use a Python API, which, under the hood, has implemented C/C++ instructions provided by CUDA to talk to GPU.

Thus, just to be clear, the objective of this deep dive is to understand the CUDA -> GPU instructions and how they are implemented.

While some proficiency in C/C++ is good to have, it is not entirely necessary as the API design is quite intuitive. Yet, I will provide some supporting texts at every stage of the programming if you have never used C++ before.

When it comes to parallelization, the traditional approach with CPUs involves leveraging threads, which carry instructions that the processor can execute independently.

For instance, consider this simple for-loop written in C, which performs the operations of a typical linear layer:

Notice the for-loop in the above code. It iterates over the elements of the arrays one by one, computes the individual result, and stores it in the output array.

Here, if you notice closely, all these individual operations are independent of each other.

Thus, they can be executed in parallel. However, the above implementation uses a sequential approach where each operation waits for the previous one to complete before starting.

Given that modern CPUs can handle multiple threads simultaneously through techniques like **multi-threading**, we can make use of it to allow the loop to spawn multiple threads, potentially one per core, to handle different parts of the array simultaneously.

And this must be obvious to understand that typically, the more threads we add, the higher the level of parallelism we can achieve.

In deep learning, however, things are different.

The building blocks of deep learning models – vectors and matrices, can have millions of elements.

However, by their very nature, CPUs are limited in the degree of parallelism they can achieve because of a limited number of cores (even high-end consumer CPUs rarely have more than 16 cores).

GPUs are modern architectures that can run millions of threads in parallel. This enhances the run-time performance of these mathematical operations when dealing with massively sized vectors and matrices.

Let’s consider the matrix multiplication operation depicted below, which is quite prevalent in deep learning:

At its core, matrix multiplication is just a series of various independent vector dot products, as depicted in the animation below:

- $1^{st}$ row of left matrix is multiplied with:
- $1^{st}$ column of the right matrix.
- $2^{nd}$ column of the right matrix.
- $3^{rd}$ column of the right matrix.
- And so on.

- $2^{nd}$ row of left matrix is multiplied with:
- $1^{st}$ column of the right matrix.
- $2^{nd}$ column of the right matrix.
- $3^{rd}$ column of the right matrix.
- And so on.

- And so on...

If you look closely, all these operations are independent of one another. As a result, all these operations can be potentially executed in parallel.

However, as discussed above, since CPUs have fairly limited threads, they are not ideal for exploiting the full potential of parallelizing these matrix multiplication operations.

In other words, CPUs typically handle only a few dozen threads simultaneously, which limits their efficiency in executing numerous independent operations at once, such as those required by large-scale matrix multiplications.

Here, it is essential to clarify one point that CPUs are not "bad."

There's a reason why all modern computers always come with a CPU, but they may or may not have a “fast” GPU. This is because the CPU and GPU are designed to accomplish completely different goals.

To begin, CPU computations are usually faster than that of GPU for a **single operation**. In other words, they are designed to quickly execute a sequence of single-threaded operations. Thus, to maintain that speed, they can only execute a few threads in parallel.

In contrast, GPUs are designed to execute millions of threads in parallel at the cost of the speed of individual threads.

So while a thread may take less time to execute on a CPU, but as GPUs can execute millions of them in parallel, it tremendously boosts the overall run-time.

One real-life analogy I heard when I first learned about them many years back compared the CPU to an F1 car and the GPU to a bus.

If the objective is to move just one person from point A to point B, then the F1 car, i.e., the CPU, will be an ideal choice.

On the other hand, if the objective is to move many people from one point to another, then the bus, i.e., the GPU, will be an ideal choice. This is despite the fact that an F1 car will take less time to travel from point A to point B.

While the bus can transport everyone in one trip, an F1 car would require multiple trips.

Now that we have covered the basics, let’s get into the programming related details.

In this section, we shall learn how CUDA programs are written. We shall first start by understanding the components of CUDA programming and then get into the implementations.

There are three components in CUDA programming – host, device, and kernel. These three components are foundational to how CUDA interfaces with the hardware and manages computations.

Here's a brief overview of each of them:

In CUDA terminology, the "host" refers to the CPU and its memory. It's where your program starts and runs before offloading any parallel compute-intensive tasks to the GPU.

**The host controls the entire application flow**, it initiates data transfers to and from the GPU memory, and it launches GPU kernels.

In other words, it orchestrates the preparation and execution of GPU tasks from a higher level.

The "device" refers to the GPU itself and its associated memory. In the context of CUDA, when we mention the device, we're typically talking about the CUDA-enabled GPU that will perform the actual parallel computations.

The device executes the code specified in the CUDA kernels, which are functions written to run on the GPU.

It handles the intensive computational tasks that have been offloaded from the CPU, utilizing its massively parallel architecture to process data more efficiently for specific types of tasks.

A kernel in CUDA is a function written in CUDA C/C++ that runs on the GPU. This is the core piece of code that is executed in parallel by multiple threads on the CUDA device.

When a kernel is launched, the GPU executes it across many threads in parallel.

Each thread executes an instance of the kernel and operates on different data elements. The kernel defines the compute operations each thread will perform, making it the primary means of parallel computation in CUDA.

👉

If things are not clear, don't worry. They will become clear shortly when we dive into the CUDA programming.

So, to recap, here’s what this process looks like:

**Preparation on the host**: The host CPU executes the main part of the CUDA program, setting up data in its own memory, and preparing instructions.**Data transfer**: Before the GPU can begin processing, the necessary data must be transferred from the host’s memory to the device’s memory.**Kernel launch**: The host directs the device to execute a kernel, and the GPU schedules and runs the kernel across its many threads.**Post-processing**: After the GPU has finished executing the kernel, the results are typically transferred back to the host for further processing or output, completing the compute cycle.

When it comes to GPU computing, the key advantage is its ability to execute many operations (specified in the kernel) in parallel.

Thus, instead of executing the kernel just once and iterating through the computations one by one, we execute it $N$ times in parallel.

However, this parallel execution is not just about blasting multiple instances of the same operation on the GPU.

Instead, it’s about structuring the entire computation in a way that maximizes the GPU's architectural strengths—mainly, its capacity to handle a vast number of simultaneous threads.

We achieve this using the hierarchical organization of GPU computation, called threads, blocks, and grids.

A thread is the smallest unit of execution. Each thread executes the set of instructions specified in the kernel that are specific to the thread that is invoking the kernel.

Also, each thread is mapped to a single CUDA core for execution:

A block is a group of threads that execute the same kernel and can cooperate by sharing data and synchronizing their execution.

Blocks are, in essence, a way to organize threads into manageable, cooperative groups that can efficiently execute part of a larger problem. Typically, a single block can contain up to 1024 threads, but this may vary depending on the computing capability of the GPU.

Each block is mapped to a corresponding CUDA core for execution:

A grid is the highest level of thread organization in CUDA. The blocks within a grid can operate independently, meaning they do not share data directly nor synchronize with each other.

The entire kernel launched is executed as one grid, which is mapped onto the entire device:

Moving on...

Structurally speaking, Threads inside a Block can be organized in up to three dimensions, as depicted below:

Consider the 1D case shown above, where Threads are arranged in a single dimension.

Furthermore, as there is a limit on the number of Threads a block can hold, we can have many blocks, which, for simplicity, may also arranged in a single dimension inside the entire grid:

Here are the parameters associated with this configuration, which we shall reference later in the kernel function:

- Every block has a width variable of
`blockDim`

.- We can obtain the width along the x-axis using
`blockDim.x`

variable. - We can obtain the width along the y-axis using
`blockDim.y`

variable (which will be`1`

in the above case).

- We can obtain the width along the x-axis using

In a recent article, we learned about LoRA, which stands for **L**ow-**R**ank **A**daptation.

It is a technique used to fine-tune large language models (LLMs) on new data. We also implemented it using PyTorch and the Huggingface PEFT library:

As also discussed in the article on vector databases, fine-tuning means adjusting the weights of a **pre-trained model** on a new dataset for better performance. This is depicted in the animation below:

The motivation for traditional fine-tuning is pretty simple.

When the model was developed, it was trained on a specific dataset that might not perfectly match the characteristics of the data a practitioner may want to use it on.

The original dataset might have had slightly different distributions, patterns, or levels of noise compared to the new dataset.

Fine-tuning allows the model to adapt to these differences, learning from the new data and adjusting its parameters to improve its performance on the specific task at hand.

However, a problem arises when we use the traditional fine-tuning technique on much larger models — LLMs, for instance.

This is because these models are huge — billions or even trillions of parameters, and hundreds of GBs in size.

Traditional fine-tuning is just not practically feasible here. In fact, not everyone can afford to do fine-tuning at such a scale due to a lack of massive infrastructure and the costs associated with such an endeavor.

We covered this in much more detail in the LoRA/QLoRA article, so I would recommend reading that:

LoRA has been among the most significant contributions to AI in recent years. As we discussed earlier, it completely redefined our approach to large model fine-tuning by modifying only a small subset of model parameters.

We also mentioned it in the 12 years of AI review we did recently (see year 2021).

Now, of course, it's been some time since LoRA was first introduced. Since then, many variants of LoRA have been proposed, each tailored to address specific challenges and improve upon the foundational technique.

The timeline of some of the most popular techniques introduced after LoRA is depicted below:

Going ahead, in this article, we will explore the LoRA family in-depth, discussing each variant's design philosophy, technical innovations, and the specific use cases they aim to address.

Let’s begin!

The core idea in LoRA, as also discussed in the earlier article, revolves around **training very few parameters** in comparison to the base model, say, full GPT-3, while preserving the performance that we would otherwise get with full-model fine-tuning (which we discussed above).

More specifically, two low-rank matrices $A$ and $B$ are added alongside specific layers, and these l0w-rank matrices contain the trainable parameters:

Mathematically, the adaptation is executed by modifying the weight matrix $\Delta W$ in a transformer layer using the formula:

Here, $W$ represents the adapted weight matrix, and $AB$ is the low-rank modification applied to $W$.

As depicted in the LoRA diagram above, the dimensions of matrices $A$ and $B$ are much smaller in size compared to $W$, leading to a significant reduction in the number of trainable parameters.

This low-rank update, despite its simplicity, proves to be remarkably effective in retaining the nuanced capabilities of the LLM while introducing the desired adaptations specific to a new task or dataset.

This way, if there are plenty of users who wish to fine-tune an LLM model (say, from OpenAI), OpenAI must only store the above two matrices $A$ and $B$ (for all layers where this was introduced), which is pretty small in size.

However, the original weight matrix $W$ being common across all fine-tuned versions can have a central version, i.e., one that can be shared across all users.

As per the original paper on LoRA, they reduced the checkpoint size by roughly **10,000 times** — from 350GB to just **35MB**.

Moreover, they also observed a 25% speedup during training on the GPT-3 175B model compared to full fine-tuning, which is pretty obvious because we do not compute the gradient for the vast majority of the parameters.

Another key benefit is that it also introduces no inference latency. This is because of its simple linear design, which allows us to merge the trainable matrices ($A$ and $B$) with the frozen weights ($W$) when deployed, so one can proceed with an inference literally the same way as they would otherwise do.

A pretty cool thing about LoRA is that the hyperparameter $r$ can be orders of magnitude smaller than the dimensions of the corresponding weight matrix.

For instance, in the results table, compare the results of $r=1$ with that of other ranks:

In most cases, we notice that $r=1$ almost performs as well as any other higher rank, which is great!

In other words, this means that the $A$ and $B$ can be a simple row and column matrix.

Next, let’s understand the variants of LoRA and how they differ from LoRA.

Building on the foundational Low-Rank Adaptation (LoRA) technique, the **LoRA-FA** method introduces a slight change that reduces the memory overhead associated with fine-tuning large language models (LLMs).

I have had the opportunity to work on several real-world machine learning projects, both as a full-time data scientist and then as a part-time data scientist companies would outsource their data science projects to.

Building end-to-end data science and machine learning modeling projects has taught me many invaluable lessons, pitfalls, and cautionary measures that I never found anyone talking about explicitly.

To be honest, the practical lessons I am about to share in this article are something I wish someone told me when I started my career (or was progressing).

But it would be best if you didn't feel that way.

So, in this blog post, I have put down eight pitfalls you might experience and cautionary measures you can take when working on data science projects.

In my experience, these pitfalls are almost always present, but they are never that obvious to observe, which ruins many projects.

Let’s begin!

If we were to visualize the decision rules (the conditions evaluated at every node) of ANY decision tree, we would ALWAYS find them to be perpendicular to the feature axes, as depicted below:

In other words, every decision tree progressively segregates feature space based on such perpendicular boundaries to split the data.

Of course, this is not a “problem” per se.

In fact, this perpendicular splitting is what makes it so powerful to perfectly overfit any dataset (read the overfitting experiment section here to learn more).

However, this also brings up a pretty interesting point that is often overlooked when fitting decision trees.

More specifically, what would happen if our dataset had a diagonal decision boundary, as depicted below:

It is easy to guess that in such a case, the decision boundary learned by a decision tree is expected to appear as follows:

In fact, if we plot this decision tree, we notice that it creates so many splits just to fit this easily separable dataset, which a model like logistic regression, support vector machine (SVM), or even a small neural network can easily handle:

It becomes more evident if we zoom into this decision tree and notice how close the thresholds of its split conditions are:

This is a bit concerning because it clearly shows that the decision tree is meticulously trying to mimic a diagonal decision boundary, which hints that it might not be the best model to proceed with.

**To double-check this, I often do the following:**

- Take the training data
`(X, y)`

;- Shape of
`X`

:`(n, m)`

. - Shape of
`y`

:`(n, 1)`

.

- Shape of
- Run PCA on
`X`

to project data into an orthogonal space of`m`

dimensions. This will give`X_pca`

, whose shape will also be`(n, m)`

. - Fit a decision tree on
`X_pca`

and visualize it (*thankfully, decision trees are always visualizable*). - If the decision tree depth is significantly smaller in this case, it validates that there is a diagonal separation.

For instance, the PCA projections on the above dataset are shown below:

It is clear that the decision boundary on PCA projections is **almost** perpendicular to the `X2``

feature (the 2nd principal component).

Fitting a decision tree on this `X_pca`

drastically reduces its depth, as depicted below:

This lets us determine that we might be better off using some other algorithm instead.

Or, we can spend some time engineering better features that the decision tree model can easily work with using its perpendicular data splits.

At this point, if you are thinking, why can’t we use the decision tree trained on `X_pca`

?

While nothing stops us from doing that, do note that PCA components are not interpretable, and maintaining feature interpretability can be important at times.

Thus, whenever you train your next decision tree model, consider spending some time inspecting what it’s doing.

Of course, the objective is not to discourage the use of decision trees. They are the building blocks of some of the most powerful ensemble models we use today.

The point is to bring forward the structural formulation of decision trees and why/when they might not be an ideal algorithm to work with.

A typical blueprint of any real-world machine learning (ML) looks like the following:

- Formulate the problem statement
- Get the management’s approval
- Gather the required datasets
- Explore the gathered data
- Start building a model
- Make improvements
- Validate the model
- Test the model
- Improve it

Once you are satisfied:

- Productionize the ideal model
- Proceed with deployment
- Set up logging methods
- Handover the model
- Go back to step 1

Of course, the above process can be a bit more comprehensive, but the overall blueprint from idea inception to handover to the team you built that solution for is almost the same across projects.

Also, in the above process:

- Steps 1-9 mainly highlight development in the local environment.
- Steps 10-13 are inclined towards the production environment.

To elaborate further, during the local development phase (steps 1-9), the model goes through rigorous engineering and testing to ensure its accuracy, robustness, and generalizability. We do this all the time.

Testing using validation/test sets is critical in this phase as it helps identify and rectify any issues before the model is sent for **productionisation** (which demands considerable engineering efforts).

For more context, productionisation is the phase where an ideal model is prepared for deployment in a production setting.

For instance, if we developed the model in Python, but the server we intend to deploy our model on runs any other language except Python, like C++ or Java, then making it compatible with such environment configuration is what productionisation involves.

We discussed this here in the following article (you can read it after this article):

Moreover, it's possible that we leveraged some classical ML models from sklearn. But these models are not production-friendly because sklearn is built on top of **NumPy, which can only run on a single core of a CPU**.

As a result, it provides sub-optimal performance.

The techniques we discussed in the following article help us make these models more production-friendly (you can read it after this article):

One more example could be that if the model is to be deployed on edge devices, we may want to reduce its size (discussed below).

In a nutshell, the two primary objectives of the productionisation phase are to optimize the model for deployment and ensure its robustness and reliability.

This involves testing the model against various edge cases and scenarios to ensure that it can handle unexpected inputs and situations gracefully.

Once the model is fully productionised, it is deployed to the production environment, where it begins to serve predictions to end-users or other systems.

Project over?

Not yet!

If we already have a model running in production, it could be a terrible idea to instantly replace the previous model with the updated model.

Instead, a more conservative and reliable strategy is to test the model in production (yes, on real-world incoming data) before completely substituting/discarding the previous version of the model.

Testing a model in production might appear risky, but ML teams do it all the time, and it isn't that complicated.

In the upcoming section, we shall discuss five commonly used techniques to test ML models in production.

We shall also implement these strategies, and in order to do that, we shall be using Modelbit, which we discussed in the following article:

It’s okay if you haven’t read it yet. We will do a quick overview of the model deployment steps in Modelbit.

👉

Let’s begin!

👉

Feel free to skip this section if you already know how Modelbit works.

The core objective behind model deployment is to obtain an API endpoint for our deployed model, which can be later used for inference purposes:

Modelbit lets us seamlessly deploy ML models directly from our Python notebooks (or Git, as we would see ahead in this article) and obtain a REST API.

Since Modelbit is a relatively new service, let’s understand the general workflow to generate an API endpoint when deploying a model with Modelbit.

The image below depicts the steps involved in deploying models with Modelbit:

- Step 1) We connect the Jupyter kernel to Modelbit.
- Step 2) Next, we train the ML model.
- Step 3) We define the inference function. Simply put, this function contains the code that will be executed at inference. Thus, it will be responsible for returning the prediction.
- Step 4) [OPTIONAL] Here, we specify the version of Python and other open-source libraries we used while training the model.
- Step 5) Lastly, we send it for deployment.

Once done, Modelbit returns the API endpoint, which we can integrate into any of the applications and serve end-users with.

Let’s implement this!

First, we must install the Modelbit package first.

We can use `pip`

to install the Modelbit package:

Done!

Also, to deploy and view our deployed models in the Modelbit dashboard, we must create a Modelbit account as well here: https://app.modelbit.com/signup.

Now, we can implement the steps depicted in the earlier animation.

First, we connect our Jupyter kernel to Modelbit. This is done as follows:

]]>In an earlier article on PyTorch Lightning, we did not discuss multi-GPU training.

I mentioned that it will require you to know more background details about how it works, the strategies we use, how multiple GPUs remain in sync with one another during model training in a distributed setting, considerations, and more.

So today, we are continuing with that topic and will be understanding some of the core technicalities behind multi-GPU training, how it works under the hood, and implementation-specific details.

We shall also look at the key considerations for multi-GPU (or distributed) training, which, if not addressed appropriately, may lead to suboptimal performance, slow training, or even instability in training.

Let’s begin!

By default, deep learning models built with PyTorch are only trained on a single GPU, even if you have multiple GPUs available.

This does mean that we cannot do multi-GPU training with PyTorch. We can do that. However, it does require us to explicitly utilize PyTorch's parallel processing capabilities.

Moreover, even if we were to utilize multiple GPUs with PyTorch, typical training procedures would always be restricted to a single machine. This limitation arises because PyTorch's default behavior is to use a single machine for model training.

Therefore, it becomes a severe bottleneck when working with larger datasets that require more computational power than what a single machine can provide.

However, acknowledging that we are restricted to a single machine for model training makes us realize that there is ample scope for further run-time optimization.

Multi-GPU training solves this.

In a gist (and as the name suggests), multi-GPU training enables us to distribute the workload of model training across multiple GPUs and even multiple machines if necessary.

This significantly reduces the training time for large datasets and complex models by leveraging the combined computational power of the available hardware.

While there are many ways (strategies) to achieve multi-GPU training, one of the most common ways is to let each GPU or machine process a portion of the input data **independently**.

This is also called data parallelism.

💡

In addition to data parallelism, other strategies such as model parallelism, pipeline parallelism, and hybrid parallelism can also be used to achieve multi-GPU training. Each of these strategies has its own advantages and disadvantages, and the choice of which strategy to use depends on the specific requirements of the model and the hardware available. We shall discuss a few of them in brief towards the end of the article.

In data parallelism, the idea is to divide the available data into smaller batches, and each batch is processed by a separate GPU.

Finally, the updates from each GPU are then aggregated and used to update the model parameters.

If you recall the deep dive on federated learning, the idea might appear very similar to what we discussed back then:

In federated learning, instead of a single centralized server processing all the data, the model is trained across multiple decentralized edge devices, each with its own data. The updates from these edge devices are then aggregated to improve the global model.

Similarly, in data parallelism, each GPU acts as a "mini-server," processing a portion of the data and updating the model parameters locally.

These local updates are then combined to update the global model. This parallel processing of data not only speeds up the training process but also allows for efficient use of resources in distributed environments.

One major difference is that in federated learning, we do not have direct access to the local dataset, whereas in data parallelism, the data is directly accessible.

Nonetheless, it is quite obvious to understand that this approach not only improves the efficiency of model training but also allows us to scale our training to handle larger datasets and more complex models than would be possible with a single GPU or machine.

As discussed above, **data parallelism** is a technique used in deep learning to parallelize the training of a model by splitting the data across multiple devices, such as GPUs or machines, and then combining the results.

Quite intuitively, this approach can significantly reduce the training time for large models and datasets by leveraging the computational power of multiple devices.

]]>Many machine learning engineers and data scientists very quickly pivot to building a different type of model when they don't get satisfying results with one kind of model.

At times, we do not fully exploit all the possibilities of existing models and continue to move towards complex models when minor tweaks in simple models can achieve promising results.

Over the course of building so many ML models, I have utilized various techniques that uncover nuances and optimizations we could apply to significantly enhance model performance without necessarily increasing the model complexity.

Thus, in this article, I will share 11 such powerful techniques that will genuinely help you supercharge your ML models so that you extract maximum value from them.

I provide clear motivation behind their usage, as well as the corresponding code, so that you can start using them right away.

Let’s begin!

The biggest problem with most regression models is that they are sensitive to outliers.

Consider linear regression, for instance.

Even a few outliers can significantly impact Linear Regression performance, as shown below:

And it isn’t hard to identify the cause of this problem.

Essentially, the loss function (MSE) scales quickly with the residual term (true-predicted).

Thus, even a few data points with a large residual can impact parameter estimation.

**Huber loss (used by Huber Regression) precisely addresses this problem.**

In a gist, it attempts to reduce the error contribution of data points with large residuals.

How?

One simple, intuitive, and obvious way to do this is by applying a threshold (δ) on the residual term:

- If the residual is smaller than the threshold, use MSE (no change here).
- Otherwise, use a loss function that has a smaller output than MSE — linear, for instance.

This is depicted below:

- For residuals smaller than the threshold (δ) → we use MSE.
- Otherwise, we use a linear loss function, which has a smaller output than MSE.

Mathematically, Huber loss is defined as follows:

For instance, consider the 2D dummy dataset again which we saw earlier:

Currently, this dataset has no outliers, so let’s add a couple of them:

As a result, we get the following dataset:

Let’s fit a linear regression model on this dataset:

Next, we visualize the regression fit by plotting the `y_pred`

values as follows:

This produces the following plot:

It is clear that the Linear Regression plot is affected by outliers.

Now, let’s look at the Huber regressor, which, as discussed earlier, uses a linear penalty for all residuals greater than the specified threshold $\delta$.

To train a Huber regressor model, we shall repeat the same steps as before, but this time, we shall use the `HuberRegressor`

class from sklearn.

As a result, we get the following regression fit.

It’s clear that Huber regression is more robust.

While trial and error is one way, creating a residual plot can be quite helpful. This is depicted below:

Here’s how we create it:

**Step 1)**Train a linear regression model as you usually would on the outlier-included dataset.

**Step 2)**Compute the absolute value of residuals (=true-predicted) on the training data.

- Step 3) Plot the absolute
`residuals`

array for every data point.

As a result, we get the following plot:

Based on the above plot, it appears that $\delta=4$ will be a good threshold value.

One good thing is that we can create this plot for any dimensional dataset. The objective is just to plot (true-predicted) values, which will always be 1D.

Overall, the idea is to reduce the contribution of outliers in the regression fit.

One of the things that always makes me a bit cautious and skeptical when using kNN is its HIGH sensitivity to the parameter `k`

.

To understand better, consider this dummy 2D dataset below. The red data point is a test instance we intend to generate a prediction for using kNN.

Say we set the value of `k=7`

.

The prediction for the red instance is generated in two steps:

- First, we count the
`7`

nearest neighbors of the red data point. - Next, we assign it to the class with the highest count among those 7 nearest neighbors.

This is depicted below:

The problem is that **step** **2** is entirely based on the notion of class contribution — the class that maximally contributes to the `k`

nearest neighbors is assigned to the data point.

But this notion fails miserably at times, especially when we have a class with few samples.

For instance, as shown below, with `k=7`

, the red data point can **NEVER** be assigned to the yellow class, no matter how close it is to that cluster:

While it is easy to tweak the hyperparameter `k`

visually in the above demo, this approach is infeasible in high-dimensional datasets.

There are two ways to address this.

Distance-weighted kNNs are a much more robust alternative to traditional kNNs.

As the name suggests, in step 2, they consider the distance to the nearest neighbor.

As a result, the closer a specific neighbor is, the more impact it will have on the final prediction.

For instance, consider this 2D dummy dataset:

Now, consider a dummy test instance:

Training a kNN classifier under the uniform (count-based) strategy, the model predicts the `class 0`

(the red class in the above dataset) as the output for the new instance:

However, the distance-weighted kNN is found to be more robust in its prediction, as demonstrated below:

This time, we get `Class 1`

(the blue class) as the output, which is correct.

As per my observation, a distance-weighted kNN typically works much better than a traditional kNN. And this makes intuitive sense as well.

Yet, this may go unnoticed because, by default, the kNN implementation of sklearn considers `uniform`

weighting.

Recall the diagram I started this kNN discussion with:

Here, one may argue that we must refrain from setting the hyperparameter `k`

to any value greater than the minimum number of samples that belong to a class in the dataset.

Of course, I agree with this to an extent.

But let me tell you the downside of doing that.

Setting a very low value of `k`

can be highly problematic in the case of extremely imbalanced datasets.

To give you more perspective, I have personally used kNN on datasets that had merely one or two instances for a particular class in the training set.

And I discovered that setting a low of `k`

(say, 1 or 2) led to suboptimal performance because the model was not as holistically evaluating the nearest neighbor patterns as it was when a large value of `k`

was used.

In other words, setting a relatively larger value of `k`

typically gives more informed predictions than using lower values.

But we just discussed above that if we set a large value of `k`

, the majority class can dominate the classification result:

**To address this, I found dynamically updating the hyperparameter **`k`

** to be much more effective.**

More specifically, there are three steps in this approach.

For every test data point:

**Step 1)**Begin with a standard value of`k`

as we usually would and find the`k`

nearest neighbors.**Step 2)**Next, update the value of the`k`

as follows:- For all
**unique classes**that**appear in the**`k`

**nearest neighbor**, find the total number of training instances they have.

- For all

- Update the value of k to:

$$ \Large k' = min(k, 40, 3) $$

- Step 3) Now perform
**majority voting only on the first**`k'`

**neighbors only**.

This makes an intuitive sense as well:

- If a minority class appears in the top
`k`

nearest neighbor, we must reduce the value of`k`

so that the majority class does not dominate. - If a minority class DOES NOT appear in the top
`k`

nearest neighbor, we will likely not update the value of`k`

and proceed with a holistic classification.

I used this approach in a couple of my research projects. If you want to learn more, here’s my research paper: **Interpretable Word Sense Disambiguation with Contextualized Embeddings****.**

PyTorch is the go-to choice for researchers and practitioners for building deep learning models due to its flexibility, intuitive Pythonic API design, and ease of use.

It takes a programmer just three steps to create a deep learning model in PyTorch:

- First, we define a model class inherited from PyTorch’s
`nn.Module`

class

- Moving on, we declare all the network components (layers, dropout, batch norm, etc.) in the
`__init__()`

method:

- Finally, we define the forward pass of the neural network in the
`forward()`

method:

That’s it!

Once we have defined the network, we can proceed with training the model by declaring the optimizer, loss function, etc., **without having to define the backward pass explicitly**.

More specifically, one can define the training loop as demonstrated below and train the model easily:

As we saw above, defining the network was so simple and elegant, wasn’t it?

However, as our models grow more complex and larger, several challenges arise when using PyTorch:

With complex models, manually managing the training loop in PyTorch can become tedious.

This includes iterating over the dataset, performing forward and backward passes, and updating the model parameters, which, of course, is quite standardized as shown below, but it does need some review and maintainability.

Logging is crucial for monitoring the training process and analyzing model performance.

PyTorch does not provide built-in support for logging, requiring users to implement their own logging solutions or integrate external logging frameworks.

As models grow larger, training them on multiple GPUs or across multiple machines becomes necessary to reduce training time.

PyTorch provides support for distributed training, but the implementation can be complex, involving setting up processes, synchronizing gradients, and handling communication between processes.

Debugging distributed training can be challenging due to the complexity of the setup and the potential for issues to arise from communication between processes.

Mixed-precision training, which involves using lower precision (e.g., half-precision floating-point numbers) for certain parts of the training process, can help reduce memory usage and speed up training.

PyTorch supports mixed-precision training, but managing the precision of different operations manually is pretty challenging.

We also saw this in a recent newsletter issues, where we into full detail about mixed-precision training, and how it works in PyTorch:

PyTorch natively supports running models on GPUs, but running models on TPUs (Tensor Processing Units) requires additional setup and configuration.

From the above discussion, it’s clear that PyTorch doesn’t provide out-of-the-box solutions for many important tasks, leading to boilerplate code and increased chances of errors.

Of course, these challenges may not be major concern for all types of models, especially small or simple ones.

For small-scale projects, the overhead of managing training loops, logging, and distributed training may not outweigh the benefits of using PyTorch directly. However, as models grow in complexity and size, these challenges become more pronounced.

PyTorch Lightning resolves each of the above-discussed challenges with PyTorch.

You can think of PyTorch Lightning as a lightweight wrapper around PyTorch that abstracts away the boilerplate code, which we typically write with PyTorch, and makes the training process more streamlined and readable.

Just like Keras is a wrapper on TensorFlow, PyTorch lightning is a wrapper on PyTorch, but one that makes it much more efficient than the traditional way of training the model.

Thus, one can use ANY PyTorch model as a PyTorch Lightning model.

As the library is an optimized wrapper around PyTorch, the developers claim to reduce the repeated (boilerplate) code by 70-80%, which minimizes the surface area for bugs and lets us focus on delivering value instead of engineering.

Moreover, as we shall see ahead, with PyTorch Lightning, we can define our model and training logic in a clear and concise manner, which lets us focus more on the research and less on the implementation details.

In fact, the utility is pretty evident from its popularity because its GitHub repo has over `26k`

stars:

Revisiting the challenges with PyTorch, we discussed above, here’s how PyTorch Lightning addresses them.

**Managing training loops**: PyTorch Lightning simplifies this process by providing a high-level abstraction for defining the training loop, reducing the amount of boilerplate code required.**Logging**: PyTorch Lightning integrates with popular logging frameworks like TensorBoard and Comet, making it easier to log training metrics and visualize them in real-time.**Handling distributed training**:**Debugging in a distributed setting**: PyTorch Lightning provides tools and utilities to facilitate debugging in a distributed setting, making it easier to identify and resolve issues.**Mixed-precision training**: PyTorch Lightning simplifies mixed-precision training by providing utilities to automatically handle the precision of operations based on user-defined settings.**Running models on TPUs:**PyTorch Lightning supports running models on TPUs, abstracting away the complexity of the underlying TPU architecture and allowing users to focus on their model implementation.

Along with that, one of the best things about PyTorch Lightning is that it has a minimal API. In most cases, the `LightningModule`

and `Trainer`

class are the only 2 APIs one must learn because the rest is just organized PyTorch.

If none of these things is clear yet, don’t worry. Let’s get into a complete walkthrough of using PyTorch Lightning.

Now that we understand what PyTorch Lightning is and the motivation to use it over PyTorch, let’s get into more details about its implementation and how PyTorch Lightning works.

More specifically:

- We shall begin with a standard PyTorch code, and learn how to convert that into PyTorch Lightning code.
- Next, we shall look at how we use the Trainer() class from PyTorch Lightning to simplify model training and define various methods for training, validation, testing and predicting. Here, we shall also learn how to log model training and integrate various performance metrics during training.
- Finally, we shall deep dive into the additional utilities offered by PyTorch Lightning like mixed precision training, callbacks, profiling code for optimization.

Let’s begin!

In this section, let’s build a simple neural network on the MNIST dataset using PyTorch. Then, we will see how we can convert that code to a PyTorch Lightning code.

Here are the traditional steps to building a model in PyTorch:

First, we import the required packages from PyTorch:

Next, we load the MNIST dataset (train and test) and create their respective PyTorch dataloaders.

Moving on, we define a simple feedforward neural network architecture. This is demonstrated below:

Moving on, we shall initialize the model and define the loss function to train it — the `CrossEntropyLoss`

.

To evaluate the model after every epoch, let’s define an `evaluate()`

method that will iterate over the examples in the `testloader`

and compute the accuracy. This is demonstrated below:

Now, we will train the PyTorch model.

With this, we are done with the PyTorch model.

Now, if we go back to the above code, there’s too much boilerplate code here.

Simply put, boilerplate means the repetitive and standardized sections of code that are necessary for the functioning of the program, but they are not unique to the model we are training. Instead, this is something that we would almost always write in most other projects too.

For instance, the `accuracy()`

method and the training loop contribute to the boilerplate code here.

While these boilerplate sections are essential for training a neural network, they can be cumbersome to write and maintain, and they are pretty repetitive as well.

Now that we have defined a network in PyTorch, let’s see how we can convert this to a Pytorch Lightning with just two to three simple changes.

But first, we must install PyTorch Lightning, which we can do as follows:

Next, we import PyTorch Lightning as follows:

]]>Over the last couple of weeks, we covered several details around vector databases, LLMs, and fine-tuning LLMs with LoRA. Moreover, we implemented LoRA from scratch, learned about RAG, its considerations, and much more.

There’s one thing that’s yet to be addressed in this series of articles.

To recall, there are broadly two popular ways to augment LLMs with additional data:

- RAG
- Fine-tuning using LoRA/QLoRA

Both of them have pros and cons, and different applications.

The question is:

Under what conditions does it make sense to proceed with RAG and when should one prefer Fine-tuning?

To continue this LLM series, I'm excited to bring you a special guest post by Damien Benveniste. He is the author of The AiEdge newsletter and was a Machine Learning Tech Lead at Meta.

In today’s machine learning deep dive, he will provide a detailed discussion on RAG vs. Fine-tuning. I personally learned a lot from this one, and I am sure you will learn a lot too.

Let’s begin!

Let’s say you have a business use case for using an LLM, but you need it to be a bit more specialized to your specific data.

I often see the problem being simplified to Fine-Tuning vs. Retrieval Augmented Generation (RAG), but this is somewhat of a false dichotomy.

Those are two ways to augment LLMs, but each approach has different applications, and they can even be used in a complementary manner.

Fine-tuning assumes you will continue training an LLM on a specific learning task. For instance, you may want to fine-tune an LLM on the following tasks:

- English-Spanish translation
- Custom support message routing
- Specialized question-answers
- Text sentiment analysis
- Named entity recognition
- …

Fine-tuning assumes you have training data to specialize an LLM on a specific learning task. That means you need to be able to identify the correct input data, the proper learning objective, and the right training process.

Regardless of the specific application, fine-tuning will require a training pipeline, a serving pipeline, and potentially a continuous training pipeline.

The training pipeline will need the following components:

**The model registry**: The model registry will contain the original models or different versions of fine-tuned models. A model registry requires capturing the model version and additional metadata describing the model.**The quantization module:**It is typical to quantize models these days to save on memory costs, but this is not a requirement. If the model is quantized, the resulting model will need to be stored in the model registry. Quantization means converting the model weights from floats to integers. This operation can reduce the model size by a factor of 4.**The feature store:**As for any ML model training, the data needs to be carefully prepared and stored in a feature store. In the context of LLMs, the requirements around the feature store may be relaxed if the data at serving time is only generated by a user.**The data validation and pre-processing modules:**When the data is injected into the training pipeline, it needs to be validated and most likely pre-processed for training.**LoRA / QLoRA modifications:**It is now typical to fine-tune LLMs using Low-Rank Adapters (LoRA) or its Quantized version (QLoRA). The idea is to augment, within the model, some of the large matrices with smaller ones, specifically for the gradient computation. When we fine-tune, we only update the weights of those newly inserted matrices. The gradient matrices are much smaller and, therefore, require much less GPU memory space. Because the pre-trained weights are frozen, we don't need to compute the gradients for a vast majority of the parameters.**The LoRA / QLoRA registry:**With LoRA / QLoRA, we need a registry to keep track of the fine-tuned weights and their versions.**Model validation module:**Like any model training, the resulting model needs to be validated on validation data. This assumes we have the right data for the task. This can be tricky because we may want to fine-tune a model on a specific task, but we may also want to retain its original capabilities. For the specific task, you may have the right validation data, but you may be missing the data you need for the original capabilities, leaving you unable to assess if the model is not forgetting its previous programming.

The serving pipeline will need the following components:

- Once ready for deployment, the model and its related LoRA / QLoRA weights are pulled from the model registry and passed through the deployment pipeline. We may have a series of tests, such as canary deployment, to make sure the model fits in a serving pipeline and an A/B test experiment to test it against the production model. After satisfactory results, we can propose the new model as a replacement for the production one.
- The model needs to be continuously monitored using typical ML monitoring systems
- As soon as the users interact with the served model, this may be an opportunity to start aggregating data for the next training update.

Retrieval Augmented Generation (RAG) means that you expose an LLM to new data stored in a database.

We don’t modify the LLM; rather, we provide additional data context in the prompt for the LLM to answer questions with information on the subject.

The idea with RAG is to encode the data you want to expose to your LLM into embeddings and index that data into a vector database.

When a user asks a question, it is converted to an embedding, and we can use it to search for similar embeddings in the database. Once we find similar embeddings, we construct a prompt with the related data to provide context for an LLM to answer the question. The similarity here is usually measured using the cosine similarity metric.

For RAG, we don’t have a training pipeline. We just need an indexing pipeline and a serving pipeline. The indexing pipeline is used to convert the data into vector representation and index it in a vector database:

]]>In the pre-LLM era, whenever someone open-sourced any high-utility model for public use, in most cases, practitioners would fine-tune that model to their specific task.

👉

Of course, it’s not necessary for a model to be open-sourced for fine-tuning if the model inventors provide API-based fine-tuning instead and decide to keep the model closed.

As also discussed in the most recent article on vector databases, fine-tuning means adjusting the weights of a **pre-trained model** on a new dataset for better performance. This is neatly depicted in the animation below:

The motivation to do this is pretty simple.

When the model was developed, it was trained on a specific dataset that might not perfectly match the characteristics of the data a practitioner wants to use it on.

The original dataset might have had slightly different distributions, patterns, or levels of noise compared to the new dataset.

Fine-tuning allows the model to adapt to these differences, learning from the new data and adjusting its parameters to improve its performance on the specific task at hand.

For instance, consider **BERT**. It’s a Transformer-based language model, which is popularly used for text-to-embedding generation (**92k+ citations on the original paper**).

It’s open-source.

As we discussed in the vector database deep dive, BERT was pre-trained on a large corpus of text data, which might be very very different from what someone else may want to use it on.

Thus, when using it on any downstream task, we can adjust the weights of the BERT model along with the augmented layers, so that it better aligns with the nuances and specificities of the new dataset.

The idea makes total practical sense. In fact, it has been successfully used for a long time now, not just after the release of BERT but even prior to that.

However, the primary reason why fine-tuning has been pretty successful in the past is that we had not been training models that were ridiculously massive.

Talking of BERT again, it has two variants:

- BERT-Base, which has 110M parameters (or .11B).
- BERT-Large, which has 340M parameters (or .34B).

This size isn’t overwhelmingly large, which makes it quite feasible to fine-tune it on a variety of datasets without requiring immense computational resources.

However, a problem arises when we use the same traditional fine-tuning technique on much larger models — LLMs, for instance.

This is because, as you may already know, these models are huge — billions or even trillions of parameters.

Consider **GPT-3**, for instance. It has 175B parameters, which is 510 times bigger than even the larger version of BERT called **BERT-Large**:

And to give you more perspective, I have successfully fine-tuned BERT-large in many of my projects on a single GPU cluster, like in this paper and this paper.

But it would have been impossible to do the same with GPT-3.

Moving on, while OpenAI has not revealed the exact number of parameters in GPT-4, it is suspected to be around 1.7 Trillion, which is roughly ten times bigger than GPT-3:

Traditional fine-tuning is just not practically feasible here, and in fact, not everyone can afford to do it due to a lack of massive infrastructure.

In fact, it’s not just about the availability of high computing power.

Consider this...

OpenAI trained GPT-3 and GPT-4 models in-house on massive GPU clusters, so they have access to them for sure.

However, they also provide a fine-tuning API to customize these models according to our application, which is currently available for the following models: `gpt-3.5-turbo-1106`

, `gpt-3.5-turbo-0613`

, `babbage-002`

, `davinci-002`

, and `gpt-4-0613`

:

Going by traditional fine-tuning, for every customer wanting to have a customized version of any of these models, OpenAI would have to dedicate an entire GPU server to load it and also ensure that they maintain sufficient computing capabilities for fine-tuning requests.

Deploying such independent instances of fine-tuned models, each with 175B parameters, is prohibitively expensive.

To put it into perspective, a GPT-3 model checkpoint is estimated to consume about 350GBs. And this is the static memory of the model, which only includes model weights. It does not even consider the memory required during training, computing activations, running backpropagation, and more.

And to make things worse, what we discussed above is just for one customer, but they already have thousands of customers who create a customized version of OpenAI models that is fine-tuned to their dataset.

In fact, there are many other users who just want to explore the fine-tuning capabilities (for skill development or general exploration, maybe), but they may never want to use that model to serve any end-users.

This is crucial because OpenAI would still have to bear the cost of maintaining and serving the full fine-tuned model if even it receives no requests because they have implemented “*Only pay for what you use*” pricing strategy:

Thus, they never earn if the model is never used beyond fine-tuning, which is also not that expensive, as marked in the above image.

From the discussion so far, it must be clear that such scenarios pose a significant challenge for traditional fine-tuning approaches.

The computational resources and time required to fine-tune these large models for individual customers would be immense.

Additionally, maintaining the infrastructure to support fine-tuning requests from potentially thousands of customers simultaneously would be a huge task for them.

**LoRA and QLoRA are two superb techniques to address this practical limitation.**

Today, I will walk you through:

- What they are?
- How do they work?
- Why are they so effective?
- And how to implement them from Scratch in PyTorch?

The whole idea behind LoRA is pretty simple and smart.

In a gist, this technique allows us to efficiently fine-tune pre-trained neural networks. Yes, they don’t have to be LLMs necessarily as they can be used across a wide range of neural networks.

The core idea revolves around **training very few parameters** in comparison to the base model, say, full GPT-3, while preserving the performance that we would otherwise get with full-model fine-tuning (which we discussed above).

Let’s get into more detail in the upcoming sections.

First, let us understand an inspiring observation that will help us formulate the LoRA in an upcoming section.

Consider the current weights of some random layer in the pre-trained model are $W$ of dimensions $d*k$, and we wish to fine-tune it on some other dataset.

During fine-tuning, the gradient update rule suggests that we must add $\Delta W$ to get the updated parameters:

For simplicity, you can think about $\Delta W$ as the update obtained after running gradient descent on the new dataset:

Also, instead of updating the original weights $W$, it is perfectly legal to maintain both matrics, $W$ and $\Delta W$.

During inference, we can compute the prediction on an input sample $x$ as follows:

**In fact, in all the model fine-tuning iterations, $W$ can be kept static, and all weight updates using gradient computation can be incorporated to $\Delta W$ instead.**

But you might be wondering...how does that even help?

The matrix $W$ is already huge, and we are talking about introducing another matrix that is equally big.

So, we must introduce some smart tricks to manipulate $\Delta W$ so that we can fulfill the fine-tuning objective while ensuring we do not consume high memory.

Now, we really can’t do much about $W$ as these weights refer to the pre-trained model. So all optimization (if we intend to use any) must be done $\Delta W$ instead.

While doing so, we must also remember that currently, both $W$ and $\Delta W$ have the same dimensions.

But given that $W$ already is huge, we must ensure that $\Delta W$ does not end up being of the same dimensions, as this will defeat the entire purpose of efficient fine-tuning.

In other words, if we were to keep $\Delta W$ of the same dimensions as $W$, then it would have been better if we had fine-tuned the original model itself.

Now, you might be thinking...

But how can we even add two matrics if both have different dimensions?

It’s true, we can’t do that.

More specifically, during fine-tuning, the weight matrix $W$ is frozen, so it does not receive any gradient updates. Thus, all gradient updates are redirected to the $\Delta W$ matrix.

]]>It’s pretty likely that in the generative AI era (since the release of ChatGPT, to be more precise), you must have at least heard of the term “**vector databases**.”

It’s okay if you don’t know what they are, as this article is primarily intended to explain everything about vector databases in detail.

But given how popular they have become lately, I think it is crucial to be aware of what makes them so powerful that they gained so much popularity, and their practical utility not just in LLMs but in other applications as well.

Let’s dive in!

To begin, we must note that vector databases are NOT new.

In fact, they have existed for a pretty long time now. You have been indirectly interacting with them daily, even before they became widely popular lately. These include applications like recommendation systems, and search engines, for instance.

Simply put, a vector database stores **unstructured data** (text, images, audio, video, etc.) in the form of **vector embeddings**.

Each data point, whether a word, a document, an image, or any other entity, is transformed into a numerical vector using ML techniques (which we shall see ahead).

This numerical vector is called an **embedding,** and the model is trained in such a way that these vectors capture the essential features and characteristics of the underlying data.

Considering word embeddings, for instance, we may discover that in the embedding space, the embeddings of fruits are found close to each other, which cities form another cluster, and so on.

This shows that embeddings can learn the semantic characteristics of entities they represent (provided they are trained appropriately).

Once stored in a vector database, we can retrieve original objects that are similar to the query we wish to run on our unstructured data.

In other words, encoding **unstructured data** allows us to run many sophisticated operations like similarity search, clustering, and classification over it, which otherwise is difficult with traditional databases.

To exemplify, when an e-commerce website provides recommendations for similar items or searches for a product based on the input query, we’re (**in most cases**) interacting with vector databases behind the scenes.

Before we get into the technical details, let me give you a couple of intuitive examples to understand vector databases and their immense utility.

Let's imagine we have a collection of photographs from various vacations we’ve taken over the years. Each photo captures different scenes, such as beaches, mountains, cities, and forests.

Now, we want to organize these photos in a way that makes it easier to find similar ones quickly.

Traditionally, we might organize them by the date they were taken or the location where they were shot.

However, we can take a more sophisticated approach by encoding them as vectors.

More specifically, instead of relying solely on dates or locations, we could represent each photo as a set of numerical vectors that capture the essence of the image.

💡

While Google Photos doesn't explicitly disclose the exact technical details of its backend systems, I speculate that it uses a vector database to facilitate its image search and organization features, which you may have already used many times.

Let’s say we use an algorithm that converts each photo into a vector based on its color composition, prominent shapes, textures, people, etc.

Each photo is now represented as a point in a multi-dimensional space, where the dimensions correspond to different visual features and elements in the image.

Now, when we want to find similar photos, say, based on our input text query, we encode the text query into a vector and compare it with image vectors.

Photos that match the query are expected to have vectors that are close together in this multi-dimensional space.

Suppose we wish to find images of mountains.

In that case, we can quickly find such photos by querying the vector database for images close to the vector representing the input query.

A point to note here is that a vector database is NOT just a database to keep track of embeddings.

Instead, it maintains both the embeddings and the raw data which generated those embeddings.

Why is that necessary, you may wonder?

Considering the above image retrieval task again, if our vector database is only composed of vectors, we would also need a way to reconstruct the image because that is what the end-user needs.

When a user queries for images of mountains, they would receive a list of vectors representing similar images, but without the actual images.

By storing both the embeddings (the vectors representing the images) and the raw image data, the vector database ensures that when a user queries for similar images, it not only returns the closest matching vectors but also provides access to the original images.

In this example, consider an all-text unstructured data, say thousands of news articles, and we wish to search for an answer from that data.

Traditional search methods rely on exact keyword search, which is entirely a brute-force approach and does not consider the inherent complexity of text data.

In other words, languages are incredibly nuanced, and each language provides various ways to express the same idea or ask the same question.

For instance, a simple inquiry like "What's the weather like today?" can be phrased in numerous ways, such as "How's the weather today?", "Is it sunny outside?", or "What are the current weather conditions?".

This linguistic diversity makes traditional keyword-based search methods inadequate.

As you may have already guessed, representing this data as vectors can be pretty helpful in this situation too.

Instead of relying solely on keywords and following a brute-force search, we can first represent text data in a high-dimensional vector space and store them in a vector database.

When users pose queries, the vector database can compare the vector representation of the query with that of the text data, **even if they don't share the exact same wording.**

At this point, if you are wondering how do we even transform words (strings) into vectors (a list of numbers), let me explain.

We also covered this in a recent newsletter issue here but not in much detail, so let’s discuss those details here.

If you already know what embedding models are, feel free to skip this part.

To build models for language-oriented tasks, it is crucial to generate numerical representations (or vectors) for words.

This allows words to be processed and manipulated mathematically and perform various computational operations on words.

The objective of embeddings is to capture semantic and syntactic relationships between words. This helps machines understand and reason about language more effectively.

In the pre-Transformers era, this was primarily done using pre-trained static embeddings.

Essentially, someone would train embeddings on, say, 100k, or 200k common words using deep learning techniques and open-source them.

Consequently, other researchers would utilize those embeddings in their projects.

The most popular models at that time (around 2013-2017) were:

- Glove
- Word2Vec
- FastText, etc.

These embeddings genuinely showed some promising results in learning the relationships between words.

For instance, at that time, an experiment showed that the vector operation `(King - Man) + Woman`

returned a vector near the word `Queen`

.

That’s pretty interesting, isn’t it?

In fact, the following relationships were also found to be true:

`Paris - France + Italy`

≈`Rome`

`Summer - Hot + Cold`

≈`Winter`

`Actor - Man + Woman`

≈`Actress`

- and more.

So, while these embeddings captured relative word representations, there was a major limitation.

Consider the following two sentences:

- Convert this data into a
**table**in Excel. - Put this bottle on the
**table**.

Here, the word “**table**” conveys two entirely different meanings:

- The first sentence refers to a “
**data**” specific sense of the word “table.” - The second sentence refers to a “
**furniture**” specific sense of the word “table.”

Yet, static embedding models assigned them the same representation.

Thus, these embeddings didn’t consider that a word may have different usages in different contexts.

But this was addressed in the Transformer era, which resulted in contextualized embedding models powered by Transformers, such as:

**BERT**: A language model trained using two techniques:- Masked Language Modeling (MLM): Predict a missing word in the sentence, given the surrounding words.
- Next Sentence Prediction (NSP).
*We shall discuss it in a bit more detail shortly.*

**DistilBERT**: A simple, effective, and lighter version of BERT, which is around 40% smaller:- Utilizes a common machine learning strategy called student-teacher theory.
- Here, the student is the distilled version of BERT, and the teacher is the original BERT model.
- The student model is supposed to replicate the teacher model’s behavior.
- If you want to learn how this is implemented practically, we discussed it here:

**SentenceTransformer**: If you read the most recent deep dive on building classification models on ordinal data, we discussed this model there.- Essentially, the
**SentenceTransformer**model takes an entire sentence and generates an embedding for that sentence.

- Essentially, the

- This differs from the BERT and DistilBERT models, which produce an embedding for all words in the sentence.

There are more models, but we won't go into more detail here, and I hope you get the point.

The idea is that these models are quite capable of generating context-aware representations, thanks to their self-attention mechanism and appropriate training mechanism.

For instance, if we consider BERT again, we discussed above that it uses the masked language modeling (MLM) technique and next sentence prediction (NSP).

These steps are also called the **pre-training step** of BERT because they involve training the model on a large corpus of text data before fine-tuning it on specific downstream tasks.

💡

Pre-training, in the context of machine learning model training, refers to the initial phase of training where the model learns general language representations from a large corpus of text data. The goal of pre-training is to enable the model to capture the syntactic and semantic properties of language, such as grammar, context, and relationships between words. While the text itself is unlabeled, MLM and NSP are two tasks that help us train the model in a supervised fashion. Once the model is trained, we can use the language understanding capabilities that the model acquired from the pre-training phase, and fine-tune the model on task-specific data. The following animation depicts fine-tuning:

Moving on, let’s see how the pre-training objectives of masked language modeling (MLM) and next sentence prediction (NSP) help BERT generate embeddings.

- In MLM, BERT is trained to predict missing words in a sentence. To do this, a certain percentage of words in
**most**(not all) sentences are randomly replaced with a special token,`[MASK]`

.

- BERT then processes the masked sentence bidirectionally, meaning it considers both the left and right context of each masked word, that is why the name “
**Bidirectional**Encoder Representation from Transformers (BERT).”

- For each masked word, BERT predicts what the original word is supposed to be from its context. It does this by assigning a probability distribution over the entire vocabulary and selecting the word with the highest probability as the predicted word.

- During training, BERT is optimized to minimize the difference between the predicted words and the actual masked words, using techniques like cross-entropy loss.

- In NSP, BERT is trained to determine whether two input sentences appear consecutively in a document or whether they are randomly paired sentences from different documents.

- During training, BERT receives pairs of sentences as input. Half of these pairs are consecutive sentences from the same document (positive examples), and the other half are randomly paired sentences from different documents (negative examples).

- BERT then learns to predict whether the second sentence follows the first sentence in the original document (
`label 1`

) or whether it is a randomly paired sentence (`label 0`

). - Similar to MLM, BERT is optimized to minimize the difference between the predicted labels and the actual labels, using techniques like binary cross-entropy loss.

💡

If we look back to MLM and NSP, in both cases, we did not need a labeled dataset to begin with. Instead, we used the structure of the text itself to create the training examples. This allows us to leverage large amounts of unlabeled text data, which is often more readily available than labeled data.

Now, let's see how these pre-training objectives help BERT generate embeddings:

**MLM:**By predicting masked words based on their context, BERT learns to capture the meaning and context of each word in a sentence. The embeddings generated by BERT reflect not just the individual meanings of words but also their relationships with surrounding words in the sentence.**NSP:**By determining whether sentences are consecutive or not, BERT learns to understand the relationship between different sentences in a document. This helps BERT generate embeddings that capture not just the meaning of individual sentences but also the broader context of a document or text passage.

With consistent training, the model learns how different words relate to each other in sentences. It learns which words often come together and how they fit into the overall meaning of a sentence.

This learning process helps BERT create embeddings for words and sentences, which are **contextualized**, unlike earlier embeddings like Glove and Word2Vec:

Contextualized means that the embedding model can dynamically generate embeddings for a word based on the context they were used in.

As a result, if a word would appear in a different context, the model would return a different representation.

This is precisely depicted in the image below for different uses of the word `Bank`

.

For visualization purposes, the embeddings have been projected into 2d space using t-SNE.

As depicted above, the static embedding models — Glove and Word2Vec produce the same embedding for different usages of a word.

However, contextualized embedding models don’t.

In fact, contextualized embeddings understand the different meanings/senses of the word “Bank”:

- A financial institution
- Sloping land
- A Long Ridge, and more.

As a result, these contextualized embedding models address the major limitations of static embedding models.

The point of the above discussion is that modern embedding models are quite proficient at the encoding task.

As a result, they can easily transform documents, paragraphs, or sentences into a numerical vector that captures its semantic meaning and context.

In the last to last sub-section, we provided an input query, which was encoded, and then we searched the vector database for vectors that were **similar** to the input vector.

In other words, the objective was to return the **nearest neighbors** as measured by a similarity metric, which could be:

- Euclidean distance (the lower the metric, the more the similarity).
- Manhattan distance (the lower the metric, the more the similarity).
- Cosine similarity (the more the metric, the more the similarity).

The idea resonates with what we do in a typical k-nearest neighbors (kNN) setting.

We can match the query vector across the already encoded vectors and return the model similar ones.

The problem with this approach is that to find, say, only the first nearest neighbor, the input query must be matched across **all** vectors stored in the vector database.

This is computationally expensive, especially when dealing with large datasets, which may have millions of data points. As the size of the vector database grows, the time required to perform a nearest neighbor search increases proportionally.

But in scenarios where real-time or near-real-time responses are necessary, this brute-force approach becomes impractical.

In fact, this problem is also observed in typical relational databases. If we were to fetch rows that match a particular criteria, the whole table must be scanned.

Indexing the database provides a quick lookup mechanism, especially in cases where near real-time latency is paramount.

More specifically, when columns used in `WHERE`

clauses or `JOIN`

conditions are indexed, it can significantly speed up query performance.

A similar idea of **indexing** is also utilized in vector databases, which results in something we call an **approximate nearest neighbor (ANN)**, which is quite self-explanatory, isn't it?

Well, the core idea is to have a tradeoff between accuracy and run time. Thus, approximate nearest neighbor algorithms are used to find the closest neighbors to a data point, though these neighbors may not always be the closest neighbors.

That is why they are also called **non-exhaustive** search algorithms.

The motivation is that when we use vector search, exact matches aren't absolutely necessary in most cases.

Approximate nearest neighbor (ANN) algorithms utilize this observation and compromise a bit of accuracy for run-time efficiency.

Thus, instead of exhaustively searching through all vectors in the database to find the closest matches, ANN algorithms provide fast, sub-linear time complexity solutions that yield approximate nearest neighbors.

Let’s discuss these techniques in the next section!

While approximate nearest neighbor algorithms may sacrifice a certain degree of precision compared to exact nearest neighbor methods, they offer significant performance gains, particularly in scenarios where real-time or near-real-time responses are required.

**The core idea is to narrow down the search space for the query vector, thereby improving the run-time performance.**

The search space is reduced with the help of **indexing**. There are five popular indexing strategies here.

Let’s go through each of them.

Flat index is another name for the brute-force search we saw earlier, which is also done by KNN. Thus, all vectors are stored in a single index structure without any hierarchical organization.

That is why this indexing technique is called “flat” – it involves no indexing strategy, and stores the data vectors as they are, i.e., in a ‘flat’ data structure.

As it searches the entire vector database, it delivers the best accuracy of all indexing methods we shall see ahead. However, this approach is incredibly slow and impractical.

Nonetheless, I wouldn’t recommend adopting any other sophisticated approach over a flat index when the data conditions are favorable, such as having only a few data points to search over and a low-dimensional dataset.

But, of course, not all datasets are small, and using a flat index is impractical in most real-life situations.

Thus, we need more sophisticated approaches to index our vectors in the vector database.

IVF is probably one of the most simple and intuitive indexing techniques. While it is commonly used in text retrieval systems, it can be adapted to vector databases for approximate nearest neighbor searches.

Here’s how!

Given a set of vectors in a high-dimensional space, the idea is to organize them into different partitions, typically using clustering algorithms like k-means.

As a result, each partition has a corresponding centroid, and every vector gets associated with **only one** partition corresponding to its nearest centroid.

Thus, every centroid maintains information about all the vectors that belong to its partition.

When searching for the nearest vector to the query vector, instead of searching across all vectors, we first find the closest **centroid** to the query vector.

The nearest neighbor is then searched in only those vectors that belong to the partition of the closest centroid found above.

Let’s estimate the search run-time difference it provides over using a flat index.

To reiterate, in flat-index, we compute the similarity of the query vector with all the vectors in the vector database.

If we have $N$ vectors and each vector is $D$ dimensional, the run-time complexity is $O(ND)$ to find the nearest vector.

Compare it to the Inverted File Index, wherein, we first compute the similarity of the query vector with all the **centroids** obtained using the clustering algorithm.

Let’s say there are $k$ centroids, a total of $N$ vectors, and each vector is $D$ dimensional.

Also, for simplicity, let’s say that the vectors are equally distributed across all partitions. Thus, each partition will have $\frac{N}{k}$ data points.

First, we compute the similarity of the query vector with all the **centroids**, whose run-time complexity is $O(kD)$.

Next, we compute the similarity of the query vector to the data points that belong to the centroid’s partition, with a run-time complexity of $O(\frac{ND}{k})$.

Thus, the overall run-time complexity comes out to be $O(kD+\frac{ND}{k})$.

To get some perspective, let’s assume we have 10M vectors in the vector databases and divide that into $k=100$ centroids. Thus, each partition is expected to have roughly 1 lakh data points.

In the flat index, we shall compare the input query across all data points – **10M**.

In IVF, first, we shall compare the input query across all centroids (`100`

), and then compare it to the vectors in the obtained partition (`100k`

). Thus, the total comparisons, in this case, will be `100,050`

, **nearly 100 times faster**.

Of course, it is essential to note that if some vectors are actually close to the input vector but still happen to be in the neighboring partition, we will miss them.

But recalling our objective, we were never aiming for the best solution but an approximate best (that's why we call it “approximate nearest neighbors”), so this accuracy tradeoff is something we willingly accept for better run-time performance.

In fact, if you remember the model compression deep dive, we followed the same idea there as well:

The idea of quantization in general refers to compressing the data while preserving the original information.

Thus, Product Quantization (PQ) is a technique used for vector compression for memory-efficient nearest neighbor search.

Let’s understand how it works in detail.

Say we have some vectors, and each vector is `256`

-dimensional. Assuming each dimension is represented by a number that takes `32`

bits, the memory consumed by each vector would be `256`

x `32`

bits = `8192`

bits.

In Product Quantization (PQ), we first split all vectors into sub-vectors. A demonstration is shown below, where we split the vector into `M`

(a parameter) segments, say `8`

:

As a result, each segment will be `32`

-dimensional.

Next, we run KMeans on each segment separately, which will generate `k`

centroids for each segment.

Do note that each centroid will represent the centroid for the subspace (`32`

dimensions) but not the entire vector space (`256`

dimensions in this demo).

For instance, if `k=100`

, this will generate `100*8`

centroids in total.

Training complete.

Next, we move to the encoding step.

The idea is that for each segment of a vector in the entire database, we find the nearest centroid from the respective segment, which has `k`

centroids that were obtained in the training step above.

For instance, consider the first segment of the `256`

-dimensional data we started with:

We compare these segment vectors to the corresponding `k`

centroids and find the nearest centroid for all segment vectors:

After obtaining the nearest centroid for each vector segment, we replace the entire segment with a unique `centroid ID`

, which can be thought of as indices (a number from `0`

to `k-1`

) of the centroids in that sub-space.

We do see for all individual segments of the vectors:

This provides us with a quantized (or compressed) representation of all vectors in the vector database, which is composed of `centroid IDs`

, and they are also known as **PQ codes**.

To recall, what we did here is that we’ve encoded all the vectors in the vector database with a vector of `centroid IDs`

, which is a number from `0`

to `k-1`

, and **every dimension now only consumes 8 bits of memory**.

As there are `8`

segments, total memory consumed is `8*8=64`

bits, which is 128 times lower memory usage than what we had earlier – `8192`

bits.

The memory saving scales immensely well when we are dealing with millions of vectors.

Of course, the encoded representation isn’t entirely accurate, but don’t worry about that as it is not that important for us to be perfectly precise on all fronts.

Now, you might be wondering, how exactly do we search for the nearest neighbor based on the encoded representations?

More specifically, given a new query vector `Q`

, we have to find a vector that is most similar (or closest) to `Q`

in our database.

We begin by splitting the query vector `Q`

into `M`

segments, as we did earlier.

Next, we calculate the distance between all segments of the vector `Q`

to all the respective centroids of that segment obtained from the KMeans step above.

This gives us a distance matrix:

The final step is to estimate the distance of the query vector `Q`

from the vectors in the vector database.

To do this, we go back to the PQ matrix we generated earlier:

Next, we look up the corresponding entries in the distance matrix generated above.

For instance, the first vector in the above PQ matrix is this:

To get the distance of our query vector to this vector, we check the corresponding entries in the distance entries.

We sum all the segment-wise distances to get a rough estimate of the distance of the query vector `Q`

from all the vectors in the vector database.

We repeat this for all vectors in the database, find the lowest distance, and return the corresponding vectors from the database.

Of course, it is important to note that the above PQ matrix lookup is still a brute-force search. This is because we look up all the distances in the distance matrix for all entries of the PQ matrix.

Moreover, since we are not estimating the vector-to-vector distances but rather vector-to-centroid distances, the obtained values are just approximated distances but not true distances.

Increasing the number of centroids and segments will increase the precision of the approximate nearest neighbor search, but it will also increase the run-time of the search algorithm.

**Here’s a summary of the product quantization approach:**

- Divide the vectors in the vector database into $M$ segments.
- Run KMeans on each segment. This will give
`k`

centroids per segment. - Encode the vectors in the vector database by replacing each segment of the vector with the
`centroid ID`

of the cluster it belongs to. This generates a PQ matrix, which is immensely memory-efficient. - Next, to determine the approximate nearest neighbor for a query vector
`Q`

, generate a distance matrix, whose each entry denotes the**distance**of a segment of the vector`Q`

to all the centroids. - Go back to the PQ codes now, and look up the distances in the distance matrix above to get an estimate of the distance between all vectors in the database and the query vector
`Q`

. Select the vector with the minimum distance to get the approximate nearest neighbor.

Done!

Approximate nearest neighbor search with product quantization is suitable for medium-sized systems, and it’s pretty clear that there is a tradeoff between precision and memory utilization.

Let’s understand some more effective ways to search for the nearest neighbor.

HNSW is possibly one of the most effective and efficient indexing methods designed specifically for nearest neighbor searches in high-dimensional spaces.

The core idea is to construct a graph structure, where each node represents a data vector, and edges connect nodes based on their similarity.

HNSW organizes the graph in such a way that facilitates fast search operations by efficiently navigating through the graph to find approximate nearest neighbors.

But before we understand HNSW, it is crucial to understand **NSW (Navigable Small World)**, which is foundational to the HNSW algorithm.

The upcoming discussion is based on the assumption that you have some idea about graphs.

While we cannot cover them in whole detail, here are some details that will be suffice to understand the upcoming concepts.

A graph is composed of vertices and edges, where edges connect vertices together. In this context, connected vertices are often referred to as neighbors.

Recalling what we discussed earlier about vectors, we know that similar vectors are usually located close to each other in the vector space.

Thus, if we represent these vectors as vertices of a graph, vertices that are close together (i.e., vectors with high similarity) should be connected as neighbors.

That said, even if two nodes are not directly connected, they should be reachable by traversing other vertices.

**This means that we must create a navigable graph.**

More formally, for a graph to be navigable, every vertex must have neighbors; otherwise, there will be no way to reach some vertices.

Also, while having neighbors is beneficial for traversal, at the same time, we want to avoid such situations where every node has too many neighbors.

This can be costly in terms of memory, storage, and computational complexity during search time.

Ideally, we want a navigable graph that resembles a small-world network, where each vertex has only a limited number of connections, and the average number of edge traversals between two randomly chosen vertices is low.

This type of graph is efficient for similarity search in large datasets.

If this is clear, we can understand how the Navigable Small World (NSW) algorithm works.

The first step in NSW is graph construction, which we call `G`

.

This is done by shuffling the vectors randomly and constructing the graph by sequentially inserting vertices in a **random order**.

When adding a new vertex (`V`

) to the graph (`G`

), it shares an edge with `K`

existing vertices in the graph that are closest to it.

This demo will make it more clear.

Say we set `K=3`

.

Initially, we insert the first vertex `A`

. As there are no other vertices in the graph at this point, `A`

remains unconnected.

Next, we add vertex `B`

, connecting it to `A`

since `A`

is the only existing vertex, and it will anyway be among the top `K`

closest vertices. Now the graph has two vertices `{A, B}`

.

Next, when vertex `C`

is inserted, it is connected to both `A`

and `B`

. The exact process takes place for the vertex `D`

as well.

Now, when vertex `E`

is inserted into the graph, it connects only to the `K=3`

closest vertices, which, in this case, are `A`

, `B`

, and `D`

.

This sequential insertion process continues, gradually building the NSW graph.

The good thing is that as more and more vertices are added, connections formed in the early stages of the graph constructions may become longer-range links, which makes it easier to navigate long distances in small hops.

This is evident from the following graph, where the connections `A — C`

and `B — D`

span greater distances.

By constructing the graph in this manner, we get an NSW graph, which, most importantly, is navigable.

**In other words, any node can be reached from any other node in the graph in some hops.**

In the NSW graph (`G`

) constructed above, the search process is conducted using a simple greedy search method that relies on local information at each step.

Say we want to find the nearest neighbor to the yellow node in the graph below:

To start the search, an entry point is randomly selected, which is also the beauty of this algorithm. In other words, a key advantage of NSW is that a search can be initiated from any vertex in the graph `G`

.

Let’s choose the node `A`

as the entry point:

After selecting the initial point, the algorithm iteratively finds a neighbor (i.e., a connected vertex) that is nearest to the query vector `Q`

.

For instance, in this case, the vertex `A`

has neighbors (`D`

, `B`

, `C`

, and `E`

). Thus, we shall compute the distance (or similarity, whatever you chose as a metric) of these `4`

neighbors to the query vector `Q`

.

In this case, node `C`

is the closest, so we move to that node `C`

from node `A`

.

Next, the search moves toward the vertex with the least distance to the query vector.

The **unevaluated** neighbor of node `C`

is only `H`

, which turns out to be closer to the query vector, so we mode to node `H`

now.

This process is repeated until no neighbor closer to the query vector can be found, which gives us the nearest neighbor in the graph for a query vector.

One thing I like about this search algorithm is how intuitive and easy to implement it is.

**That said, the search is still approximate, and it is not guaranteed that we will always find the closest neighbor, and it may return highly suboptimal results.**

For instance, consider this graph below, where node `A`

is the entry point, and the yellow node is the vector we need the nearest neighbor for:

Following the above procedure of nearest neighbor search, we shall evaluate the neighbors of the node `A`

, which are `C`

and `B`

.

It is clear that both nodes are further distant from the query vector than node `A`

. Thus, the algorithm returns the node `A`

as the final nearest neighbor.

To avoid such situations, it is recommended to repeat the search process with multiple entry points, which, of course, consumes more time.

While NSW is quite a promising and intuitive approach, another major issue is that we end up traversing the graph many times (or repeating the search multiple times) to arrive at an optimal approximate nearest neighbor node.

HNSW speeds up the search process by indexing the vector database into a more optimal graph structure, which is based on the idea of a **skip list**.

First, let me give you some details about the skip list data structure, as that is important here.

And to do this, let’s consider a pretty intuitive example that will make it pretty clear.

Say you wish to travel from New York to California.

If we were following an NSW approach, this journey would be like traveling from one city to another, say, via an intercity taxi, which takes many hops, but gradually moves us closer to our destination, as shown below:

Is that optimal?

No right?

Now, think about it for a second.

How can you more optimally cover this route, or how would you more optimally cover this in real life?

If you are thinking of flights, you are thinking in the right direction.

Think of **skip lists** as a way to plan your trip using different modes of transportation, of which, some modes can travel larger distances in small hops.

So essentially, instead of hopping from one city to another, we could start by taking a flight from New York to a major city closer to California, say Denver.

This flight covers a longer distance in a single hop, analogous to skipping several vertices in the graph that we would have covered otherwise going from one city to another.

👉

Of course, I know there is a direct flight between New York and California. This is just for demonstration purposes so assume that there is no such flight between New York and California.

From Denver, we can take another faster mode of transport, which will involve fewer hops, like a train to reach California:

To add even more granularity, say, once we reach the train station in California, we wish to travel to some place within Los Angeles, California.

Now we need something that can take smaller hops, so a taxi is perfect here.

So what did we do here?

This combination of longer flights, a relatively shorter train trip, and a taxi to travel within the city allowed us to reach our destination in relatively very few stops.

This is precisely what skip lists do.

Skip lists are a data structure that allows for efficient searching of elements in a sorted list. They are similar to linked lists but with an added layer of "skip pointers" that allow faster traversal.

This is what linked lists look like:

In a skip list, each element (or node) contains a value and a set of forward pointers that can "skip" over several elements in the list.

These forward pointers create multiple layers within the list (`layer 0`

, `layer 1`

, `layer 2`

, in the above visual), with each level representing a different "skip distance."

- Top layer (
`layer 2`

) can be thought of as a flight that can travel longer distances in one hop. - Middle layer (
`layer 1`

) can be thought of as a train that can travel relatively shorter distances than a flight in one hop. - Bottom layer (
`layer 0`

) can be thought of as a taxi that can travel short distances in one hop.

The nodes that must be kept in each layer are decided using a probabilistic approach.

The basic idea is that nodes are included in higher layers with decreasing probability, resulting in fewer nodes at higher levels, while the bottom layer ALWAYS contains all nodes.

More specifically, before skip list construction, each node is randomly assigned an integer `L`

, which indicates the **maximum** layer at which it can be present in the skip list data structure. This is done as follows:

`uniform(0,1)`

generates a random number between 0 and 1.`floor()`

rounds the result down to the nearest integer.- $C_{LM}$ is a layer multiplier constant that adjusts the overlap between layers. Increasing this parameter leads to more overlap.

For instance, if a node has `L=2`

, which means it must exist on `layer 2`

, `layer 1`

and `layer 0`

.

Also, say the layer multiplier ($C_{LM}$) was set to `0`

. This would mean that `L=0`

for all nodes:

As discussed above, `L`

indicates the **maximum** layer at which a node can be present in the skip list data structure. If `L=0`

for all nodes, this means that the skip list will only have one layer.

Increasing this parameter leads to more overlap between layers and more layers, as shown in the plots below:

As depicted above:

- With $C_{LM}=0$, the skip list can only have
`layer 0`

, which is similar to the NSW search. - With $C_{LM}=0.25$, we get one more layer, which has around 6-7 nodes.
- With $C_{LM}=1$, we get four layers.
- In all cases,
`layer 0`

always has all nodes.

The objective is to decide an optimal value for $C_{LM}$ because we do not want to have too many layers and so much overlap, while, at the same time, also not having only one layer (when $C_{LM}=0), which would result in no speedup improvement.

Now, let me explain how skip lists speedup the search process.

Let’s say we want to find the element `50`

in this list.

If we were using the typical linked list, we would have started from the first element (`HEAD`

), and scanned each node one by one to see and check if it matches the query or not (`50`

).

See how a skip list helps us optimize this search process.

We begin from the top layer (`layer 2`

) and check the value corresponding to the next node in the same layer, which is `65`

.

As `65>50`

and it is a unidirectional linked list, we must go down one level.

In `layer 1`

, we check the value corresponding to the next node in the same layer, which is `36`

.

As `50>36`

and it is wise to move to the node corresponding to the value `36`

.

Now again in `layer 1`

, we check the value corresponding to the next node in the same layer, which is `65`

.

Again, as `65>50`

and it is a unidirectional linked list, we must go down one level.

We reach `layer 0`

, which can be traversed the way we usually would.

Had we traversed the linked list without building a skip list, we would have taken `5`

hops:

But with a skip list, we completed the same search in `3`

hops instead:

That was pretty simple and elegant, wasn't it?

While reducing the number of hops from `5`

to `3`

might not sound like a big improvement, but it is important to note that typical vector databases have millions of nodes.

Thus, such improvements scale pretty quickly to provide run-time benefits.

Now that we understand how skip lists work, understanding the graph construction process of Hierarchical Navigable Small World is also pretty straightforward.

Consider that this is our current graph structure:

I understand we are starting in the middle of the construction process, but just bear with me for a while, as it will clear everything up.

Essentially, the above graph has three layers in total, and it is in the middle of construction. Also, as we go up, the number of nodes decreases, which is what happens ideally in skip lists.

Now, let’s say we wish to insert a new node (blue node in the image below), and its max level (determined by the probability distribution) is `L=1`

. This means that this node will be present on `layer 1`

and `layer 0`

.

Now, our objective is to connect this new node to other nodes in the graph on `layer 0`

and `layer 1`

.

This is how we do it:

- We start from the topmost layer (
`layer 2`

) and select an entry point for this new node randomly:

- We explore the neighbors of this entry point and select the one that is nearest to the new node to be inserted.

- The nearest neighbor found for the blue node becomes the entry point for the next layer. Thus, we move to the corresponding node of the nearest neighbor in the next layer (
`layer 1`

):

- With that, we have arrived at the layer where this blue node must be inserted.

Here, note that had there still been more layers, we would have repeated the above process of finding the nearest neighbor of the entry point and moving one layer down until we had arrived at the layer of interest.

For instance, imagine that the max level value for this node was `L=0`

. Thus, the blue node would have only existed on the bottom-most layer.

After moving the `layer 1`

(as shown in the figure above), we haven't arrived at the layer of interest yet.

So we explore the neighborhood of the entry point in `layer 1`

and find the nearest neighbor of the blue point again.

Now, the nearest neighbors found on `layer 1`

becomes our entry point for `layer 0`

.

Coming back to the situation where the max level value was `L=1`

. We are currently at `layer 1`

, where the entry point is marked in the image below, and we must insert the blue node on this layer:

To insert a node, here’s what we do:

- Explore the neighbors of the entry point in the current layer and connect the new node to the top
`K`

nearest neighbors. To determine the top`K`

neighbors, a total of`efConstruction`

(a hyperparameter) neighbors are greedily explored in this step. For instance, if`K=2`

, then in the above diagram, we connect the blue node to the following nodes:

However, to determine these top K=2 nodes, we may have explored, say `efConstruction=3`

neighbors instead (the purpose of doing this will become clear shortly), as shown below:

Now, we must insert the blue node at a layer that is below the max layer value `L`

of the blue node.

In such a case, we don’t keep just one entry point to the next layer, like we did earlier, as shown below:

However, all `efConstruction`

nodes explored in the above layer are considered as an entry point for the next layer.

Once we enter the next layer, The process is repeated, wherein, we connect the blue node to the top K neighbors by exploring all `efConstruction`

entry nodes.

The layer-by-layer insertion process ends when the new node gets connected to nodes at `level 0`

.

💡

I have intentionally cut out a few minor details here because it will get too complicated to understand. Also, in any real-life situation, we would hardly have to implement this algorithm as it is already implemented in popular vector databases that use HNSW for indexing. The only thing that we must know is how HNSW works on a high level.

Consider that after all the nodes have been inserted, we get the following graph:

Let’s understand how the approximate nearest neighbor search would work.

Say we want to find the nearest neighbor of the yellow vector in the image below:

We begin our search with an entry point in the top layer (`layer 2`

):

We explore the connected neighbors of `A`

and see which is closest to the yellow node. In this layer, it is `C`

.

The algorithm greedily explores the neighborhood of vertices in a layer. We consistently move towards the query vector during this process.

When no closer node can be found in a layer that is closer to the query vector, we move to the next layer while considering the nearest neighbor (`C`

, in this case) as an entry point to the next layer:

The process of neighborhood exploration is repeated again.

We explore the neighbors of `C`

and greedily move to that specific neighbor which is closest to the query vector:

Again, as no node exists in `layer 1`

that is closer to the query vector, we move to the next layer while considering the nearest neighbor (`F`

, in this case) as an entry point to the next layer.

But this time, we have reached `layer 0`

. Thus, an approximate nearest neighbor will be returned in this case.

When we move to `layer 0`

, and start exploring its neighborhood, we notice that it has no neighbors that are closer to the query vector:

Thus, the vector corresponding to node F is returned as an approximate nearest neighbor, which, coincidentally, also happens to be the true nearest neighbor.

In the above search process, it only took ** 2 hops** (descending is not a hop) to return the nearest neighbor to the query vector.

Let’s see how many hops it would have taken us to find the nearest neighbor with NSW. For simplicity, let’s consider that the graph constructed by NSW is the one represented by `layer 0`

of the HNSW graph:

We started with the node `A`

as the entry point earlier, so let’s consider the same here as well.

We begin with the node `A`

, explore its neighbors, and move to the node `E`

as that is the closest to the query vector:

From node `E`

, we move to node `B`

, which is closer to the query vector than the node `E`

.

Next, we explore the neighbors of node `B`

, and notice that node `I`

is the closest to the query vector, so we hop onto that node now:

As node `I`

cannot find any other node that it is connected to that is closer than itself, the algorithm returns node `I`

as the nearest neighbor.

What happened there?

Not only did the algorithm take more hops (`3`

) to return a nearest neighbor, but it also returned a less optimal nearest neighbor.

HNSW, on the other hand, took fewer hops and returned a more accurate and optimal nearest neighbor.

Perfect!

With that, you have learned five pretty common indexing strategies to index vectors in a vector database for efficient search.

At this point, one interesting thing to learn is how exactly do Large Language Models (LLMs) take advantage of Vector Databases.

In my experience, the biggest conundrum many people face is the following:

Once we have trained our LLM, it will have some model weights for text generation. Where do vector databases fit in here?

And it is a pretty genuine query, in my opinion.

Let me explain how vector databases help LLMs be more accurate and reliable in what they produce.

To begin, we must understand that an LLM is deployed after learning from a static version of the corpus it was fed during training.

For instance, if the model was deployed after considering the data until `31st Jan 2024`

, and we use it, say, a week after training, it will have no clue about what happened in those days.

Repeatedly training a new model (or adapting the latest version) every single day on new data is impractical and cost-ineffective. In fact, LLMs can take weeks to train.

Also, what if we open-sourced the LLM and someone else wants to use it on their privately held dataset, which, of course, was not shown during training?

As expected, the LLM will have no clue about it.

But if you think about it, is it really our objective to train an LLM to know every single thing in the world?

**A BIGGG NOOOO!**

That’s not our objective.

Instead, it is more about helping the LLM learn the overall structure of the language, and how to understand and generate it.

So, once we have trained this model on a ridiculously large enough training corpus, it can be expected that the model will have a decent level of language understanding and generation capabilities.

Thus, if we could figure out a way for LLMs to look up new information they were not trained on and use it in text generation (**without training the model again**), that would be great!

One way could be to provide that information in the prompt itself.

In other words, if training or fine-tuning the model isn’t desired, we can provide all the necessary details in the prompt given to the LLM.

Unfortunately, this will only work for a small amount of information.

This is because LLMs are auto-regressive models.

💡

Auto-regressive models are those models that generate outputs one step at a time, where each step depends on the previous steps. In the case of LLMs, this means that the model generates text one word at a time, **based on the words it has already generated**.

Thus, as the LLM considers previous words, they have a token limit that they practically can not exceed in their prompts.

Overall, this approach of providing everything in the prompt is not that promising because it limits the utility to a few thousand tokens, whereas in real life, additional information can have millions of tokens.

**This is where vector databases help.**

Instead of retraining the LLM every time new data emerges or changes, we can leverage vector databases to update the model's understanding of the world dynamically.

How?

It’s simple.

As discussed earlier in the article, vector databases help us store information in the form of vectors, where each vector captures semantic information about the piece of text being encoded.

Thus, we can maintain our available information in a vector database by encoding it into vectors using an embedding model.

When the LLM needs to access this information, it can query the vector database using a similarity search with the prompt vector.

More specifically, the similarity search will try to find contents in the vector database that are similar to the input query vector.

This is where indexing becomes important because our vector database can possibly have millions of vectors.

Theoretically, we can compare the input vector to every vector in the vector database.

But for practical utility, we must find the nearest neighbor as quickly as possible.

That is why indexing techniques, which we discussed earlier, become so important. They help us find the approximate nearest neighbor in almost real-time.

Moving on, once the approximate nearest neighbor gets retrieved, we gather the context from which those specific vectors were generated. This is possible because a vector database not only stores vectors but also the raw data that generated those vectors.

This search process retrieves context that is similar to the query vector, which represents the context or topic the LLM is interested in.

We can augment this retrieved content along with the actual prompt provided by the user and give it as input to the LLM.

Consequently, the LLM can easily incorporate this info while generating text because it now has the relevant details available in the prompt.

And...Congratulations!

You just learned **Retrieval-Augmented Generation (RAG)**. I am sure you must have heard this term many times now and what we discussed above is the entire idea behind RAG.

I intentionally did not mention RAG anywhere earlier to build the desired flow and avoid intimidating you with this term first.

In fact, even its name entirely justifies what we do with this technique:

**Retrieval**: Accessing and retrieving information from a knowledge source, such as a database or memory.**Augmented**: Enhancing or enriching something, in this case, the text generation process, with additional information or context.**Generation**: The process of creating or producing something, in this context, generating text or language.

Another critical advantage of RAG is that it drastically helps the LLM reduce **hallucinations** in its responses. I am sure you must have heard of this term too somewhere.

Hallucinations happen when a language model generates information that is not grounded in reality or when it makes up things.

This can lead to the model generating incorrect or misleading information, which can be problematic in many applications.

With RAG, the language model can use the retrieved information (which is expected to be reliable) from the vector database to ensure that its responses are grounded in real-world knowledge and context, reducing the likelihood of hallucinations.

This makes the model's responses more accurate, reliable, and contextually relevant, improving its overall performance and utility.

This idea makes intuitive sense as well.

These days, there are tons of vector database providers, which can help us store and retrieve vector representations of our data efficiently.

**Pinecone**: Pinecone is a managed vector database service that provides fast, scalable, and efficient storage and retrieval of vector data. It offers a range of features for building AI applications, such as similarity search and real-time analytics.**Weaviate**: Weaviate is an**open-source vector database**that is robust, scalable, cloud-native, and fast. With Weaviate, one can turn your text, images, and more into a searchable vector database using state-of-the-art ML models.**Milvus**: Milvus is an open-source vector database built to power embedding similarity search and AI applications. Milvus makes unstructured data search more accessible and provides a consistent user experience regardless of the deployment environment.**Qdrant**: Qdrant is a vector similarity search engine and vector database. It provides a production-ready service with a convenient API to store, search, and manage points—vectors with an additional payload. Qdrant is tailored to extended filtering support. It makes it useful for all sorts of neural network or semantic-based matching, faceted search, and other applications.

Going ahead, let’s get into some practical details on how we can index vectors in a vector database and perform search operations over them.

For this demonstration, I will be using Pinecone as it’s possibly one of the easiest to begin with and understand. But many other provides are linked above, which you can explore if you wish to.

To begin, we first install a few dependencies, like Pinecone and Sentence transformers:

It is important to encode this text dataset into vector embeddings to store them in the vector database. For this purpose, we will leverage the `SentenceTransformers`

library.

It provides pre-trained transformer-based architectures that can efficiently encode text into dense vector representations, often called embeddings.

The SentenceTransformers model offers various pre-trained architectures, such as `BERT`

, `RoBERTa`

, and `DistilBERT`

, fine-tuned specifically for sentence embeddings.

These embeddings capture semantic similarities and relationships between textual inputs, making them suitable for downstream tasks like classification and clustering.

DistilBERT is a relatively smaller model, so we shall be using it in this demonstration.

Next, open a Jupyter Notebook and import the above libraries:

Moving on, we download and instantiate the DistilBERT sentence transformer model as follows:

To get started with Pinecone and create a vector database, we need a Pinecone API key.

To get this, head over to the Pinecone website and create an account here: https://app.pinecone.io/?sessionType=signup. We get to the following page after signing up:

Grab your API key from the left panel in the dashboard below:

Click on `API Keys`

-> `Create API Key`

-> `Enter API Key Name`

-> `Create`

.

Done!

Grab this API key (like the copy to clipboard button), go back to the Jupyter Notebook, and establish a connection to Pinecone using this API Key, as follows:

In Pinecone, we store vector embeddings in indexes. The vectors in any index we create must share the same dimensionality and distance metric for measuring similarity.

We create an index using the `create_index()`

method of the `Pinecone`

class object created above.

On a side note, currently, as we have no indexes, running the `list_indexes()`

method returns an empty list in the `indexes`

key of the dictionary:

Coming back to creating an index, we use the using the `create_index()`

method of the `Pinecone`

class object as follows:

Here’s a breakdown of this function call:

`name`

: The name of the index. This is a user-defined name that can be used to refer to the index later when performing operations on it.`dimension`

: The dimensionality of the vectors that will be stored in the index. This should match the dimensionality of the vectors that will be inserted into the index. We have specified`768`

here because that is the embedding dimension returned by the`SentenceTransformer`

model.`metric`

: The distance metric used to calculate the similarity between vectors. In this case,`euclidean`

is used, which means that the Euclidean distance will be used as the similarity metric.`spec`

: A`PodSpec`

object that specifies the environment in which the index will be created. In this example, the index is created in a GCP (Google Cloud Platform) environment named`gcp-starter`

.

Executing this method creates an index, which we can also see in the dashboard:

Now that we have created an index, we can push vector embeddings.

To do this, let’s create some text data and encode it using the `SentenceTransformer`

model.

I have created a dummy data below:

We create embeddings for these sentences as follows:

This code snippet iterates over each sentence in the `data`

list we defined earlier and encodes the text of each sentence into a vector using the downloaded sentence transformer model (`model`

).

It then creates a dictionary `vector_info`

containing the sentence ID (`id`

) and the corresponding vector (`values`

), and appends this dictionary to the `vector_data`

list.

In practical instances, there can be multiple indexes under the same account, we must create an `index`

object that specifies the index we wish to add these embeddings to. This is done as follows:

Now that we have the embeddings and the index, we **upsert** these vectors.

Upsert is a database operation that combines the actions of **update** and **insert**. It inserts a new document into a collection if the document does not already exist, or updates an existing document if it does exist. Upsert is a common operation in databases, especially in NoSQL databases, where it is used to ensure that a document is either inserted or updated based on its existence in the collection.

Done!

We have added these vectors to the index. While the output does highlight this, we can double verify this by using the `describe_index_stats`

operation to check if the current vector count matches the number of vectors we upserted:

Here's what each key in the returned dictionary represents:

`dimension`

: The dimensionality of the vectors stored in the index (`768`

, in this case).`index_fullness`

: A measure of how full the index is, typically indicating the percentage of slots in the index that are occupied.`namespaces`

: A dictionary containing statistics for each namespace in the index. In this case, there is only one namespace ('') with a`vector_count`

of`10`

, indicating that there are`10`

vectors in the index.`total_vector_count`

: The total number of vectors in the index across all namespaces (`10`

, in this case).

Now that we have stored the vectors in the above index, let’s run a similarity search to see the obtained results.

We can do this using the `query()`

method of the `index`

object we created earlier.

First, we define a search text and generate its embedding:

Next, we query this as follows:

This code snippet calls the `query`

method on an index object, which performs a nearest neighbor search for a given query vector (`search_embedding`

) and returns the top `3`

matches.

Here's what each key in the returned dictionary represents:

`matches`

: A list of dictionaries, where each dictionary contains information about a matching vector. Each dictionary includes the`id`

of the matching vector, the`score`

indicating the similarity between the query vector and the matching vector. As we specified`euclidean`

as our metric while creating this index, a higher score indicates more distance or similarity.`namespace`

: The namespace of the index where the query was performed. In this case, the namespace is an empty string (''), indicating the default namespace.`usage`

: A dictionary containing information about the usage of resources during the query operation. In this case,`read_units`

indicates the number of read units consumed by the query operation, which is`5`

. However, we originally appended`10`

vectors to this index, which shows that it did look through all the vectors to find the nearest neighbors.

From the above results, we notice that the top 3 neighbors of the `search_text`

(`"Vector database are really helpful"`

) are:

Awesome!

With this, we come to an end to this deep dive into vector databases.

To recap, we learned that vector databases are specialized databases designed to efficiently store and retrieve vector representations of data.

By organizing vectors into indexes, vector databases enable fast and accurate similarity searches, making them invaluable for tasks like recommendation systems and information retrieval.

Moreover, the Pinecone demo showcased how easy it is to create and query a vector index using Pinecone's service.

Before I end this article, there’s one important point I want to mention.

Just because vector databases sound cool, it does not mean that you have to adopt them in each and every place where you wish to find vector similarities.

It's so important to assess whether using a vector database is necessary for your specific use case.

For small-scale applications with a limited number of vectors, simpler solutions like NumPy arrays and doing an exhaustive search will suffice.

There’s no need to move to vector databases unless you see any benefits, such as latency improvement in your application, cost reduction, and more.

I hope you learned something new today!

I know we discussed many details in this deep dive, so if there’s any confusion, feel free to post them in the comments.

Or, if you wish to connect privately, feel free to initiate a chat here:

Thanks for reading!

If you loved reading this deep dive, I am sure you will learn a lot of practical skills from other rich deep dives too, like these:

**Become a full member (if you aren't already) so that you never miss an article:**

As you may already know, the primary objective of building classification models is to categorize data points into predefined classes or labels based on the input features.

More technically speaking, the objective is to learn a function $f$ that maps an input vector (`x`

) to a label (`y`

).

They can be further divided into two categories:

**Probabilistic models**: These models output a probabilistic estimate for each class.**Direct labeling models**: These models directly predict the class label without providing probabilistic estimates.

We have already discussed this in the newsletter recently, so we won’t get into much detail again:

Moving on...

Talking specifically about neural networks (which are probabilistic models), the model is typically trained with the cross-entropy loss function, as shown below:

where:

- $N$ is the number of samples in the dataset.
- $C$ is the number of classes.
- $y_{ij}$ is an indicator variable that equals $1$ if the $i^{th}$ sample belongs to class $j$, and $0$ otherwise.
- $p_{ij}$ is the predicted probability that the $i^{th}$ sample belongs to class $j$ according to the model.

If it is getting difficult to understand this loss function, just think of it as the sum of the log-loss (binary classification loss function) over every class.

While cross-entropy is undoubtedly one of the most used loss functions for training multiclass classification models, it is not entirely suitable in certain situations.

More specifically, in many real-world classification tasks, the class labels often possess a relative ordering between them.

For instance, consider an age detection task where the goal is to predict the age group of individuals based on facial features:

In such a scenario, the class labels typically represent age ranges or groups, such as `child`

, `teenager`

, `young adult`

, `middle-aged`

, and `senior`

. These age groups inherently possess an ordered relationship, where `child`

precedes `teenager`

, `teenager`

precedes `young adult`

and so on.

Traditional classification approaches, such as cross-entropy loss, treat each `age group`

as a separate and independent category. Thus, they fail to capture the underlying ordinal relationships between the age groups.

Consequently, the model might struggle to differentiate between adjacent age groups, leading to suboptimal performance and classifier ranking inconsistencies.

By "ranking inconsistencies," we mean those situations where the predicted probabilities assigned to adjacent age groups do not align with their natural ordering.

For example, if the model predicts a lower probability for the `child`

age group than for the `teenager`

age group, despite the fact that `teenager`

logically follows `child`

in the age hierarchy, this would constitute a ranking inconsistency.

We could also interpret it in this way that, say, the true label for an input sample is `young adult`

. Then in that case, we would want our classifier to highlight that the input sample is "at least a child", "at least a teenager", and "at least a young adult".

Beyond that, it may start outputting low probabilities for other age groups, as depicted above.

However, these inconsistencies are largely observed when we use cross-entropy loss. They arise due to the lack of explicit consideration for the ordinal relationships between age groups in traditional classification approaches.

Since cross-entropy loss treats each age group as a separate category with no inherent order, the model may struggle to learn and generalize the correct progression of age.

As a result, the model may exhibit inconsistent ranking behavior, where it assigns higher probabilities to age groups that logically should have lower precedence according to the age hierarchy.

This inconsistency not only undermines the interpretability of the model but also compromises its predictive accuracy, especially in scenarios where precise age estimation is crucial.

Here, we must note that ordinal classification techniques are not limited to age but are applicable across a wide range of domains where class labels exhibit inherent ordering.

Here are some more use cases:

**Product Reviews:**In sentiment analysis of product reviews, sentiment labels such as`excellent`

,`good`

,`average`

,`poor`

, and`terrible`

represent an ordered ranking of the overall sentiment expressed in the reviews.**Economic Indicators:**In economic forecasting, indicators such as`strong growth`

,`moderate growth`

,`stagnation`

,`recession`

, and`depression`

represent an ordered ranking of economic conditions.**Risk Assessment:**Risk assessment models may categorize risks into ordered levels such as`low risk`

,`medium risk`

, and`high risk`

, based on the likelihood and impact of potential events.**Education Grading:**In educational assessment, students' performance levels are often categorized based on grades, such as`A`

,`B`

,`C`

,`D`

, and`F`

. These grades represent an ordered ranking from highest to lowest performance.- And many many more.

These examples illustrate how rank ordinal classification is prevalent across various domains.

Coming back...

The discussion so far indicates that we want our model to accurately classify data points into different categories and understand and respect the natural order or hierarchy present in the labels, which must also be respected during inference time.

However, as discussed above, commonly used loss functions like multi-category cross-entropy do not explicitly capture this ordinal information.

As the name suggests, ordinal classification involves predicting labels on an ordinal scale.

More formally, the model is trained such that it learns a ranking rule that maps a data point $x$ to an ordered set $y$, where each element $y_i \in y$ represents a class or category, and the order of these elements reflects the ordinal relationship between them.

In ordinal classification, the focus shifts from simply assigning data points to discrete classes to understanding and respecting the relative order or hierarchy present in the classes.

As discussed above, this is particularly important in tasks where the classes exhibit a natural progression or ranking, such as age groups, severity levels, or performance categories.

The goal of ordinal classification is twofold:

- first, to accurately predict the class labels for each data point,
- and second, to ensure that these predictions adhere to the inherent order or ranking of the classes.

Achieving this requires specialized techniques and methodologies that go beyond traditional classification approaches.

]]>A common and fundamental way of training a logistic regression model, as taught in most lectures/blogs/tutorials, is using SGD.

Essentially, given an input $X$ and the model parameters $\theta$, the output probability ($\hat y$) is computed as follows:

Next, we compute the loss function ($J$), which is log-loss:

👉

Why log loss? We covered it in this article: https://www.dailydoseofds.com/why-do-we-use-log-loss-to-train-logistic-regression.

The final step is to update the parameters $\theta$ using gradient descent as follows:

As depicted above, the weight update depends on the learning rate hyperparameter ($\alpha$), which we specify before training the model.

We execute the above steps (summarized again below) over and over for some epochs or until the parameter converges:

**Step 1)**Initialize model parameters $\theta$.**Step 2)**Compute output probability $\hat y$ for all samples.**Step 3)**Compute the loss function $J(\theta)$, which is log-loss.**Step 4)**Update the parameters $\theta$ using gradient descent.- Repeat steps 2-4 until convergence.

Simple, isn't it?

I am sure this is the method that you must also be thoroughly aware of.

However, if that is true, why don’t we see a learning rate hyperparameter $(\alpha)$ in the sklearn logistic regression implementation:

As depicted above, there is no learning rate parameter in this documentation.

However, we see a `max_iter`

parameter that intuitively looks analogous to the epochs.

But how does that make sense?

We have epochs but no learning rate $\alpha$, so how do we even update the parameters of our model, as we do in SGD below?

**Are we missing something here?**

As it turns out, we are indeed missing something here.

More specifically, there are a few more ways to train a logistic regression model, but most of us are only aware of the above SGD procedure, which depends on the learning rate.

But most of us never happen to consider them.

**However, the importance of these alternate mechanisms is entirely reflected by the fact that even sklearn, one of the most popular libraries of data science and machine learning, DOES NOT use SGD in its logistic regression implementation.**

Thus, in this article, I want to share the overlooked details of logistic regression and introduce you to one more way of training this model, which does not depend on the learning rate hyperparameter.

Let’s begin!

Before understanding the alternative training mechanism of training logistic regression, it is immensely crucial to know how we model data while using logistic regression.

In other words, let’s understand how we frame its modeling mathematically.

📚

The blog ahead is a bit math-intensive. Yet, I have simplified it as much as possible. If you have any queries, feel free to comment and I'll help you out.

Essentially, whenever we model data using logistic regression, the model is instructed to maximize the likelihood of observing the given data $(X, y)$.

More formally, a model attempts to find a specific set of parameters $\theta$ (also called model weights), which maximizes the following function:

The above function $L$ is called the likelihood function, and in simple words, it says:

- maximize the likelihood of observing y
- given X
- when the prediction is parameterized by some parameters $\theta$ (also called weights)

When we begin modeling:

- We know $X$.
- We also know $y$.
- The only unknown is $\theta$, which we are trying to estimate.

Thus, the instructions given to the model are:

- Find the specific set of parameters $\theta$ that maximizes the likelihood of observing the data $(X, y)$.

This is commonly referred to as **maximum likelihood estimation (MLE)** in machine learning.

MLE is a method for estimating the parameters of a statistical model by maximizing the likelihood of the observed data.

It is a common approach for parameter estimation in various models, including linear regression, logistic regression, and many others.

The key idea behind MLE is to find the parameter values that make the observed data most probable.

The steps are simple and straightforward:

**Define the likelihood function for the entire dataset:**Here, we typically assume that the observations are independent. Thus, the likelihood function for the entire dataset is the product of the individual likelihoods. Also, the likelihood function is parameterized by a set of parameters $\theta$, which are trying to estimate.**Take the logarithm (the obtained function is called log-likelihood):**To simplify calculations and avoid numerical issues, it is common to take the logarithm of the likelihood function.**Maximize the log-likelihood:**Finally, the goal is to find the set of parameters $\theta$, which maximizes the log-likelihood function.

In fact, it’s the MLE that helps us derive the log-loss used in logistic regression.

We all know that in logistic regression, the model outputs the probability that a sample belongs to a specific class.

Let’s call it $\hat y$.

Assuming that you have a total of N **independent** samples $(X, y) = \{(x_{1}, y_{1}), (x_{2}, y_{2}), \dots, (x_{N}, y_{N})\}$, the likelihood estimation can be written as:

Essentially, we assume that all samples are independent.

Thus, the likelihood of observing the entire data is the same as the product of the likelihood of observing individual points.

Next, we should determine these individual likelihoods $L(y_{i}|x_{i};\theta)$ as a function of the output probability of logistic regression $\hat y_i$:

While training logistic regression, the model returns a continuous output $\hat y$, representing the probability that a sample belongs to a **specific** class.

In logistic regression, the “**specific class**” is the one we have assigned the label of $y_{i} = 1$.

In other words, it is important to understand that a logistic regression model, by its very nature, outputs the probability that a sample belongs to one of the two classes.

More specifically, it is the class with true label $y_{i} = 1$.

The higher the output, the higher the probability that the sample has a true label $y_{i} = 1$.

Thus, we can say that when the true label $y_{i} = 1$, the likelihood of observing that data point is the output of logistic regression, i.e., $\hat y$.

But how do we determine the likelihood when the true label $y_{i} = 0$?

For simplicity, consider the illustration below:

Assume that the label “Cat” is denoted as “Class 1” and the label “Dog” is denoted by “Class 0”.

Thus, all the logistic regression model outputs inherently denote the probability that an input is “Cat.”

But if we need the probability that the input is a “Dog,” we should take the complement of the output of the logistic regression model.

In other words, the likelihood when the true label $y_{i} = 0$ can be derived from the output of logistic regression $\hat y$.

Therefore, we get the following likelihood function for observing a specific data point $i$:

The likelihood of observing a sample with true label $y_{i} = 1$ is $\hat y_{i}$ (or the output of the model).

But the likelihood of observing a sample with true label $y_{i} = 0$ is $(1-\hat y_{i})$.

We can dissolve the piecewise notation above to get the following:

Let’s plug the likelihood function of individual data points back into the likelihood estimation for all samples:

We can simplify the product to a summation by taking the logarithm on both sides.

In practice, maximizing the log-likelihood function is often more convenient than the likelihood function.

Since the logarithm is a monotonically increasing function, maximizing the log-likelihood is equivalent to maximizing the likelihood.

On further simplification, we get the following:

And to recall, $\hat y_i$ is the output probability by the logistic regression model:

The above derivation gave us the log-likelihood of the logistic regression model.

If we take negative (-) on both sides, it will give us the log loss, which can be conveniently minimized by gradient descent.

However, there is another way to manipulate the above log-likelihood formulation for more convenient optimization.

Let’s understand below.

In an earlier deep dive on model compression techniques (linked below), we understood various ways to reduce the model size, which is extremely useful for optimizing inference and deployment processes and improving operational metrics.

This is important because when it comes to deploying ML models in production (or user-facing) systems, the focus shifts from raw accuracy to considerations such as efficiency, speed, and resource consumption.

Thus, in most cases, when we deploy any model to production, the specific model that gets shipped to production is NOT solely determined based on performance.

Instead, we must consider several operational metrics that are not ML-related and typically, they are not considered during the prototyping phase of the model.

These include factors like:

**Inference latency**: The time it takes for a model to process a single input and generate a prediction.**Throughput**: The number of inference requests a model can handle in a given time period.**Model size**: The amount of memory a model occupies when loaded for inference purposes.- and more.

In that article, we talked about four techniques that help us reduce model size while almost preserving the model’s accuracy:

They attempt to strike a balance between model size and accuracy, making it relatively easier to deploy models in user-facing products.

What’s more, as the models are relatively quite smaller now, one can expect much faster inference runtime. This is desired because we can never expect end users to wait for, say, a minute for the model to run and generate predictions.

**Now, even if we have fairly compressed the model, there’s one more caveat that still exists, which can affect the model’s performance in production systems.**

Let’s understand this in more detail.

PyTorch is the go-to choice for researchers and practitioners for building deep learning models due to its flexibility, intuitive Pythonic API design, and ease of use.

It takes a programmer just three steps to create a deep learning model in PyTorch:

- First, we define a model class inherited from PyTorch’s
`nn.Module`

class

- Moving on, we declare all the network components (layers, dropout, batch norm, etc.) in the
`__init__()`

method:

- Finally, we define the forward pass of the neural network in the
`forward()`

method:

That’s it!

As we saw above, defining the network was so simple and elegant, wasn’t it?

Once we have defined the network, one can proceed with training the model by declaring the optimizer, loss function, etc., **without having to define the backward pass explicitly**.

However, when it comes to deploying these models in production systems, PyTorch's standard and well-adopted design encounters certain limitations, specific to scale and performance.

Let’s understand!

One significant constraint of PyTorch is its predominant reliance on Python.

While Python offers simplicity, versatility, and readability, it is well known for being relatively slower compared to languages like C++ or Java.

More technically speaking, the Python-centric nature of PyTorch brings concerns related to the Global Interpreter Lock (GIL), a mechanism in CPython (the default Python interpreter) that hinders true parallelism in multi-threaded applications.

This limitation poses challenges in scenarios where low-latency and high-throughput requirements are crucial, such as real-time applications and services.

In fact, typical production systems demand model interoperability across various frameworks and systems.

It's possible that the server we intend to deploy our model on might be leveraging any other language except Python, like C++, Java, and more.

Thus, the models we build MUST BE portable to various environments which are designed to handle concurrent requests at scale.

However, the Python-centric nature of PyTorch can limit its integration with systems or platforms that require interoperability with languages beyond Python.

In other words, in scenarios where deployment involves a diverse technology stack, this restriction can become a hindrance.

This limitation can impact the model's ability to efficiently utilize hardware resources, further influencing factors like inference latency and throughput, which are immensely critical in business applications.

Historically, all PyTorch models were tightly coupled with the Python run-time.

This design choice reflected the framework's emphasis on dynamic computation graphs and ease of use for researchers and developers working on experimental projects.

More specifically, PyTorch's dynamic nature allowed for intuitive model building, easy debugging, and seamless integration with Python's scientific computing ecosystem.

This is also called the **eager mode** of PyTorch, which, as the name suggests, was specifically built for faster prototyping, training, and experimenting.

However, as the demand for deploying PyTorch models in production environments grew, the limitations of this design became more apparent.

The Python-centric nature of PyTorch, while advantageous during development, introduced challenges for production deployments where performance, scalability, and interoperability were paramount.

Of course, PyTorch inherently leveraged all sources of optimizations it possibly could like parallelism, integrating hardware accelerators, and more.

Nonetheless, its over-dependence on Python still left ample room for improvement, especially in scenarios demanding efficient deployment and execution of deep learning models at scale.

Of course, one solution might be to use entirely different frameworks for building deep learning models, like PyTorch, and then replicating the obtained model to another environment-agnostic framework.

However, this approach of building models in one framework and then replicating them in another environment-agnostic framework introduces its own set of challenges and complexities.

First and foremost, it requires expertise in both frameworks, increasing the learning curve for developers and potentially slowing down the development process.

In fact, no matter how much we criticize Python for its slowness, every developer loves the Pythonic experience and its flexibility.

Moreover, translating models between frameworks may not always be straightforward. Each deep learning framework has its own unique syntax, conventions, and quirks, making the migration of models a non-trivial task.

The differences in how frameworks handle operations, memory management, and optimization techniques can lead to subtle discrepancies in the behavior of the model, potentially affecting its performance and accuracy.

In fact, any updates to the developed model would have to be extended again to yet another framework, creating redundancy and resulting in a loss of productivity.

In other words, maintaining consistency across different frameworks also becomes an ongoing challenge.

As models evolve and updates are made, ensuring that the replicated version in the environment-agnostic framework stays in sync with the original PyTorch model becomes a manual and error-prone process.

To address these limitations, PyTorch developed the **script mode**, which is specifically designed for production use cases.

PyTorch’s script mode has two components:

]]>In my experience, most ML projects lack a dedicated experimentation management/tracking system.

As the name suggests, this helps us track:

**Model configuration**→ critical for reproducibility.**Model performance**→ critical for comparing different models.

…across all experiments.

Most data scientists and machine learning engineers develop entire models in Jupyter notebooks without having any well-defined and automated reproducibility and performance tracking protocols.

They heavily rely on inefficient and manual tracking systems — Sheets, Docs, etc., which get difficult to manage quickly.

**MLflow** stands out as a valuable tool for ML engineers, offering robust practices for ML pipelines.

It seamlessly integrates with various cloud services, which facilitates flexibility in usage — whether locally for an individual or remotely for a large ML engineering team.

This deep dive is a complete walkthrough guide on understanding how we can integrate MLflow in our existing machine learning projects, which lets us automate many redundant and manual tasks.

Let’s begin!

Before getting into the technical details, it’s pretty important to get into more details about understanding the motivation for using MLflow.

Thus, let’s spend some more time learning about the challenges with existing approaches, which are mostly manual.

One of the significant challenges in traditional ML modeling is the absence of proper version control practices.

Many ML practitioners struggle with managing different versions of their models, code, and data, leading to potential issues such as difficulty in reproducing results, tracking changes over time, and ensuring consistency across the development lifecycle.

In conventional ML workflows, tracking and managing parameters and data manually pose significant challenges. This manual process is error-prone, time-consuming, and lacks traceability.

ML practitioners often resort to spreadsheets or documents to record parameter configurations and data sources.

However, this approach not only hinders efficiency but also makes it challenging to maintain a comprehensive record of experiments, hindering the ability to reproduce and validate results.

Collaboration in traditional ML environments often faces bottlenecks due to the lack of centralized tools and standardized practices.

Team members mostly work in isolation and find it cumbersome to share models, experiments, and findings efficiently.

This lack of a unified platform for collaboration leads to siloed efforts, with each team member operating independently.

The absence of a streamlined collaboration process impedes knowledge sharing, slows down development cycles, and hinders the collective progress of the team.

Beyond version control, manual tracking, and collaboration bottlenecks, traditional ML modeling encounters various other challenges.

Inconsistent deployment processes, scalability issues, and the lack of standardized project structures contribute to the complexity of ML development.

Addressing these challenges is crucial for establishing efficient, scalable, and reproducible machine learning workflows, a topic we will explore further in the subsequent sections of this deep dive.

By now, I hope you would have understood the profound challenges in developing machine learning models:

- Version controlling data, code, and models is difficult.
- Tracking experiment configuration is quite challenging and manual.
- Effectively collaborating with team members and sharing results is tedious.

These challenges can be elegantly taken care of with standard MLOps practices. These help us build, train, deploy, and even automate various stages in our machine learning projects, without much intervention.

One of the best tools in this respect is MLflow, which is entirely open-source.

As we will see ahead, MLflow provides plenty of functionalities that help machine learning teams effortlessly manage the **end-to-end** ML project lifecycle.

Being end-to-end means it includes everything we need to:

- Track experimentations
- Share code/model/data
- Reproduce results
- Deploy models
- Monitor performance
- Schedule updates, and more.

As of 24th Jan 2023, MLflow offers several key components.

**1) MLflow Tracking **for tracking experiments (code, data, model config, and results) and comparing them for model selection.

**2) MLflow Projects** for packaging code used in data science projects in a format that makes them reproducible on any platform.

**3) MLflow Models** for deploying machine learning models built in various libraries to diverse serving environments.

**4) MLflow Models** **Registry** for creating a dedicated system to manage, organize, version, and track ML models and their associated metadata.

As we will learn ahead, each of these components is specifically designed to solve a particular problem in a machine learning project lifecycle, which we also learned earlier.

Other than these two, MLflow recently added two more components:

**5) MLflow Deployments for LLMs** for streamlining the usage and management of various large language model (LLM) providers, such as OpenAI and Anthropic.

**6) MLflow LLM Evaluate** for evaluating LLMs and the prompts.

Don’t worry if you don’t understand them yet. We shall discuss them in detail in the upcoming sections.

MLflow is available on PyPI and can be installed as follows:

Done!

Now, we can move to learning about the individual components of MLflow.

Machine learning model training is heavily driven by iterative experimentation. One may try various experiments by varying hyperparameters, testing new features, getting more data, using different algorithms, and many more.

During this iterative experimentation process, one inevitably ends with SO MANY combinations of experimental runs that it is almost impossible to remember how the best models were produced (*unless one is manually taking note of everything they did*).

In other words, tracing the best model to its exact configuration is quite challenging and tedious.

Why is it important, you may wonder?

Well, the ability to trace the best model to its exact configuration is crucial for several reasons, the primary reason being reproducibility.

Reproducibility is one of the critical aspects of building reliable machine learning. It is fundamental for collaboration among team members and for comparing models across different time points or environments.

Imagine this: Something that one works on one system but does not work on another reflects bad reproducibility practices.

However, definite reproducibility practices ensure that results can be replicated and validated by others, which improves the overall credibility of our work.

MLflow tracking enables reproducibility by recording and logging all the parameters, code versions, and dependencies used during model training.

This ensures that the entire experimentation process can be reconstructed precisely, allowing others to reproduce the results.

It provides an elegant UI, within which we can view the results of our model logging and compare them.

This UI can either be local or hosted on a remote server, which makes it quite easy to share experiments across different teams.

An experiment in MLflow is a named collection of runs, where each run represents a specific execution of a machine learning workflow or training process.

The concept of an experiment is designed to help organize and group together related runs, providing a structured way to manage and compare different iterations of your machine learning models.

When we start a new **experiment**, MLflow creates a dedicated space for it, allowing us to track and compare runs within that experiment easily.

And to elaborate further, a **run** is a single execution of a machine learning workflow within a specific experiment.

It encapsulates all the details of that particular execution, including the code, parameters, metrics, and artifacts produced during the run.

As we will see ahead, MLflow automatically logs various details of a run, such as the hyperparameters used, metrics calculated during training, and any output files or models generated.

The logged information is then accessible through the MLflow Tracking UI or programmatically through the MLflow API.

Before getting into the high-end technical details, let’s do a quick demo of setting up MLflow tracking in a simple data science project and understanding the MLflow UI.

What we shall be demonstrating ahead must be written in scripts and not in a Jupyter notebook.

For this demonstration, consider we have the following dummy classification dataset stored in a CSV:

Let’s assume our training and evaluation script is `train_model.py`

. Let’s train a Random Forest model on this dataset.

First, we begin by specifying the imports:

Next, we specify some hyperparameters/configuration variables for our model:

As it is a classification task, let’s define a performance metric function (`performance`

), which accepts the true labels and prediction labels as parameters and returns three performance metrics:

- F1 score
- Accuracy score
- Precision score

This is implemented below:

Now, let’s proceed towards implementing the main guard of our model training script — `if __name__ == "__main__"`

.

In the above code:

- We start by reading the CSV file.
- Moving on, we split the data into training and test sets based on the specified
`test_size`

parameter. - Next, as the
`data`

DataFrame has all columns (including the`label`

), we create it in`(X, y)`

format for training the model. - In this step, we instantiate a
`RandomForestClassifier`

model with the specified parameters earlier and train the model. - After training, we generate the predictions on test data using the
`predict()`

method and evaluate them using the`performance()`

method we defined earlier. - Finally, we print the model details, such as the
`hyperparameters`

,`accuracy`

,`F1`

and`precision`

.

Running this script (with `python train_model.py`

) gives the following output:

Moving on, in an attempt to improve the model’s performance, one may try changing the hyperparameters, as demonstrated below:

But tracking and comparing all the experiments this way quickly becomes a mess.

Next, let’s see how MLflow tracking can help us simplify model experiment tracking.

We begin by adding the necessary imports to our `train_model.py`

script:

As we used sklearn in our experiment earlier, we imported the `mlflow.sklearn`

module, which provides an API for logging and loading scikit-learn models.

As discussed earlier, MLflow uses the `experiment`

and `run`

hierarchy to track experiments.

The concept of an experiment is designed to help organize and group together related runs, providing a structured way to manage and compare different iterations of your machine learning models.

We declare an experiment inside the main block using the `mlflow.set_experiment()`

method as follows:

Our next objective is to declare a model `run`

, which will be associated with the above experiment.

To do this, we create a context manager (using `with`

keyword) right before instantiating the `RandomForestClassifier`

class.

Next, we indent the model training and performance metric code within that context manager, as demonstrated below:

Indenting the model training code inside the `mlflow.start_run()`

context manager will allow us to record the model training and evaluation metadata.

We aren’t done yet, though.

Next, we must also specify the model training and evaluation metadata that we want MLflow to log.

This includes:

- Hyperparameters,
- Performance details,
- Data,
- The type of model trained,
- The trained model, etc.

We must log these details so that we can compare them with other experiment runs.

- To log hyperparameters, we can use the
`mlflow.log_params()`

method. - To log performance metrics, we can use the
`mlflow.log_metrics()`

method. - As we used sklearn, to log this model, we can use the
`mlflow.sklearn.log_model()`

method.

These methods accept a dictionary as an argument, which is a string to parameter/metric mapping, as demonstrated below:

💡

MLflow also provides

`mlflow.log_param()`

and `mlflow.log_metric()`

methods (the name does not have an ‘s’ at the end). They accept two parameters, the first is a string and the second is the metadata detail. In the `log_model()`

method:

- The first argument is the scikit-learn model to be saved.
- The second argument is the relative artifact path (we shall look at this shortly).

With that, we have implemented a basic model tracking code.

That was easy, wasn’t it?

To recall, we did three things:

- First, we created an experiment using the
`mlflow.set_experiment()`

method. The name specified as a parameter to this method must be kept the same across all runs of the same experiment. - Next, we created a context manager using the
`mlflow.start_run()`

and indented our model training code within that context manager. - Finally, we log model metadata such as hyperparameters, metrics, and the model using the relevant logging methods.

Before we run this script, let’s check the directory structure:

Let’s execute the `train_model.py`

script now:

If we check the directory structure now, we notice that MLflow creates a new folder `mlruns`

:

MLflow will store all our logging details in this directory.

💡

While the directory has been created locally, if needed, we can also specify the location of a remote server.

**mlruns**