Browse Source

[to #42322933]support multi tasks-- will be failed, since configuration has not changed yet

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10492024
master
zhangzhicheng.zzc yingda.chen 2 years ago
parent
commit
2a87dee561
5 changed files with 14 additions and 5 deletions
  1. +2
    -0
      modelscope/models/nlp/heads/infromation_extraction_head.py
  2. +2
    -0
      modelscope/models/nlp/task_models/information_extraction.py
  3. +3
    -0
      modelscope/pipelines/builder.py
  4. +2
    -0
      modelscope/pipelines/nlp/information_extraction_pipeline.py
  5. +5
    -5
      tests/pipelines/test_relation_extraction.py

+ 2
- 0
modelscope/models/nlp/heads/infromation_extraction_head.py View File

@@ -10,6 +10,8 @@ from modelscope.utils.constant import Tasks

@HEADS.register_module(
Tasks.information_extraction, module_name=Heads.information_extraction)
@HEADS.register_module(
Tasks.relation_extraction, module_name=Heads.information_extraction)
class InformationExtractionHead(TorchHead):

def __init__(self, **kwargs):


+ 2
- 0
modelscope/models/nlp/task_models/information_extraction.py View File

@@ -16,6 +16,8 @@ __all__ = ['InformationExtractionModel']
@MODELS.register_module(
Tasks.information_extraction,
module_name=TaskModels.information_extraction)
@MODELS.register_module(
Tasks.relation_extraction, module_name=TaskModels.information_extraction)
class InformationExtractionModel(SingleBackboneTaskModelBase):

def __init__(self, model_dir: str, *args, **kwargs):


+ 3
- 0
modelscope/pipelines/builder.py View File

@@ -31,6 +31,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.named_entity_recognition:
(Pipelines.named_entity_recognition,
'damo/nlp_raner_named-entity-recognition_chinese-base-news'),
Tasks.relation_extraction:
(Pipelines.relation_extraction,
'damo/nlp_bert_relation-extraction_chinese-base'),
Tasks.information_extraction:
(Pipelines.relation_extraction,
'damo/nlp_bert_relation-extraction_chinese-base'),


+ 2
- 0
modelscope/pipelines/nlp/information_extraction_pipeline.py View File

@@ -17,6 +17,8 @@ __all__ = ['InformationExtractionPipeline']

@PIPELINES.register_module(
Tasks.information_extraction, module_name=Pipelines.relation_extraction)
@PIPELINES.register_module(
Tasks.relation_extraction, module_name=Pipelines.relation_extraction)
class InformationExtractionPipeline(Pipeline):

def __init__(self,


+ 5
- 5
tests/pipelines/test_relation_extraction.py View File

@@ -15,7 +15,7 @@ from modelscope.utils.test_utils import test_level
class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.information_extraction
self.task = Tasks.relation_extraction
self.model_id = 'damo/nlp_bert_relation-extraction_chinese-base'

sentence = '高捷,祖籍江苏,本科毕业于东南大学'
@@ -28,7 +28,7 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck):
pipeline1 = InformationExtractionPipeline(
model, preprocessor=tokenizer)
pipeline2 = pipeline(
Tasks.information_extraction, model=model, preprocessor=tokenizer)
Tasks.relation_extraction, model=model, preprocessor=tokenizer)
print(f'sentence: {self.sentence}\n'
f'pipeline1:{pipeline1(input=self.sentence)}')
print()
@@ -39,7 +39,7 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck):
model = Model.from_pretrained(self.model_id)
tokenizer = RelationExtractionPreprocessor(model.model_dir)
pipeline_ins = pipeline(
task=Tasks.information_extraction,
task=Tasks.relation_extraction,
model=model,
preprocessor=tokenizer)
print(pipeline_ins(input=self.sentence))
@@ -47,12 +47,12 @@ class RelationExtractionTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_name(self):
pipeline_ins = pipeline(
task=Tasks.information_extraction, model=self.model_id)
task=Tasks.relation_extraction, model=self.model_id)
print(pipeline_ins(input=self.sentence))

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.information_extraction)
pipeline_ins = pipeline(task=Tasks.relation_extraction)
print(pipeline_ins(input=self.sentence))

@unittest.skip('demo compatibility test is only enabled on a needed-basis')


Loading…
Cancel
Save