zhangzhicheng.zzc yingda.chen 3 years ago
parent
commit
06abae4dc6
2 changed files with 9 additions and 2 deletions
  1. +1
    -2
      modelscope/preprocessors/nlp/token_classification_preprocessor.py
  2. +8
    -0
      tests/pipelines/test_named_entity_recognition.py

+ 1
- 2
modelscope/preprocessors/nlp/token_classification_preprocessor.py View File

@@ -140,8 +140,7 @@ class TokenClassificationPreprocessor(NLPTokenizerPreprocessorBase):
label_mask.append(1)
offset_mapping.append(encodings['offset_mapping'][i])
else:
encodings = self.tokenizer(
text, add_special_tokens=False, **self.tokenize_kwargs)
encodings = self.tokenizer(text, **self.tokenize_kwargs)
input_ids = encodings['input_ids']
label_mask, offset_mapping = self.get_label_mask_and_offset_mapping(
text)


+ 8
- 0
tests/pipelines/test_named_entity_recognition.py View File

@@ -19,9 +19,11 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
self.task = Tasks.named_entity_recognition
self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'

english_model_id = 'damo/nlp_raner_named-entity-recognition_english-large-ecom'
tcrf_model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
lcrf_model_id = 'damo/nlp_lstm_named-entity-recognition_chinese-news'
sentence = '这与温岭市新河镇的一个神秘的传说有关。'
sentence_en = 'pizza shovel'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_tcrf_by_direct_model_download(self):
@@ -89,6 +91,12 @@ class NamedEntityRecognitionTest(unittest.TestCase, DemoCompatibilityCheck):
task=Tasks.named_entity_recognition, model=self.lcrf_model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_english_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.named_entity_recognition, model=self.english_model_id)
print(pipeline_ins(input='pizza shovel'))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.named_entity_recognition)


Loading…
Cancel
Save