Browse Source

[to #42322933]修正finetune text ranking bugs

之前的finetune代码当dataset最后长度不足制定batch size时会出错,现已修正
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10524066
master
zhangyanzhao.zyz yingda.chen 2 years ago
parent
commit
781fe49d63
3 changed files with 17 additions and 15 deletions
  1. +11
    -8
      modelscope/models/nlp/bert/text_ranking.py
  2. +1
    -2
      modelscope/msdatasets/task_datasets/text_ranking_dataset.py
  3. +5
    -5
      tests/trainers/test_finetune_text_ranking.py

+ 11
- 8
modelscope/models/nlp/bert/text_ranking.py View File

@@ -18,14 +18,12 @@ logger = logging.get_logger(__name__)
@MODELS.register_module(Tasks.text_ranking, module_name=Models.bert)
class BertForTextRanking(BertForSequenceClassification):

def __init__(self, config, **kwargs):
def __init__(self, config, *args, **kwargs):
super().__init__(config)
self.train_batch_size = kwargs.get('train_batch_size', 4)
neg_sample = kwargs.get('neg_sample', 8)
self.neg_sample = neg_sample
setattr(self, self.base_model_prefix,
BertModel(self.config, add_pooling_layer=True))
self.register_buffer(
'target_label',
torch.zeros(self.train_batch_size, dtype=torch.long))

def forward(self,
input_ids=None,
@@ -55,9 +53,12 @@ class BertForTextRanking(BertForSequenceClassification):
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
if self.base_model.training:
scores = logits.view(self.train_batch_size, -1)
scores = logits.view(-1, self.neg_sample + 1)
batch_size = scores.size(0)
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(scores, self.target_label)
target_label = torch.zeros(
batch_size, dtype=torch.long, device=scores.device)
loss = loss_fct(scores, target_label)
return AttentionTextClassificationModelOutput(
loss=loss,
logits=logits,
@@ -78,9 +79,11 @@ class BertForTextRanking(BertForSequenceClassification):
Returns:
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""

num_labels = kwargs.get('num_labels', 1)
neg_sample = kwargs.get('neg_sample', 4)
model_args = {} if num_labels is None else {'num_labels': num_labels}
if neg_sample is not None:
model_args['neg_sample'] = neg_sample

model_dir = kwargs.get('model_dir')
model = super(Model, cls).from_pretrained(


+ 1
- 2
modelscope/msdatasets/task_datasets/text_ranking_dataset.py View File

@@ -39,8 +39,7 @@ class TextRankingDataset(TorchTaskDataset):
['title', 'text'])
self.qid_field = self.dataset_config.get('qid_field', 'query_id')
if mode == ModeKeys.TRAIN:
train_config = kwargs.get('train', {})
self.neg_samples = train_config.get('neg_samples', 4)
self.neg_samples = self.dataset_config.get('neg_sample', 4)

super().__init__(datasets, mode, preprocessor, **kwargs)



+ 5
- 5
tests/trainers/test_finetune_text_ranking.py View File

@@ -63,6 +63,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
def test_finetune_msmarco(self):

def cfg_modify_fn(cfg):
neg_sample = 4
cfg.task = 'text-ranking'
cfg['preprocessor'] = {'type': 'text-ranking'}
cfg.train.optimizer.lr = 2e-5
@@ -73,7 +74,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
'pos_sequence': 'positive_passages',
'neg_sequence': 'negative_passages',
'text_fileds': ['title', 'text'],
'qid_field': 'query_id'
'qid_field': 'query_id',
'neg_sample': neg_sample
},
'val': {
'type': 'bert',
@@ -84,7 +86,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
'qid_field': 'query_id'
},
}
cfg['train']['neg_samples'] = 4
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30
cfg.train.max_epochs = 1
cfg.train.train_batch_size = 4
@@ -96,6 +97,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
'by_epoch': False
}
}
cfg.model['neg_sample'] = 4
cfg.train.hooks = [{
'type': 'CheckpointHook',
'interval': 1
@@ -151,7 +153,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase):
'qid_field': 'query_id'
},
}
cfg['train']['neg_samples'] = 4
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30
cfg.train.max_epochs = 1
cfg.train.train_batch_size = 4
@@ -180,9 +181,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase):

# load dataset
ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull')
train_ds = ds['train'].to_hf_dataset()
train_ds = ds['train'].to_hf_dataset().shard(1000, index=0)
dev_ds = ds['dev'].to_hf_dataset()

model_id = 'damo/nlp_rom_passage-ranking_chinese-base'
self.finetune(
model_id=model_id,


Loading…
Cancel
Save