Browse Source

[to #42322933] solve memory error for translation finetune

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10713843

    * [to #42322933] solve memory error for translation finetune
master
xiangpeng.wxp yingda.chen 2 years ago
parent
commit
d6ea41fb70
2 changed files with 50 additions and 35 deletions
  1. +41
    -32
      modelscope/trainers/nlp/csanmt_translation_trainer.py
  2. +9
    -3
      tests/trainers/test_translation_trainer.py

+ 41
- 32
modelscope/trainers/nlp/csanmt_translation_trainer.py View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os.path as osp
import time
from typing import Dict, Optional

import tensorflow as tf
@@ -122,8 +123,7 @@ class CsanmtTranslationTrainer(BaseTrainer):
self.params['scale_l1'] = self.cfg['train']['scale_l1']
self.params['scale_l2'] = self.cfg['train']['scale_l2']
self.params['train_max_len'] = self.cfg['train']['train_max_len']
self.params['max_training_steps'] = self.cfg['train'][
'max_training_steps']
self.params['num_of_epochs'] = self.cfg['train']['num_of_epochs']
self.params['save_checkpoints_steps'] = self.cfg['train'][
'save_checkpoints_steps']
self.params['num_of_samples'] = self.cfg['train']['num_of_samples']
@@ -144,14 +144,15 @@ class CsanmtTranslationTrainer(BaseTrainer):
vocab_src = osp.join(self.model_dir, self.params['vocab_src'])
vocab_trg = osp.join(self.model_dir, self.params['vocab_trg'])

epoch = 0
iteration = 0

with self._session.as_default() as tf_session:
while True:
iteration += 1
if iteration >= self.params['max_training_steps']:
epoch += 1
if epoch >= self.params['num_of_epochs']:
break
tf.logging.info('%s: Epoch %i' % (__name__, epoch))
train_input_fn = input_fn(
train_src,
train_trg,
@@ -160,36 +161,44 @@ class CsanmtTranslationTrainer(BaseTrainer):
batch_size_words=self.params['train_batch_size_words'],
max_len=self.params['train_max_len'],
num_gpus=self.params['num_gpus']
if self.params['num_gpus'] > 0 else 1,
if self.params['num_gpus'] > 1 else 1,
is_train=True,
session=tf_session,
iteration=iteration)
epoch=epoch)

features, labels = train_input_fn

features_batch, labels_batch = tf_session.run(
[features, labels])

feed_dict = {
self.source_wids: features_batch,
self.target_wids: labels_batch
}
sess_outputs = self._session.run(
self.output, feed_dict=feed_dict)
loss_step = sess_outputs['loss']
logger.info('Iteration: {}, step loss: {:.6f}'.format(
iteration, loss_step))

if iteration % self.params['save_checkpoints_steps'] == 0:
tf.logging.info('%s: Saving model on step: %d.' %
(__name__, iteration))
ck_path = self.model_dir + 'model.ckpt'
self.model_saver.save(
tf_session,
ck_path,
global_step=tf.train.get_global_step())

tf.logging.info('%s: NMT training completed at time: %s.')
try:
while True:
features_batch, labels_batch = tf_session.run(
[features, labels])
iteration += 1
feed_dict = {
self.source_wids: features_batch,
self.target_wids: labels_batch
}
sess_outputs = self._session.run(
self.output, feed_dict=feed_dict)
loss_step = sess_outputs['loss']
logger.info('Iteration: {}, step loss: {:.6f}'.format(
iteration, loss_step))

if iteration % self.params[
'save_checkpoints_steps'] == 0:
tf.logging.info('%s: Saving model on step: %d.' %
(__name__, iteration))
ck_path = self.model_dir + 'model.ckpt'
self.model_saver.save(
tf_session,
ck_path,
global_step=tf.train.get_global_step())

except tf.errors.OutOfRangeError:
tf.logging.info('epoch %d end!' % (epoch))

tf.logging.info(
'%s: NMT training completed at time: %s.' %
(__name__, time.asctime(time.localtime(time.time()))))

def evaluate(self,
checkpoint_path: Optional[str] = None,
@@ -222,7 +231,7 @@ def input_fn(src_file,
num_gpus=1,
is_train=True,
session=None,
iteration=None):
epoch=None):
src_vocab = tf.lookup.StaticVocabularyTable(
tf.lookup.TextFileInitializer(
src_vocab_file,
@@ -291,7 +300,7 @@ def input_fn(src_file,

if is_train:
session.run(iterator.initializer)
if iteration == 1:
if epoch == 1:
session.run(tf.tables_initializer())
return features, labels



+ 9
- 3
tests/trainers/test_translation_trainer.py View File

@@ -6,11 +6,17 @@ from modelscope.utils.test_utils import test_level


class TranslationTest(unittest.TestCase):
model_id = 'damo/nlp_csanmt_translation_zh2en'

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
trainer = CsanmtTranslationTrainer(model=self.model_id)
def test_run_with_model_name_for_en2zh(self):
model_id = 'damo/nlp_csanmt_translation_en2zh'
trainer = CsanmtTranslationTrainer(model=model_id)
trainer.train()

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name_for_en2fr(self):
model_id = 'damo/nlp_csanmt_translation_en2fr'
trainer = CsanmtTranslationTrainer(model=model_id)
trainer.train()




Loading…
Cancel
Save