You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_faq_question_answering.py 3.3 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. import numpy as np
  4. from modelscope.hub.api import HubApi
  5. from modelscope.hub.snapshot_download import snapshot_download
  6. from modelscope.models import Model
  7. from modelscope.models.nlp import SbertForFaqQuestionAnswering
  8. from modelscope.pipelines import pipeline
  9. from modelscope.pipelines.nlp import FaqQuestionAnsweringPipeline
  10. from modelscope.preprocessors import FaqQuestionAnsweringPreprocessor
  11. from modelscope.utils.constant import Tasks
  12. from modelscope.utils.test_utils import test_level
  13. class FaqQuestionAnsweringTest(unittest.TestCase):
  14. model_id = 'damo/nlp_structbert_faq-question-answering_chinese-base'
  15. param = {
  16. 'query_set': ['如何使用优惠券', '在哪里领券', '在哪里领券'],
  17. 'support_set': [{
  18. 'text': '卖品代金券怎么用',
  19. 'label': '6527856'
  20. }, {
  21. 'text': '怎么使用优惠券',
  22. 'label': '6527856'
  23. }, {
  24. 'text': '这个可以一起领吗',
  25. 'label': '1000012000'
  26. }, {
  27. 'text': '付款时送的优惠券哪里领',
  28. 'label': '1000012000'
  29. }, {
  30. 'text': '购物等级怎么长',
  31. 'label': '13421097'
  32. }, {
  33. 'text': '购物等级二心',
  34. 'label': '13421097'
  35. }]
  36. }
  37. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  38. def test_run_with_direct_file_download(self):
  39. cache_path = snapshot_download(self.model_id)
  40. preprocessor = FaqQuestionAnsweringPreprocessor(cache_path)
  41. model = SbertForFaqQuestionAnswering(cache_path)
  42. model.load_checkpoint(cache_path)
  43. pipeline_ins = FaqQuestionAnsweringPipeline(
  44. model, preprocessor=preprocessor)
  45. result = pipeline_ins(self.param)
  46. print(result)
  47. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  48. def test_run_with_model_from_modelhub(self):
  49. model = Model.from_pretrained(self.model_id)
  50. preprocessor = FaqQuestionAnsweringPreprocessor(model.model_dir)
  51. pipeline_ins = pipeline(
  52. task=Tasks.faq_question_answering,
  53. model=model,
  54. preprocessor=preprocessor)
  55. result = pipeline_ins(self.param)
  56. print(result)
  57. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  58. def test_run_with_model_name(self):
  59. pipeline_ins = pipeline(
  60. task=Tasks.faq_question_answering, model=self.model_id)
  61. result = pipeline_ins(self.param)
  62. print(result)
  63. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  64. def test_run_with_default_model(self):
  65. pipeline_ins = pipeline(task=Tasks.faq_question_answering)
  66. print(pipeline_ins(self.param, max_seq_length=20))
  67. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  68. def test_sentence_embedding(self):
  69. pipeline_ins = pipeline(task=Tasks.faq_question_answering)
  70. sentence_vec = pipeline_ins.get_sentence_embedding(
  71. ['今天星期六', '明天星期几明天星期几'])
  72. print(np.shape(sentence_vec))
  73. if __name__ == '__main__':
  74. unittest.main()