You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dialog_intent_prediction.py 2.1 kB

3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
3 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.hub.snapshot_download import snapshot_download
  4. from modelscope.models import Model
  5. from modelscope.models.nlp import DialogIntentModel
  6. from modelscope.pipelines import DialogIntentPredictionPipeline, pipeline
  7. from modelscope.preprocessors import DialogIntentPredictionPreprocessor
  8. from modelscope.utils.constant import Tasks
  9. class DialogIntentPredictionTest(unittest.TestCase):
  10. model_id = 'damo/nlp_space_dialog-intent-prediction'
  11. test_case = [
  12. 'How do I locate my card?',
  13. 'I still have not received my new card, I ordered over a week ago.'
  14. ]
  15. @unittest.skip('test with snapshot_download')
  16. def test_run(self):
  17. cache_path = snapshot_download(self.model_id)
  18. preprocessor = DialogIntentPredictionPreprocessor(model_dir=cache_path)
  19. model = DialogIntentModel(
  20. model_dir=cache_path,
  21. text_field=preprocessor.text_field,
  22. config=preprocessor.config)
  23. pipelines = [
  24. DialogIntentPredictionPipeline(
  25. model=model, preprocessor=preprocessor),
  26. pipeline(
  27. task=Tasks.dialog_intent_prediction,
  28. model=model,
  29. preprocessor=preprocessor)
  30. ]
  31. for my_pipeline, item in list(zip(pipelines, self.test_case)):
  32. print(my_pipeline(item))
  33. def test_run_with_model_from_modelhub(self):
  34. model = Model.from_pretrained(self.model_id)
  35. preprocessor = DialogIntentPredictionPreprocessor(
  36. model_dir=model.model_dir)
  37. pipelines = [
  38. DialogIntentPredictionPipeline(
  39. model=model, preprocessor=preprocessor),
  40. pipeline(
  41. task=Tasks.dialog_intent_prediction,
  42. model=model,
  43. preprocessor=preprocessor)
  44. ]
  45. for my_pipeline, item in list(zip(pipelines, self.test_case)):
  46. print(my_pipeline(item))
  47. if __name__ == '__main__':
  48. unittest.main()