Hydra 与 OmegaConf 入门教程
一、为什么需要 Hydra?
问题场景
你正在训练一个深度学习模型,每次调整参数都要:
手动修改代码
记住一堆命令行参数
不同实验的配置混乱难以管理
Hydra 的解决方案
把所有参数 都写在 YAML 配置文件中,运行时想改就改,还能自动保存每次实验的配置。
二、5 分钟快速上手
1. 安装
1 pip install hydra-core omegaconf
2. 最简单的例子
创建 config.yaml:
1 2 3 4 5 model: resnet18 batch_size: 32 learning_rate: 0.001 epochs: 10
创建 train.py:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 import hydrafrom omegaconf import DictConfig@hydra.main(config_path="." , config_name="config" ) def main (cfg: DictConfig ): print (f"模型: {cfg.model} " ) print (f"批次大小: {cfg.batch_size} " ) print (f"学习率: {cfg.learning_rate} " ) print (f"训练轮数: {cfg.epochs} " ) if __name__ == "__main__" : main()
运行:
输出:
1 2 3 4 模型: resnet18 批次大小: 32 学习率: 0.001 训练轮数: 10
就这么简单!你已经学会了 Hydra 的基本使用。
三、OmegaConf 基础操作
OmegaConf 是 Hydra 用来读取和操作配置的工具。cfg 就是一个 OmegaConf 对象。
1. 访问配置(两种方式)
1 2 3 4 5 6 7 8 9 cfg.model cfg.batch_size cfg["model" ] cfg["train" ]["learning_rate" ]
2. 检查配置是否存在
1 2 3 4 5 if "model" in cfg: print (cfg.model) model_name = cfg.get("model" , "default_model" )
3. 打印完整配置
1 2 3 4 5 6 7 from omegaconf import OmegaConfprint (OmegaConf.to_yaml(cfg))print (cfg)
四、组织复杂配置
当配置变多时,可以分文件组织:
项目结构
1 2 3 4 5 6 7 8 9 10 my_project/ ├── config/ │ ├── config.yaml # 主配置 │ ├── model/ │ │ ├── resnet.yaml │ │ └── vit.yaml │ └── data/ │ ├── cifar10.yaml │ └── imagenet.yaml └── train.py
主配置 config/config.yaml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 defaults: - model: resnet - data: cifar10 - _self_ train: batch_size: 32 learning_rate: 0.001 epochs: 100 device: cuda logging: save_dir: ./logs log_interval: 10
模型配置 config/model/resnet.yaml
1 2 3 4 name: resnet18 pretrained: true num_classes: 10
数据配置 config/data/cifar10.yaml
1 2 3 4 5 name: cifar10 root: ./data download: true batch_size: 64
使用
1 2 3 4 5 6 7 8 9 10 11 12 @hydra.main(config_path="config" , config_name="config" ) def main (cfg: DictConfig ): print (f"模型名称: {cfg.model.name} " ) print (f"预训练: {cfg.model.pretrained} " ) print (f"数据集: {cfg.data.name} " ) print (f"批次大小: {cfg.train.batch_size} " ) print (f"学习率: {cfg.train.learning_rate} " )
命令行覆盖
1 2 3 4 5 6 7 8 python train.py model=vit python train.py train.batch_size=128 train.learning_rate=0.0001 python train.py data=imagenet data.batch_size=128
五、OmegaConf 变量引用(避免重复)
问题:很多地方需要同一个路径
1 2 3 4 data_root: /home/user/data train_path: /home/user/data/train val_path: /home/user/data/val
解决方案:变量引用
1 2 3 4 5 6 7 8 9 10 11 data_root: /home/user/data train_path: ${data_root}/train val_path: ${data_root}/val test_path: ${data_root}/test paths: root: /home/user/data train: ${paths.root}/train checkpoint: ${paths.root}/checkpoints/${model.name}
六、动态创建对象(target )
这是 Hydra 最强大的功能之一:根据配置自动创建 PyTorch 模型、优化器等对象。
示例:动态创建优化器
1 2 3 4 5 6 7 8 9 10 optimizer: _target_: torch.optim.Adam lr: 0.001 weight_decay: 1e-5 scheduler: _target_: torch.optim.lr_scheduler.StepLR step_size: 30 gamma: 0.1
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import hydraimport torch@hydra.main(config_path="." , config_name="config" ) def main (cfg ): model = torch.nn.Linear(784 , 10 ) optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters()) scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer) print (f"优化器类型: {type (optimizer)} " ) print (f"学习率: {optimizer.param_groups[0 ]['lr' ]} " )
切换优化器只需改配置
1 2 3 4 5 optimizer: _target_: torch.optim.SGD lr: 0.01 momentum: 0.9
七、结构化配置(类型安全)
用 Python 类定义配置,IDE 会自动补全和类型检查。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 from dataclasses import dataclassimport hydrafrom omegaconf import OmegaConf@dataclass class ModelConfig : name: str = "resnet18" pretrained: bool = True num_classes: int = 10 @dataclass class TrainConfig : batch_size: int = 32 learning_rate: float = 0.001 epochs: int = 100 @dataclass class Config : model: ModelConfig = ModelConfig() train: TrainConfig = TrainConfig() seed: int = 42 @hydra.main(config_path=None ) def main (cfg: Config ): print (cfg.model.name) print (cfg.train.batch_size) if __name__ == "__main__" : cfg = OmegaConf.structured(Config) main(cfg)
八、实用技巧
1. 自动保存配置
Hydra 会自动保存每次运行的配置到输出目录:
1 2 3 4 5 6 outputs/ └── 2024-01-15/ └── 10-30-45/ └── .hydra/ ├── config.yaml # 完整配置 └── hydra.yaml # Hydra 配置
2. 固定随机种子
1 2 3 4 5 6 import lightning as L@hydra.main(... ) def main (cfg ): if cfg.get("seed" ): L.seed_everything(cfg.seed)
3. 调试模式
1 2 3 4 5 python train.py --config-name debug python train.py train.epochs=1 train.batch_size=2
九、常见问题
Q1: 为什么我的配置修改后没生效?
A: 检查是否在函数开头使用 cfg,Hydra 会在装饰器运行时加载配置。
Q2: 如何访问嵌套配置?
A: 使用点号:cfg.model.encoder.hidden_size
Q3: 命令行参数优先级?
A: 命令行 > 配置文件 > 默认值
Q4: 如何看到最终生效的配置?
A: 在代码中添加:
1 print (OmegaConf.to_yaml(cfg))
十、完整示例:训练 MNIST
config.yaml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 defaults: - _self_ model: _target_: torch.nn.Sequential _args_: - _target_: torch.nn.Linear in_features: 784 out_features: 128 - _target_: torch.nn.ReLU - _target_: torch.nn.Linear in_features: 128 out_features: 10 optimizer: _target_: torch.optim.Adam lr: 0.001 train: batch_size: 64 epochs: 5 device: cuda data: root: ./data batch_size: 64
train_mnist.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 import hydraimport torchimport torch.nn as nnfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transformsfrom omegaconf import DictConfig, OmegaConf@hydra.main(config_path="." , config_name="config" ) def main (cfg: DictConfig ): print ("配置:" ) print (OmegaConf.to_yaml(cfg)) transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1 )) ]) train_dataset = datasets.MNIST( root=cfg.data.root, train=True , download=True , transform=transform ) train_loader = DataLoader( train_dataset, batch_size=cfg.train.batch_size, shuffle=True ) model = hydra.utils.instantiate(cfg.model) device = cfg.train.device if torch.cuda.is_available() else "cpu" model = model.to(device) optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters()) criterion = nn.CrossEntropyLoss() for epoch in range (cfg.train.epochs): total_loss = 0 for batch_idx, (data, target) in enumerate (train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = criterion(output, target) loss.backward() optimizer.step() total_loss += loss.item() if batch_idx % 100 == 0 : print (f"Epoch {epoch+1 } /{cfg.train.epochs} , " f"Batch {batch_idx} /{len (train_loader)} , " f"Loss: {loss.item():.4 f} " ) avg_loss = total_loss / len (train_loader) print (f"Epoch {epoch+1 } 平均损失: {avg_loss:.4 f} " ) if __name__ == "__main__" : main()
运行:
1 2 3 4 5 6 7 8 python train_mnist.py python train_mnist.py train.batch_size=128 optimizer.lr=0.0005 python train_mnist.py optimizer._target_=torch.optim.SGD optimizer.lr=0.01 optimizer.momentum=0.9
总结
功能
代码示例
加载配置
@hydra.main(config_path=".", config_name="config")
访问配置
cfg.model.name
打印配置
print(OmegaConf.to_yaml(cfg))
命令行覆盖
python train.py train.batch_size=128
变量引用
train_path: ${data_root}/train
动态创建对象
hydra.utils.instantiate(cfg.optimizer, params=...)
结构化配置
@dataclass class Config: