|
|
|
@@ -50,17 +50,11 @@ class TestTextGenerationTrainer(unittest.TestCase): |
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_trainer(self): |
|
|
|
|
|
|
|
def cfg_modify_fn(cfg): |
|
|
|
cfg.preprocessor.type = 'text-gen-tokenizer' |
|
|
|
return cfg |
|
|
|
|
|
|
|
kwargs = dict( |
|
|
|
model=self.model_id, |
|
|
|
train_dataset=self.dataset, |
|
|
|
eval_dataset=self.dataset, |
|
|
|
work_dir=self.tmp_dir, |
|
|
|
cfg_modify_fn=cfg_modify_fn, |
|
|
|
model_revision='beta') |
|
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
trainer = build_trainer( |
|
|
|
name='NlpEpochBasedTrainer', default_args=kwargs) |
|
|
|
@@ -76,7 +70,7 @@ class TestTextGenerationTrainer(unittest.TestCase): |
|
|
|
if not os.path.exists(tmp_dir): |
|
|
|
os.makedirs(tmp_dir) |
|
|
|
|
|
|
|
cache_path = snapshot_download(self.model_id, revision='beta') |
|
|
|
cache_path = snapshot_download(self.model_id) |
|
|
|
model = PalmForTextGeneration.from_pretrained(cache_path) |
|
|
|
kwargs = dict( |
|
|
|
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), |
|
|
|
|