You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

callback.py 2.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Mock the MindSpore mindspore/train/callback.py."""
  16. import os
  17. class RunContext:
  18. """Mock the RunContext class."""
  19. def __init__(self, original_args=None):
  20. self._original_args = original_args
  21. self._stop_requested = False
  22. def original_args(self):
  23. """Mock original_args."""
  24. return self._original_args
  25. def stop_requested(self):
  26. """Mock stop_requested method."""
  27. return self._stop_requested
  28. class Callback:
  29. """Mock the Callback class."""
  30. def __init__(self):
  31. pass
  32. def begin(self, run_context):
  33. """Called once before network training."""
  34. def epoch_begin(self, run_context):
  35. """Called before each epoch begin."""
  36. class _ListCallback(Callback):
  37. """Mock the _ListCallabck class."""
  38. def __init__(self, callbacks):
  39. super(_ListCallback, self).__init__()
  40. self._callbacks = callbacks
  41. class ModelCheckpoint(Callback):
  42. """Mock the ModelCheckpoint class."""
  43. def __init__(self, prefix='CKP', directory=None, config=None):
  44. super(ModelCheckpoint, self).__init__()
  45. self._prefix = prefix
  46. self._directory = directory
  47. self._config = config
  48. self._latest_ckpt_file_name = os.path.join(directory, prefix + 'test_model.ckpt')
  49. @property
  50. def model_file_name(self):
  51. """Get the file name of model."""
  52. return self._model_file_name
  53. @property
  54. def latest_ckpt_file_name(self):
  55. """Get the latest file name fo checkpoint."""
  56. return self._latest_ckpt_file_name
  57. class SummaryStep(Callback):
  58. """Mock the SummaryStep class."""
  59. def __init__(self, summary, flush_step=10):
  60. super(SummaryStep, self).__init__()
  61. self._sumamry = summary
  62. self._flush_step = flush_step
  63. self.summary_file_name = summary.full_file_name