|
- '''
- 由于a100的适配性问题,使用训练环境前请使用平台的含有cuda11以上的推荐镜像在调试环境中调试自己的代码,
- 本示例的镜像地址是dockerhub.pcl.ac.cn:5000/user-images/openi:cuda111_python37_pytorch191,并
- 提交镜像,再切到训练环境训练已跑通的代码。
- 在训练环境中,上传的数据集会自动放在/dataset目录下,模型下载路径默认在/model下,请将模型输出位置指定到/model,
- 启智平台界面会提供/model目录下的文件下载。
- '''
-
- import torchvision
- from torch.autograd import Variable
- import torch
- import argparse
-
- # Training settings
- parser = argparse.ArgumentParser(description='Resnet50 Example')
- #数据集位置放在/dataset下
- parser.add_argument('--traindata', default="/dataset/train" ,help='path to train dataset')
- parser.add_argument('--testdata', default="/dataset/test" ,help='path to test dataset')
- parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
- parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
-
- if __name__ == '__main__':
- input_name = ['input']
- output_name = ['output']
- input = Variable(torch.randn(1, 3, 224, 224)).cuda()
- model = torchvision.models.resnet50(pretrained=True).cuda()
-
- #模型输出位置放在/model下
- torch.save(model, '/model/resnet50.pth')
-
|