|
|
@@ -1,8 +1,11 @@ |
|
|
|
import os.path as osp |
|
|
|
from typing import Any, Dict |
|
|
|
|
|
|
|
import jieba |
|
|
|
import numpy as np |
|
|
|
import tensorflow as tf |
|
|
|
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer |
|
|
|
from subword_nmt import apply_bpe |
|
|
|
|
|
|
|
from modelscope.metainfo import Pipelines |
|
|
|
from modelscope.models.base import Model |
|
|
@@ -59,6 +62,21 @@ class TranslationPipeline(Pipeline): |
|
|
|
dtype=tf.int64, shape=[None, None], name='input_wids') |
|
|
|
self.output = {} |
|
|
|
|
|
|
|
# preprocess |
|
|
|
self._src_lang = self.cfg['preprocessor']['src_lang'] |
|
|
|
self._tgt_lang = self.cfg['preprocessor']['tgt_lang'] |
|
|
|
self._src_bpe_path = osp.join( |
|
|
|
model, self.cfg['preprocessor']['src_bpe']['file']) |
|
|
|
|
|
|
|
if self._src_lang == 'zh': |
|
|
|
self._tok = jieba |
|
|
|
else: |
|
|
|
self._punct_normalizer = MosesPunctNormalizer(lang=self._src_lang) |
|
|
|
self._tok = MosesTokenizer(lang=self._src_lang) |
|
|
|
self._detok = MosesDetokenizer(lang=self._tgt_lang) |
|
|
|
|
|
|
|
self._bpe = apply_bpe.BPE(open(self._src_bpe_path)) |
|
|
|
|
|
|
|
# model |
|
|
|
output = self.model(self.input_wids) |
|
|
|
self.output.update(output) |
|
|
@@ -70,10 +88,19 @@ class TranslationPipeline(Pipeline): |
|
|
|
model_loader.restore(sess, model_path) |
|
|
|
|
|
|
|
def preprocess(self, input: str) -> Dict[str, Any]: |
|
|
|
if self._src_lang == 'zh': |
|
|
|
input_tok = self._tok.cut(input) |
|
|
|
input_tok = ' '.join(list(input_tok)) |
|
|
|
else: |
|
|
|
input = self._punct_normalizer.normalize(input) |
|
|
|
input_tok = self._tok.tokenize( |
|
|
|
input, return_str=True, aggressive_dash_splits=True) |
|
|
|
|
|
|
|
input_bpe = self._bpe.process_line(input_tok) |
|
|
|
input_ids = np.array([[ |
|
|
|
self._src_vocab[w] |
|
|
|
if w in self._src_vocab else self.cfg['model']['src_vocab_size'] |
|
|
|
for w in input.strip().split() |
|
|
|
for w in input_bpe.strip().split() |
|
|
|
]]) |
|
|
|
result = {'input_ids': input_ids} |
|
|
|
return result |
|
|
@@ -92,5 +119,6 @@ class TranslationPipeline(Pipeline): |
|
|
|
self._trg_rvocab[wid] if wid in self._trg_rvocab else '<unk>' |
|
|
|
for wid in wids |
|
|
|
]).replace('@@ ', '').replace('@@', '') |
|
|
|
translation_out = self._detok.detokenize(translation_out.split()) |
|
|
|
result = {OutputKeys.TRANSLATION: translation_out} |
|
|
|
return result |