Weights & Biases (wandb) 介绍与使用指南

1. 什么是 wandb?

Weights & Biases (wandb) 是一个用于机器学习实验跟踪、模型可视化和协作的工具。它帮助研究人员和工程师:

  • 记录实验的超参数、指标、代码版本、硬件信息等
  • 实时可视化训练过程(损失、准确率等)
  • 对比不同实验的结果
  • 共享和协作分析实验
  • 自动化超参数搜索(sweep)

wandb 提供云端面板(dashboard)和本地存储两种方式,支持 PyTorch、TensorFlow、Keras、JAX 等主流框架,并可与各种环境(本地、Colab、集群)无缝集成。

2. 为什么选择 wandb?

  • 轻量级:只需几行代码即可开始跟踪
  • 可视化强大:自动生成图表、表格、图像、媒体等
  • 协作便捷:团队共享实验链接,在线讨论
  • 与 ML 生态深度集成:支持 Hugging Face、PyTorch Lightning、Keras 等
  • 免费学术版:个人和小团队有充足的免费额度

3. 安装与初始化

安装

1
pip install wandb

登录

首次使用时需要登录(注册账号):

1
wandb login

或者直接在代码中登录:

1
2
import wandb
wandb.login(key='你的API key')

初始化实验

使用 wandb.init() 创建一个新的运行(run),可以指定项目名称、运行名称、实体等。

1
2
3
4
5
6
7
8
9
10
run = wandb.init(
project="my-project", # 项目名称
name="experiment-1", # 运行名称(可选)
entity="username", # 用户名或团队名(可选)
config={ # 超参数配置
"learning_rate": 0.001,
"batch_size": 32,
"epochs": 100,
}
)

注意wandb.init() 会返回一个 run 对象,也可以全局使用 wandb.log 等函数,只要已经初始化。

4. 基础功能

4.1 记录指标

使用 wandb.log() 记录标量指标(如损失、准确率),支持自动聚合和可视化。

1
2
3
4
5
6
7
8
9
10
for epoch in range(epochs):
# 训练代码
train_loss = ...
val_acc = ...

wandb.log({
"train/loss": train_loss,
"val/accuracy": val_acc,
"epoch": epoch
})
  • 支持字典形式,键可以是任意字符串,可以用 / 分组
  • 默认每个 step 记录一次,可以手动指定 step(如 wandb.log({"loss": loss}, step=global_step)
  • wandb 会自动绘制曲线图、直方图等

4.2 超参数配置

通过 wandb.config 存储超参数,方便对比。

1
2
3
config = wandb.config
config.learning_rate = 0.001
config.batch_size = 32

也可以在 wandb.init 时通过 config 参数传入字典。配置会显示在面板的 “Config” 部分。

4.3 结束运行

1
wandb.finish()   # 确保所有数据上传完毕

通常脚本结束后会自动调用,但显式调用可以避免意外。

5. 常用高级功能

5.1 记录图像、表格、媒体

wandb 支持记录多种媒体类型,便于可视化模型输出。

1
2
3
4
5
6
7
8
9
10
# 图像
wandb.log({"samples": [wandb.Image(img) for img in images]})

# 表格
table = wandb.Table(columns=["id", "pred", "label"])
table.add_data(1, 0.9, 1)
wandb.log({"predictions": table})

# 音频、视频、3D 点云等
wandb.log({"audio": wandb.Audio(audio_array, sample_rate=16000)})

5.2 模型保存与版本管理

使用 wandb.save() 将文件上传到云端,可保存模型权重、日志等。

1
2
3
4
5
6
# 保存模型
torch.save(model.state_dict(), "model.pth")
wandb.save("model.pth")

# 或保存整个文件夹
wandb.save("logs/*")

也可以在 wandb.init 中设置 save_code=True 自动保存当前脚本。

5.3 自定义可视化(图表)

通过 wandb.plot 可以创建自定义图表,如混淆矩阵、PR 曲线等。

1
2
3
4
5
6
7
8
9
10
11
# 混淆矩阵
wandb.log({"conf_mat": wandb.plot.confusion_matrix(
probs=preds,
y_true=labels,
class_names=["cat", "dog"]
)})

# 自定义折线图
data = [[x, y] for x, y in zip(xs, ys)]
table = wandb.Table(data=data, columns=["x", "y"])
wandb.log({"my_plot": wandb.plot.line(table, "x", "y", title="My Plot")})

5.4 系统指标监控

wandb 自动记录 GPU 使用率、CPU 使用率、内存占用等系统信息,无需额外代码。

5.5 项目面板(Dashboard)

登录 wandb 网站后,可以看到所有实验的运行列表,可以按参数筛选、对比多个运行的指标曲线,还可以添加注释、标签等。

6. 进阶功能:超参数搜索 (Sweep)

wandb sweep 提供自动化超参数优化功能,支持随机搜索、网格搜索、贝叶斯优化等。

6.1 定义 sweep 配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# sweep.yaml
program: train.py
method: bayes
metric:
name: val_acc
goal: maximize
parameters:
learning_rate:
min: 0.0001
max: 0.1
batch_size:
values: [16, 32, 64]
dropout:
values: [0.2, 0.3, 0.5]

6.2 初始化 sweep

1
wandb sweep sweep.yaml

会输出一个 sweep_id。

6.3 启动 agent

1
wandb agent username/project/sweep_id

agent 会自动运行多个实验,每次使用一组超参数,并将结果记录到 wandb。

也可以在代码中创建 sweep:

1
2
3
4
5
6
7
8
9
10
sweep_config = {
"method": "bayes",
"metric": {"name": "val_acc", "goal": "maximize"},
"parameters": {
"learning_rate": {"min": 0.0001, "max": 0.1},
"batch_size": {"values": [16, 32, 64]},
}
}
sweep_id = wandb.sweep(sweep_config, project="my-project")
wandb.agent(sweep_id, function=train_function)

其中 train_function 是训练脚本的入口函数,它会自动获取 wandb.config 中的超参数。

7. 与 PyTorch 集成的最佳实践

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import wandb
import torch
from torch import nn, optim

# 初始化 wandb
wandb.init(project="my-project", config={
"lr": 0.001,
"batch_size": 32,
"epochs": 10
})
config = wandb.config

# 定义模型、数据等
model = ...
optimizer = optim.Adam(model.parameters(), lr=config.lr)

# 训练循环
for epoch in range(config.epochs):
# 训练
train_loss = ...
# 验证
val_acc = ...

# 记录指标
wandb.log({
"train/loss": train_loss,
"val/accuracy": val_acc,
"epoch": epoch
})

# 保存模型
torch.save(model.state_dict(), "model.pth")
wandb.save("model.pth")

wandb.finish()

8. 最佳实践与技巧

  • 分组命名:使用 / 组织指标,如 train/lossval/loss,面板会自动分组。
  • 使用 wandb.init(reinit=True):在多个实验循环中重复初始化(如 Jupyter Notebook 中)。
  • 合理设置 step:如果记录非连续 step,手动传入 step 以保证图表对齐。
  • 记录模型权重直方图:使用 wandb.log({"gradients": wandb.Histogram(grad)}) 了解参数分布。
  • 使用 wandb.define_metric:自定义 x 轴(如 epoch、iteration),更灵活。
  • 利用 Tags:为运行添加标签(如 baseline, attention),便于筛选。
  • 版本控制集成:wandb 会记录 Git 提交信息,方便复现。