diff --git a/modelscope/trainers/nlp/csanmt_translation_trainer.py b/modelscope/trainers/nlp/csanmt_translation_trainer.py index c93599c7..08a3a351 100644 --- a/modelscope/trainers/nlp/csanmt_translation_trainer.py +++ b/modelscope/trainers/nlp/csanmt_translation_trainer.py @@ -1,6 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp +import time from typing import Dict, Optional import tensorflow as tf @@ -122,8 +123,7 @@ class CsanmtTranslationTrainer(BaseTrainer): self.params['scale_l1'] = self.cfg['train']['scale_l1'] self.params['scale_l2'] = self.cfg['train']['scale_l2'] self.params['train_max_len'] = self.cfg['train']['train_max_len'] - self.params['max_training_steps'] = self.cfg['train'][ - 'max_training_steps'] + self.params['num_of_epochs'] = self.cfg['train']['num_of_epochs'] self.params['save_checkpoints_steps'] = self.cfg['train'][ 'save_checkpoints_steps'] self.params['num_of_samples'] = self.cfg['train']['num_of_samples'] @@ -144,14 +144,15 @@ class CsanmtTranslationTrainer(BaseTrainer): vocab_src = osp.join(self.model_dir, self.params['vocab_src']) vocab_trg = osp.join(self.model_dir, self.params['vocab_trg']) + epoch = 0 iteration = 0 with self._session.as_default() as tf_session: while True: - iteration += 1 - if iteration >= self.params['max_training_steps']: + epoch += 1 + if epoch >= self.params['num_of_epochs']: break - + tf.logging.info('%s: Epoch %i' % (__name__, epoch)) train_input_fn = input_fn( train_src, train_trg, @@ -160,36 +161,44 @@ class CsanmtTranslationTrainer(BaseTrainer): batch_size_words=self.params['train_batch_size_words'], max_len=self.params['train_max_len'], num_gpus=self.params['num_gpus'] - if self.params['num_gpus'] > 0 else 1, + if self.params['num_gpus'] > 1 else 1, is_train=True, session=tf_session, - iteration=iteration) + epoch=epoch) features, labels = train_input_fn - features_batch, labels_batch = tf_session.run( - [features, labels]) - - feed_dict = { - self.source_wids: features_batch, - self.target_wids: labels_batch - } - sess_outputs = self._session.run( - self.output, feed_dict=feed_dict) - loss_step = sess_outputs['loss'] - logger.info('Iteration: {}, step loss: {:.6f}'.format( - iteration, loss_step)) - - if iteration % self.params['save_checkpoints_steps'] == 0: - tf.logging.info('%s: Saving model on step: %d.' % - (__name__, iteration)) - ck_path = self.model_dir + 'model.ckpt' - self.model_saver.save( - tf_session, - ck_path, - global_step=tf.train.get_global_step()) - - tf.logging.info('%s: NMT training completed at time: %s.') + try: + while True: + features_batch, labels_batch = tf_session.run( + [features, labels]) + iteration += 1 + feed_dict = { + self.source_wids: features_batch, + self.target_wids: labels_batch + } + sess_outputs = self._session.run( + self.output, feed_dict=feed_dict) + loss_step = sess_outputs['loss'] + logger.info('Iteration: {}, step loss: {:.6f}'.format( + iteration, loss_step)) + + if iteration % self.params[ + 'save_checkpoints_steps'] == 0: + tf.logging.info('%s: Saving model on step: %d.' % + (__name__, iteration)) + ck_path = self.model_dir + 'model.ckpt' + self.model_saver.save( + tf_session, + ck_path, + global_step=tf.train.get_global_step()) + + except tf.errors.OutOfRangeError: + tf.logging.info('epoch %d end!' % (epoch)) + + tf.logging.info( + '%s: NMT training completed at time: %s.' % + (__name__, time.asctime(time.localtime(time.time())))) def evaluate(self, checkpoint_path: Optional[str] = None, @@ -222,7 +231,7 @@ def input_fn(src_file, num_gpus=1, is_train=True, session=None, - iteration=None): + epoch=None): src_vocab = tf.lookup.StaticVocabularyTable( tf.lookup.TextFileInitializer( src_vocab_file, @@ -291,7 +300,7 @@ def input_fn(src_file, if is_train: session.run(iterator.initializer) - if iteration == 1: + if epoch == 1: session.run(tf.tables_initializer()) return features, labels diff --git a/tests/trainers/test_translation_trainer.py b/tests/trainers/test_translation_trainer.py index 71bed241..7be23145 100644 --- a/tests/trainers/test_translation_trainer.py +++ b/tests/trainers/test_translation_trainer.py @@ -6,11 +6,17 @@ from modelscope.utils.test_utils import test_level class TranslationTest(unittest.TestCase): - model_id = 'damo/nlp_csanmt_translation_zh2en' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_model_name(self): - trainer = CsanmtTranslationTrainer(model=self.model_id) + def test_run_with_model_name_for_en2zh(self): + model_id = 'damo/nlp_csanmt_translation_en2zh' + trainer = CsanmtTranslationTrainer(model=model_id) + trainer.train() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2fr(self): + model_id = 'damo/nlp_csanmt_translation_en2fr' + trainer = CsanmtTranslationTrainer(model=model_id) trainer.train()