Browse Source

[to #42322933] add gpt3 base finetune

添加 gpt3 中小模型单机单卡下的 finetune 代码
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9899004
master
hemu.zp yingda.chen 3 years ago
parent
commit
20a935d406
5 changed files with 101 additions and 49 deletions
  1. +14
    -8
      modelscope/models/nlp/gpt3/modeling_gpt3.py
  2. +6
    -4
      modelscope/models/nlp/palm_v2/modeling_palm.py
  3. +4
    -3
      modelscope/models/nlp/palm_v2/palm_for_text_generation.py
  4. +12
    -7
      modelscope/preprocessors/nlp.py
  5. +65
    -27
      tests/trainers/test_finetune_text_generation.py

+ 14
- 8
modelscope/models/nlp/gpt3/modeling_gpt3.py View File

@@ -16,9 +16,10 @@ import math
import os
from typing import Optional, Union

import addict
import torch
from addict import Dict
from torch.nn import Dropout, Embedding, LayerNorm, Linear, Module, Softmax
from torch.nn import (CrossEntropyLoss, Dropout, Embedding, LayerNorm, Linear,
Module, Softmax)
from torch.nn import functional as F
from transformers.modeling_utils import PreTrainedModel

@@ -308,20 +309,25 @@ class GPT3Model(PreTrainedModel):
input_ids,
attention_mask=None,
position_ids=None,
labels=None,
**kwargs):
seq_length = input_ids.size(1)
if attention_mask is None:
attention_mask = torch.tril(
torch.ones((1, seq_length, seq_length),
dtype=torch.long,
device=input_ids.device))
attention_mask = torch.tril(
torch.ones((1, 1, seq_length, seq_length),
dtype=torch.long,
device=input_ids.device))
if position_ids is None:
position_ids = torch.arange(
seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)

logits = self.language_model(input_ids, attention_mask, position_ids)
return Dict(logits=logits)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.config.vocab_size), labels.view(-1))
return addict.Dict(loss=loss, logits=logits)

@classmethod
def from_pretrained(


+ 6
- 4
modelscope/models/nlp/palm_v2/modeling_palm.py View File

@@ -6,6 +6,7 @@ import subprocess
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import addict
import json
import numpy as np
import torch
@@ -726,10 +727,11 @@ class PalmForConditionalGeneration(PalmPreTrainedModel):
self.palm.vocab_size,
config.label_smoothing)

def forward(self, src, tgt, mask_src):
output = self.palm(src, tgt, mask_src)[0]
loss = self.loss(tgt, output)
return loss
def forward(self, input_ids, attention_mask, labels):
output = self.palm(
src=input_ids, tgt=labels, mask_src=attention_mask)[0]
loss = self.loss(labels, output)
return addict.Dict(loss=loss)


class Translator(nn.Module):


+ 4
- 3
modelscope/models/nlp/palm_v2/palm_for_text_generation.py View File

