Fine-Tuning BERT for Text Classification (w/ Example Code)

28.01k views3686 WordsCopy TextShare
Shaw Talebi
Get exclusive access to AI resources and project ideas: https://the-data-entrepreneurs.kit.com/shaw ...
Video Transcript:
massive Transformer models like GPT 40 llama and Claude are the current state-of-the-art in AI however many of the problems we care about do not require a 100 billion parameter model there are countless problems we can solve with smaller language models in this video I'm going to share one example of this by fine-tuning Bert to classify fishing URLs I'll start by covering some key Concepts then walk through example python code and if you're new here welcome I'm sha I make videos about data science and Entrepreneurship if you enjoy this content please consider clicking the Subscribe button that's a great no cost way you can support me in all the videos that I make so the title this video has a few pieces of jargon fine-tuning bird and text classification so let's talk about each of these terms one by one starting with fine-tuning this is when we adapt a pre-trained model to a particular task through additional training using an analogy we can think of the pre-trained model like a slab of marble that we can refine and chisel into a statue which in this analogy would be our fine-tuned model and typically the pre-trained model in this approach to machine learning is a self-supervised model meaning that the labels used to train the pre-trained model aren't manually defined but rather are derived from the structure of the data itself for example when talking about Text data we can use the inherent structure of text in order to do word prediction we could take a sentence like adapting a pre-trained model to a particular task and simply train the model to recursively predict the next tokens so maybe adapt would be used to predict ing adapting would be used to predict a adapting a would be used to predict pre adapting a pre would be used to predict train and so on and so forth so the key upside of this is that humans do not need to annotate or label the training data on the other hand the fine-tuned model will typically be trained in a supervised way in other words the targets are manually defined by humans so the reason we might want to develop a model in this way is because since the pre-trained model doesn't require any man manual annotation of training data that means we can train on much larger training data sets we no longer have that human bottleneck for example Bert which is the model we're going to be focusing on in this video was trained on 3 billion words in other words we have billions of examples that this model can learn from this is in contrast to a typical fine-tuning task which might consist of a few thousand examples so we're talking about six orders of magnitude difference in the number of examples today this is the prevailing Paradigm for developing state-of-the-art machine learning models there is an initial step of pre-training which will typically use some kind of self-supervised learning which allows the model to learn a tremendous amount of knowledge on a massive training data set and then the pre-trained model can be refined to be a bit more helpful through fine-tuning another benefit of developing models like this is that it unlocks the democratization of AI because developing these massive pre-trained models like llama or mistel or Claude or GPT 40 requires tremendous resources that no individual or typical research group will be able to pull off however splitting the model develop ment process into these two steps of first pre-training and then fine-tuning enables the pre-training to be done by specialized research Labs such as open AI Google meta anthropic mistal so on and so forth who can then make their pre-trained Foundation models publicly available for fine-tuning another Nuance here is that we aren't restricted to just doing one iteration of fine-tuning really any fine tun tuned model can serve as the starting place for additional fine-tuning so maybe we'll take this fine-tuned model and then we'll do some additional fine tuning to make it even more specialized for a particular use case and this is exactly the story of Chad GPT which was developed through three phases it started with the unsupervised pre-training then there was a supervised fine-tuning step and then there was a final step of reinforcement learning which refined the model even further that's all I'm going to say about fine-tuning in this video but if you want to learn more I do a deep dive in a previous video of this series the second piece of jargon in the title of this video is Bert even though this Paradigm of pre-training and fine-tuning models got super popular when open AI released chat gbt and shared that they were able to do that through this three-step process of pre-training fine-tuning and then additional fine tuning this is an approach that's been around since at least 2015 and one of the early models that really exploited this idea was Google's Bert which was released in 2019 ber is a language model developed specifically with fine-tuning in mind in other words the researchers at Google created Bert so that other researchers and individuals could repurpose the model through fine tuning on additional tasks in order to do this Bert was trained on two tasks the first task is masked language modeling what that means is is if we have the sentence the cat blank on the mat masked language modeling consists of using the context before and after this masked word this hidden word in order to predict it so given this input sequence this can be passed to Bert and then Bert will predict what the masked word is this is in contrast to models like GPT which don't do masked language modeling but rather causal language modeling where the training task is next word or next token prediction as opposed to masked language modeling the benefit of doing masked language modeling versus causal language modeling is that the model can use context from both sides of the sentence it can use both the text before and after the masked word in order to make a prediction intuitively this additional context gives the model more information for making predictions but that's not the only thing that Bert was trained on the second task that it was trained to do is next sentence prediction what this looks like is given two sentences A and B so here sentence a is Bert is conceptually simple and empirically powerful and then sentence B is it obtains new state-of-the-art results on 11 NLP tasks Bert is trying to take sentence pairs like this and output a label of is next or is not next this is a sentence pair taken from the original ber paper which is reference number one and so this is indeed the next sentence alternatively we could have a pair of sentences that are not a match so instead we might have Bert is conceptually simple and empirically powerful and then the next sentence is the cat sat on the mat and so this would be not the next sentence and would receive a different label the intuition behind having this second task of next sentence prediction in addition to mask language modeling understanding the relationship between two sentences like sentence a and sentence B is important for Downstream tasks like question answering or sentence similarity so training Bert on these two different tasks allows it to be effectively fine-tuned on a wide range of tasks the last thing we're going to talk about before getting into the example code is text classification which consists of assigning a label to text sequences and actually the next sentence prediction task we saw in the previous slide is an example of text classification however there are countless other examples of where text text classification might be handy one is Spam detection so given an input sequence of text which is a incoming email we could develop a model to predict whether that email is Spam or not spam similarly if we have incoming it tickets we could develop a text classification model to take those text sequences and categorize the it tickets into the appropriate tiers finally one could use a text classification model to do sentiment analysis on customer reviews in other words to analyze the number of happy customers and the number of unhappy customers with all the jargon covered fine-tuning we covered Bert and we covered text classification let's see a concrete example of doing this so here I will fine-tune Bert to identify fishing URLs this will actually be similar to the example I shared in my fine-tuning video from last year however soon after posting that video someone had pointed out an issue in that example that I had missed and if we look at the training metrics it's actually pretty easy to spot the training loss is decreasing as expected then if we roughly look at the accuracy it improves from the early Epoch to the last epox but it's a bit shaky so actually the best performing Epoch was number three and it seems to wiggle around for the remainder of training but the most obvious red flag is if we look at the validation loss so this is actually monotonically increasing during training which is the opposite of what we want to happen so this is a clear sign of overfitting this wasn't something I had caught so I'm glad someone had pointed this out on the medium article here I'm going to do another example where we don't see this overfitting so for this example we'll be using Python and the hugging face ecosystem so we'll import the data sets library from hugging face which gives us a data structure for loading our training testing and validation data we'll import some things from the Transformers Library the auto tokenizer class the auto model for sequence classification training arguments class and the trainer we'll import the evaluate library from hugging face which has some handy model evaluation functions we'll also import numpy which will be used to compute some of these metrics and then we'll also import this data collator from the Transformers Library which we'll talk about in a little bit with our Imports we'll load our data set so this is something I've made freely available on the hugging face Hub it's a data set of 3,000 examples 70% of which are used for training 15% of which are used for testing and then the final 15% are used for independent validation with our data loaded in we can load the pre-trained model so here we're going to be using Google's Bert more specifically Bert base uncased which consists of 110 million parameters which by today's standards is Tiny But at one point this was a pretty big model we'll load in the tokenizer for this model what this will do is take in arbitrary text sequences and convert them into integer sequences based on what the Burt model is expecting finally we can load in the Bert model but slap on top of it a binary classification head so all we have to do is use this autom model for sequence classification class and use this fir pre-trained method and we'll pass into it the model path which is what we defined here the number of labels the number of classes and then a mapping for the classes so the ID is going to be an integer and then we'll have a label for each class zero will correspond to a safe URL and one will correspond to a unsafe URL with our model loaded in and our classification head slapped on top of it let's set the trainable parameters by default when we loaded in our model in this line of code it initializes all the model parameters as trainable so all 110 million parameters plus the parameters in the binary classification head are all ready to be trained however if you're just running this example on your laptop like I did that's going to be pretty computationally expensive and potentially unnecessary so to help with that here we're going to freeze most of the parameters in the model we can do that pretty easily so here I have a for Loop that's going to go through the base model and it's going to freeze all the base model parameters to just break this down a little bit we have this model object which has this base model attribute and then that base model attribute has this named parameters method which will return tupal corresponding to all the parameters in the model basically what we can do is go through all the base model parameters and set this requires grad property to false after running this all the base model parameters are frozen and only the parameters in that classification head that we slapped on top of the model are trainable another name for for training a model like this is called transfer learning where we leave the base model parameters Frozen and only train a classification head that we add on top of it however this might result in a pretty rigid model because we can only refine the parameters in that classification head so one thing we can do is we can unfreeze the base model parameters in the final two layers so in the pooling layers of this model one way we we can do this is that we can loop back through the base model parameters and then if we see the word Pooler in the name of the model we can just unfreeze those parameters the result of all this is that we freeze all the model parameters except for the last four layers this allows us to keep the computational cost down for fine tuning while also giving us a fair amount of flexibility and of course you can free free or unfreeze any of the model parameters that you like and this would be like a fun thing to experiment with on your own with our model ready to be trained next we need to pre-process our data this is actually going to be pretty simple we'll Define a function called pre-process function which will take in an arbitrary sequence of text and tokenize it so it'll translate a string of words into a sequence of integers according to to the Bur tokenizer additionally I add this truncation flag what that will do is ensure that none of the input URLs are too long so by default I think this will truncate all the input sequences to 512 tokens and then with that function defined we can just apply this pre-processing step to all the data sets in our data set dictionary that we imported earlier so this will tokenize the training testing and validation data sets and return it in this tokenized data variable and then another thing we can initialize at this point is a data collator and so we took care of input sequences that might be too long but when we are training a model it's important that every example in a batch is the same size and so even though we won't have any examples greater than 512 integers we will have examples with fewer than that so to ensure all the samples in a batch have a uniform length we can create this data collator which will automatically do that for us during training okay and then the final step before we actually train the model is to Define evaluation metrics these are the metrics that will be printed during the training process here I'll use two evaluation metrics the accuracy and the Au score these are both loaded from the evaluate library then I will Define this function called compute metrics which will comp comp the accuracy and Au score for any example so the input of this function will be a tuple it'll consist of predictions and labels predictions will be a logit it will be a number between minus1 and 1 while labels will be the ground truth so this will be either zero or one to convert the logits into probabilities so basically to map a number between -1 and 1 to a number between 0 and 1 we'll apply the softmax function which looks like this this will compute probabilities for both cases for both the URL is safe and the URL is not safe so let's only look at the probabilities for the URL being not safe and then we'll use that to compute the Au score we do that by passing in the positive class probabilities so the probability that the URL is not safe and the ground truth and we'll round it to three decimal plates that gives us our Au score and then we can predict the most probable class so predictions will consist of two numbers it'll be a loic corresponding to a safe URL and the logic corresponding to a unsafe URL we'll just do ARG Max so it'll just return which element is larger and then we can pass that into this line of code to compute the accuracy so we'll compare the predicted class with the ground truth and we'll round that number to three decimal places then we'll return a dictionary which consists of the accurate and Au score while this may have been a lot of Hoops to jump through this will be nice during training because at each Epoch our trainer will print out the accuracy and Au score for us now we're ready to start training the model we'll Define our training parameters we'll set the learning rate at 2 to Theus 4 batch size as eight and number of epoch as 10 we'll put all these into a training arguments variable so we can set our output directory which I set as Bert fishing class classifier teacher the reason this teacher is here we'll talk about near the end we'll set the learning rate defined here we'll put the per device training and evaluation batch size which is just going to be eight number of epoch is 10 and then logging evaluation and save strategy so logging strategy sets how often the trainer will print the training loss eval strategy sets how often the trainer will print the evaluation metrics that we defined earlier so Au and accuracy and then we can also set our save strategy so we can have the model get saved at every Epoch just in case something goes wrong we can refer back to that latest save and then we'll set that the trainer will load the best model at the end so if the 10th EPO isn't the best model maybe it was the eighth Epoch we'll use that one instead of the last one with all the training arguments set we are ready to find tune our model we'll pass everything into our trainer pass in our model the training arguments from before our TR training data set our evaluation data set the tokenizer the data cator and the compute metric function and then simply run trainer. Trin so this I think took about 15 minutes to run on my laptop didn't use a GPU or anything but here are the results we can see the training loss is steadily decreasing we see the validation loss is a little rocky at times but it eventually goes down and then we have our accuracy in AU C so we can see the accuracy is for the most part increasing and then the Au is steadily increasing so this is what we want to see we don't want our validation loss to be monotonically increasing we want it to be decreasing with training but we can go one step further this is evaluating our model on the training data and the testing data but let's look at it on the independent validation data set this validation data was not used for training the model parameters or to the hyperparameters so it gives us a fair evaluation of the model to apply the model to the validation data will'll generate predictions for it like this then we will extract the logits and the labels from the predictions and then we'll just pass these into the compute metric function defined earlier and then we can see that these are the results the accuracy and Au score are comparable to what we're seeing for the testing data set in the previous slide so accuracy about 0. 89 Au about 0.
9 5 if we go back we see accuracy 0. 87 Au 0.
Related Videos
Fine-Tuning Text Embeddings For Domain-specific Search (w/ Python)
21:34
Fine-Tuning Text Embeddings For Domain-spe...
Shaw Talebi
7,571 views
Fine-tuning Large Language Models (LLMs) | w/ Example Code
28:18
Fine-tuning Large Language Models (LLMs) |...
Shaw Talebi
464,461 views
BERT explained: Training, Inference,  BERT vs GPT/LLamA, Fine tuning, [CLS] token
54:52
BERT explained: Training, Inference, BERT...
Umar Jamil
59,472 views
Text Embeddings, Classification, and Semantic Search (w/ Python Code)
24:30
Text Embeddings, Classification, and Seman...
Shaw Talebi
65,430 views
How to fine tune LLM | How to fine tune Chatgpt | How to fine tune llama3
31:25
How to fine tune LLM | How to fine tune Ch...
Unfold Data Science
5,286 views
How to Improve LLMs with RAG (Overview + Python Code)
21:41
How to Improve LLMs with RAG (Overview + P...
Shaw Talebi
122,126 views
AI Engineering in 76 Minutes (Complete Course/Speedrun!)
1:16:03
AI Engineering in 76 Minutes (Complete Cou...
Marina Wyss - Gratitude Driven
44,747 views
"okay, but I want Llama 3 for my specific use case" - Here's how
24:20
"okay, but I want Llama 3 for my specific ...
David Ondrej
320,520 views
Veritasium: What Everyone Gets Wrong About AI and Learning – Derek Muller Explains
1:15:11
Veritasium: What Everyone Gets Wrong About...
Perimeter Institute for Theoretical Physics
1,428,215 views
All Machine Learning Models Clearly Explained!
22:23
All Machine Learning Models Clearly Explai...
AI For Beginners
183,194 views
RAG vs Fine-Tuning vs Prompt Engineering: Optimizing AI Models
13:10
RAG vs Fine-Tuning vs Prompt Engineering: ...
IBM Technology
56,066 views
Text Classification with Python: Build and Compare Three Text Classifiers
29:18
Text Classification with Python: Build and...
Pythonology
20,795 views
Local LLM Fine-tuning on Mac (M1 16GB)
24:12
Local LLM Fine-tuning on Mac (M1 16GB)
Shaw Talebi
34,610 views
Compressing Large Language Models (LLMs) | w/ Python Code
24:04
Compressing Large Language Models (LLMs) |...
Shaw Talebi
9,905 views
RAG vs. Fine Tuning
8:57
RAG vs. Fine Tuning
IBM Technology
271,604 views
Does Fine Tuning Embedding Models Improve RAG?
26:04
Does Fine Tuning Embedding Models Improve ...
Adam Lucek
7,069 views
QLoRA—How to Fine-tune an LLM on a Single GPU (w/ Python Code)
36:58
QLoRA—How to Fine-tune an LLM on a Single ...
Shaw Talebi
81,684 views
Fine Tune DeepSeek R1 | Build a Medical Chatbot
48:52
Fine Tune DeepSeek R1 | Build a Medical Ch...
DataCamp
122,804 views
Text Classification: AI Techniques and Real-World Applications
14:00
Text Classification: AI Techniques and Rea...
IBM Technology
7,581 views
Why Does Diffusion Work Better than Auto-Regression?
20:18
Why Does Diffusion Work Better than Auto-R...
Algorithmic Simplicity
565,073 views
Copyright © 2025. Made with ♥ in London by YTScribe.com