Training a simple MNIST model on YottaLabs with SkyPilot
Step 1 — Create working directory and training script
mkdir ./testimport os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
batch_size = 128
epochs = 3
learning_rate = 1e-3
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(
root="./data",
train=True,
download=True,
transform=transform,
)
test_dataset = datasets.MNIST(
root="./data",
train=False,
download=True,
transform=transform,
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10),
).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
avg_loss = running_loss / len(train_loader.dataset)
model.eval()
correct = 0
total = 0
with torch.no_grad():
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
predictions = outputs.argmax(dim=1)
total += labels.size(0)
correct += (predictions == labels).sum().item()
accuracy = 100.0 * correct / total
print(f"Epoch {epoch + 1}/{epochs} - loss: {avg_loss:.4f} - test accuracy: {accuracy:.2f}%")
os.makedirs("outputs", exist_ok=True)
torch.save(model.state_dict(), "outputs/mnist_model.pt")
print("Saved checkpoint to outputs/mnist_model.pt")Step 2 — Define the SkyPilot task
Step 3 — Launch the training job

Step 4 — Understand what this job is doing
Step 5 — Save and inspect the result
Step 6 — Clean up the cluster
Last updated
Was this helpful?