You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_run_config.py 2.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test_run_config """
  16. import pytest
  17. from mindspore.train.callback import CheckpointConfig
  18. def test_init():
  19. """ test_init """
  20. save_checkpoint_steps = 1
  21. keep_checkpoint_max = 5
  22. config = CheckpointConfig(save_checkpoint_steps,
  23. keep_checkpoint_max)
  24. assert config.save_checkpoint_steps == save_checkpoint_steps
  25. assert config.keep_checkpoint_max == keep_checkpoint_max
  26. policy = config.get_checkpoint_policy()
  27. assert policy['keep_checkpoint_max'] == keep_checkpoint_max
  28. def test_arguments_values():
  29. """ test_arguments_values """
  30. config = CheckpointConfig()
  31. assert config.save_checkpoint_steps == 1
  32. assert config.save_checkpoint_seconds is None
  33. assert config.keep_checkpoint_max == 5
  34. assert config.keep_checkpoint_per_n_minutes is None
  35. with pytest.raises(TypeError):
  36. CheckpointConfig(save_checkpoint_steps='abc')
  37. with pytest.raises(TypeError):
  38. CheckpointConfig(save_checkpoint_seconds='abc')
  39. with pytest.raises(TypeError):
  40. CheckpointConfig(keep_checkpoint_max='abc')
  41. with pytest.raises(TypeError):
  42. CheckpointConfig(keep_checkpoint_per_n_minutes='abc')
  43. with pytest.raises(ValueError):
  44. CheckpointConfig(save_checkpoint_steps=-1)
  45. with pytest.raises(ValueError):
  46. CheckpointConfig(save_checkpoint_seconds=-1)
  47. with pytest.raises(ValueError):
  48. CheckpointConfig(keep_checkpoint_max=-1)
  49. with pytest.raises(ValueError):
  50. CheckpointConfig(keep_checkpoint_per_n_minutes=-1)