Browse Source

更新 'gpu_new/train.py'

test_v20221116
wjtest1215 1 year ago
parent
commit
fa1e620b44
1 changed files with 5 additions and 0 deletions
  1. +5
    -0
      gpu_new/train.py

+ 5
- 0
gpu_new/train.py View File

@@ -39,6 +39,11 @@ parser.add_argument('--testdata', default="/dataset/test" ,help='path to test da
parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train') parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train')
parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch') parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')


# 参数声明
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Model().to(device)
optimizer = SGD(model.parameters(), lr=1e-1)

if __name__ == '__main__': if __name__ == '__main__':
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
#log output #log output


Loading…
Cancel
Save