* nlp translation preprocess branch * pull the latest master Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9920445master
@@ -1,8 +1,11 @@ | |||||
import os.path as osp | import os.path as osp | ||||
from typing import Any, Dict | from typing import Any, Dict | ||||
import jieba | |||||
import numpy as np | import numpy as np | ||||
import tensorflow as tf | import tensorflow as tf | ||||
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer | |||||
from subword_nmt import apply_bpe | |||||
from modelscope.metainfo import Pipelines | from modelscope.metainfo import Pipelines | ||||
from modelscope.models.base import Model | from modelscope.models.base import Model | ||||
@@ -59,6 +62,21 @@ class TranslationPipeline(Pipeline): | |||||
dtype=tf.int64, shape=[None, None], name='input_wids') | dtype=tf.int64, shape=[None, None], name='input_wids') | ||||
self.output = {} | self.output = {} | ||||
# preprocess | |||||
self._src_lang = self.cfg['preprocessor']['src_lang'] | |||||
self._tgt_lang = self.cfg['preprocessor']['tgt_lang'] | |||||
self._src_bpe_path = osp.join( | |||||
model, self.cfg['preprocessor']['src_bpe']['file']) | |||||
if self._src_lang == 'zh': | |||||
self._tok = jieba | |||||
else: | |||||
self._punct_normalizer = MosesPunctNormalizer(lang=self._src_lang) | |||||
self._tok = MosesTokenizer(lang=self._src_lang) | |||||
self._detok = MosesDetokenizer(lang=self._tgt_lang) | |||||
self._bpe = apply_bpe.BPE(open(self._src_bpe_path)) | |||||
# model | # model | ||||
output = self.model(self.input_wids) | output = self.model(self.input_wids) | ||||
self.output.update(output) | self.output.update(output) | ||||
@@ -70,10 +88,19 @@ class TranslationPipeline(Pipeline): | |||||
model_loader.restore(sess, model_path) | model_loader.restore(sess, model_path) | ||||
def preprocess(self, input: str) -> Dict[str, Any]: | def preprocess(self, input: str) -> Dict[str, Any]: | ||||
if self._src_lang == 'zh': | |||||
input_tok = self._tok.cut(input) | |||||
input_tok = ' '.join(list(input_tok)) | |||||
else: | |||||
input = self._punct_normalizer.normalize(input) | |||||
input_tok = self._tok.tokenize( | |||||
input, return_str=True, aggressive_dash_splits=True) | |||||
input_bpe = self._bpe.process_line(input_tok) | |||||
input_ids = np.array([[ | input_ids = np.array([[ | ||||
self._src_vocab[w] | self._src_vocab[w] | ||||
if w in self._src_vocab else self.cfg['model']['src_vocab_size'] | if w in self._src_vocab else self.cfg['model']['src_vocab_size'] | ||||
for w in input.strip().split() | |||||
for w in input_bpe.strip().split() | |||||
]]) | ]]) | ||||
result = {'input_ids': input_ids} | result = {'input_ids': input_ids} | ||||
return result | return result | ||||
@@ -92,5 +119,6 @@ class TranslationPipeline(Pipeline): | |||||
self._trg_rvocab[wid] if wid in self._trg_rvocab else '<unk>' | self._trg_rvocab[wid] if wid in self._trg_rvocab else '<unk>' | ||||
for wid in wids | for wid in wids | ||||
]).replace('@@ ', '').replace('@@', '') | ]).replace('@@ ', '').replace('@@', '') | ||||
translation_out = self._detok.detokenize(translation_out.split()) | |||||
result = {OutputKeys.TRANSLATION: translation_out} | result = {OutputKeys.TRANSLATION: translation_out} | ||||
return result | return result |
@@ -241,8 +241,10 @@ def input_fn(src_file, | |||||
trg_dataset = tf.data.TextLineDataset(trg_file) | trg_dataset = tf.data.TextLineDataset(trg_file) | ||||
src_trg_dataset = tf.data.Dataset.zip((src_dataset, trg_dataset)) | src_trg_dataset = tf.data.Dataset.zip((src_dataset, trg_dataset)) | ||||
src_trg_dataset = src_trg_dataset.map( | src_trg_dataset = src_trg_dataset.map( | ||||
lambda src, trg: | |||||
(tf.string_split([src]).values, tf.string_split([trg]).values), | |||||
lambda src, trg: (tf.string_split([src]), tf.string_split([trg])), | |||||
num_parallel_calls=10).prefetch(1000000) | |||||
src_trg_dataset = src_trg_dataset.map( | |||||
lambda src, trg: (src.values, trg.values), | |||||
num_parallel_calls=10).prefetch(1000000) | num_parallel_calls=10).prefetch(1000000) | ||||
src_trg_dataset = src_trg_dataset.map( | src_trg_dataset = src_trg_dataset.map( | ||||
lambda src, trg: (src_vocab.lookup(src), trg_vocab.lookup(trg)), | lambda src, trg: (src_vocab.lookup(src), trg_vocab.lookup(trg)), | ||||
@@ -1,11 +1,14 @@ | |||||
en_core_web_sm>=2.3.5 | en_core_web_sm>=2.3.5 | ||||
fairseq>=0.10.2 | fairseq>=0.10.2 | ||||
jieba>=0.42.1 | |||||
pai-easynlp | pai-easynlp | ||||
# rough-score was just recently updated from 0.0.4 to 0.0.7 | # rough-score was just recently updated from 0.0.4 to 0.0.7 | ||||
# which introduced compatability issues that are being investigated | # which introduced compatability issues that are being investigated | ||||
rouge_score<=0.0.4 | rouge_score<=0.0.4 | ||||
sacremoses>=0.0.41 | |||||
seqeval | seqeval | ||||
spacy>=2.3.5 | spacy>=2.3.5 | ||||
subword_nmt>=0.3.8 | |||||
text2sql_lgesql | text2sql_lgesql | ||||
tokenizers | tokenizers | ||||
transformers>=4.12.0 | transformers>=4.12.0 |
@@ -7,18 +7,26 @@ from modelscope.utils.test_utils import test_level | |||||
class TranslationTest(unittest.TestCase): | class TranslationTest(unittest.TestCase): | ||||
model_id = 'damo/nlp_csanmt_translation_zh2en' | |||||
inputs = '声明 补充 说 , 沃伦 的 同事 都 深感 震惊 , 并且 希望 他 能够 投@@ 案@@ 自@@ 首 。' | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
def test_run_with_model_name(self): | |||||
pipeline_ins = pipeline(task=Tasks.translation, model=self.model_id) | |||||
print(pipeline_ins(input=self.inputs)) | |||||
def test_run_with_model_name_for_zh2en(self): | |||||
model_id = 'damo/nlp_csanmt_translation_zh2en' | |||||
inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' | |||||
pipeline_ins = pipeline(task=Tasks.translation, model=model_id) | |||||
print(pipeline_ins(input=inputs)) | |||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
def test_run_with_model_name_for_en2zh(self): | |||||
model_id = 'damo/nlp_csanmt_translation_en2zh' | |||||
inputs = 'Elon Musk, co-founder and chief executive officer of Tesla Motors.' | |||||
pipeline_ins = pipeline(task=Tasks.translation, model=model_id) | |||||
print(pipeline_ins(input=inputs)) | |||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | ||||
def test_run_with_default_model(self): | def test_run_with_default_model(self): | ||||
inputs = '声明补充说,沃伦的同事都深感震惊,并且希望他能够投案自首。' | |||||
pipeline_ins = pipeline(task=Tasks.translation) | pipeline_ins = pipeline(task=Tasks.translation) | ||||
print(pipeline_ins(input=self.inputs)) | |||||
print(pipeline_ins(input=inputs)) | |||||
if __name__ == '__main__': | if __name__ == '__main__': | ||||