|
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import argparse
- import copy
- import tempfile
- import unittest
-
- import json
-
- from modelscope.utils.config import Config, check_config
-
- obj = {'a': 1, 'b': {'c': [1, 2, 3], 'd': 'dd'}}
-
-
- class ConfigTest(unittest.TestCase):
-
- def test_json(self):
- config_file = 'configs/examples/configuration.json'
- cfg = Config.from_file(config_file)
- self.assertEqual(cfg.a, 1)
- self.assertEqual(cfg.b, obj['b'])
-
- def test_yaml(self):
- config_file = 'configs/examples/configuration.yaml'
- cfg = Config.from_file(config_file)
- self.assertEqual(cfg.a, 1)
- self.assertEqual(cfg.b, obj['b'])
-
- def test_py(self):
- config_file = 'configs/examples/configuration.py'
- cfg = Config.from_file(config_file)
- self.assertEqual(cfg.a, 1)
- self.assertEqual(cfg.b, obj['b'])
-
- def test_dump(self):
- config_file = 'configs/examples/configuration.py'
- cfg = Config.from_file(config_file)
- self.assertEqual(cfg.a, 1)
- self.assertEqual(cfg.b, obj['b'])
- pretty_text = 'a = 1\n'
- pretty_text += "b = dict(c=[1, 2, 3], d='dd')\n"
-
- json_str = '{"a": 1, "b": {"c": [1, 2, 3], "d": "dd"}}'
- yaml_str = 'a: 1\nb:\n c:\n - 1\n - 2\n - 3\n d: dd\n'
- with tempfile.NamedTemporaryFile(suffix='.json') as ofile:
- self.assertEqual(pretty_text, cfg.dump())
- cfg.dump(ofile.name)
- with open(ofile.name, 'r') as infile:
- self.assertDictEqual(
- json.loads(json_str), json.loads(infile.read()))
-
- with tempfile.NamedTemporaryFile(suffix='.yaml') as ofile:
- cfg.dump(ofile.name)
- with open(ofile.name, 'r') as infile:
- self.assertEqual(yaml_str, infile.read())
-
- def test_to_dict(self):
- config_file = 'configs/examples/configuration.json'
- cfg = Config.from_file(config_file)
- d = cfg.to_dict()
- print(d)
- self.assertTrue(isinstance(d, dict))
-
- def test_to_args(self):
-
- def parse_fn(args):
- parser = argparse.ArgumentParser(prog='PROG')
- parser.add_argument('--model-dir', default='')
- parser.add_argument('--lr', type=float, default=0.001)
- parser.add_argument('--optimizer', default='')
- parser.add_argument('--weight-decay', type=float, default=1e-7)
- parser.add_argument(
- '--save-checkpoint-epochs', type=int, default=30)
- return parser.parse_args(args)
-
- cfg = Config.from_file('configs/examples/plain_args.yaml')
- args = cfg.to_args(parse_fn)
-
- self.assertEqual(args.model_dir, 'path/to/model')
- self.assertAlmostEqual(args.lr, 0.01)
- self.assertAlmostEqual(args.weight_decay, 1e-6)
- self.assertEqual(args.optimizer, 'Adam')
- self.assertEqual(args.save_checkpoint_epochs, 20)
-
- def test_check_config(self):
- check_config('configs/cv/configuration.json')
- check_config('configs/nlp/sbert_sentence_similarity.json')
-
- def test_merge_from_dict(self):
- base_cfg = copy.deepcopy(obj)
- base_cfg.update({'dict_list': [dict(l1=1), dict(l2=2)]})
-
- cfg = Config(base_cfg)
-
- merge_dict = {
- 'a': 2,
- 'b.d': 'ee',
- 'b.c': [3, 3, 3],
- 'dict_list': {
- '0': dict(l1=3)
- },
- 'c': 'test'
- }
-
- cfg1 = copy.deepcopy(cfg)
- cfg1.merge_from_dict(merge_dict)
- self.assertDictEqual(
- cfg1._cfg_dict, {
- 'a': 2,
- 'b': {
- 'c': [3, 3, 3],
- 'd': 'ee'
- },
- 'dict_list': [dict(l1=3), dict(l2=2)],
- 'c': 'test'
- })
-
- cfg2 = copy.deepcopy(cfg)
- cfg2.merge_from_dict(merge_dict, force=False)
- self.assertDictEqual(
- cfg2._cfg_dict, {
- 'a': 1,
- 'b': {
- 'c': [1, 2, 3],
- 'd': 'dd'
- },
- 'dict_list': [dict(l1=1), dict(l2=2)],
- 'c': 'test'
- })
-
- def test_merge_from_dict_with_list(self):
- base_cfg = {
- 'a':
- 1,
- 'b': {
- 'c': [1, 2, 3],
- 'd': 'dd'
- },
- 'dict_list': [dict(type='l1', v=1),
- dict(type='l2', v=2)],
- 'dict_list2': [
- dict(
- type='l1',
- v=[dict(type='l1_1', v=1),
- dict(type='l1_2', v=2)]),
- dict(type='l2', v=2)
- ]
- }
- cfg = Config(base_cfg)
-
- merge_dict_for_list = {
- 'a':
- 2,
- 'b.c': [3, 3, 3],
- 'b.d':
- 'ee',
- 'dict_list': [dict(type='l1', v=8),
- dict(type='l3', v=8)],
- 'dict_list2': [
- dict(
- type='l1',
- v=[
- dict(type='l1_1', v=8),
- dict(type='l1_2', v=2),
- dict(type='l1_3', v=8),
- ]),
- dict(type='l2', v=8)
- ],
- 'c':
- 'test'
- }
-
- cfg1 = copy.deepcopy(cfg)
- cfg1.merge_from_dict(merge_dict_for_list, force=False)
- self.assertDictEqual(
- cfg1._cfg_dict, {
- 'a':
- 1,
- 'b': {
- 'c': [1, 2, 3],
- 'd': 'dd'
- },
- 'dict_list': [
- dict(type='l1', v=1),
- dict(type='l2', v=2),
- dict(type='l3', v=8)
- ],
- 'dict_list2': [
- dict(
- type='l1',
- v=[
- dict(type='l1_1', v=1),
- dict(type='l1_2', v=2),
- dict(type='l1_3', v=8),
- ]),
- dict(type='l2', v=2)
- ],
- 'c':
- 'test'
- })
-
- cfg2 = copy.deepcopy(cfg)
- cfg2.merge_from_dict(merge_dict_for_list, force=True)
- self.assertDictEqual(
- cfg2._cfg_dict, {
- 'a':
- 2,
- 'b': {
- 'c': [3, 3, 3],
- 'd': 'ee'
- },
- 'dict_list': [
- dict(type='l1', v=8),
- dict(type='l2', v=2),
- dict(type='l3', v=8)
- ],
- 'dict_list2': [
- dict(
- type='l1',
- v=[
- dict(type='l1_1', v=8),
- dict(type='l1_2', v=2),
- dict(type='l1_3', v=8),
- ]),
- dict(type='l2', v=8)
- ],
- 'c':
- 'test'
- })
-
-
- if __name__ == '__main__':
- unittest.main()
|