|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Mock the MindSpore mindspore/train/callback.py."""
- import os
-
-
- class RunContext:
- """Mock the RunContext class."""
-
- def __init__(self, original_args=None):
- self._original_args = original_args
- self._stop_requested = False
-
- def original_args(self):
- """Mock original_args."""
- return self._original_args
-
- def stop_requested(self):
- """Mock stop_requested method."""
- return self._stop_requested
-
-
- class Callback:
- """Mock the Callback class."""
-
- def __init__(self):
- pass
-
- def begin(self, run_context):
- """Called once before network training."""
-
- def epoch_begin(self, run_context):
- """Called before each epoch begin."""
-
-
- class _ListCallback(Callback):
- """Mock the _ListCallabck class."""
-
- def __init__(self, callbacks):
- super(_ListCallback, self).__init__()
- self._callbacks = callbacks
-
-
- class ModelCheckpoint(Callback):
- """Mock the ModelCheckpoint class."""
-
- def __init__(self, prefix='CKP', directory=None, config=None):
- super(ModelCheckpoint, self).__init__()
- self._prefix = prefix
- self._directory = directory
- self._config = config
- self._latest_ckpt_file_name = os.path.join(directory, prefix + 'test_model.ckpt')
-
- @property
- def model_file_name(self):
- """Get the file name of model."""
- return self._model_file_name
-
- @property
- def latest_ckpt_file_name(self):
- """Get the latest file name fo checkpoint."""
- return self._latest_ckpt_file_name
-
-
- class SummaryStep(Callback):
- """Mock the SummaryStep class."""
-
- def __init__(self, summary, flush_step=10):
- super(SummaryStep, self).__init__()
- self._sumamry = summary
- self._flush_step = flush_step
- self.summary_file_name = summary.full_file_name
|