You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_nlp.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.preprocessors import build_preprocessor, nlp
  4. from modelscope.utils.constant import Fields, InputFields
  5. from modelscope.utils.logger import get_logger
  6. logger = get_logger()
  7. class NLPPreprocessorTest(unittest.TestCase):
  8. def test_tokenize(self):
  9. cfg = dict(type='Tokenize', tokenizer_name='bert-base-cased')
  10. preprocessor = build_preprocessor(cfg, Fields.nlp)
  11. input = {
  12. InputFields.text:
  13. 'Do not meddle in the affairs of wizards, '
  14. 'for they are subtle and quick to anger.'
  15. }
  16. output = preprocessor(input)
  17. self.assertTrue(InputFields.text in output)
  18. self.assertEqual(output['input_ids'], [
  19. 101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678, 1116,
  20. 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470, 119, 102
  21. ])
  22. self.assertEqual(
  23. output['token_type_ids'],
  24. [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
  25. self.assertEqual(
  26. output['attention_mask'],
  27. [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
  28. def test_token_classification_tokenize(self):
  29. with self.subTest(tokenizer_type='bert'):
  30. cfg = dict(
  31. type='token-cls-tokenizer',
  32. model_dir='bert-base-cased',
  33. label2id={
  34. 'O': 0,
  35. 'B': 1,
  36. 'I': 2
  37. })
  38. preprocessor = build_preprocessor(cfg, Fields.nlp)
  39. input = 'Do not meddle in the affairs of wizards, ' \
  40. 'for they are subtle and quick to anger.'
  41. output = preprocessor(input)
  42. self.assertTrue(InputFields.text in output)
  43. self.assertEqual(output['input_ids'].tolist()[0], [
  44. 101, 2091, 1136, 1143, 13002, 1107, 1103, 5707, 1104, 16678,
  45. 1116, 117, 1111, 1152, 1132, 11515, 1105, 3613, 1106, 4470,
  46. 119, 102
  47. ])
  48. self.assertEqual(output['attention_mask'].tolist()[0], [
  49. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  50. 1
  51. ])
  52. self.assertEqual(output['label_mask'].tolist()[0], [
  53. False, True, True, True, False, True, True, True, True, True,
  54. False, True, True, True, True, True, True, True, True, True,
  55. True, False
  56. ])
  57. self.assertEqual(output['offset_mapping'], [(0, 2), (3, 6),
  58. (7, 13), (14, 16),
  59. (17, 20), (21, 28),
  60. (29, 31), (32, 39),
  61. (39, 40), (41, 44),
  62. (45, 49), (50, 53),
  63. (54, 60), (61, 64),
  64. (65, 70), (71, 73),
  65. (74, 79), (79, 80)])
  66. with self.subTest(tokenizer_type='roberta'):
  67. cfg = dict(
  68. type='token-cls-tokenizer',
  69. model_dir='xlm-roberta-base',
  70. label2id={
  71. 'O': 0,
  72. 'B': 1,
  73. 'I': 2
  74. })
  75. preprocessor = build_preprocessor(cfg, Fields.nlp)
  76. input = 'Do not meddle in the affairs of wizards, ' \
  77. 'for they are subtle and quick to anger.'
  78. output = preprocessor(input)
  79. self.assertTrue(InputFields.text in output)
  80. self.assertEqual(output['input_ids'].tolist()[0], [
  81. 0, 984, 959, 128, 19298, 23, 70, 103086, 7, 111, 6, 44239,
  82. 99397, 4, 100, 1836, 621, 1614, 17991, 136, 63773, 47, 348, 56,
  83. 5, 2
  84. ])
  85. self.assertEqual(output['attention_mask'].tolist()[0], [
  86. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  87. 1, 1, 1, 1, 1
  88. ])
  89. self.assertEqual(output['label_mask'].tolist()[0], [
  90. False, True, True, True, False, True, True, True, False, True,
  91. True, False, False, False, True, True, True, True, False, True,
  92. True, True, True, False, False, False
  93. ])
  94. self.assertEqual(output['offset_mapping'], [(0, 2), (3, 6),
  95. (7, 13), (14, 16),
  96. (17, 20), (21, 28),
  97. (29, 31), (32, 40),
  98. (41, 44), (45, 49),
  99. (50, 53), (54, 60),
  100. (61, 64), (65, 70),
  101. (71, 73), (74, 80)])
  102. if __name__ == '__main__':
  103. unittest.main()