From fa1e620b44106599e81d090a80f03d140ae3db40 Mon Sep 17 00:00:00 2001 From: wjtest1215 Date: Wed, 26 Oct 2022 17:17:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'gpu=5Fnew/train.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gpu_new/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gpu_new/train.py b/gpu_new/train.py index c6e3dfb..50da2a5 100755 --- a/gpu_new/train.py +++ b/gpu_new/train.py @@ -39,6 +39,11 @@ parser.add_argument('--testdata', default="/dataset/test" ,help='path to test da parser.add_argument('--epoch_size', type=int, default=1, help='how much epoch to train') parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch') +# 参数声明 +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +model = Model().to(device) +optimizer = SGD(model.parameters(), lr=1e-1) + if __name__ == '__main__': args, unknown = parser.parse_known_args() #log output