| @@ -26,6 +26,7 @@ cmake-build-debug | |||||
| *_pb2.py | *_pb2.py | ||||
| *.pb.h | *.pb.h | ||||
| *.pb.cc | *.pb.cc | ||||
| *.pb | |||||
| # Object files | # Object files | ||||
| *.o | *.o | ||||
| @@ -86,7 +86,7 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { | |||||
| } | } | ||||
| bool converted = parse::ConvertData(obj, &converted_ret); | bool converted = parse::ConvertData(obj, &converted_ret); | ||||
| if (!converted) { | if (!converted) { | ||||
| MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(obj)); | |||||
| MS_LOG(EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(obj)); | |||||
| } | } | ||||
| (void)this->AddAttr(attr_name, converted_ret); | (void)this->AddAttr(attr_name, converted_ret); | ||||
| } | } | ||||
| @@ -345,14 +345,14 @@ abstract::AbstractBasePtr Tensor::ToAbstract() { | |||||
| std::string Tensor::GetShapeAndDataTypeInfo() const { | std::string Tensor::GetShapeAndDataTypeInfo() const { | ||||
| std::ostringstream buf; | std::ostringstream buf; | ||||
| buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString(); | |||||
| buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); | |||||
| return buf.str(); | return buf.str(); | ||||
| } | } | ||||
| std::string Tensor::ToString() const { | std::string Tensor::ToString() const { | ||||
| const int small_tensor_size = 30; | const int small_tensor_size = 30; | ||||
| std::ostringstream buf; | std::ostringstream buf; | ||||
| buf << "Tensor \nshape:[" << shape() << "]" << this->Dtype()->ToString(); | |||||
| buf << "Tensor shape:[" << shape() << "]" << this->Dtype()->ToString(); | |||||
| // only print small tensor | // only print small tensor | ||||
| if (DataSize() < small_tensor_size) { | if (DataSize() < small_tensor_size) { | ||||
| buf << "val:" << std::string(py::str(data())); | buf << "val:" << std::string(py::str(data())); | ||||
| @@ -234,7 +234,11 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo | |||||
| current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); | current_fg->debug_info()->set_deco_location(GetLocation(deco_list)); | ||||
| } | } | ||||
| bool set_flag = ast_->UpdateFuncGraphFlags(current_fg); | |||||
| bool set_flag = UpdateFuncGraphFlags(ast_->function(), current_fg); | |||||
| if (ast_->obj() != ast_->function()) { | |||||
| set_flag = set_flag && UpdateFuncGraphFlags(ast_->obj(), current_fg); | |||||
| } | |||||
| if (!set_flag) { | if (!set_flag) { | ||||
| MS_LOG(ERROR) << "Set flags failed"; | MS_LOG(ERROR) << "Set flags failed"; | ||||
| return nullptr; | return nullptr; | ||||
| @@ -1436,17 +1440,17 @@ bool ParseAst::IsClassMember(const py::object &node) { | |||||
| return ret.cast<bool>(); | return ret.cast<bool>(); | ||||
| } | } | ||||
| bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) { | |||||
| bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { | |||||
| if (func_graph == nullptr) { | if (func_graph == nullptr) { | ||||
| MS_LOG(ERROR) << "FuncGraph is null"; | MS_LOG(ERROR) << "FuncGraph is null"; | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (!py::hasattr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG)) { | |||||
| if (!py::hasattr(obj, PYTHON_EXTERN_MINDSPORE_FLAG)) { | |||||
| MS_LOG(DEBUG) << "No flags"; | MS_LOG(DEBUG) << "No flags"; | ||||
| return true; | return true; | ||||
| } | } | ||||
| py::dict flags = python_adapter::GetPyObjAttr(obj_, PYTHON_EXTERN_MINDSPORE_FLAG); | |||||
| py::dict flags = python_adapter::GetPyObjAttr(obj, PYTHON_EXTERN_MINDSPORE_FLAG); | |||||
| for (auto &item : flags) { | for (auto &item : flags) { | ||||
| if (!py::isinstance<py::str>(item.first)) { | if (!py::isinstance<py::str>(item.first)) { | ||||
| MS_LOG(ERROR) << "Type error in flags dict convert"; | MS_LOG(ERROR) << "Type error in flags dict convert"; | ||||
| @@ -1466,7 +1470,6 @@ bool ParseAst::UpdateFuncGraphFlags(const FuncGraphPtr &func_graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -327,9 +327,6 @@ class ParseAst { | |||||
| bool IsClassMember(const py::object &node); | bool IsClassMember(const py::object &node); | ||||
| // update the graph flags | |||||
| bool UpdateFuncGraphFlags(const FuncGraphPtr &func_graph); | |||||
| private: | private: | ||||
| // save obj,eg: class instance or function | // save obj,eg: class instance or function | ||||
| py::object obj_; | py::object obj_; | ||||
| @@ -350,6 +347,9 @@ class ParseAst { | |||||
| int function_line_offset_; | int function_line_offset_; | ||||
| }; | }; | ||||
| // update the graph flags | |||||
| bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph); | |||||
| AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m); | ||||
| } // namespace parse | } // namespace parse | ||||
| @@ -284,7 +284,6 @@ class ClipByNorm(Cell): | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=True) | self.reduce_sum = P.ReduceSum(keep_dims=True) | ||||
| self.select_ = P.Select() | self.select_ = P.Select() | ||||
| self.greater_ = P.Greater() | self.greater_ = P.Greater() | ||||
| self.axis = () | |||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| self.zero = Tensor(np.array([0.0]).astype(np.float32)) | self.zero = Tensor(np.array([0.0]).astype(np.float32)) | ||||
| self.sqrt = P.Sqrt() | self.sqrt = P.Sqrt() | ||||
| @@ -299,7 +298,7 @@ class ClipByNorm(Cell): | |||||
| def construct(self, x, clip_norm): | def construct(self, x, clip_norm): | ||||
| """add ms_function decorator for pynative mode""" | """add ms_function decorator for pynative mode""" | ||||
| mul_x = F.square(x) | mul_x = F.square(x) | ||||
| l2sum = self.cast(self.reduce_sum(mul_x, self.axis), mstype.float32) | |||||
| l2sum = self.cast(self.reduce_sum(mul_x), mstype.float32) | |||||
| cond = self.greater_(l2sum, self.zero) | cond = self.greater_(l2sum, self.zero) | ||||
| ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) | ones_ = self.fill(self.dtype(cond), self.shape(cond), 1.0) | ||||
| @@ -234,8 +234,8 @@ class TrainOneStepWithLossScaleCell(Cell): | |||||
| if scale_update_cell: | if scale_update_cell: | ||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | ||||
| name="loss_scale") | name="loss_scale") | ||||
| self.add_flags(has_effect=True) | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, data, label, sens=None): | def construct(self, data, label, sens=None): | ||||
| weights = self.weights | weights = self.weights | ||||
| loss = self.network(data, label) | loss = self.network(data, label) | ||||
| @@ -30,16 +30,16 @@ from ...common.parameter import Parameter | |||||
| __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | __all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | ||||
| def add_flags(fn, **flags): | |||||
| def add_flags(fn=None, **flags): | |||||
| """ | """ | ||||
| An interface to add flag for a function. | |||||
| An decorator to add flag for a function. | |||||
| Note: | Note: | ||||
| Only supports bool value. | Only supports bool value. | ||||
| Args: | Args: | ||||
| fn (Function): Function or cell to add flag. | |||||
| flags (bool): Flags use kwargs. | |||||
| fn (Function): Function or cell to add flag. Default: None. | |||||
| flags (dict): Flags use kwargs. Default: None. | |||||
| Returns: | Returns: | ||||
| Function, the fn added flags. | Function, the fn added flags. | ||||
| @@ -47,11 +47,17 @@ def add_flags(fn, **flags): | |||||
| Examples: | Examples: | ||||
| >>> add_flags(net, predit=True) | >>> add_flags(net, predit=True) | ||||
| """ | """ | ||||
| # need set the attr and access on c++ | |||||
| if not hasattr(fn, "_mindspore_flags"): | |||||
| fn._mindspore_flags = {} | |||||
| fn._mindspore_flags.update({**flags}) | |||||
| return fn | |||||
| def deco(fn): | |||||
| # need set the attr and access on c++ | |||||
| if not hasattr(fn, "_mindspore_flags"): | |||||
| fn._mindspore_flags = {} | |||||
| fn._mindspore_flags.update({**flags}) | |||||
| return fn | |||||
| ret = deco | |||||
| if fn is not None: | |||||
| ret = deco(fn) | |||||
| return ret | |||||
| def core(fn=None, **flags): | def core(fn=None, **flags): | ||||
| @@ -277,8 +277,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| if scale_update_cell: | if scale_update_cell: | ||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | ||||
| name="loss_scale") | name="loss_scale") | ||||
| self.add_flags(has_effect=True) | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, | def construct(self, | ||||
| source_eos_ids, | source_eos_ids, | ||||
| source_eos_mask, | source_eos_mask, | ||||
| @@ -132,9 +132,9 @@ class GetNextSentenceOutput(nn.Cell): | |||||
| def __init__(self, config): | def __init__(self, config): | ||||
| super(GetNextSentenceOutput, self).__init__() | super(GetNextSentenceOutput, self).__init__() | ||||
| self.log_softmax = _selected_ops.LogSoftmax() | self.log_softmax = _selected_ops.LogSoftmax() | ||||
| self.weight_init = TruncatedNormal(config.initializer_range) | |||||
| weight_init = TruncatedNormal(config.initializer_range) | |||||
| self.dense = nn.Dense(config.hidden_size, 2, | self.dense = nn.Dense(config.hidden_size, 2, | ||||
| weight_init=self.weight_init, has_bias=True).to_float(config.compute_type) | |||||
| weight_init=weight_init, has_bias=True).to_float(config.compute_type) | |||||
| self.dtype = config.dtype | self.dtype = config.dtype | ||||
| self.cast = P.Cast() | self.cast = P.Cast() | ||||
| @@ -321,7 +321,6 @@ class BertTrainOneStepCell(nn.Cell): | |||||
| if self.reducer_flag: | if self.reducer_flag: | ||||
| # apply grad reducer on grads | # apply grad reducer on grads | ||||
| grads = self.grad_reducer(grads) | grads = self.grad_reducer(grads) | ||||
| succ = self.optimizer(grads) | succ = self.optimizer(grads) | ||||
| return F.depend(loss, succ) | return F.depend(loss, succ) | ||||
| @@ -380,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| if scale_update_cell: | if scale_update_cell: | ||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | ||||
| name="loss_scale") | name="loss_scale") | ||||
| self.add_flags(has_effect=True) | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, | def construct(self, | ||||
| input_ids, | input_ids, | ||||
| input_mask, | input_mask, | ||||
| @@ -17,6 +17,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from mindspore.ops.composite import add_flags | |||||
| from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ | from .backbone.resnet_deeplab import _conv_bn_relu, resnet50_dl, _deep_conv_bn_relu, \ | ||||
| DepthwiseConv2dNative, SpaceToBatch, BatchToSpace | DepthwiseConv2dNative, SpaceToBatch, BatchToSpace | ||||
| @@ -121,6 +122,7 @@ class ASPP(nn.Cell): | |||||
| self.feature_shape = feature_shape | self.feature_shape = feature_shape | ||||
| self.concat = P.Concat(axis=1) | self.concat = P.Concat(axis=1) | ||||
| @add_flags(loop_can_unroll=True) | |||||
| def construct(self, x, scale_index=0): | def construct(self, x, scale_index=0): | ||||
| aspp0 = self.aspp0(x) | aspp0 = self.aspp0(x) | ||||
| aspp1 = self.global_poolings[scale_index](x) | aspp1 = self.global_poolings[scale_index](x) | ||||
| @@ -276,7 +278,7 @@ class SingleDeepLabV3(nn.Cell): | |||||
| atrous_rates=atrous_rates, | atrous_rates=atrous_rates, | ||||
| output_stride=output_stride, | output_stride=output_stride, | ||||
| fine_tune_batch_norm=fine_tune_batch_norm) | fine_tune_batch_norm=fine_tune_batch_norm) | ||||
| self.aspp.add_flags(loop_can_unroll=True) | |||||
| atrous_rates_len = 0 | atrous_rates_len = 0 | ||||
| if atrous_rates is not None: | if atrous_rates is not None: | ||||
| atrous_rates_len = len(atrous_rates) | atrous_rates_len = len(atrous_rates) | ||||
| @@ -379,8 +379,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): | |||||
| if scale_update_cell: | if scale_update_cell: | ||||
| self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), | ||||
| name="loss_scale") | name="loss_scale") | ||||
| self.add_flags(has_effect=True) | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, | def construct(self, | ||||
| input_ids, | input_ids, | ||||
| input_mask, | input_mask, | ||||
| @@ -133,8 +133,8 @@ def test_keep_order_io_effect_exception_return_dtype(): | |||||
| self.dtype = P.DType() | self.dtype = P.DType() | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.neg = P.Neg() | self.neg = P.Neg() | ||||
| self.add_flags(has_effect=True) | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| init = self.alloc_status() | init = self.alloc_status() | ||||
| self.clear_status(init) | self.clear_status(init) | ||||
| @@ -268,8 +268,8 @@ class NpuFloatNet(nn.Cell): | |||||
| self.reduce_sum = P.ReduceSum(keep_dims=True) | self.reduce_sum = P.ReduceSum(keep_dims=True) | ||||
| self.sub = P.Sub() | self.sub = P.Sub() | ||||
| self.neg = P.Neg() | self.neg = P.Neg() | ||||
| self.add_flags(has_effect=True) | |||||
| @C.add_flags(has_effect=True) | |||||
| def construct(self, x): | def construct(self, x): | ||||
| init = self.alloc_status() | init = self.alloc_status() | ||||
| self.clear_status(init) | self.clear_status(init) | ||||
| @@ -14,13 +14,13 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ test_lenet_model """ | """ test_lenet_model """ | ||||
| import numpy as np | import numpy as np | ||||
| import pytest | |||||
| import mindspore.nn as nn | import mindspore.nn as nn | ||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.nn import WithGradCell, WithLossCell | from mindspore.nn import WithGradCell, WithLossCell | ||||
| from mindspore.nn.optim import Momentum | from mindspore.nn.optim import Momentum | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from ....ut_filter import non_graph_engine | |||||
| class LeNet5(nn.Cell): | class LeNet5(nn.Cell): | ||||
| @@ -47,7 +47,7 @@ class LeNet5(nn.Cell): | |||||
| return x | return x | ||||
| @non_graph_engine | |||||
| @pytest.mark.skip(reason="need ge backend") | |||||
| def test_lenet_pynative_train_net(): | def test_lenet_pynative_train_net(): | ||||
| """ test_lenet_pynative_train_net """ | """ test_lenet_pynative_train_net """ | ||||
| data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) | data = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01) | ||||