|
|
@@ -1,6 +1,7 @@ |
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates. |
|
|
|
|
|
|
|
import os.path as osp |
|
|
|
import time |
|
|
|
from typing import Dict, Optional |
|
|
|
|
|
|
|
import tensorflow as tf |
|
|
@@ -122,8 +123,7 @@ class CsanmtTranslationTrainer(BaseTrainer): |
|
|
|
self.params['scale_l1'] = self.cfg['train']['scale_l1'] |
|
|
|
self.params['scale_l2'] = self.cfg['train']['scale_l2'] |
|
|
|
self.params['train_max_len'] = self.cfg['train']['train_max_len'] |
|
|
|
self.params['max_training_steps'] = self.cfg['train'][ |
|
|
|
'max_training_steps'] |
|
|
|
self.params['num_of_epochs'] = self.cfg['train']['num_of_epochs'] |
|
|
|
self.params['save_checkpoints_steps'] = self.cfg['train'][ |
|
|
|
'save_checkpoints_steps'] |
|
|
|
self.params['num_of_samples'] = self.cfg['train']['num_of_samples'] |
|
|
@@ -144,14 +144,15 @@ class CsanmtTranslationTrainer(BaseTrainer): |
|
|
|
vocab_src = osp.join(self.model_dir, self.params['vocab_src']) |
|
|
|
vocab_trg = osp.join(self.model_dir, self.params['vocab_trg']) |
|
|
|
|
|
|
|
epoch = 0 |
|
|
|
iteration = 0 |
|
|
|
|
|
|
|
with self._session.as_default() as tf_session: |
|
|
|
while True: |
|
|
|
iteration += 1 |
|
|
|
if iteration >= self.params['max_training_steps']: |
|
|
|
epoch += 1 |
|
|
|
if epoch >= self.params['num_of_epochs']: |
|
|
|
break |
|
|
|
|
|
|
|
tf.logging.info('%s: Epoch %i' % (__name__, epoch)) |
|
|
|
train_input_fn = input_fn( |
|
|
|
train_src, |
|
|
|
train_trg, |
|
|
@@ -160,36 +161,44 @@ class CsanmtTranslationTrainer(BaseTrainer): |
|
|
|
batch_size_words=self.params['train_batch_size_words'], |
|
|
|
max_len=self.params['train_max_len'], |
|
|
|
num_gpus=self.params['num_gpus'] |
|
|
|
if self.params['num_gpus'] > 0 else 1, |
|
|
|
if self.params['num_gpus'] > 1 else 1, |
|
|
|
is_train=True, |
|
|
|
session=tf_session, |
|
|
|
iteration=iteration) |
|
|
|
epoch=epoch) |
|
|
|
|
|
|
|
features, labels = train_input_fn |
|
|
|
|
|
|
|
features_batch, labels_batch = tf_session.run( |
|
|
|
[features, labels]) |
|
|
|
|
|
|
|
feed_dict = { |
|
|
|
self.source_wids: features_batch, |
|
|
|
self.target_wids: labels_batch |
|
|
|
} |
|
|
|
sess_outputs = self._session.run( |
|
|
|
self.output, feed_dict=feed_dict) |
|
|
|
loss_step = sess_outputs['loss'] |
|
|
|
logger.info('Iteration: {}, step loss: {:.6f}'.format( |
|
|
|
iteration, loss_step)) |
|
|
|
|
|
|
|
if iteration % self.params['save_checkpoints_steps'] == 0: |
|
|
|
tf.logging.info('%s: Saving model on step: %d.' % |
|
|
|
(__name__, iteration)) |
|
|
|
ck_path = self.model_dir + 'model.ckpt' |
|
|
|
self.model_saver.save( |
|
|
|
tf_session, |
|
|
|
ck_path, |
|
|
|
global_step=tf.train.get_global_step()) |
|
|
|
|
|
|
|
tf.logging.info('%s: NMT training completed at time: %s.') |
|
|
|
try: |
|
|
|
while True: |
|
|
|
features_batch, labels_batch = tf_session.run( |
|
|
|
[features, labels]) |
|
|
|
iteration += 1 |
|
|
|
feed_dict = { |
|
|
|
self.source_wids: features_batch, |
|
|
|
self.target_wids: labels_batch |
|
|
|
} |
|
|
|
sess_outputs = self._session.run( |
|
|
|
self.output, feed_dict=feed_dict) |
|
|
|
loss_step = sess_outputs['loss'] |
|
|
|
logger.info('Iteration: {}, step loss: {:.6f}'.format( |
|
|
|
iteration, loss_step)) |
|
|
|
|
|
|
|
if iteration % self.params[ |
|
|
|
'save_checkpoints_steps'] == 0: |
|
|
|
tf.logging.info('%s: Saving model on step: %d.' % |
|
|
|
(__name__, iteration)) |
|
|
|
ck_path = self.model_dir + 'model.ckpt' |
|
|
|
self.model_saver.save( |
|
|
|
tf_session, |
|
|
|
ck_path, |
|
|
|
global_step=tf.train.get_global_step()) |
|
|
|
|
|
|
|
except tf.errors.OutOfRangeError: |
|
|
|
tf.logging.info('epoch %d end!' % (epoch)) |
|
|
|
|
|
|
|
tf.logging.info( |
|
|
|
'%s: NMT training completed at time: %s.' % |
|
|
|
(__name__, time.asctime(time.localtime(time.time())))) |
|
|
|
|
|
|
|
def evaluate(self, |
|
|
|
checkpoint_path: Optional[str] = None, |
|
|
@@ -222,7 +231,7 @@ def input_fn(src_file, |
|
|
|
num_gpus=1, |
|
|
|
is_train=True, |
|
|
|
session=None, |
|
|
|
iteration=None): |
|
|
|
epoch=None): |
|
|
|
src_vocab = tf.lookup.StaticVocabularyTable( |
|
|
|
tf.lookup.TextFileInitializer( |
|
|
|
src_vocab_file, |
|
|
@@ -291,7 +300,7 @@ def input_fn(src_file, |
|
|
|
|
|
|
|
if is_train: |
|
|
|
session.run(iterator.initializer) |
|
|
|
if iteration == 1: |
|
|
|
if epoch == 1: |
|
|
|
session.run(tf.tables_initializer()) |
|
|
|
return features, labels |
|
|
|
|
|
|
|