| @@ -1233,6 +1233,27 @@ FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNod | |||
| return ret_graph; | |||
| } | |||
| FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||
| // select indexed item | |||
| // args: tuple of items, index | |||
| const std::string op_name = std::string("TupleGetItemTensor"); | |||
| abstract::CheckArgsSize(op_name, args_spec_list, 2); | |||
| AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||
| AbstractBasePtrList branches = branches_abs->elements(); | |||
| if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { | |||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||
| AnfNodePtr functions = ret_graph->add_parameter(); | |||
| auto index = ret_graph->add_parameter(); | |||
| ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); | |||
| return ret_graph; | |||
| } | |||
| MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; | |||
| } | |||
| REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | |||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_") | |||
| .def(py::init<std::string &>()); | |||
| @@ -1247,5 +1268,11 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { | |||
| (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | |||
| .def(py::init<std::string &>()); | |||
| })); | |||
| REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { | |||
| (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>( | |||
| *m, "TupleGetItemTensor_") | |||
| .def(py::init<std::string &>()); | |||
| })); | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -210,6 +210,18 @@ class TensorSlice : public MetaFuncGraph { | |||
| FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; | |||
| }; | |||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | |||
| class TupleGetItemTensor : public MetaFuncGraph { | |||
| public: | |||
| explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} | |||
| ~TupleGetItemTensor() override = default; | |||
| MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) | |||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||
| friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { | |||
| return lhs.name_ == rhs.name_; | |||
| } | |||
| }; | |||
| using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>; | |||
| } // namespace prim | |||
| } // namespace mindspore | |||
| @@ -129,22 +129,27 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, | |||
| AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const AbstractBasePtrList &args_spec_list) { | |||
| // Inputs: index, branch | |||
| if (args_spec_list.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is " | |||
| << args_spec_list.size() << "."; | |||
| } | |||
| AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(primitive->name(), args_spec_list, 1); | |||
| const std::string op_name = primitive->name(); | |||
| abstract::CheckArgsSize(op_name, args_spec_list, 2); | |||
| (void)CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||
| AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||
| AbstractBasePtrList branches = branches_abs->elements(); | |||
| const size_t maximum_layer_num = 1000; | |||
| if (branches.size() < 0 || branches.size() > maximum_layer_num) { | |||
| MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got " | |||
| MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " | |||
| << branches.size() << " branches."; | |||
| } | |||
| MS_EXCEPTION_IF_NULL(branches[0]); | |||
| for (size_t i = 0; i < branches.size(); i++) { | |||
| MS_EXCEPTION_IF_NULL(branches[i]); | |||
| if (!branches[i]->isa<AbstractFunction>()) { | |||
| MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " | |||
| << branches[i]->ToString() << " as the " << i << "th element."; | |||
| } | |||
| } | |||
| auto b = branches[0]; | |||
| for (size_t i = 1; i < branches.size(); i++) { | |||
| MS_EXCEPTION_IF_NULL(branches[i]); | |||
| b = b->Join(branches[i]); | |||
| } | |||
| return b; | |||
| @@ -18,13 +18,13 @@ | |||
| """Basic composite operations.""" | |||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | |||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_ | |||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ | |||
| from ...common import dtype as mstype | |||
| from ...common.api import ms_function | |||
| from .. import functional as F | |||
| from .. import operations as P | |||
| __all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_] | |||
| __all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | |||
| def add_flags(fn, **flags): | |||
| @@ -72,6 +72,28 @@ _tensor_slice = _TensorSlice('tensor_slice') | |||
| """_tensor_slice is an metafuncgraph object which will slice a tensor.""" | |||
| class _TupleGetItemTensor(base.TupleGetItemTensor_): | |||
| """ | |||
| Getting item of tuple by tensor index. | |||
| Inputs: | |||
| data (tuple): A tuple of items. | |||
| index (Tensor): The index in tensor. | |||
| Outputs: | |||
| Type, is same as the element type of data. | |||
| """ | |||
| def __init__(self, name): | |||
| base.TupleGetItemTensor_.__init__(self, name) | |||
| def __call__(self, *args): | |||
| pass | |||
| _tuple_get_item_tensor = _TupleGetItemTensor('tuple_get_item_tensor') | |||
| """_tuple_get_item_tensor is an metafuncgraph object which will select indexed item.""" | |||
| @getitem.register("Tuple", "Number") | |||
| def _tuple_getitem_by_number(data, number_index): | |||
| """ | |||
| @@ -102,6 +124,21 @@ def _tuple_getitem_by_slice(data, slice_index): | |||
| return _tuple_slice(data, slice_index) | |||
| @getitem.register("Tuple", "Tensor") | |||
| def _tuple_getitem_by_tensor(data, tensor_index): | |||
| """ | |||
| Getting item out of tuple by tensor index. | |||
| Inputs: | |||
| data (tuple): A tuple of items to index. | |||
| tensor_index (Tensor): Index to select item. | |||
| Outputs: | |||
| Type, is same as the element type of data. | |||
| """ | |||
| return _tuple_get_item_tensor(data, tensor_index) | |||
| @getitem.register("List", "Number") | |||
| def _list_getitem_by_number(data, number_index): | |||
| """ | |||
| @@ -387,7 +387,38 @@ def test_switch_layer(): | |||
| ret = F.switch_layer(index, self.layers)(x) * self.z3 | |||
| return ret | |||
| index = Tensor(0) | |||
| net = SwitchLayerCell() | |||
| net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| def test_index_to_switch_layer(): | |||
| class Layer1(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer1, self).__init__() | |||
| self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') | |||
| def construct(self, x): | |||
| return x * self.z1 | |||
| class Layer2(nn.Cell): | |||
| def __init__(self): | |||
| super(Layer2, self).__init__() | |||
| self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') | |||
| def construct(self, x): | |||
| return x * self.z2 | |||
| class SwitchLayerCell(nn.Cell): | |||
| def __init__(self): | |||
| super(SwitchLayerCell, self).__init__() | |||
| self.layers = (Layer1(), Layer2()) | |||
| self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') | |||
| def construct(self, index, x): | |||
| ret = self.layers[index](x) * self.z3 | |||
| return ret | |||
| index = Tensor(0) | |||
| net = SwitchLayerCell() | |||
| net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||
| C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||