之前的finetune代码当dataset最后长度不足制定batch size时会出错,现已修正 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10524066master
@@ -18,14 +18,12 @@ logger = logging.get_logger(__name__) | |||
@MODELS.register_module(Tasks.text_ranking, module_name=Models.bert) | |||
class BertForTextRanking(BertForSequenceClassification): | |||
def __init__(self, config, **kwargs): | |||
def __init__(self, config, *args, **kwargs): | |||
super().__init__(config) | |||
self.train_batch_size = kwargs.get('train_batch_size', 4) | |||
neg_sample = kwargs.get('neg_sample', 8) | |||
self.neg_sample = neg_sample | |||
setattr(self, self.base_model_prefix, | |||
BertModel(self.config, add_pooling_layer=True)) | |||
self.register_buffer( | |||
'target_label', | |||
torch.zeros(self.train_batch_size, dtype=torch.long)) | |||
def forward(self, | |||
input_ids=None, | |||
@@ -55,9 +53,12 @@ class BertForTextRanking(BertForSequenceClassification): | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
if self.base_model.training: | |||
scores = logits.view(self.train_batch_size, -1) | |||
scores = logits.view(-1, self.neg_sample + 1) | |||
batch_size = scores.size(0) | |||
loss_fct = torch.nn.CrossEntropyLoss() | |||
loss = loss_fct(scores, self.target_label) | |||
target_label = torch.zeros( | |||
batch_size, dtype=torch.long, device=scores.device) | |||
loss = loss_fct(scores, target_label) | |||
return AttentionTextClassificationModelOutput( | |||
loss=loss, | |||
logits=logits, | |||
@@ -78,9 +79,11 @@ class BertForTextRanking(BertForSequenceClassification): | |||
Returns: | |||
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained | |||
""" | |||
num_labels = kwargs.get('num_labels', 1) | |||
neg_sample = kwargs.get('neg_sample', 4) | |||
model_args = {} if num_labels is None else {'num_labels': num_labels} | |||
if neg_sample is not None: | |||
model_args['neg_sample'] = neg_sample | |||
model_dir = kwargs.get('model_dir') | |||
model = super(Model, cls).from_pretrained( | |||
@@ -39,8 +39,7 @@ class TextRankingDataset(TorchTaskDataset): | |||
['title', 'text']) | |||
self.qid_field = self.dataset_config.get('qid_field', 'query_id') | |||
if mode == ModeKeys.TRAIN: | |||
train_config = kwargs.get('train', {}) | |||
self.neg_samples = train_config.get('neg_samples', 4) | |||
self.neg_samples = self.dataset_config.get('neg_sample', 4) | |||
super().__init__(datasets, mode, preprocessor, **kwargs) | |||
@@ -63,6 +63,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
def test_finetune_msmarco(self): | |||
def cfg_modify_fn(cfg): | |||
neg_sample = 4 | |||
cfg.task = 'text-ranking' | |||
cfg['preprocessor'] = {'type': 'text-ranking'} | |||
cfg.train.optimizer.lr = 2e-5 | |||
@@ -73,7 +74,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
'pos_sequence': 'positive_passages', | |||
'neg_sequence': 'negative_passages', | |||
'text_fileds': ['title', 'text'], | |||
'qid_field': 'query_id' | |||
'qid_field': 'query_id', | |||
'neg_sample': neg_sample | |||
}, | |||
'val': { | |||
'type': 'bert', | |||
@@ -84,7 +86,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
'qid_field': 'query_id' | |||
}, | |||
} | |||
cfg['train']['neg_samples'] = 4 | |||
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | |||
cfg.train.max_epochs = 1 | |||
cfg.train.train_batch_size = 4 | |||
@@ -96,6 +97,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
'by_epoch': False | |||
} | |||
} | |||
cfg.model['neg_sample'] = 4 | |||
cfg.train.hooks = [{ | |||
'type': 'CheckpointHook', | |||
'interval': 1 | |||
@@ -151,7 +153,6 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
'qid_field': 'query_id' | |||
}, | |||
} | |||
cfg['train']['neg_samples'] = 4 | |||
cfg['evaluation']['dataloader']['batch_size_per_gpu'] = 30 | |||
cfg.train.max_epochs = 1 | |||
cfg.train.train_batch_size = 4 | |||
@@ -180,9 +181,8 @@ class TestFinetuneSequenceClassification(unittest.TestCase): | |||
# load dataset | |||
ds = MsDataset.load('dureader-retrieval-ranking', 'zyznull') | |||
train_ds = ds['train'].to_hf_dataset() | |||
train_ds = ds['train'].to_hf_dataset().shard(1000, index=0) | |||
dev_ds = ds['dev'].to_hf_dataset() | |||
model_id = 'damo/nlp_rom_passage-ranking_chinese-base' | |||
self.finetune( | |||
model_id=model_id, | |||