Browse Source

create the latest nlp_translation_finetune

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9662707
master
xiangpeng.wxp yingda.chen 3 years ago
parent
commit
d36002d6b9
4 changed files with 11 additions and 61 deletions
  1. +1
    -1
      modelscope/models/nlp/csanmt_for_translation.py
  2. +8
    -56
      modelscope/pipelines/nlp/translation_pipeline.py
  3. +1
    -3
      modelscope/trainers/nlp/csanmt_translation_trainer.py
  4. +1
    -1
      tests/pipelines/test_csanmt_translation.py

+ 1
- 1
modelscope/models/nlp/csanmt_for_translation.py View File

@@ -21,7 +21,7 @@ class CsanmtForTranslation(Model):
params (dict): the model configuration.
"""
super().__init__(model_dir, *args, **kwargs)
self.params = kwargs['params']
self.params = kwargs

def __call__(self,
input: Dict[str, Tensor],


+ 8
- 56
modelscope/pipelines/nlp/translation_pipeline.py View File

@@ -1,5 +1,4 @@
import os.path as osp
from threading import Lock
from typing import Any, Dict

import numpy as np
@@ -10,7 +9,7 @@ from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.config import Config
from modelscope.utils.constant import Frameworks, ModelFile, Tasks
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger

if tf.__version__ >= '2.0':
@@ -27,25 +26,22 @@ __all__ = ['TranslationPipeline']
class TranslationPipeline(Pipeline):

def __init__(self, model: str, **kwargs):
tf.reset_default_graph()
self.framework = Frameworks.tf
self.device_name = 'cpu'

super().__init__(model=model)
model = self.model.model_dir
tf.reset_default_graph()

model_path = osp.join(
osp.join(model, ModelFile.TF_CHECKPOINT_FOLDER), 'ckpt-0')

self.cfg = Config.from_file(osp.join(model, ModelFile.CONFIGURATION))

self.params = {}
self._override_params_from_file()

self._src_vocab_path = osp.join(model, self.params['vocab_src'])
self._src_vocab_path = osp.join(
model, self.cfg['dataset']['src_vocab']['file'])
self._src_vocab = dict([
(w.strip(), i) for i, w in enumerate(open(self._src_vocab_path))
])
self._trg_vocab_path = osp.join(model, self.params['vocab_trg'])
self._trg_vocab_path = osp.join(
model, self.cfg['dataset']['trg_vocab']['file'])
self._trg_rvocab = dict([
(i, w.strip()) for i, w in enumerate(open(self._trg_vocab_path))
])
@@ -59,7 +55,6 @@ class TranslationPipeline(Pipeline):
self.output = {}

# model
self.model = CsanmtForTranslation(model_path, params=self.params)
output = self.model(self.input_wids)
self.output.update(output)

@@ -69,53 +64,10 @@ class TranslationPipeline(Pipeline):
model_loader = tf.train.Saver(tf.global_variables())
model_loader.restore(sess, model_path)

def _override_params_from_file(self):

# model
self.params['hidden_size'] = self.cfg['model']['hidden_size']
self.params['filter_size'] = self.cfg['model']['filter_size']
self.params['num_heads'] = self.cfg['model']['num_heads']
self.params['num_encoder_layers'] = self.cfg['model'][
'num_encoder_layers']
self.params['num_decoder_layers'] = self.cfg['model'][
'num_decoder_layers']
self.params['layer_preproc'] = self.cfg['model']['layer_preproc']
self.params['layer_postproc'] = self.cfg['model']['layer_postproc']
self.params['shared_embedding_and_softmax_weights'] = self.cfg[
'model']['shared_embedding_and_softmax_weights']
self.params['shared_source_target_embedding'] = self.cfg['model'][
'shared_source_target_embedding']
self.params['initializer_scale'] = self.cfg['model'][
'initializer_scale']
self.params['position_info_type'] = self.cfg['model'][
'position_info_type']
self.params['max_relative_dis'] = self.cfg['model']['max_relative_dis']
self.params['num_semantic_encoder_layers'] = self.cfg['model'][
'num_semantic_encoder_layers']
self.params['src_vocab_size'] = self.cfg['model']['src_vocab_size']
self.params['trg_vocab_size'] = self.cfg['model']['trg_vocab_size']
self.params['attention_dropout'] = 0.0
self.params['residual_dropout'] = 0.0
self.params['relu_dropout'] = 0.0

# dataset
self.params['vocab_src'] = self.cfg['dataset']['src_vocab']['file']
self.params['vocab_trg'] = self.cfg['dataset']['trg_vocab']['file']

# train
self.params['train_max_len'] = self.cfg['train']['train_max_len']
self.params['confidence'] = self.cfg['train']['confidence']

# evaluation
self.params['beam_size'] = self.cfg['evaluation']['beam_size']
self.params['lp_rate'] = self.cfg['evaluation']['lp_rate']
self.params['max_decoded_trg_len'] = self.cfg['evaluation'][
'max_decoded_trg_len']

def preprocess(self, input: str) -> Dict[str, Any]:
input_ids = np.array([[
self._src_vocab[w]
if w in self._src_vocab else self.params['src_vocab_size']
if w in self._src_vocab else self.cfg['model']['src_vocab_size']
for w in input.strip().split()
]])
result = {'input_ids': input_ids}


+ 1
- 3
modelscope/trainers/nlp/csanmt_translation_trainer.py View File

@@ -47,7 +47,7 @@ class CsanmtTranslationTrainer(BaseTrainer):

self.global_step = tf.train.create_global_step()

self.model = CsanmtForTranslation(self.model_path, params=self.params)
self.model = CsanmtForTranslation(self.model_path, **self.params)
output = self.model(input=self.source_wids, label=self.target_wids)
self.output.update(output)

@@ -319,6 +319,4 @@ def get_pretrained_variables_map(checkpoint_file_path, ignore_scope=None):
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
restore_map[saved_var_name] = curr_var
tf.logging.info('Restore paramter %s from %s ...' %
(saved_var_name, checkpoint_file_path))
return restore_map

+ 1
- 1
tests/pipelines/test_csanmt_translation.py View File

@@ -10,7 +10,7 @@ class TranslationTest(unittest.TestCase):
model_id = 'damo/nlp_csanmt_translation_zh2en'
inputs = '声明 补充 说 , 沃伦 的 同事 都 深感 震惊 , 并且 希望 他 能够 投@@ 案@@ 自@@ 首 。'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id)
print(pipeline_ins(input=self.inputs))


Loading…
Cancel
Save