You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helloworld.py 1.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import torch
  2. import moxing as mox
  3. import os
  4. import argparse
  5. import torch_npu # 确保安装了 torch_npu
  6. # 创建参数解析器
  7. parser = argparse.ArgumentParser(description='Training script with output path option')
  8. parser.add_argument('--dataset_input', default=None, type=str, help='dataset_input, where the dataset is stored.')
  9. parser.add_argument('--output', type=str, default='', help='Output path for saving the result')
  10. args = parser.parse_args()
  11. # 检查 NPU 是否可用
  12. try:
  13. import torch_npu
  14. if torch_npu.npu.is_available():
  15. device = torch.device('npu')
  16. print("Using NPU for training.")
  17. else:
  18. device = torch.device('cpu')
  19. print("NPU is not available, using CPU for training.")
  20. except ImportError:
  21. device = torch.device('cpu')
  22. print("torch_npu module not found, using CPU for training.")
  23. # 模拟一个简单的训练过程
  24. # 这里其实不涉及真正的模型训练,只是作为示例流程展示
  25. # 输出 Hello World
  26. result = "Hello World"
  27. print(result)
  28. # 将结果保存到本地文件
  29. local_result_path = 'hellworld.pth'
  30. with open(local_result_path, 'w') as f:
  31. f.write(result)
  32. # 根据参数确定 OBS 路径
  33. obs_result_path = args.output if args.output else 'obs://nudt-cloudream2/cds/trainResult/model/hellworld.pth'
  34. # 如果传入的路径不是以斜杠结尾,则添加斜杠
  35. if obs_result_path and not obs_result_path.endswith('/'):
  36. obs_result_path += '/'
  37. obs_result_path += 'hellworld.pth'
  38. try:
  39. mox.file.copy(local_result_path, obs_result_path)
  40. print(f"Result has been successfully saved to {obs_result_path}")
  41. except Exception as e:
  42. print(f"Failed to save result to OBS: {e}")

No Description