|
|
@@ -6,14 +6,14 @@ import unittest |
|
|
|
|
|
|
|
from modelscope.hub.snapshot_download import snapshot_download |
|
|
|
from modelscope.metainfo import Trainers |
|
|
|
from modelscope.models.nlp.palm_v2 import PalmForTextGeneration |
|
|
|
from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration |
|
|
|
from modelscope.msdatasets import MsDataset |
|
|
|
from modelscope.trainers import build_trainer |
|
|
|
from modelscope.utils.constant import ModelFile |
|
|
|
from modelscope.utils.test_utils import test_level |
|
|
|
|
|
|
|
|
|
|
|
class TestTextGenerationTrainer(unittest.TestCase): |
|
|
|
class TestFinetuneTextGeneration(unittest.TestCase): |
|
|
|
|
|
|
|
def setUp(self): |
|
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) |
|
|
@@ -21,40 +21,41 @@ class TestTextGenerationTrainer(unittest.TestCase): |
|
|
|
if not os.path.exists(self.tmp_dir): |
|
|
|
os.makedirs(self.tmp_dir) |
|
|
|
|
|
|
|
self.model_id = 'damo/nlp_palm2.0_text-generation_english-base' |
|
|
|
|
|
|
|
# todo: Replace below scripts with MsDataset.load when the formal dataset service is ready |
|
|
|
from datasets import Dataset |
|
|
|
dataset_dict = { |
|
|
|
|
|
|
|
src_dataset_dict = { |
|
|
|
'src_txt': [ |
|
|
|
'This is test sentence1-1', 'This is test sentence2-1', |
|
|
|
'This is test sentence3-1' |
|
|
|
], |
|
|
|
] |
|
|
|
} |
|
|
|
src_tgt_dataset_dict = { |
|
|
|
'src_txt': |
|
|
|
src_dataset_dict['src_txt'], |
|
|
|
'tgt_txt': [ |
|
|
|
'This is test sentence1-2', 'This is test sentence2-2', |
|
|
|
'This is test sentence3-2' |
|
|
|
] |
|
|
|
} |
|
|
|
dataset = Dataset.from_dict(dataset_dict) |
|
|
|
|
|
|
|
class MsDatasetDummy(MsDataset): |
|
|
|
self.src_dataset = MsDataset(Dataset.from_dict(src_dataset_dict)) |
|
|
|
self.src_tgt_dataset = MsDataset( |
|
|
|
Dataset.from_dict(src_tgt_dataset_dict)) |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
return len(self._hf_ds) |
|
|
|
|
|
|
|
self.dataset = MsDatasetDummy(dataset) |
|
|
|
self.max_epochs = 3 |
|
|
|
|
|
|
|
def tearDown(self): |
|
|
|
shutil.rmtree(self.tmp_dir) |
|
|
|
super().tearDown() |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_trainer(self): |
|
|
|
def test_trainer_with_palm(self): |
|
|
|
|
|
|
|
kwargs = dict( |
|
|
|
model=self.model_id, |
|
|
|
train_dataset=self.dataset, |
|
|
|
eval_dataset=self.dataset, |
|
|
|
model='damo/nlp_palm2.0_text-generation_english-base', |
|
|
|
train_dataset=self.src_tgt_dataset, |
|
|
|
eval_dataset=self.src_tgt_dataset, |
|
|
|
max_epochs=self.max_epochs, |
|
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
trainer = build_trainer( |
|
|
@@ -62,30 +63,67 @@ class TestTextGenerationTrainer(unittest.TestCase): |
|
|
|
trainer.train() |
|
|
|
results_files = os.listdir(self.tmp_dir) |
|
|
|
self.assertIn(f'{trainer.timestamp}.log.json', results_files) |
|
|
|
for i in range(3): |
|
|
|
for i in range(self.max_epochs): |
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_trainer_with_model_and_args(self): |
|
|
|
tmp_dir = tempfile.TemporaryDirectory().name |
|
|
|
if not os.path.exists(tmp_dir): |
|
|
|
os.makedirs(tmp_dir) |
|
|
|
def test_trainer_with_palm_with_model_and_args(self): |
|
|
|
|
|
|
|
cache_path = snapshot_download(self.model_id) |
|
|
|
cache_path = snapshot_download( |
|
|
|
'damo/nlp_palm2.0_text-generation_english-base') |
|
|
|
model = PalmForTextGeneration.from_pretrained(cache_path) |
|
|
|
kwargs = dict( |
|
|
|
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), |
|
|
|
model=model, |
|
|
|
train_dataset=self.dataset, |
|
|
|
eval_dataset=self.dataset, |
|
|
|
max_epochs=2, |
|
|
|
train_dataset=self.src_tgt_dataset, |
|
|
|
eval_dataset=self.src_tgt_dataset, |
|
|
|
max_epochs=self.max_epochs, |
|
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
trainer = build_trainer(default_args=kwargs) |
|
|
|
trainer.train() |
|
|
|
results_files = os.listdir(self.tmp_dir) |
|
|
|
self.assertIn(f'{trainer.timestamp}.log.json', results_files) |
|
|
|
for i in range(self.max_epochs): |
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') |
|
|
|
def test_trainer_with_gpt3(self): |
|
|
|
|
|
|
|
kwargs = dict( |
|
|
|
model='damo/nlp_gpt3_text-generation_chinese-base', |
|
|
|
train_dataset=self.src_dataset, |
|
|
|
eval_dataset=self.src_dataset, |
|
|
|
max_epochs=self.max_epochs, |
|
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
trainer = build_trainer( |
|
|
|
name=Trainers.nlp_base_trainer, default_args=kwargs) |
|
|
|
trainer.train() |
|
|
|
results_files = os.listdir(self.tmp_dir) |
|
|
|
self.assertIn(f'{trainer.timestamp}.log.json', results_files) |
|
|
|
for i in range(self.max_epochs): |
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_trainer_with_gpt3_with_model_and_args(self): |
|
|
|
|
|
|
|
cache_path = snapshot_download( |
|
|
|
'damo/nlp_gpt3_text-generation_chinese-base') |
|
|
|
model = GPT3ForTextGeneration.from_pretrained(cache_path) |
|
|
|
kwargs = dict( |
|
|
|
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), |
|
|
|
model=model, |
|
|
|
train_dataset=self.src_dataset, |
|
|
|
eval_dataset=self.src_dataset, |
|
|
|
max_epochs=self.max_epochs, |
|
|
|
work_dir=self.tmp_dir) |
|
|
|
|
|
|
|
trainer = build_trainer(default_args=kwargs) |
|
|
|
trainer.train() |
|
|
|
results_files = os.listdir(self.tmp_dir) |
|
|
|
self.assertIn(f'{trainer.timestamp}.log.json', results_files) |
|
|
|
for i in range(2): |
|
|
|
for i in range(self.max_epochs): |
|
|
|
self.assertIn(f'epoch_{i+1}.pth', results_files) |
|
|
|
|
|
|
|
@unittest.skip |