Browse Source

update rdrop

master
行嗔 2 years ago
parent
commit
cc8b78eac8
3 changed files with 5 additions and 6 deletions
  1. +1
    -1
      modelscope/trainers/multi_modal/ofa/ofa_trainer.py
  2. +1
    -1
      modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py
  3. +3
    -4
      tests/trainers/test_ofa_trainer.py

+ 1
- 1
modelscope/trainers/multi_modal/ofa/ofa_trainer.py View File

@@ -131,7 +131,7 @@ class OFATrainer(EpochBasedTrainer):
model.train()
# model_outputs = model.forward(inputs)
loss, sample_size, logging_output = self.criterion(model, inputs)
train_outputs = {'loss': loss / 100}
train_outputs = {'loss': loss}
# add model output info to log
if 'log_vars' not in train_outputs:
default_keys_pattern = ['loss']


+ 1
- 1
modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py View File

@@ -144,7 +144,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
sample_size = (
sample['target'].size(0) if self.sentence_avg else ntokens)
logging_output = {
'loss': loss.data / 100,
'loss': loss.data,
'nll_loss': nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],


+ 3
- 4
tests/trainers/test_ofa_trainer.py View File

@@ -1,7 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
import os.path as osp
import shutil
import unittest

@@ -98,8 +96,9 @@ class TestOfaTrainer(unittest.TestCase):
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()

self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, 'output')))
self.assertIn(
ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))
shutil.rmtree(WORKSPACE)




Loading…
Cancel
Save