|
|
|
@@ -330,6 +330,10 @@ class OnnxExporter { |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimReLU6(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
void ExportPrimDepthwiseConv2d(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *graph_proto); |
|
|
|
|
|
|
|
void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *graph_proto); |
|
|
|
@@ -711,6 +715,115 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CN |
|
|
|
node_proto->add_input(input_slope); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimReLU6(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_op_type("Clip"); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
node_proto->add_input(input_x); |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); |
|
|
|
attr_proto->set_name("min"); |
|
|
|
attr_proto->set_f(0.f); |
|
|
|
attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_FLOAT); |
|
|
|
attr_proto->set_name("max"); |
|
|
|
attr_proto->set_f(6.f); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportPrimDepthwiseConv2d(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, |
|
|
|
onnx::GraphProto *const graph_proto) { |
|
|
|
auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); |
|
|
|
auto input_w = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); |
|
|
|
auto x_shape = dyn_cast<abstract::Shape>(node->input(1)->Shape()); |
|
|
|
auto w_shape = dyn_cast<abstract::Shape>(node->input(2)->Shape()); |
|
|
|
MS_EXCEPTION_IF_NULL(x_shape); |
|
|
|
MS_EXCEPTION_IF_NULL(w_shape); |
|
|
|
if (x_shape->shape().size() != 4 || w_shape->shape().size() != 4) { |
|
|
|
MS_LOG(EXCEPTION) << "DepthwiseConv2d input shape should be 4d."; |
|
|
|
} |
|
|
|
if (w_shape->shape()[0] != 1 && w_shape->shape()[1] != 1) { |
|
|
|
MS_LOG(EXCEPTION) << "DepthwiseConv2d weight shape[0] != 1 and shape[1] != 1, cannot reshape"; |
|
|
|
} |
|
|
|
// create w_shape constant node |
|
|
|
auto node_idx = AllocateNodeIndex(); |
|
|
|
onnx::NodeProto *node_proto = graph_proto->add_node(); |
|
|
|
std::string name_w_shape = std::to_string(node_idx); |
|
|
|
node_proto->add_output(name_w_shape); |
|
|
|
node_proto->set_op_type("Constant"); |
|
|
|
// create Value Tensor |
|
|
|
onnx::AttributeProto *attr_proto = node_proto->add_attribute(); |
|
|
|
attr_proto->set_name("value"); |
|
|
|
attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); |
|
|
|
onnx::TensorProto *tensor_proto = attr_proto->mutable_t(); |
|
|
|
tensor_proto->add_dims(static_cast<::google::protobuf::int64>(w_shape->shape().size())); |
|
|
|
tensor_proto->set_data_type(onnx::TensorProto_DataType_INT64); |
|
|
|
// reshape |
|
|
|
tensor_proto->add_int64_data(w_shape->shape()[1]); |
|
|
|
tensor_proto->add_int64_data(w_shape->shape()[0]); |
|
|
|
tensor_proto->add_int64_data(w_shape->shape()[2]); |
|
|
|
tensor_proto->add_int64_data(w_shape->shape()[3]); |
|
|
|
|
|
|
|
// add reshape node |
|
|
|
node_idx = AllocateNodeIndex(); |
|
|
|
node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_op_type(prim::kPrimReshape->name()); |
|
|
|
node_proto->add_input(input_w); |
|
|
|
node_proto->add_input(name_w_shape); |
|
|
|
input_w = std::to_string(node_idx); |
|
|
|
node_proto->add_output(input_w); |
|
|
|
|
|
|
|
// add conv node |
|
|
|
node_idx = AllocateNodeIndex(); |
|
|
|
(*node_map_ptr)[node] = node_idx; |
|
|
|
node_proto = graph_proto->add_node(); |
|
|
|
node_proto->set_op_type("Conv"); |
|
|
|
node_proto->add_input(input_x); |
|
|
|
node_proto->add_input(input_w); |
|
|
|
node_proto->add_output(std::to_string(node_idx)); |
|
|
|
// set attributes |
|
|
|
AnfNodePtr op = node->input(0); |
|
|
|
auto op_value = dyn_cast<ValueNode>(op); |
|
|
|
auto prim = dyn_cast<Primitive>(op_value->value()); |
|
|
|
// set dilations |
|
|
|
onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); |
|
|
|
onnx_attr_proto->set_name("dilations"); |
|
|
|
SetAttrTupleValueToProto<2>(prim->GetAttr("dilation"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, |
|
|
|
prim); |
|
|
|
// set group |
|
|
|
onnx_attr_proto = node_proto->add_attribute(); |
|
|
|
onnx_attr_proto->set_name("group"); |
|
|
|
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); |
|
|
|
onnx_attr_proto->set_i(x_shape->shape()[1]); |
|
|
|
// set kernel_shape |
|
|
|
onnx_attr_proto = node_proto->add_attribute(); |
|
|
|
onnx_attr_proto->set_name("kernel_shape"); |
|
|
|
SetAttrTupleValueToProto<0>(prim->GetAttr("kernel_size"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, |
|
|
|
prim); |
|
|
|
|
|
|
|
// set pad |
|
|
|
onnx_attr_proto = node_proto->add_attribute(); |
|
|
|
auto attr_value = GetValue<std::string>(prim->GetAttr("pad_mode")); |
|
|
|
onnx_attr_proto->set_name("auto_pad"); |
|
|
|
onnx_attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); |
|
|
|
if (attr_value == "valid") { |
|
|
|
onnx_attr_proto->set_s("VALID"); |
|
|
|
} else if (attr_value == "same") { |
|
|
|
onnx_attr_proto->set_s("SAME_UPPER"); |
|
|
|
} else { |
|
|
|
onnx_attr_proto->set_name("pads"); |
|
|
|
SetAttrTupleValueToProto(prim->GetAttr("pads"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); |
|
|
|
} |
|
|
|
// set strides |
|
|
|
onnx_attr_proto = node_proto->add_attribute(); |
|
|
|
onnx_attr_proto->set_name("strides"); |
|
|
|
SetAttrTupleValueToProto<2>(prim->GetAttr("stride"), onnx::AttributeProto_AttributeType_INTS, onnx_attr_proto, prim); |
|
|
|
} |
|
|
|
|
|
|
|
void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, |
|
|
|
std::map<AnfNodePtr, size_t> *node_map_ptr, onnx::GraphProto *const graph_proto) { |
|
|
|
// Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert |
|
|
|
@@ -732,6 +845,16 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n |
|
|
|
return ExportPrimPReLU(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
// MindSpore ReLU6(x) --> ONNX Clip[min=0.f, max=6.f](x) |
|
|
|
if (node->IsApply(std::make_shared<Primitive>("ReLU6"))) { |
|
|
|
return ExportPrimReLU6(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
// MindSpore DepthwiseConv2dNative --> ONNX Conv(x, reshape(w)) |
|
|
|
if (node->IsApply(std::make_shared<Primitive>("DepthwiseConv2dNative"))) { |
|
|
|
return ExportPrimDepthwiseConv2d(func_graph, node, node_map_ptr, graph_proto); |
|
|
|
} |
|
|
|
|
|
|
|
auto inputs = node->inputs(); |
|
|
|
if (inputs.size() < 1) { |
|
|
|
MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; |
|
|
|
|