You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_and_test.py 3.1 kB

5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ test_training """
  15. import os
  16. from mindspore import Model, context
  17. from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
  18. from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
  19. from src.callbacks import LossCallBack, EvalCallBack
  20. from src.datasets import create_dataset
  21. from src.metrics import AUCMetric
  22. from src.config import WideDeepConfig
  23. context.set_context(mode=context.GRAPH_MODE, device_target="Davinci")
  24. def get_WideDeep_net(config):
  25. WideDeep_net = WideDeepModel(config)
  26. loss_net = NetWithLossClass(WideDeep_net, config)
  27. train_net = TrainStepWrap(loss_net)
  28. eval_net = PredictWithSigmoid(WideDeep_net)
  29. return train_net, eval_net
  30. class ModelBuilder():
  31. """
  32. ModelBuilder
  33. """
  34. def __init__(self):
  35. pass
  36. def get_hook(self):
  37. pass
  38. def get_train_hook(self):
  39. hooks = []
  40. callback = LossCallBack()
  41. hooks.append(callback)
  42. if int(os.getenv('DEVICE_ID')) == 0:
  43. pass
  44. return hooks
  45. def get_net(self, config):
  46. return get_WideDeep_net(config)
  47. def test_train_eval(config):
  48. """
  49. test_train_eval
  50. """
  51. data_path = config.data_path
  52. batch_size = config.batch_size
  53. epochs = config.epochs
  54. ds_train = create_dataset(data_path, train_mode=True, epochs=epochs, batch_size=batch_size)
  55. ds_eval = create_dataset(data_path, train_mode=False, epochs=epochs + 1, batch_size=batch_size)
  56. print("ds_train.size: {}".format(ds_train.get_dataset_size()))
  57. print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
  58. net_builder = ModelBuilder()
  59. train_net, eval_net = net_builder.get_net(config)
  60. train_net.set_train()
  61. auc_metric = AUCMetric()
  62. model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
  63. eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
  64. callback = LossCallBack(config=config)
  65. ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5)
  66. ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig)
  67. out = model.eval(ds_eval)
  68. print("=====" * 5 + "model.eval() initialized: {}".format(out))
  69. model.train(epochs, ds_train, callbacks=[eval_callback, callback, ckpoint_cb])
  70. if __name__ == "__main__":
  71. wide_deep_config = WideDeepConfig()
  72. wide_deep_config.argparse_init()
  73. test_train_eval(wide_deep_config)