Get Started in LLM Training with Pytorch 2.8.0
🤫Before we get started:
From the jupyter notebook, type:
!pythonthen enter the following code:
import torch
x = torch.rand(5, 3)
print(x)The output should be something similar to:
tensor([[0.3380, 0.3845, 0.3217],
[0.8337, 0.9050, 0.2650],
[0.2979, 0.7141, 0.9069],
[0.1449, 0.1132, 0.1375],
[0.4675, 0.3947, 0.1426]])Let’s Get Started!
This tutorial is going to guide you through a complete MNIST handwritten digit recognition project using PyTorch, from data loading to model training and testing🤗
1. Environment Setup
First, ensure you have the necessary Python packages installed:
2. Import Libraries
3. Data Loading and Preprocessing
The MNIST dataset contains 60,000 training images and 10,000 test images. Each image is a 28x28 pixel grayscale handwritten digit (0-9).
So, for any starters, we should first clarify the concept of "training set" and "test set". Usually, we divide a classification dataset into 2 basic subparts: a training set to help our machine learn the hidden patterns, and a test set to examine if it actually "learns" instead of simply memorizing the data- for which the LLM researchers have a fancy name called "overfitting".
In our MNIST project:
🔧Training set: 60,000 handwritten digit images - the model learns from these
📝Test set: 10,000 handwritten digit images - the model has never seen these before
The golden rule: Never let your model peek at the test set during training! Otherwise, you're essentially letting a student see the exam questions while studying - the grades won't reflect true understanding.
The MNIST dataset contains 60,000 training images and 10,000 test images. Each image is a 28x28 pixel grayscale handwritten digit (0-9).
What's happening here? The transforms.ToTensor() does two things:
1️⃣Translates the image from a PIL Image or NumPy array into a PyTorch tensor (the format-or "language"- PyTorch understands).
2️⃣Automatically scales pixel values from [0, 255] to [0, 1] - this normalization helps the model train more effectively
✔️Why do we use DataLoader?
Instead of feeding all 60,000 images at once (which would overwhelm your computer's memory), we use batches!A batch is just a small group of images processed together. Here we use batches of 32 images - think of it as studying 32 flashcards at a time instead of trying to memorize all 60,000 at once.
❓Also, think about this question: Why shuffle=True for training but False for testing?
Shuffling the training data prevents the model from learning the order of examples rather than the actual patterns. For testing, order doesn't matter since we're just evaluating - no learning happens.
4. Data Exploration
Before training, it's important to understand the structure and content of the data. As the old saying goes: "garbage in, garbage out" - understanding your data is crucial for successful machine learning.
This visualization helps you catch potential issues early - are the images rotated correctly? Are they actually digits? Is the quality good enough?
Expected Output:
Wow, wait!
Are these numbers and brackets making your head spin?
Here is explanation:
1️⃣32: Batch size - we're processing 32 images at once
2️⃣1: Number of channels - grayscale images have 1 channel (RGB images would have 3)
3️⃣28, 28: Height and width of each image in pixels
4️⃣Labels are just a 1D array of 32 numbers, each indicating which digit (0-9) the corresponding image represents
5. Build CNN Model
Now comes the fun part - building our neural network! We'll use a Convolutional Neural Network (CNN), which is particularly good at recognizing visual patterns.
Why CNN for images? Traditional neural networks treat each pixel independently, but CNNs are smart enough to recognize that nearby pixels form patterns (like edges, curves, and eventually whole digits). It's like how you recognize a face by seeing eyes, nose, and mouth in specific spatial arrangements, not just as random dots.
python
Let's break down what each component does:
Convolutional Layer (
self.conv): This is like a pattern detector. It slides a 3x3 window across the image, looking for 32 different patterns (edges, curves, corners, etc.). Each of these 32 "filters" learns to detect a different feature.ReLU Activation: Stands for "Rectified Linear Unit" - it's a simple function that helps the network learn non-linear patterns. Without it, the network could only learn straight-line relationships, which isn't useful for complex images.
Flatten (
view): After convolution, we have a 3D structure (32 feature maps of size 26x26). We need to flatten this into a 1D array before feeding it to regular neural network layers.Fully Connected Layers (
self.d1,self.d2): These layers combine all the patterns detected by the convolutional layer to make the final decision about which digit it is.Softmax: Converts the final layer's outputs into probabilities that sum to 1. For example: [0.1, 0.05, 0.7, 0.05, ...] means 70% confidence it's a "2", 10% confidence it's a "0", etc.
Where does 26×26 come from?
The convolutional layer shrinks the image slightly. Here's the formula:
In our case:
Input: 28 × 28
Kernel: 3 × 3
Stride (default): 1
Padding (default): 0
Calculation: H_out = (28 - 3) / 1 + 1 = 26
So each of our 32 feature maps is 26 × 26 pixels, giving us 26 × 26 × 32 = 21,632 features to feed into the first fully connected layer.
Expected output: torch.Size([32, 10]) - for each of the 32 images in the batch, we get 10 probabilities (one for each digit 0-9).
7. Train the Model
Training is where the magic happens - this is where the model actually learns from the data!
What are these components?
Loss Function (CrossEntropyLoss): Measures how wrong the model's predictions are. Lower loss = better predictions. It's like a grading system that tells the model how badly it messed up.
Optimizer (Adam): Decides how to adjust the model's weights to reduce the loss. Think of it as a GPS that guides the model toward better performance. Adam is popular because it adapts the learning speed automatically - taking bigger steps when far from the goal, smaller steps when getting close.
Learning Rate (lr=0.001): Controls how big each adjustment step is. Too large, and the model overshoots the optimal solution; too small, and training takes forever.
Understanding the training loop:
An epoch is one complete pass through the entire training dataset. We typically need multiple epochs because the model learns gradually - like reading a textbook multiple times to fully understand it.
For each batch of images:
Forward Pass: Feed images through the model to get predictions
Calculate Loss: Compare predictions to true labels to see how wrong we are
Backward Pass: Calculate how each weight contributed to the error (using calculus!)
Update Weights: Adjust weights to reduce the error
Why optimizer.zero_grad()? PyTorch accumulates gradients by default. If we don't reset them to zero, gradients from previous batches would interfere with the current batch - like trying to navigate with directions from your last trip still on the GPS.
Expected Output:
Notice how the loss decreases and accuracy increases with each epoch - the model is learning! The improvements get smaller over time because the easy patterns are learned first.
8. Test the Model
Now for the moment of truth - let's see how well our model performs on data it has never seen before!
Key differences from training:
model.eval(): Tells the model we're evaluating, not training. Some layers (like Dropout, which we don't have here) behave differently during evaluation.torch.no_grad(): Disables gradient calculation. Since we're not updating weights during testing, we don't need gradients - this saves memory and speeds things up.No backward pass: We only do forward propagation to get predictions. No learning happens here!
Expected Output:
How to interpret the results:
98.45% test accuracy: Our model correctly identifies about 98 out of every 100 handwritten digits it's never seen before. Pretty impressive!
Compare with training accuracy (98.91%): The test accuracy is slightly lower than training accuracy, which is normal and expected. A small gap like this indicates healthy generalization.
Red flag scenarios
:
Training: 99%, Test: 65% → Severe overfitting (the model memorized instead of learned)
Training: 65%, Test: 64% → Underfitting (the model didn't learn enough)
Training: 98%, Test: 98.5% → Suspicious (test shouldn't be better; might indicate data leakage)
Happy learning! 🎉
Last updated
Was this helpful?