You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helloworld_mlu.py 5.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # 1. 导入依赖库(Cambricon-PyTorch 镜像已预装 torch 和 torch_mlu)
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.utils.data import TensorDataset, DataLoader
  6. # 2. 验证 MLU 环境(关键:确认 MLU 设备可用)
  7. def check_mlu_env():
  8. print("=" * 50)
  9. print("【MLU 环境验证】")
  10. # 检查 torch_mlu 是否导入成功
  11. try:
  12. import torch_mlu
  13. print(f"✅ torch_mlu 版本: {torch_mlu.__version__}")
  14. except ImportError:
  15. raise ImportError("❌ 未找到 torch_mlu,请确认使用 Cambricon-PyTorch 镜像")
  16. # 检查 MLU 设备是否可用
  17. if torch_mlu.is_available():
  18. mlu_device_count = torch_mlu.device_count()
  19. print(f"✅ MLU 设备数量: {mlu_device_count}")
  20. print(f"✅ 当前使用 MLU 设备: {torch_mlu.get_device_name(0)}") # 默认使用第 0 张 MLU 卡
  21. return torch.device("mlu:0") # 返回 MLU 设备对象
  22. else:
  23. raise RuntimeError("❌ MLU 设备不可用,请确认集群已挂载 MLU 卡且驱动正常")
  24. # 3. 定义极简训练模型(单线性层,模拟分类/回归任务)
  25. class HelloMLUModel(nn.Module):
  26. def __init__(self, input_dim=10, output_dim=1):
  27. """
  28. 输入:input_dim 维特征(模拟样本特征)
  29. 输出:output_dim 维结果(模拟分类/回归输出)
  30. """
  31. super(HelloMLUModel, self).__init__()
  32. self.linear = nn.Linear(input_dim, output_dim) # 核心计算层
  33. self.relu = nn.ReLU() # 激活函数(增加非线性)
  34. def forward(self, x):
  35. """前向传播:定义数据流向"""
  36. out = self.linear(x)
  37. out = self.relu(out)
  38. return out
  39. # 4. 生成模拟训练数据(随机特征 + 随机标签,用于测试流程)
  40. def generate_sim_data(sample_num=100, input_dim=10):
  41. """
  42. 生成模拟数据集:
  43. - 特征:sample_num 个样本,每个样本 input_dim 维,服从正态分布
  44. - 标签:sample_num 个标签,服从正态分布(模拟回归任务)
  45. """
  46. features = torch.randn(sample_num, input_dim) # 特征 shape: (100, 10)
  47. labels = torch.randn(sample_num, 1) # 标签 shape: (100, 1)
  48. # 封装为 TensorDataset(便于 DataLoader 加载)
  49. dataset = TensorDataset(features, labels)
  50. # 构建 DataLoader(批量加载数据,模拟真实训练的数据读取流程)
  51. dataloader = DataLoader(dataset, batch_size=10, shuffle=True) # 批大小 10,打乱数据
  52. return dataloader
  53. # 5. 核心训练流程(适配 MLU 设备)
  54. def train_hello_mlu(model, dataloader, device, epochs=5):
  55. print("\n" + "=" * 50)
  56. print(f"【开始 MLU 训练】共 {epochs} 轮")
  57. print("=" * 50)
  58. # 定义损失函数(均方误差,适合回归任务)
  59. criterion = nn.MSELoss()
  60. # 定义优化器(随机梯度下降,更新模型参数)
  61. optimizer = optim.SGD(model.parameters(), lr=0.01) # 学习率 0.01
  62. # 将模型迁移到 MLU 设备
  63. model.to(device)
  64. # 训练循环(每轮遍历所有数据)
  65. for epoch in range(epochs):
  66. model.train() # 开启训练模式(影响 Dropout、BN 等层,此处无但规范保留)
  67. total_loss = 0.0 # 统计每轮总损失
  68. # 批量加载数据并训练
  69. for batch_idx, (batch_features, batch_labels) in enumerate(dataloader):
  70. # 1. 将数据迁移到 MLU 设备(关键:确保计算在 MLU 上执行)
  71. batch_features = batch_features.to(device)
  72. batch_labels = batch_labels.to(device)
  73. # 2. 梯度清零(避免上一轮梯度累积)
  74. optimizer.zero_grad()
  75. # 3. 前向传播:模型预测
  76. outputs = model(batch_features)
  77. # 4. 计算损失
  78. loss = criterion(outputs, batch_labels)
  79. # 5. 反向传播:计算梯度
  80. loss.backward()
  81. # 6. 优化器更新参数
  82. optimizer.step()
  83. # 累加损失(用于打印日志)
  84. total_loss += loss.item() * batch_features.size(0)
  85. # 计算每轮平均损失
  86. avg_loss = total_loss / len(dataloader.dataset)
  87. # 打印每轮训练结果(Hello World 级别的输出反馈)
  88. print(f"Epoch [{epoch + 1}/{epochs}] | 平均损失: {avg_loss:.6f} | 设备: {device}")
  89. print("\n" + "=" * 50)
  90. print("【MLU 训练完成】✅ Hello MLU Training!")
  91. print("=" * 50)
  92. # 6. 主函数(串联所有流程)
  93. if __name__ == "__main__":
  94. # 步骤1:验证 MLU 环境并获取设备
  95. mlu_device = check_mlu_env()
  96. # 步骤2:初始化模型(输入维度 10,输出维度 1)
  97. model = HelloMLUModel(input_dim=10, output_dim=1)
  98. print(f"\n✅ 模型初始化完成: {model}")
  99. # 步骤3:生成模拟训练数据
  100. train_dataloader = generate_sim_data(sample_num=100, input_dim=10)
  101. print(f"✅ 模拟数据生成完成: 共 {len(train_dataloader.dataset)} 个样本,{len(train_dataloader)} 个批次")
  102. # 步骤4:启动 MLU 训练
  103. train_hello_mlu(model, train_dataloader, mlu_device, epochs=5)

No Description