You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_config.py 3.0 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import argparse
  3. import os.path as osp
  4. import tempfile
  5. import unittest
  6. from pathlib import Path
  7. from modelscope.fileio import dump, load
  8. from modelscope.utils.config import Config
  9. obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
  10. class ConfigTest(unittest.TestCase):
  11. def test_json(self):
  12. config_file = 'configs/examples/configuration.json'
  13. cfg = Config.from_file(config_file)
  14. self.assertEqual(cfg.a, 1)
  15. self.assertEqual(cfg.b, obj['b'])
  16. def test_yaml(self):
  17. config_file = 'configs/examples/configuration.yaml'
  18. cfg = Config.from_file(config_file)
  19. self.assertEqual(cfg.a, 1)
  20. self.assertEqual(cfg.b, obj['b'])
  21. def test_py(self):
  22. config_file = 'configs/examples/configuration.py'
  23. cfg = Config.from_file(config_file)
  24. self.assertEqual(cfg.a, 1)
  25. self.assertEqual(cfg.b, obj['b'])
  26. def test_dump(self):
  27. config_file = 'configs/examples/configuration.py'
  28. cfg = Config.from_file(config_file)
  29. self.assertEqual(cfg.a, 1)
  30. self.assertEqual(cfg.b, obj['b'])
  31. pretty_text = 'a = 1\n'
  32. pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
  33. json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
  34. yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
  35. with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
  36. self.assertEqual(pretty_text, cfg.dump())
  37. cfg.dump(ofile.name)
  38. with open(ofile.name, 'r') as infile:
  39. self.assertEqual(json_str, infile.read())
  40. with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
  41. cfg.dump(ofile.name)
  42. with open(ofile.name, 'r') as infile:
  43. self.assertEqual(yaml_str, infile.read())
  44. def test_to_dict(self):
  45. config_file = 'configs/examples/configuration.json'
  46. cfg = Config.from_file(config_file)
  47. d = cfg.to_dict()
  48. print(d)
  49. self.assertTrue(isinstance(d, dict))
  50. def test_to_args(self):
  51. def parse_fn(args):
  52. parser = argparse.ArgumentParser(prog='PROG')
  53. parser.add_argument('--model-dir', default='')
  54. parser.add_argument('--lr', type=float, default=0.001)
  55. parser.add_argument('--optimizer', default='')
  56. parser.add_argument('--weight-decay', type=float, default=1e-7)
  57. parser.add_argument(
  58. '--save-checkpoint-epochs', type=int, default=30)
  59. return parser.parse_args(args)
  60. cfg = Config.from_file('configs/examples/plain_args.yaml')
  61. args = cfg.to_args(parse_fn)
  62. self.assertEqual(args.model_dir, 'path/to/model')
  63. self.assertAlmostEqual(args.lr, 0.01)
  64. self.assertAlmostEqual(args.weight_decay, 1e-6)
  65. self.assertEqual(args.optimizer, 'Adam')
  66. self.assertEqual(args.save_checkpoint_epochs, 20)
  67. if __name__ == '__main__':
  68. unittest.main()

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展