import torch import moxing as mox import os import argparse import torch_npu # 确保安装了 torch_npu # 创建参数解析器 parser = argparse.ArgumentParser(description='Training script with output path option') parser.add_argument('--dataset_input', default=None, type=str, help='dataset_input, where the dataset is stored.') parser.add_argument('--output', type=str, default='', help='Output path for saving the result') args = parser.parse_args() # 检查 NPU 是否可用 try: import torch_npu if torch_npu.npu.is_available(): device = torch.device('npu') print("Using NPU for training.") else: device = torch.device('cpu') print("NPU is not available, using CPU for training.") except ImportError: device = torch.device('cpu') print("torch_npu module not found, using CPU for training.") # 模拟一个简单的训练过程 # 这里其实不涉及真正的模型训练,只是作为示例流程展示 # 输出 Hello World result = "Hello World" print(result) # 将结果保存到本地文件 local_result_path = 'hellworld.pth' with open(local_result_path, 'w') as f: f.write(result) # 根据参数确定 OBS 路径 obs_result_path = args.output if args.output else 'obs://nudt-cloudream2/cds/trainResult/model/hellworld.pth' # 如果传入的路径不是以斜杠结尾,则添加斜杠 if obs_result_path and not obs_result_path.endswith('/'): obs_result_path += '/' obs_result_path += 'hellworld.pth' try: mox.file.copy(local_result_path, obs_result_path) print(f"Result has been successfully saved to {obs_result_path}") except Exception as e: print(f"Failed to save result to OBS: {e}")