You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train_step_wrap.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. train step wrap
  17. """
  18. import mindspore.nn as nn
  19. from mindspore import ParameterTuple
  20. from mindspore.ops import composite as C
  21. class TrainStepWrap(nn.Cell):
  22. """
  23. TrainStepWrap definition
  24. """
  25. def __init__(self, network):
  26. super(TrainStepWrap, self).__init__()
  27. self.network = network
  28. self.network.set_train()
  29. self.weights = ParameterTuple(network.trainable_params())
  30. self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
  31. self.hyper_map = C.HyperMap()
  32. self.grad = C.GradOperation(get_by_list=True)
  33. def construct(self, x, label):
  34. weights = self.weights
  35. grads = self.grad(self.network, weights)(x, label)
  36. return self.optimizer(grads)
  37. class NetWithLossClass(nn.Cell):
  38. """
  39. NetWithLossClass definition
  40. """
  41. def __init__(self, network):
  42. super(NetWithLossClass, self).__init__(auto_prefix=False)
  43. self.loss = nn.SoftmaxCrossEntropyWithLogits()
  44. self.network = network
  45. def construct(self, x, label):
  46. predict = self.network(x)
  47. return self.loss(predict, label)
  48. def train_step_with_loss_warp(network):
  49. return TrainStepWrap(NetWithLossClass(network))
  50. class TrainStepWrap2(nn.Cell):
  51. """
  52. TrainStepWrap2 definition
  53. """
  54. def __init__(self, network, sens):
  55. super(TrainStepWrap2, self).__init__()
  56. self.network = network
  57. self.network.set_train()
  58. self.weights = ParameterTuple(network.get_parameters())
  59. self.optimizer = nn.Momentum(self.weights, 0.1, 0.9)
  60. self.hyper_map = C.HyperMap()
  61. self.grad = C.GradOperation(get_by_list=True, sens_param=True)
  62. self.sens = sens
  63. def construct(self, x):
  64. weights = self.weights
  65. grads = self.grad(self.network, weights)(x, self.sens)
  66. return self.optimizer(grads)
  67. def train_step_with_sens(network, sens):
  68. return TrainStepWrap2(network, sens)
  69. class TrainStepWrapWithoutOpt(nn.Cell):
  70. """
  71. TrainStepWrapWithoutOpt definition
  72. """
  73. def __init__(self, network):
  74. super(TrainStepWrapWithoutOpt, self).__init__()
  75. self.network = network
  76. self.weights = ParameterTuple(network.trainable_params())
  77. self.grad = C.GradOperation(get_by_list=True)
  78. def construct(self, x, label):
  79. grads = self.grad(self.network, self.weights)(x, label)
  80. return grads
  81. def train_step_without_opt(network):
  82. return TrainStepWrapWithoutOpt(NetWithLossClass(network))