You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import argparse
  3. import copy
  4. import tempfile
  5. import unittest
  6. from modelscope.utils.config import Config, check_config
  7. obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
  8. class ConfigTest(unittest.TestCase):
  9. def test_json(self):
  10. config_file = 'configs/examples/configuration.json'
  11. cfg = Config.from_file(config_file)
  12. self.assertEqual(cfg.a, 1)
  13. self.assertEqual(cfg.b, obj['b'])
  14. def test_yaml(self):
  15. config_file = 'configs/examples/configuration.yaml'
  16. cfg = Config.from_file(config_file)
  17. self.assertEqual(cfg.a, 1)
  18. self.assertEqual(cfg.b, obj['b'])
  19. def test_py(self):
  20. config_file = 'configs/examples/configuration.py'
  21. cfg = Config.from_file(config_file)
  22. self.assertEqual(cfg.a, 1)
  23. self.assertEqual(cfg.b, obj['b'])
  24. def test_dump(self):
  25. config_file = 'configs/examples/configuration.py'
  26. cfg = Config.from_file(config_file)
  27. self.assertEqual(cfg.a, 1)
  28. self.assertEqual(cfg.b, obj['b'])
  29. pretty_text = 'a = 1\n'
  30. pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
  31. json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
  32. yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
  33. with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
  34. self.assertEqual(pretty_text, cfg.dump())
  35. cfg.dump(ofile.name)
  36. with open(ofile.name, 'r') as infile:
  37. self.assertEqual(json_str, infile.read())
  38. with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
  39. cfg.dump(ofile.name)
  40. with open(ofile.name, 'r') as infile:
  41. self.assertEqual(yaml_str, infile.read())
  42. def test_to_dict(self):
  43. config_file = 'configs/examples/configuration.json'
  44. cfg = Config.from_file(config_file)
  45. d = cfg.to_dict()
  46. print(d)
  47. self.assertTrue(isinstance(d, dict))
  48. def test_to_args(self):
  49. def parse_fn(args):
  50. parser = argparse.ArgumentParser(prog='PROG')
  51. parser.add_argument('--model-dir', default='')
  52. parser.add_argument('--lr', type=float, default=0.001)
  53. parser.add_argument('--optimizer', default='')
  54. parser.add_argument('--weight-decay', type=float, default=1e-7)
  55. parser.add_argument(
  56. '--save-checkpoint-epochs', type=int, default=30)
  57. return parser.parse_args(args)
  58. cfg = Config.from_file('configs/examples/plain_args.yaml')
  59. args = cfg.to_args(parse_fn)
  60. self.assertEqual(args.model_dir, 'path/to/model')
  61. self.assertAlmostEqual(args.lr, 0.01)
  62. self.assertAlmostEqual(args.weight_decay, 1e-6)
  63. self.assertEqual(args.optimizer, 'Adam')
  64. self.assertEqual(args.save_checkpoint_epochs, 20)
  65. def test_check_config(self):
  66. check_config('configs/cv/configuration.json')
  67. check_config('configs/nlp/sbert_sentence_similarity.json')
  68. def test_merge_from_dict(self):
  69. base_cfg = copy.deepcopy(obj)
  70. base_cfg.update({'dict_list': [dict(l1=1), dict(l2=2)]})
  71. cfg = Config(base_cfg)
  72. merge_dict = {
  73. 'a': 2,
  74. 'b.d': 'ee',
  75. 'b.c': [3, 3, 3],
  76. 'dict_list': {
  77. '0': dict(l1=3)
  78. },
  79. 'c': 'test'
  80. }
  81. cfg1 = copy.deepcopy(cfg)
  82. cfg1.merge_from_dict(merge_dict)
  83. self.assertDictEqual(
  84. cfg1._cfg_dict, {
  85. 'a': 2,
  86. 'b': {
  87. 'c': [3, 3, 3],
  88. 'd': 'ee'
  89. },
  90. 'dict_list': [dict(l1=3), dict(l2=2)],
  91. 'c': 'test'
  92. })
  93. cfg2 = copy.deepcopy(cfg)
  94. cfg2.merge_from_dict(merge_dict, force=False)
  95. self.assertDictEqual(
  96. cfg2._cfg_dict, {
  97. 'a': 1,
  98. 'b': {
  99. 'c': [1, 2, 3],
  100. 'd': 'dd'
  101. },
  102. 'dict_list': [dict(l1=1), dict(l2=2)],
  103. 'c': 'test'
  104. })
  105. def test_merge_from_dict_with_list(self):
  106. base_cfg = {
  107. 'a':
  108. 1,
  109. 'b': {
  110. 'c': [1, 2, 3],
  111. 'd': 'dd'
  112. },
  113. 'dict_list': [dict(type='l1', v=1),
  114. dict(type='l2', v=2)],
  115. 'dict_list2': [
  116. dict(
  117. type='l1',
  118. v=[dict(type='l1_1', v=1),
  119. dict(type='l1_2', v=2)]),
  120. dict(type='l2', v=2)
  121. ]
  122. }
  123. cfg = Config(base_cfg)
  124. merge_dict_for_list = {
  125. 'a':
  126. 2,
  127. 'b.c': [3, 3, 3],
  128. 'b.d':
  129. 'ee',
  130. 'dict_list': [dict(type='l1', v=8),
  131. dict(type='l3', v=8)],
  132. 'dict_list2': [
  133. dict(
  134. type='l1',
  135. v=[
  136. dict(type='l1_1', v=8),
  137. dict(type='l1_2', v=2),
  138. dict(type='l1_3', v=8),
  139. ]),
  140. dict(type='l2', v=8)
  141. ],
  142. 'c':
  143. 'test'
  144. }
  145. cfg1 = copy.deepcopy(cfg)
  146. cfg1.merge_from_dict(merge_dict_for_list, force=False)
  147. self.assertDictEqual(
  148. cfg1._cfg_dict, {
  149. 'a':
  150. 1,
  151. 'b': {
  152. 'c': [1, 2, 3],
  153. 'd': 'dd'
  154. },
  155. 'dict_list': [
  156. dict(type='l1', v=1),
  157. dict(type='l2', v=2),
  158. dict(type='l3', v=8)
  159. ],
  160. 'dict_list2': [
  161. dict(
  162. type='l1',
  163. v=[
  164. dict(type='l1_1', v=1),
  165. dict(type='l1_2', v=2),
  166. dict(type='l1_3', v=8),
  167. ]),
  168. dict(type='l2', v=2)
  169. ],
  170. 'c':
  171. 'test'
  172. })
  173. cfg2 = copy.deepcopy(cfg)
  174. cfg2.merge_from_dict(merge_dict_for_list, force=True)
  175. self.assertDictEqual(
  176. cfg2._cfg_dict, {
  177. 'a':
  178. 2,
  179. 'b': {
  180. 'c': [3, 3, 3],
  181. 'd': 'ee'
  182. },
  183. 'dict_list': [
  184. dict(type='l1', v=8),
  185. dict(type='l2', v=2),
  186. dict(type='l3', v=8)
  187. ],
  188. 'dict_list2': [
  189. dict(
  190. type='l1',
  191. v=[
  192. dict(type='l1_1', v=8),
  193. dict(type='l1_2', v=2),
  194. dict(type='l1_3', v=8),
  195. ]),
  196. dict(type='l2', v=8)
  197. ],
  198. 'c':
  199. 'test'
  200. })
  201. if __name__ == '__main__':
  202. unittest.main()