wandb 别名 W&B 是什么?

W&B (Weights & Biases) 是一个专门用于机器学习实验跟踪和可视化的平台。简单来说,它就像一个"白富美版"的TensorBoard——颜值更高、交互性更强,而且会自动记录你的训练过程,不需要你手动写可视化代码

核心功能包括:

  • 📊 实验跟踪:自动记录超参数、损失曲线、准确率等指标
  • 🔍 结果对比:在同一张图上比较多次实验的效果
  • 💻 系统监控:自动记录GPU利用率、显存占用、温度等
  • 📁 模型管理:保存和版本化管理模型checkpoint
  • 🤝 团队协作:与团队成员共享实验结果
  • W&B会自动监控所有GPU的占用情况,无需额外配置

注册并登录

  1. 访问 wandb.ai 注册免费账号
  2. 登录后获取你的 API Key(在设置页面可以找到)
  3. 在终端执行登录命令:
1
wandb login

系统会提示你输入API Key,粘贴进去即可。或者你也可以设置环境变量:

1
export WANDB_API_KEY=你的API密钥

在训练脚本中添加代码

只需添加几行代码,W&B就会自动记录你的训练过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import wandb

# ========== 1. 初始化W&B ==========
wandb.init(
project="mnist-demo", # 项目名称
name="my-first-run", # 本次实验名称(可选)
config={ # 记录超参数
"architecture": "CNN",
"epochs": 5,
"batch_size": 64,
"learning_rate": 0.001,
"optimizer": "Adam",
}
)

# ========== 2. 加载数据 ==========
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=wandb.config.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=wandb.config.batch_size, shuffle=False)

# ========== 3. 定义简单的CNN模型 ==========
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
~~~~~~~~~~~~~
return output

model = SimpleCNN()
optimizer = optim.Adam(model.parameters(), lr=wandb.config.learning_rate)
criterion = nn.CrossEntropyLoss()

# 让W&B自动监控模型参数和梯度
wandb.watch(model, log="all", log_freq=10)

# ========== 4. 训练循环 ==========
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(wandb.config.epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0

for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
~~~~~~~~~~~~~~~~~~~~~

# 每100个batch记录一次
if batch_idx % 100 == 0:
wandb.log({
"batch_loss": loss.item(),
"batch_accuracy": 100. * correct / total,
"batch_idx": batch_idx
})

# 每个epoch结束后评估
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader)
accuracy = 100. * correct / len(test_loader.dataset)

# ========== 5. 记录指标 ==========
wandb.log({
"epoch": epoch + 1,
"train_loss": running_loss / len(train_loader),
"train_accuracy": 100. * correct / total,
"test_loss": test_loss,
"test_accuracy": accuracy,
})

# 记录图片
wandb.log({"sample_images": [wandb.Image(img) for img in images]})

# 记录表格
table = wandb.Table(columns=["id", "prediction", "label"], data=[[1, "cat", "cat"], [2, "dog", "cat"]])
wandb.log({"predictions": table})

# 记录matplotlib图表
import matplotlib.pyplot as plt
plt.plot([1, 2, 3], [1, 4, 9])
wandb.log({"chart": wandb.Plotly(plt.gcf())})

print(f'Epoch {epoch+1}: Train Loss: {running_loss/len(train_loader):.4f}, '
f'Test Acc: {accuracy:.2f}%')

# ========== 6. 保存模型 ==========
torch.save(model.state_dict(), "mnist_model.pth")
wandb.save("mnist_model.pth") # 将模型上传到W&B云端

# ========== 7. 结束记录 ==========
wandb.finish()

运行脚本后,终端会显示一个链接(类似 https://wandb.ai/你的用户名/my-first-project/runs/xxx),点击就能在浏览器中实时查看训练曲线了!

超参数搜索(Sweeps)

W&B内置了超参数搜索功能,支持Grid、Random、Bayesian三种方法:

1
2
3
4
# 创建sweep配置
wandb sweep sweep.yaml
# 启动sweep agent
wandb agent your-entity/your-project/sweep-id

离线模式

默认据同时存在本地(./wandb文件夹)和云端

如果不想同步到云端(比如测试时),可以设置离线模式:

1
os.environ["WANDB_MODE"] = "dryrun"