@@ -63,14 +63,15 @@ class PalmForTextGeneration(TorchModel):
}
"""
if self.training:
return {'loss': self.model(**input)}
return self.model(**input)
else:
outputs = self.generator(input['src'], input['mask_src'])
outputs = self.generator(input['input_ids'],
input['attention_mask'])
preds = outputs['predictions']
pred_ids_list = [
pred_batch[0].cpu().numpy().tolist() for pred_batch in preds
]
tgt_ids_list = input['tgt'].cpu().numpy().tolist()
tgt_ids_list = input['labels'].cpu().numpy().tolist()
return {
'preds': self._evaluate_postprocess(pred_ids_list),
'tgts': self._evaluate_postprocess(tgt_ids_list)


+ 12
- 7
modelscope/preprocessors/nlp.py View File

@@ -368,15 +368,20 @@ class TextGenerationPreprocessor(NLPTokenizerPreprocessorBase):
def __call__(self, data: Union[Dict, str]) -> Dict[str, Any]:
if self._mode == ModeKeys.INFERENCE:
return super().__call__(data)
src_txt = data['src_txt']
tgt_txt = data['tgt_txt']
src_rst = super().__call__(src_txt)
tgt_rst = super().__call__(tgt_txt)
src_rst = super().__call__(data['src_txt'])
src_input_ids = src_rst['input_ids']
src_attention_mask = src_rst['attention_mask']
if 'tgt_txt' in data:
labels = super().__call__(data['tgt_txt'])['input_ids']
else:
labels = src_input_ids[1:]
src_input_ids = src_input_ids[:-1]
src_attention_mask = src_attention_mask[:-1]

return {
'src': src_rst['input_ids'],
'tgt': tgt_rst['input_ids'],
'mask_src': src_rst['attention_mask']
'input_ids': src_input_ids,
'attention_mask': src_attention_mask,
'labels': labels,
}




tests/trainers/test_text_generation_trainer.py → tests/trainers/test_finetune_text_generation.py View File

@@ -6,14 +6,14 @@ import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import Trainers
from modelscope.models.nlp.palm_v2 import PalmForTextGeneration
from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.test_utils import test_level


class TestTextGenerationTrainer(unittest.TestCase):
class TestFinetuneTextGeneration(unittest.TestCase):

def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
@@ -21,40 +21,41 @@ class TestTextGenerationTrainer(unittest.TestCase):
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)

self.model_id = 'damo/nlp_palm2.0_text-generation_english-base'

# todo: Replace below scripts with MsDataset.load when the formal dataset service is ready
from datasets import Dataset
dataset_dict = {

src_dataset_dict = {
'src_txt': [
'This is test sentence1-1', 'This is test sentence2-1',
'This is test sentence3-1'
],
]
}
src_tgt_dataset_dict = {
'src_txt':
src_dataset_dict['src_txt'],
'tgt_txt': [
'This is test sentence1-2', 'This is test sentence2-2',
'This is test sentence3-2'
]
}
dataset = Dataset.from_dict(dataset_dict)

class MsDatasetDummy(MsDataset):
self.src_dataset = MsDataset(Dataset.from_dict(src_dataset_dict))
self.src_tgt_dataset = MsDataset(
Dataset.from_dict(src_tgt_dataset_dict))

def __len__(self):
return len(self._hf_ds)

self.dataset = MsDatasetDummy(dataset)
self.max_epochs = 3

def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
def test_trainer_with_palm(self):

kwargs = dict(
model=self.model_id,
train_dataset=self.dataset,
eval_dataset=self.dataset,
model='damo/nlp_palm2.0_text-generation_english-base',
train_dataset=self.src_tgt_dataset,
eval_dataset=self.src_tgt_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)

trainer = build_trainer(
@@ -62,30 +63,67 @@ class TestTextGenerationTrainer(unittest.TestCase):
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(3):
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_model_and_args(self):
tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
def test_trainer_with_palm_with_model_and_args(self):

cache_path = snapshot_download(self.model_id)
cache_path = snapshot_download(
'damo/nlp_palm2.0_text-generation_english-base')
model = PalmForTextGeneration.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.dataset,
eval_dataset=self.dataset,
max_epochs=2,
train_dataset=self.src_tgt_dataset,
eval_dataset=self.src_tgt_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_with_gpt3(self):

kwargs = dict(
model='damo/nlp_gpt3_text-generation_chinese-base',
train_dataset=self.src_dataset,
eval_dataset=self.src_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)

trainer = build_trainer(
name=Trainers.nlp_base_trainer, default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_trainer_with_gpt3_with_model_and_args(self):

cache_path = snapshot_download(
'damo/nlp_gpt3_text-generation_chinese-base')
model = GPT3ForTextGeneration.from_pretrained(cache_path)
kwargs = dict(
cfg_file=os.path.join(cache_path, ModelFile.CONFIGURATION),
model=model,
train_dataset=self.src_dataset,
eval_dataset=self.src_dataset,
max_epochs=self.max_epochs,
work_dir=self.tmp_dir)

trainer = build_trainer(default_args=kwargs)
trainer.train()
results_files = os.listdir(self.tmp_dir)
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
for i in range(2):
for i in range(self.max_epochs):
self.assertIn(f'epoch_{i+1}.pth', results_files)

@unittest.skip

Loading…
Cancel
Save