1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
| import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms import wandb
wandb.init( project="mnist-demo", name="my-first-run", config={ "architecture": "CNN", "epochs": 5, "batch_size": 64, "learning_rate": 0.001, "optimizer": "Adam", } )
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=wandb.config.batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=wandb.config.batch_size, shuffle=False)
class SimpleCNN(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10)
def forward(self, x): ~~~~~~~~~~~~~ return output
model = SimpleCNN() optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate) criterion = nn.CrossEntropyLoss()
wandb.watch(model, log="all", log_freq=10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device)
for epoch in range(wandb.config.epochs): model.train() running_loss = 0.0 correct = 0 total = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) ~~~~~~~~~~~~~~~~~~~~~ if batch_idx % 100 == 0: wandb.log({ "batch_loss": loss.item(), "batch_accuracy": 100. * correct / total, "batch_idx": batch_idx }) model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += criterion(output, target).item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader) accuracy = 100. * correct / len(test_loader.dataset) wandb.log({ "epoch": epoch + 1, "train_loss": running_loss / len(train_loader), "train_accuracy": 100. * correct / total, "test_loss": test_loss, "test_accuracy": accuracy, }) wandb.log({"sample_images": [wandb.Image(img) for img in images]})
table = wandb.Table(columns=["id", "prediction", "label"], data=[[1, "cat", "cat"], [2, "dog", "cat"]]) wandb.log({"predictions": table})
import matplotlib.pyplot as plt plt.plot([1, 2, 3], [1, 4, 9]) wandb.log({"chart": wandb.Plotly(plt.gcf())}) print(f'Epoch {epoch+1}: Train Loss: {running_loss/len(train_loader):.4f}, ' f'Test Acc: {accuracy:.2f}%')
torch.save(model.state_dict(), "mnist_model.pth") wandb.save("mnist_model.pth")
wandb.finish()
|