|
|
@@ -1,7 +1,5 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
import glob |
|
|
|
import os |
|
|
|
import os.path as osp |
|
|
|
import shutil |
|
|
|
import unittest |
|
|
|
|
|
|
@@ -98,8 +96,9 @@ class TestOfaTrainer(unittest.TestCase): |
|
|
|
trainer = build_trainer(name=Trainers.ofa, default_args=args) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, |
|
|
|
os.listdir(os.path.join(WORKSPACE, 'output'))) |
|
|
|
self.assertIn( |
|
|
|
ModelFile.TORCH_MODEL_BIN_FILE, |
|
|
|
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) |
|
|
|
shutil.rmtree(WORKSPACE) |
|
|
|
|
|
|
|
|
|
|
|