/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_FUNCTIONAL_OPS_H_ #define GE_FUNCTIONAL_OPS_H_ #include "graph/operator_reg.h" #include "graph/operator.h" namespace ge { REG_OP(SymbolicGradient) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(f) .OP_END_FACTORY_REG(SymbolicGradient) REG_OP(RemoteCall) .INPUT(target, DT_STRING) .DYNAMIC_INPUT(args, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(f) .OP_END_FACTORY_REG(RemoteCall) REG_OP(_If) .INPUT(cond, TensorType::ALL()) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(then_branch) .GRAPH(else_branch) .OP_END_FACTORY_REG(_If) REG_OP(StatelessIf) .INPUT(cond, TensorType::ALL()) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(then_branch) .GRAPH(else_branch) .OP_END_FACTORY_REG(StatelessIf) REG_OP(If) .INPUT(cond, TensorType::ALL()) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(then_branch) .GRAPH(else_branch) .OP_END_FACTORY_REG(If) REG_OP(Case) .INPUT(branch_index, DT_INT32) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .DYNAMIC_GRAPH(branches) .OP_END_FACTORY_REG(Case) REG_OP(_While) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(cond) .GRAPH(body) .OP_END_FACTORY_REG(_While) REG_OP(While) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(cond) .GRAPH(body) .ATTR(parallel_iterations, Int, 10) .OP_END_FACTORY_REG(While) REG_OP(StatelessWhile) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(cond) .GRAPH(body) .ATTR(parallel_iterations, Int, 10) .OP_END_FACTORY_REG(StatelessWhile) REG_OP(For) .INPUT(start, DT_INT32) .INPUT(limit, DT_INT32) .INPUT(delta, DT_INT32) .DYNAMIC_INPUT(input, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(body) .OP_END_FACTORY_REG(For) REG_OP(PartitionedCall) .DYNAMIC_INPUT(args, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(f) .ATTR(config, String, "") .ATTR(config_proto, String, "") .ATTR(executor_type, String, "") .OP_END_FACTORY_REG(PartitionedCall) REG_OP(StatefulPartitionedCall) .DYNAMIC_INPUT(args, TensorType::ALL()) .DYNAMIC_OUTPUT(output, TensorType::ALL()) .GRAPH(f) .ATTR(config, String, "") .ATTR(config_proto, String, "") .ATTR(executor_type, String, "") .OP_END_FACTORY_REG(StatefulPartitionedCall) REG_OP(FakeParam) .OUTPUT(output, TensorType::ALL()) .ATTR(shape, ListInt, {}) .OP_END_FACTORY_REG(FakeParam) } // namespace ge #endif // GE_FUNCTIONAL_OPS_H_