|
- import torch
- import moxing as mox
- import os
- import argparse
- import torch_npu # 确保安装了 torch_npu
-
- # 创建参数解析器
- parser = argparse.ArgumentParser(description='Training script with output path option')
- parser.add_argument('--dataset_input', default=None, type=str, help='dataset_input, where the dataset is stored.')
- parser.add_argument('--output', type=str, default='', help='Output path for saving the result')
- args = parser.parse_args()
-
- # 检查 NPU 是否可用
- try:
- import torch_npu
-
- if torch_npu.npu.is_available():
- device = torch.device('npu')
- print("Using NPU for training.")
- else:
- device = torch.device('cpu')
- print("NPU is not available, using CPU for training.")
- except ImportError:
- device = torch.device('cpu')
- print("torch_npu module not found, using CPU for training.")
-
- # 模拟一个简单的训练过程
- # 这里其实不涉及真正的模型训练,只是作为示例流程展示
- # 输出 Hello World
- result = "Hello World"
- print(result)
-
- # 将结果保存到本地文件
- local_result_path = 'hellworld.pth'
- with open(local_result_path, 'w') as f:
- f.write(result)
-
- # 根据参数确定 OBS 路径
- obs_result_path = args.output if args.output else 'obs://nudt-cloudream2/cds/trainResult/model/hellworld.pth'
- # 如果传入的路径不是以斜杠结尾,则添加斜杠
- if obs_result_path and not obs_result_path.endswith('/'):
- obs_result_path += '/'
- obs_result_path += 'hellworld.pth'
-
- try:
- mox.file.copy(local_result_path, obs_result_path)
- print(f"Result has been successfully saved to {obs_result_path}")
- except Exception as e:
- print(f"Failed to save result to OBS: {e}")
|