You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import argparse
  3. import copy
  4. import tempfile
  5. import unittest
  6. import json
  7. from modelscope.utils.config import Config, check_config
  8. obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
  9. class ConfigTest(unittest.TestCase):
  10. def test_json(self):
  11. config_file = 'configs/examples/configuration.json'
  12. cfg = Config.from_file(config_file)
  13. self.assertEqual(cfg.a, 1)
  14. self.assertEqual(cfg.b, obj['b'])
  15. def test_yaml(self):
  16. config_file = 'configs/examples/configuration.yaml'
  17. cfg = Config.from_file(config_file)
  18. self.assertEqual(cfg.a, 1)
  19. self.assertEqual(cfg.b, obj['b'])
  20. def test_py(self):
  21. config_file = 'configs/examples/configuration.py'
  22. cfg = Config.from_file(config_file)
  23. self.assertEqual(cfg.a, 1)
  24. self.assertEqual(cfg.b, obj['b'])
  25. def test_dump(self):
  26. config_file = 'configs/examples/configuration.py'
  27. cfg = Config.from_file(config_file)
  28. self.assertEqual(cfg.a, 1)
  29. self.assertEqual(cfg.b, obj['b'])
  30. pretty_text = 'a = 1\n'
  31. pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
  32. json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
  33. yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
  34. with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
  35. self.assertEqual(pretty_text, cfg.dump())
  36. cfg.dump(ofile.name)
  37. with open(ofile.name, 'r') as infile:
  38. self.assertDictEqual(
  39. json.loads(json_str), json.loads(infile.read()))
  40. with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
  41. cfg.dump(ofile.name)
  42. with open(ofile.name, 'r') as infile:
  43. self.assertEqual(yaml_str, infile.read())
  44. def test_to_dict(self):
  45. config_file = 'configs/examples/configuration.json'
  46. cfg = Config.from_file(config_file)
  47. d = cfg.to_dict()
  48. print(d)
  49. self.assertTrue(isinstance(d, dict))
  50. def test_to_args(self):
  51. def parse_fn(args):
  52. parser = argparse.ArgumentParser(prog='PROG')
  53. parser.add_argument('--model-dir', default='')
  54. parser.add_argument('--lr', type=float, default=0.001)
  55. parser.add_argument('--optimizer', default='')
  56. parser.add_argument('--weight-decay', type=float, default=1e-7)
  57. parser.add_argument(
  58. '--save-checkpoint-epochs', type=int, default=30)
  59. return parser.parse_args(args)
  60. cfg = Config.from_file('configs/examples/plain_args.yaml')
  61. args = cfg.to_args(parse_fn)
  62. self.assertEqual(args.model_dir, 'path/to/model')
  63. self.assertAlmostEqual(args.lr, 0.01)
  64. self.assertAlmostEqual(args.weight_decay, 1e-6)
  65. self.assertEqual(args.optimizer, 'Adam')
  66. self.assertEqual(args.save_checkpoint_epochs, 20)
  67. def test_check_config(self):
  68. check_config('configs/cv/configuration.json')
  69. check_config('configs/nlp/sbert_sentence_similarity.json')
  70. def test_merge_from_dict(self):
  71. base_cfg = copy.deepcopy(obj)
  72. base_cfg.update({'dict_list': [dict(l1=1), dict(l2=2)]})
  73. cfg = Config(base_cfg)
  74. merge_dict = {
  75. 'a': 2,
  76. 'b.d': 'ee',
  77. 'b.c': [3, 3, 3],
  78. 'dict_list': {
  79. '0': dict(l1=3)
  80. },
  81. 'c': 'test'
  82. }
  83. cfg1 = copy.deepcopy(cfg)
  84. cfg1.merge_from_dict(merge_dict)
  85. self.assertDictEqual(
  86. cfg1._cfg_dict, {
  87. 'a': 2,
  88. 'b': {
  89. 'c': [3, 3, 3],
  90. 'd': 'ee'
  91. },
  92. 'dict_list': [dict(l1=3), dict(l2=2)],
  93. 'c': 'test'
  94. })
  95. cfg2 = copy.deepcopy(cfg)
  96. cfg2.merge_from_dict(merge_dict, force=False)
  97. self.assertDictEqual(
  98. cfg2._cfg_dict, {
  99. 'a': 1,
  100. 'b': {
  101. 'c': [1, 2, 3],
  102. 'd': 'dd'
  103. },
  104. 'dict_list': [dict(l1=1), dict(l2=2)],
  105. 'c': 'test'
  106. })
  107. def test_merge_from_dict_with_list(self):
  108. base_cfg = {
  109. 'a':
  110. 1,
  111. 'b': {
  112. 'c': [1, 2, 3],
  113. 'd': 'dd'
  114. },
  115. 'dict_list': [dict(type='l1', v=1),
  116. dict(type='l2', v=2)],
  117. 'dict_list2': [
  118. dict(
  119. type='l1',
  120. v=[dict(type='l1_1', v=1),
  121. dict(type='l1_2', v=2)]),
  122. dict(type='l2', v=2)
  123. ]
  124. }
  125. cfg = Config(base_cfg)
  126. merge_dict_for_list = {
  127. 'a':
  128. 2,
  129. 'b.c': [3, 3, 3],
  130. 'b.d':
  131. 'ee',
  132. 'dict_list': [dict(type='l1', v=8),
  133. dict(type='l3', v=8)],
  134. 'dict_list2': [
  135. dict(
  136. type='l1',
  137. v=[
  138. dict(type='l1_1', v=8),
  139. dict(type='l1_2', v=2),
  140. dict(type='l1_3', v=8),
  141. ]),
  142. dict(type='l2', v=8)
  143. ],
  144. 'c':
  145. 'test'
  146. }
  147. cfg1 = copy.deepcopy(cfg)
  148. cfg1.merge_from_dict(merge_dict_for_list, force=False)
  149. self.assertDictEqual(
  150. cfg1._cfg_dict, {
  151. 'a':
  152. 1,
  153. 'b': {
  154. 'c': [1, 2, 3],
  155. 'd': 'dd'
  156. },
  157. 'dict_list': [
  158. dict(type='l1', v=1),
  159. dict(type='l2', v=2),
  160. dict(type='l3', v=8)
  161. ],
  162. 'dict_list2': [
  163. dict(
  164. type='l1',
  165. v=[
  166. dict(type='l1_1', v=1),
  167. dict(type='l1_2', v=2),
  168. dict(type='l1_3', v=8),
  169. ]),
  170. dict(type='l2', v=2)
  171. ],
  172. 'c':
  173. 'test'
  174. })
  175. cfg2 = copy.deepcopy(cfg)
  176. cfg2.merge_from_dict(merge_dict_for_list, force=True)
  177. self.assertDictEqual(
  178. cfg2._cfg_dict, {
  179. 'a':
  180. 2,
  181. 'b': {
  182. 'c': [3, 3, 3],
  183. 'd': 'ee'
  184. },
  185. 'dict_list': [
  186. dict(type='l1', v=8),
  187. dict(type='l2', v=2),
  188. dict(type='l3', v=8)
  189. ],
  190. 'dict_list2': [
  191. dict(
  192. type='l1',
  193. v=[
  194. dict(type='l1_1', v=8),
  195. dict(type='l1_2', v=2),
  196. dict(type='l1_3', v=8),
  197. ]),
  198. dict(type='l2', v=8)
  199. ],
  200. 'c':
  201. 'test'
  202. })
  203. if __name__ == '__main__':
  204. unittest.main()