diff --git a/modelscope/pipelines/nlp/translation_pipeline.py b/modelscope/pipelines/nlp/translation_pipeline.py index b9b74ce4..e4893577 100644 --- a/modelscope/pipelines/nlp/translation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_pipeline.py @@ -1,8 +1,11 @@ import os.path as osp from typing import Any, Dict +import jieba import numpy as np import tensorflow as tf +from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer +from subword_nmt import apply_bpe from modelscope.metainfo import Pipelines from modelscope.models.base import Model @@ -59,6 +62,21 @@ class TranslationPipeline(Pipeline): dtype=tf.int64, shape=[None, None], name='input_wids') self.output = {} + # preprocess + self._src_lang = self.cfg['preprocessor']['src_lang'] + self._tgt_lang = self.cfg['preprocessor']['tgt_lang'] + self._src_bpe_path = osp.join( + model, self.cfg['preprocessor']['src_bpe']['file']) + + if self._src_lang == 'zh': + self._tok = jieba + else: + self._punct_normalizer = MosesPunctNormalizer(lang=self._src_lang) + self._tok = MosesTokenizer(lang=self._src_lang) + self._detok = MosesDetokenizer(lang=self._tgt_lang) + + self._bpe = apply_bpe.BPE(open(self._src_bpe_path)) + # model output = self.model(self.input_wids) self.output.update(output) @@ -70,10 +88,19 @@ class TranslationPipeline(Pipeline): model_loader.restore(sess, model_path) def preprocess(self, input: str) -> Dict[str, Any]: + if self._src_lang == 'zh': + input_tok = self._tok.cut(input) + input_tok = ' '.join(list(input_tok)) + else: + input = self._punct_normalizer.normalize(input) + input_tok = self._tok.tokenize( + input, return_str=True, aggressive_dash_splits=True) + + input_bpe = self._bpe.process_line(input_tok) input_ids = np.array([[ self._src_vocab[w] if w in self._src_vocab else self.cfg['model']['src_vocab_size'] - for w in input.strip().split() + for w in input_bpe.strip().split() ]]) result = {'input_ids': input_ids} return result @@ -92,5 +119,6 @@ class TranslationPipeline(Pipeline): self._trg_rvocab[wid] if wid in self._trg_rvocab else '' for wid in wids ]).replace('@@ ', '').replace('@@', '') + translation_out = self._detok.detokenize(translation_out.split()) result = {OutputKeys.TRANSLATION: translation_out} return result diff --git a/modelscope/trainers/nlp/csanmt_translation_trainer.py b/modelscope/trainers/nlp/csanmt_translation_trainer.py index 067c1d83..62ae91a8 100644 --- a/modelscope/trainers/nlp/csanmt_translation_trainer.py +++ b/modelscope/trainers/nlp/csanmt_translation_trainer.py @@ -241,8 +241,10 @@ def input_fn(src_file, trg_dataset = tf.data.TextLineDataset(trg_file) src_trg_dataset = tf.data.Dataset.zip((src_dataset, trg_dataset)) src_trg_dataset = src_trg_dataset.map( - lambda src, trg: - (tf.string_split([src]).values, tf.string_split([trg]).values), + lambda src, trg: (tf.string_split([src]), tf.string_split([trg])), + num_parallel_calls=10).prefetch(1000000) + src_trg_dataset = src_trg_dataset.map( + lambda src, trg: (src.values, trg.values), num_parallel_calls=10).prefetch(1000000) src_trg_dataset = src_trg_dataset.map( lambda src, trg: (src_vocab.lookup(src), trg_vocab.lookup(trg)), diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 6bd56aff..ada4fc50 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -1,11 +1,14 @@ en_core_web_sm>=2.3.5 fairseq>=0.10.2 +jieba>=0.42.1 pai-easynlp # rough-score was just recently updated from 0.0.4 to 0.0.7 # which introduced compatability issues that are being investigated rouge_score<=0.0.4 +sacremoses>=0.0.41 seqeval spacy>=2.3.5 +subword_nmt>=0.3.8 text2sql_lgesql tokenizers transformers>=4.12.0 diff --git a/tests/pipelines/test_csanmt_translation.py b/tests/pipelines/test_csanmt_translation.py index c852b1ff..bb6022ec 100644 --- a/tests/pipelines/test_csanmt_translation.py +++ b/tests/pipelines/test_csanmt_translation.py @@ -7,18 +7,26 @@ from modelscope.utils.test_utils import test_level class TranslationTest(unittest.TestCase): - model_id = 'damo/nlp_csanmt_translation_zh2en' - inputs = '声明 补充 说 , 沃伦 的 同事 都 深感 震惊 , 并且 希望 他 能够 投@@ 案@@ 自@@ 首 。' @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') - def test_run_with_model_name(self): - pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id) - print(pipeline_ins(input=self.inputs)) + def test_run_with_model_name_for_zh2en(self): + model_id = 'damo/nlp_csanmt_translation_zh2en' + inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' + pipeline_ins = pipeline(task=Tasks.translation, model=model_id) + print(pipeline_ins(input=inputs)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_name_for_en2zh(self): + model_id = 'damo/nlp_csanmt_translation_en2zh' + inputs = 'Elon Musk, co-founder and chief executive officer of Tesla Motors.' + pipeline_ins = pipeline(task=Tasks.translation, model=model_id) + print(pipeline_ins(input=inputs)) @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') def test_run_with_default_model(self): + inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' pipeline_ins = pipeline(task=Tasks.translation) - print(pipeline_ins(input=self.inputs)) + print(pipeline_ins(input=inputs)) if __name__ == '__main__':