You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_sentiment_classification.py 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from maas_hub.snapshot_download import snapshot_download
  4. from modelscope.models import Model
  5. from modelscope.models.nlp import SbertForSentimentClassification
  6. from modelscope.pipelines import SentimentClassificationPipeline, pipeline
  7. from modelscope.preprocessors import SentimentClassificationPreprocessor
  8. from modelscope.utils.constant import Tasks
  9. class SentimentClassificationTest(unittest.TestCase):
  10. model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base'
  11. sentence1 = '启动的时候很大声音,然后就会听到1.2秒的卡察的声音,类似齿轮摩擦的声音'
  12. def test_run_from_local(self):
  13. cache_path = snapshot_download(self.model_id)
  14. tokenizer = SentimentClassificationPreprocessor(cache_path)
  15. model = SbertForSentimentClassification(
  16. cache_path, tokenizer=tokenizer)
  17. pipeline1 = SentimentClassificationPipeline(
  18. model, preprocessor=tokenizer)
  19. pipeline2 = pipeline(
  20. Tasks.sentiment_classification,
  21. model=model,
  22. preprocessor=tokenizer)
  23. print(f'sentence1: {self.sentence1}\n'
  24. f'pipeline1:{pipeline1(input=self.sentence1)}')
  25. print()
  26. print(f'sentence1: {self.sentence1}\n'
  27. f'pipeline1: {pipeline2(input=self.sentence1)}')
  28. def test_run_with_model_from_modelhub(self):
  29. model = Model.from_pretrained(self.model_id)
  30. tokenizer = SentimentClassificationPreprocessor(model.model_dir)
  31. pipeline_ins = pipeline(
  32. task=Tasks.sentiment_classification,
  33. model=model,
  34. preprocessor=tokenizer)
  35. print(pipeline_ins(input=self.sentence1))
  36. def test_run_with_model_name(self):
  37. pipeline_ins = pipeline(
  38. task=Tasks.sentiment_classification, model=self.model_id)
  39. print(pipeline_ins(input=self.sentence1))
  40. def test_run_with_default_model(self):
  41. pipeline_ins = pipeline(task=Tasks.sentiment_classification)
  42. print(pipeline_ins(input=self.sentence1))
  43. if __name__ == '__main__':
  44. unittest.main()