From 20a935d4065e6b05dcd5e688108cbe337519e95c Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Fri, 26 Aug 2022 14:54:45 +0800 Subject: [PATCH] [to #42322933] add gpt3 base finetune MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加 gpt3 中小模型单机单卡下的 finetune 代码 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9899004 --- modelscope/models/nlp/gpt3/modeling_gpt3.py | 22 +++-- .../models/nlp/palm_v2/modeling_palm.py | 10 +- .../nlp/palm_v2/palm_for_text_generation.py | 7 +- modelscope/preprocessors/nlp.py | 19 ++-- ...er.py => test_finetune_text_generation.py} | 92 +++++++++++++------ 5 files changed, 101 insertions(+), 49 deletions(-) rename tests/trainers/{test_text_generation_trainer.py => test_finetune_text_generation.py} (56%) diff --git a/modelscope/models/nlp/gpt3/modeling_gpt3.py b/modelscope/models/nlp/gpt3/modeling_gpt3.py index f7024713..4e30f697 100644 --- a/modelscope/models/nlp/gpt3/modeling_gpt3.py +++ b/modelscope/models/nlp/gpt3/modeling_gpt3.py @@ -16,9 +16,10 @@ import math import os from typing import Optional, Union +import addict import torch -from addict import Dict -from torch.nn import Dropout, Embedding, LayerNorm, Linear, Module, Softmax +from torch.nn import (CrossEntropyLoss, Dropout, Embedding, LayerNorm, Linear, + Module, Softmax) from torch.nn import functional as F from transformers.modeling_utils import PreTrainedModel @@ -308,20 +309,25 @@ class GPT3Model(PreTrainedModel): input_ids, attention_mask=None, position_ids=None, + labels=None, **kwargs): seq_length = input_ids.size(1) - if attention_mask is None: - attention_mask = torch.tril( - torch.ones((1, seq_length, seq_length), - dtype=torch.long, - device=input_ids.device)) + attention_mask = torch.tril( + torch.ones((1, 1, seq_length, seq_length), + dtype=torch.long, + device=input_ids.device)) if position_ids is None: position_ids = torch.arange( seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) logits = self.language_model(input_ids, attention_mask, position_ids) - return Dict(logits=logits) + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.view(-1, self.config.vocab_size), labels.view(-1)) + return addict.Dict(loss=loss, logits=logits) @classmethod def from_pretrained( diff --git a/modelscope/models/nlp/palm_v2/modeling_palm.py b/modelscope/models/nlp/palm_v2/modeling_palm.py index 1cbf4f58..ff6fd732 100644 --- a/modelscope/models/nlp/palm_v2/modeling_palm.py +++ b/modelscope/models/nlp/palm_v2/modeling_palm.py @@ -6,6 +6,7 @@ import subprocess from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union +import addict import json import numpy as np import torch @@ -726,10 +727,11 @@ class PalmForConditionalGeneration(PalmPreTrainedModel): self.palm.vocab_size, config.label_smoothing) - def forward(self, src, tgt, mask_src): - output = self.palm(src, tgt, mask_src)[0] - loss = self.loss(tgt, output) - return loss + def forward(self, input_ids, attention_mask, labels): + output = self.palm( + src=input_ids, tgt=labels, mask_src=attention_mask)[0] + loss = self.loss(labels, output) + return addict.Dict(loss=loss) class Translator(nn.Module): diff --git a/modelscope/models/nlp/palm_v2/palm_for_text_generation.py b/modelscope/models/nlp/palm_v2/palm_for_text_generation.py index 98aa56c7..ae92427e 100644 --- a/modelscope/models/nlp/palm_v2/palm_for_text_generation.py +++ b/modelscope/models/nlp/palm_v2/palm_for_text_generation.py @@ -63,14 +63,15 @@ class PalmForTextGeneration(TorchModel): } """ if self.training: - return {'loss': self.model(**input)} + return self.model(**input) else: - outputs = self.generator(input['src'], input['mask_src']) + outputs = self.generator(input['input_ids'], + input['attention_mask']) preds = outputs['predictions'] pred_ids_list = [ pred_batch[0].cpu().numpy().tolist() for pred_batch in preds ] - tgt_ids_list = input['tgt'].cpu().numpy().tolist() + tgt_ids_list = input['labels'].cpu().numpy().tolist() return { 'preds': self._evaluate_postprocess(pred_ids_list), 'tgts': self._evaluate_postprocess(tgt_ids_list) diff --git a/modelscope/preprocessors/nlp.py b/modelscope/preprocessors/nlp.py index 094cbfe2..345d3711 100644 --- a/modelscope/preprocessors/nlp.py +++ b/modelscope/preprocessors/nlp.py @@ -368,15 +368,20 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase): def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]: if self._mode == ModeKeys.INFERENCE: return super().__call__(data) - src_txt = data['src_txt'] - tgt_txt = data['tgt_txt'] - src_rst = super().__call__(src_txt) - tgt_rst = super().__call__(tgt_txt) + src_rst = super().__call__(data['src_txt']) + src_input_ids = src_rst['input_ids'] + src_attention_mask = src_rst['attention_mask'] + if 'tgt_txt' in data: + labels = super().__call__(data['tgt_txt'])['input_ids'] + else: + labels = src_input_ids[1:] + src_input_ids = src_input_ids[:-1] + src_attention_mask = src_attention_mask[:-1] return { - 'src': src_rst['input_ids'], - 'tgt': tgt_rst['input_ids'], - 'mask_src': src_rst['attention_mask'] + 'input_ids': src_input_ids, + 'attention_mask': src_attention_mask, + 'labels': labels, } diff --git a/tests/trainers/test_text_generation_trainer.py b/tests/trainers/test_finetune_text_generation.py similarity index 56% rename from tests/trainers/test_text_generation_trainer.py rename to tests/trainers/test_finetune_text_generation.py index a60bc903..8cdfdf01 100644 --- a/tests/trainers/test_text_generation_trainer.py +++ b/tests/trainers/test_finetune_text_generation.py @@ -6,14 +6,14 @@ import unittest from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Trainers -from modelscope.models.nlp.palm_v2 import PalmForTextGeneration +from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer from modelscope.utils.constant import ModelFile from modelscope.utils.test_utils import test_level -class TestTextGenerationTrainer(unittest.TestCase): +class TestFinetuneTextGeneration(unittest.TestCase): def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) @@ -21,40 +21,41 @@ class TestTextGenerationTrainer(unittest.TestCase): if not os.path.exists(self.tmp_dir): os.makedirs(self.tmp_dir) - self.model_id = 'damo/nlp_palm2.0_text-generation_english-base' - - # todo: Replace below scripts with MsDataset.load when the formal dataset service is ready from datasets import Dataset - dataset_dict = { + + src_dataset_dict = { 'src_txt': [ 'This is test sentence1-1', 'This is test sentence2-1', 'This is test sentence3-1' - ], + ] + } + src_tgt_dataset_dict = { + 'src_txt': + src_dataset_dict['src_txt'], 'tgt_txt': [ 'This is test sentence1-2', 'This is test sentence2-2', 'This is test sentence3-2' ] } - dataset = Dataset.from_dict(dataset_dict) - class MsDatasetDummy(MsDataset): + self.src_dataset = MsDataset(Dataset.from_dict(src_dataset_dict)) + self.src_tgt_dataset = MsDataset( + Dataset.from_dict(src_tgt_dataset_dict)) - def __len__(self): - return len(self._hf_ds) - - self.dataset = MsDatasetDummy(dataset) + self.max_epochs = 3 def tearDown(self): shutil.rmtree(self.tmp_dir) super().tearDown() @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_trainer(self): + def test_trainer_with_palm(self): kwargs = dict( - model=self.model_id, - train_dataset=self.dataset, - eval_dataset=self.dataset, + model='damo/nlp_palm2.0_text-generation_english-base', + train_dataset=self.src_tgt_dataset, + eval_dataset=self.src_tgt_dataset, + max_epochs=self.max_epochs, work_dir=self.tmp_dir) trainer = build_trainer( @@ -62,30 +63,67 @@ class TestTextGenerationTrainer(unittest.TestCase): trainer.train() results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(3): + for i in range(self.max_epochs): self.assertIn(f'epoch_{i+1}.pth', results_files) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') - def test_trainer_with_model_and_args(self): - tmp_dir = tempfile.TemporaryDirectory().name - if not os.path.exists(tmp_dir): - os.makedirs(tmp_dir) + def test_trainer_with_palm_with_model_and_args(self): - cache_path = snapshot_download(self.model_id) + cache_path = snapshot_download( + 'damo/nlp_palm2.0_text-generation_english-base') model = PalmForTextGeneration.from_pretrained(cache_path) kwargs = dict( cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), model=model, - train_dataset=self.dataset, - eval_dataset=self.dataset, - max_epochs=2, + train_dataset=self.src_tgt_dataset, + eval_dataset=self.src_tgt_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer = build_trainer(default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_trainer_with_gpt3(self): + + kwargs = dict( + model='damo/nlp_gpt3_text-generation_chinese-base', + train_dataset=self.src_dataset, + eval_dataset=self.src_dataset, + max_epochs=self.max_epochs, + work_dir=self.tmp_dir) + + trainer = build_trainer( + name=Trainers.nlp_base_trainer, default_args=kwargs) + trainer.train() + results_files = os.listdir(self.tmp_dir) + self.assertIn(f'{trainer.timestamp}.log.json', results_files) + for i in range(self.max_epochs): + self.assertIn(f'epoch_{i+1}.pth', results_files) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_trainer_with_gpt3_with_model_and_args(self): + + cache_path = snapshot_download( + 'damo/nlp_gpt3_text-generation_chinese-base') + model = GPT3ForTextGeneration.from_pretrained(cache_path) + kwargs = dict( + cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION), + model=model, + train_dataset=self.src_dataset, + eval_dataset=self.src_dataset, + max_epochs=self.max_epochs, work_dir=self.tmp_dir) trainer = build_trainer(default_args=kwargs) trainer.train() results_files = os.listdir(self.tmp_dir) self.assertIn(f'{trainer.timestamp}.log.json', results_files) - for i in range(2): + for i in range(self.max_epochs): self.assertIn(f'epoch_{i+1}.pth', results_files) @unittest.skip