You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_rkme.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. import json
  3. import string
  4. import random
  5. import torch
  6. import unittest
  7. import tempfile
  8. import numpy as np
  9. from learnware.specification import RKMETableSpecification, RKMEImageSpecification, RKMETextSpecification
  10. from learnware.specification import generate_stat_spec
  11. class TestRKME(unittest.TestCase):
  12. def test_rkme(self):
  13. def _test_table_rkme(X):
  14. rkme = generate_stat_spec(type="table", X=X)
  15. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  16. rkme_path = os.path.join(tempdir, "rkme.json")
  17. rkme.save(rkme_path)
  18. with open(rkme_path, "r") as f:
  19. data = json.load(f)
  20. assert data["type"] == "RKMETableSpecification"
  21. rkme2 = RKMETableSpecification()
  22. rkme2.load(rkme_path)
  23. assert rkme2.type == "RKMETableSpecification"
  24. _test_table_rkme(np.random.uniform(-10000, 10000, size=(5000, 200)))
  25. _test_table_rkme(np.random.uniform(-10000, 10000, size=(10000, 100)))
  26. _test_table_rkme(np.random.uniform(-10000, 10000, size=(5, 20)))
  27. _test_table_rkme(np.random.uniform(-10000, 10000, size=(1, 50)))
  28. _test_table_rkme(np.random.uniform(-10000, 10000, size=(100, 150)))
  29. def test_image_rkme(self):
  30. def _test_image_rkme(X):
  31. image_rkme = generate_stat_spec(type="image", X=X, steps=10)
  32. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  33. rkme_path = os.path.join(tempdir, "rkme.json")
  34. image_rkme.save(rkme_path)
  35. with open(rkme_path, "r") as f:
  36. data = json.load(f)
  37. assert data["type"] == "RKMEImageSpecification"
  38. rkme2 = RKMEImageSpecification()
  39. rkme2.load(rkme_path)
  40. assert rkme2.type == "RKMEImageSpecification"
  41. _test_image_rkme(np.random.randint(0, 255, size=(2000, 3, 32, 32)))
  42. _test_image_rkme(np.random.randint(0, 255, size=(100, 1, 128, 128)))
  43. _test_image_rkme(np.random.randint(0, 255, size=(50, 3, 128, 128)) / 255)
  44. _test_image_rkme(torch.randint(0, 255, (2000, 3, 32, 32)))
  45. _test_image_rkme(torch.randint(0, 255, (20, 3, 128, 128)))
  46. _test_image_rkme(torch.randint(0, 255, (1, 1, 128, 128)) / 255)
  47. def test_text_rkme(self):
  48. def generate_random_text_list(num, text_type="en", min_len=10, max_len=1000):
  49. text_list = []
  50. for i in range(num):
  51. length = random.randint(min_len, max_len)
  52. if text_type == "en":
  53. characters = string.ascii_letters + string.digits + string.punctuation
  54. result_str = "".join(random.choice(characters) for i in range(length))
  55. text_list.append(result_str)
  56. elif text_type == "zh":
  57. result_str = "".join(chr(random.randint(0x4E00, 0x9FFF)) for i in range(length))
  58. text_list.append(result_str)
  59. else:
  60. raise ValueError("Type should be en or zh")
  61. return text_list
  62. def _test_text_rkme(X):
  63. rkme = generate_stat_spec(type="text", X=X)
  64. with tempfile.TemporaryDirectory(prefix="learnware_") as tempdir:
  65. rkme_path = os.path.join(tempdir, "rkme.json")
  66. rkme.save(rkme_path)
  67. with open(rkme_path, "r") as f:
  68. data = json.load(f)
  69. assert data["type"] == "RKMETextSpecification"
  70. rkme2 = RKMETextSpecification()
  71. rkme2.load(rkme_path)
  72. assert rkme2.type == "RKMETextSpecification"
  73. return rkme2.get_z().shape[1]
  74. dim1 = _test_text_rkme(generate_random_text_list(3000, "en"))
  75. dim2 = _test_text_rkme(generate_random_text_list(100, "en"))
  76. dim3 = _test_text_rkme(generate_random_text_list(50, "zh"))
  77. dim4 = _test_text_rkme(generate_random_text_list(5000, "zh"))
  78. dim5 = _test_text_rkme(generate_random_text_list(1, "zh"))
  79. assert dim1 == dim2 and dim2 == dim3 and dim3 == dim4 and dim4 == dim5
  80. if __name__ == "__main__":
  81. unittest.main()