You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fill_mask_ponet.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.metainfo import Pipelines
  4. from modelscope.pipelines import pipeline
  5. from modelscope.utils.constant import Tasks
  6. from modelscope.utils.test_utils import test_level
  7. class FillMaskPonetTest(unittest.TestCase):
  8. model_id_ponet = {
  9. 'zh': 'damo/nlp_ponet_fill-mask_chinese-base',
  10. 'en': 'damo/nlp_ponet_fill-mask_english-base'
  11. }
  12. ori_texts = {
  13. 'zh':
  14. '段誉轻挥折扇,摇了摇头,说道:“你师父是你的师父,你师父可不是我的师父。'
  15. '你师父差得动你,你师父可差不动我。',
  16. 'en':
  17. 'Everything in what you call reality is really just a reflection of your '
  18. 'consciousness. Your whole universe is just a mirror reflection of your story.'
  19. }
  20. test_inputs = {
  21. 'zh':
  22. '段誉轻[MASK]折扇,摇了摇[MASK],[MASK]道:“你师父是你的[MASK][MASK],你'
  23. '师父可不是[MASK]的师父。你师父差得动你,你师父可[MASK]不动我。',
  24. 'en':
  25. 'Everything in [MASK] you call reality is really [MASK] a reflection of your '
  26. '[MASK]. Your [MASK] universe is just a mirror [MASK] of your story.'
  27. }
  28. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  29. def test_run_with_ponet_model(self):
  30. for language in ['zh', 'en']:
  31. ori_text = self.ori_texts[language]
  32. test_input = self.test_inputs[language]
  33. pipeline_ins = pipeline(
  34. task=Tasks.fill_mask, model=self.model_id_ponet[language])
  35. print(f'\nori_text: {ori_text}\ninput: {test_input}\npipeline: '
  36. f'{pipeline_ins(test_input)}\n')
  37. if __name__ == '__main__':
  38. unittest.main()