PyTorch深度学习框架基础——模型训练的可视化与调试

喜欢花科技君 2025-03-09 02:06:02
使用TensorBoard进行训练可视化

安装与配置:

pip install tensorboard

在代码中导入并初始化SummaryWriter:

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("runs/experiment_1") # 日志保存路径

记录指标:

标量数据(损失、准确率等):for epoch in range(epochs): train_loss = ... val_acc = ... writer.add_scalar("Loss/train", train_loss, epoch) writer.add_scalar("Accuracy/val", val_acc, epoch)学习率记录(结合调度器):writer.add_scalar("Learning Rate", optimizer.param_groups[0]['lr'], epoch)模型参数分布(直方图):for name, param in model.named_parameters(): writer.add_histogram(name, param, epoch)

启动TensorBoard:

tensorboard --logdir=runs/调试工具与技巧

检测NaN/Inf异常:

torch.autograd.set_detect_anomaly(True)with torch.autograd.detect_anomaly(): loss.backward() # 反向传播时检测异常,会显著降低速度,仅调试时使用

梯度检查:

注册钩子(Hook)监控梯度:def gradient_hook(grad): if torch.any(torch.isnan(grad)): print("NaN梯度 detected!")for param in model.parameters(): param.register_hook(gradient_hook)统计梯度范数:total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)writer.add_scalar("Gradient Norm", total_norm, epoch)

中间变量调试:

使用pdb设置断点:import pdb; pdb.set_trace() # 在需要检查的位置插入保存中间结果:torch.save({"tensor": intermediate_tensor}, "debug_tensor.pt")数据与模型验证

数据加载检查:

# 可视化一个batch的数据(图像示例)images, labels = next(iter(train_loader))grid = torchvision.utils.make_grid(images)writer.add_image("Sample Images", grid, 0)

模型结构可视化:

dummy_input = torch.randn(1, 3, 224, 224) # 输入示例writer.add_graph(model, dummy_input)过拟合/欠拟合分析监控训练与验证损失:过拟合:训练损失下降,验证损失上升。欠拟合:两者均高且停滞。调整策略: 增加数据增强、正则化(Dropout、权重衰减),或简化模型结构。学习率调度与记录scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)for epoch in range(epochs): train(...) scheduler.step() writer.add_scalar("Learning Rate", optimizer.param_groups[0]['lr'], epoch)高级工具:PyTorch Lightning

简化日志和调试流程:

import pytorch_lightning as plclass LitModel(pl.LightningModule): def training_step(self, batch, batch_idx): loss = ... self.log("train_loss", loss) # 自动记录到TensorBoard return losstrainer = pl.Trainer()trainer.fit(model, train_loader)

0 阅读:2

喜欢花科技君

简介:感谢大家的关注