You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_text_rkme.py 2.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import os
  2. import json
  3. import string
  4. import random
  5. import unittest
  6. import tempfile
  7. from learnware.specification import RKMETextSpecification
  8. from learnware.specification import generate_stat_spec
  9. class TestTextRKME(unittest.TestCase):
  10. @staticmethod
  11. def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
  12. text_list = []
  13. for i in range(num):
  14. length = random.randint(min_len, max_len)
  15. if text_type == "en":
  16. characters = string.ascii_letters + string.digits + string.punctuation
  17. result_str = "".join(random.choice(characters) for i in range(length))
  18. text_list.append(result_str)
  19. elif text_type == "zh":
  20. result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
  21. text_list.append(result_str)
  22. else:
  23. raise ValueError("Type should be en or zh")
  24. return text_list
  25. @staticmethod
  26. def _test_text_rkme(X):
  27. rkme = generate_stat_spec(type="text", X=X)
  28. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  29. rkme_path = os.path.join(tempdir, "rkme.json")
  30. rkme.save(rkme_path)
  31. with open(rkme_path, "r") as f:
  32. data = json.load(f)
  33. assert data["type"] == "RKMETextSpecification"
  34. rkme2 = RKMETextSpecification()
  35. rkme2.load(rkme_path)
  36. assert rkme2.type == "RKMETextSpecification"
  37. return rkme2.get_z().shape[1]
  38. def test_text_rkme(self):
  39. dim1 = self._test_text_rkme(self.generate_random_text_list(3000, "en"))
  40. dim2 = self._test_text_rkme(self.generate_random_text_list(100, "en"))
  41. dim3 = self._test_text_rkme(self.generate_random_text_list(50, "zh"))
  42. dim4 = self._test_text_rkme(self.generate_random_text_list(5000, "zh"))
  43. dim5 = self._test_text_rkme(self.generate_random_text_list(1, "zh"))
  44. assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 and dim4 == dim5
  45. if __name__ == "__main__":
  46. unittest.main()