You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import argparse
  3. import tempfile
  4. import unittest
  5. from modelscope.utils.config import Config
  6. obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
  7. class ConfigTest(unittest.TestCase):
  8. def test_json(self):
  9. config_file = 'configs/examples/configuration.json'
  10. cfg = Config.from_file(config_file)
  11. self.assertEqual(cfg.a, 1)
  12. self.assertEqual(cfg.b, obj['b'])
  13. def test_yaml(self):
  14. config_file = 'configs/examples/configuration.yaml'
  15. cfg = Config.from_file(config_file)
  16. self.assertEqual(cfg.a, 1)
  17. self.assertEqual(cfg.b, obj['b'])
  18. def test_py(self):
  19. config_file = 'configs/examples/configuration.py'
  20. cfg = Config.from_file(config_file)
  21. self.assertEqual(cfg.a, 1)
  22. self.assertEqual(cfg.b, obj['b'])
  23. def test_dump(self):
  24. config_file = 'configs/examples/configuration.py'
  25. cfg = Config.from_file(config_file)
  26. self.assertEqual(cfg.a, 1)
  27. self.assertEqual(cfg.b, obj['b'])
  28. pretty_text = 'a = 1\n'
  29. pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
  30. json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
  31. yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
  32. with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
  33. self.assertEqual(pretty_text, cfg.dump())
  34. cfg.dump(ofile.name)
  35. with open(ofile.name, 'r') as infile:
  36. self.assertEqual(json_str, infile.read())
  37. with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
  38. cfg.dump(ofile.name)
  39. with open(ofile.name, 'r') as infile:
  40. self.assertEqual(yaml_str, infile.read())
  41. def test_to_dict(self):
  42. config_file = 'configs/examples/configuration.json'
  43. cfg = Config.from_file(config_file)
  44. d = cfg.to_dict()
  45. print(d)
  46. self.assertTrue(isinstance(d, dict))
  47. def test_to_args(self):
  48. def parse_fn(args):
  49. parser = argparse.ArgumentParser(prog='PROG')
  50. parser.add_argument('--model-dir', default='')
  51. parser.add_argument('--lr', type=float, default=0.001)
  52. parser.add_argument('--optimizer', default='')
  53. parser.add_argument('--weight-decay', type=float, default=1e-7)
  54. parser.add_argument(
  55. '--save-checkpoint-epochs', type=int, default=30)
  56. return parser.parse_args(args)
  57. cfg = Config.from_file('configs/examples/plain_args.yaml')
  58. args = cfg.to_args(parse_fn)
  59. self.assertEqual(args.model_dir, 'path/to/model')
  60. self.assertAlmostEqual(args.lr, 0.01)
  61. self.assertAlmostEqual(args.weight_decay, 1e-6)
  62. self.assertEqual(args.optimizer, 'Adam')
  63. self.assertEqual(args.save_checkpoint_epochs, 20)
  64. if __name__ == '__main__':
  65. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展