|
|
|
@@ -236,14 +236,8 @@ const char *const kFieldInnerPro = "inner_product_param"; |
|
|
|
const char *const kFieldDim = "dim"; |
|
|
|
const char *const kFieldBiasTerm = "bias_term"; |
|
|
|
const char *const kDevNull = "/dev/null"; |
|
|
|
const std::string kMessage = "message"; |
|
|
|
const std::string kLayerParameter = "LayerParameter"; |
|
|
|
const std::string kCloseBrace = "}"; |
|
|
|
const std::string kOptional = "optional"; |
|
|
|
const std::string kRepeated = "repeated"; |
|
|
|
const std::string kRequired = "required"; |
|
|
|
const std::string kCustom = "custom"; |
|
|
|
const std::string kBuiltin = "built-in"; |
|
|
|
const char *const kCustom = "custom"; |
|
|
|
const char *const kBuiltin = "built-in"; |
|
|
|
std::vector<std::string> kAddTensorIrSkipNodes = {ge::parser::DATA, ge::parser::YOLODETECTIONOUTPUT, |
|
|
|
ge::parser::NETOUTPUT}; |
|
|
|
const std::set<std::string> kCustomProtoLayerCommonField = {"name", "type"}; |
|
|
|
@@ -284,104 +278,104 @@ const set<string> CaffeWeightsParser::skiped_layer_type_ = {"Split", "SoftmaxW |
|
|
|
"Dropout", "MultiLabelLMDB", "Python", "AnnotatedData"}; |
|
|
|
|
|
|
|
Status CaffeModelParser::ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag) const { |
|
|
|
if (proto_message.input_size() > 0) { |
|
|
|
GELOGI("This net exsit input."); |
|
|
|
if (proto_message.input_size() <= 0) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
GELOGI("This net exsit input."); |
|
|
|
if (proto_message.input_dim_size() > 0) { |
|
|
|
if (proto_message.input_shape_size() > 0) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E11001"); |
|
|
|
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (proto_message.input_dim_size() > 0) { |
|
|
|
if (proto_message.input_shape_size() > 0) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E11001"); |
|
|
|
GELOGE(FAILED, "[Check][Size]input_dim and input_shape can not both exist!"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
const int32_t input_dim_size = proto_message.input_dim_size(); |
|
|
|
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || |
|
|
|
((input_dim_size % proto_message.input_size()) != 0)); |
|
|
|
if (is_input_invalid) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, |
|
|
|
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); |
|
|
|
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", |
|
|
|
input_dim_size, proto_message.input_size()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
const int32_t input_dim_size = proto_message.input_dim_size(); |
|
|
|
const bool is_input_invalid = (((input_dim_size / proto_message.input_size()) != parser::DIM_DEFAULT_SIZE) || |
|
|
|
((input_dim_size % proto_message.input_size()) != 0)); |
|
|
|
if (is_input_invalid) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E11003", {"input_dim_size", "input_size"}, |
|
|
|
{std::to_string(input_dim_size), std::to_string(proto_message.input_size())}); |
|
|
|
GELOGE(FAILED, "[Check][Size]Model input_dim size[%d] is not 4 times of input size[%d].", |
|
|
|
input_dim_size, proto_message.input_size()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
for (int i = 0; i < proto_message.input_size(); i++) { |
|
|
|
domi::caffe::LayerParameter *layer = proto_message.add_layer(); |
|
|
|
GE_CHECK_NOTNULL(layer); |
|
|
|
layer->set_name(proto_message.input(i)); |
|
|
|
layer->set_type(ge::parser::INPUT_TYPE); |
|
|
|
layer->add_top(proto_message.input(i)); |
|
|
|
|
|
|
|
for (int i = 0; i < proto_message.input_size(); i++) { |
|
|
|
domi::caffe::LayerParameter *layer = proto_message.add_layer(); |
|
|
|
GE_CHECK_NOTNULL(layer); |
|
|
|
layer->set_name(proto_message.input(i)); |
|
|
|
layer->set_type(ge::parser::INPUT_TYPE); |
|
|
|
layer->add_top(proto_message.input(i)); |
|
|
|
|
|
|
|
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); |
|
|
|
GE_CHECK_NOTNULL(input_param); |
|
|
|
domi::caffe::BlobShape *shape = input_param->add_shape(); |
|
|
|
GE_CHECK_NOTNULL(shape); |
|
|
|
|
|
|
|
for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { |
|
|
|
// Can guarantee that it will not cross the border |
|
|
|
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); |
|
|
|
} |
|
|
|
input_data_flag = true; |
|
|
|
} |
|
|
|
} else if (proto_message.input_shape_size() > 0) { |
|
|
|
if (proto_message.input_shape_size() != proto_message.input_size()) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, |
|
|
|
{std::to_string(proto_message.input_shape_size()), |
|
|
|
std::to_string(proto_message.input_size())}); |
|
|
|
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", |
|
|
|
proto_message.input_shape_size(), proto_message.input_size()); |
|
|
|
return FAILED; |
|
|
|
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); |
|
|
|
GE_CHECK_NOTNULL(input_param); |
|
|
|
domi::caffe::BlobShape *shape = input_param->add_shape(); |
|
|
|
GE_CHECK_NOTNULL(shape); |
|
|
|
|
|
|
|
for (int j = 0; j < parser::DIM_DEFAULT_SIZE; j++) { |
|
|
|
// Can guarantee that it will not cross the border |
|
|
|
shape->add_dim(static_cast<int64_t>(proto_message.input_dim(j + i * parser::DIM_DEFAULT_SIZE))); |
|
|
|
} |
|
|
|
input_data_flag = true; |
|
|
|
} |
|
|
|
} else if (proto_message.input_shape_size() > 0) { |
|
|
|
if (proto_message.input_shape_size() != proto_message.input_size()) { |
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E11004", {"input_shape_size", "input_size"}, |
|
|
|
{std::to_string(proto_message.input_shape_size()), |
|
|
|
std::to_string(proto_message.input_size())}); |
|
|
|
GELOGE(FAILED, "[Check][Size]caffe net input_shape size(%d) is not equal input size(%d).", |
|
|
|
proto_message.input_shape_size(), proto_message.input_size()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
for (int i = 0; i < proto_message.input_size(); i++) { |
|
|
|
int dim_size = proto_message.input_shape(i).dim_size(); |
|
|
|
for (int i = 0; i < proto_message.input_size(); i++) { |
|
|
|
int dim_size = proto_message.input_shape(i).dim_size(); |
|
|
|
|
|
|
|
domi::caffe::LayerParameter *layer = proto_message.add_layer(); |
|
|
|
GE_CHECK_NOTNULL(layer); |
|
|
|
layer->set_name(proto_message.input(i)); |
|
|
|
layer->set_type(ge::parser::INPUT_TYPE); |
|
|
|
layer->add_top(proto_message.input(i)); |
|
|
|
domi::caffe::LayerParameter *layer = proto_message.add_layer(); |
|
|
|
GE_CHECK_NOTNULL(layer); |
|
|
|
layer->set_name(proto_message.input(i)); |
|
|
|
layer->set_type(ge::parser::INPUT_TYPE); |
|
|
|
layer->add_top(proto_message.input(i)); |
|
|
|
|
|
|
|
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); |
|
|
|
GE_CHECK_NOTNULL(input_param); |
|
|
|
domi::caffe::BlobShape *shape = input_param->add_shape(); |
|
|
|
GE_CHECK_NOTNULL(shape); |
|
|
|
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); |
|
|
|
GE_CHECK_NOTNULL(input_param); |
|
|
|
domi::caffe::BlobShape *shape = input_param->add_shape(); |
|
|
|
GE_CHECK_NOTNULL(shape); |
|
|
|
|
|
|
|
for (int j = 0; j < dim_size; j++) { |
|
|
|
// Can guarantee that it will not cross the border |
|
|
|
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j))); |
|
|
|
} |
|
|
|
input_data_flag = true; |
|
|
|
for (int j = 0; j < dim_size; j++) { |
|
|
|
// Can guarantee that it will not cross the border |
|
|
|
shape->add_dim(static_cast<int64_t>(proto_message.input_shape(i).dim(j))); |
|
|
|
} |
|
|
|
} else { |
|
|
|
const ge::ParserContext &ctx = ge::GetParserContext(); |
|
|
|
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; |
|
|
|
for (int i = 0; i < proto_message.input_size(); i++) { |
|
|
|
string name = proto_message.input(i); |
|
|
|
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input |
|
|
|
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name})); |
|
|
|
GELOGE(FAILED, "[Find][Dim]Model has no input shape."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
std::vector<int64_t> dims = input_dims.at(name); |
|
|
|
size_t dim_size = dims.size(); |
|
|
|
|
|
|
|
domi::caffe::LayerParameter *layer = proto_message.add_layer(); |
|
|
|
GE_CHECK_NOTNULL(layer); |
|
|
|
layer->set_name(name); |
|
|
|
layer->set_type(ge::parser::INPUT_TYPE); |
|
|
|
layer->add_top(proto_message.input(i)); |
|
|
|
|
|
|
|
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); |
|
|
|
GE_CHECK_NOTNULL(input_param); |
|
|
|
domi::caffe::BlobShape *shape = input_param->add_shape(); |
|
|
|
GE_CHECK_NOTNULL(shape); |
|
|
|
|
|
|
|
for (size_t j = 0; j < dim_size; j++) { |
|
|
|
shape->add_dim(dims.at(j)); |
|
|
|
} |
|
|
|
input_data_flag = true; |
|
|
|
input_data_flag = true; |
|
|
|
} |
|
|
|
} else { |
|
|
|
const ge::ParserContext &ctx = ge::GetParserContext(); |
|
|
|
std::map<std::string, std::vector<int64_t>> input_dims = ctx.input_dims; |
|
|
|
for (int i = 0; i < proto_message.input_size(); i++) { |
|
|
|
string name = proto_message.input(i); |
|
|
|
if (input_dims.count(name) == 0) { // Input defined by model does not exist in input of external input |
|
|
|
REPORT_INPUT_ERROR("E11005", std::vector<std::string>({"input"}), std::vector<std::string>({name})); |
|
|
|
GELOGE(FAILED, "[Find][Dim]Model has no input shape."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
std::vector<int64_t> dims = input_dims.at(name); |
|
|
|
size_t dim_size = dims.size(); |
|
|
|
|
|
|
|
domi::caffe::LayerParameter *layer = proto_message.add_layer(); |
|
|
|
GE_CHECK_NOTNULL(layer); |
|
|
|
layer->set_name(name); |
|
|
|
layer->set_type(ge::parser::INPUT_TYPE); |
|
|
|
layer->add_top(proto_message.input(i)); |
|
|
|
|
|
|
|
domi::caffe::InputParameter *input_param = layer->mutable_input_param(); |
|
|
|
GE_CHECK_NOTNULL(input_param); |
|
|
|
domi::caffe::BlobShape *shape = input_param->add_shape(); |
|
|
|
GE_CHECK_NOTNULL(shape); |
|
|
|
|
|
|
|
for (size_t j = 0; j < dim_size; j++) { |
|
|
|
shape->add_dim(dims.at(j)); |
|
|
|
} |
|
|
|
input_data_flag = true; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
@@ -423,7 +417,7 @@ Status CaffeModelParser::ParseNetModelByCustomProto(const char *model_path, cons |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (ParseLayerParameter(layer_descriptor, message, operators) != SUCCESS) { |
|
|
|
if (ParseLayerParameter(*layer_descriptor, *message, operators) != SUCCESS) { |
|
|
|
delete message; |
|
|
|
GELOGE(FAILED, "[Parse][LayerParameter] failed, model path:%s.", model_path); |
|
|
|
return FAILED; |
|
|
|
@@ -536,18 +530,18 @@ Status CaffeModelParser::ReadCaffeModelFromText(const char *model_path, google:: |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, |
|
|
|
const google::protobuf::Message *message, |
|
|
|
Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, |
|
|
|
const google::protobuf::Message &message, |
|
|
|
vector<ge::Operator> &operators) const { |
|
|
|
auto field_name = layer_descriptor->FindFieldByName(kFieldName); |
|
|
|
auto field_name = layer_descriptor.FindFieldByName(kFieldName); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); |
|
|
|
auto field_type = layer_descriptor->FindFieldByName(kFieldType); |
|
|
|
auto field_type = layer_descriptor.FindFieldByName(kFieldType); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); |
|
|
|
|
|
|
|
const google::protobuf::Reflection *reflection = message->GetReflection(); |
|
|
|
const google::protobuf::Reflection *reflection = message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
vector<const google::protobuf::FieldDescriptor *> field_desc; |
|
|
|
reflection->ListFields(*message, &field_desc); |
|
|
|
reflection->ListFields(message, &field_desc); |
|
|
|
for (auto &field : field_desc) { |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field, "Get FieldDescriptor failed in google::protobuf::Message"); |
|
|
|
// Only care about layers |
|
|
|
@@ -561,10 +555,10 @@ Status CaffeModelParser::ParseLayerParameter(const google::protobuf::Descriptor |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
int field_size = reflection->FieldSize(*message, field); |
|
|
|
int field_size = reflection->FieldSize(message, field); |
|
|
|
GELOGI("Total Layer num of model file is %d", field_size); |
|
|
|
for (int i = 0; i < field_size; ++i) { |
|
|
|
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); |
|
|
|
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); |
|
|
|
const google::protobuf::Reflection *layer_reflection = layer_message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
GE_CHECK_NOTNULL(layer_reflection); |
|
|
|
@@ -1316,7 +1310,8 @@ Status CaffeModelParser::ParseFromMemory(const char *data, uint32_t size, ge::Co |
|
|
|
layer_name_map[layer.name()]++; |
|
|
|
// Set the name in proto and layer |
|
|
|
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); |
|
|
|
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) |
|
|
|
duplicate_name_layer->set_name(new_name); |
|
|
|
layer.set_name(new_name);) |
|
|
|
|
|
|
|
// Insert the new operator name, the number of times of duplicate name is recorded as 1 |
|
|
|
layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); |
|
|
|
@@ -1539,7 +1534,8 @@ Status CaffeModelParser::Parse(const char *model_path, ge::ComputeGraphPtr &grap |
|
|
|
layer_name_map[layer.name()]++; |
|
|
|
// Set the name in proto and layer |
|
|
|
domi::caffe::LayerParameter *duplicate_name_layer = proto_message.mutable_layer(layer_index); |
|
|
|
duplicate_name_layer->set_name(new_name); layer.set_name(new_name);) |
|
|
|
duplicate_name_layer->set_name(new_name); |
|
|
|
layer.set_name(new_name);) |
|
|
|
|
|
|
|
// Insert the new operator name, the number of times of duplicate name is recorded as 1 |
|
|
|
layer_name_map.insert(std::make_pair(layer.name(), kNumOne)); |
|
|
|
@@ -1832,13 +1828,13 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (CheckLayersSize(message) != SUCCESS) { |
|
|
|
if (CheckLayersSize(*message) != SUCCESS) { |
|
|
|
delete message; |
|
|
|
message = nullptr; |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (ParseLayerParameter(layer_descriptor, message, graph) != SUCCESS) { |
|
|
|
if (ParseLayerParameter(*layer_descriptor, *message, graph) != SUCCESS) { |
|
|
|
delete message; |
|
|
|
message = nullptr; |
|
|
|
REPORT_CALL_ERROR("E19999", "ParseLayerParameter failed failed from weight file:%s.", weight_path); |
|
|
|
@@ -1852,18 +1848,18 @@ Status CaffeWeightsParser::ParseWeightByFusionProto(const char *weight_path, con |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor, |
|
|
|
const google::protobuf::Message *message, |
|
|
|
Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descriptor &layer_descriptor, |
|
|
|
const google::protobuf::Message &message, |
|
|
|
ge::ComputeGraphPtr &graph) { |
|
|
|
auto field_name = layer_descriptor->FindFieldByName(kFieldName); |
|
|
|
auto field_name = layer_descriptor.FindFieldByName(kFieldName); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_name, "Does not find name in google::protobuf::Descriptor"); |
|
|
|
auto field_type = layer_descriptor->FindFieldByName(kFieldType); |
|
|
|
auto field_type = layer_descriptor.FindFieldByName(kFieldType); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(field_type, "Does not find type in google::protobuf::Descriptor"); |
|
|
|
|
|
|
|
const google::protobuf::Reflection *reflection = message->GetReflection(); |
|
|
|
const google::protobuf::Reflection *reflection = message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
vector<const google::protobuf::FieldDescriptor *> field_desc; |
|
|
|
reflection->ListFields(*message, &field_desc); |
|
|
|
reflection->ListFields(message, &field_desc); |
|
|
|
|
|
|
|
NetParameter tmp_net; |
|
|
|
for (auto &field : field_desc) { |
|
|
|
@@ -1880,13 +1876,13 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
int field_size = reflection->FieldSize(*message, field); |
|
|
|
int field_size = reflection->FieldSize(message, field); |
|
|
|
GELOGI("Total Layer num of model file is %d", field_size); |
|
|
|
for (int i = 0; i < field_size; ++i) { |
|
|
|
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(*message, field, i); |
|
|
|
const google::protobuf::Message &layer_message = reflection->GetRepeatedMessage(message, field, i); |
|
|
|
|
|
|
|
LayerParameter *layer = tmp_net.add_layer(); |
|
|
|
if (ConvertLayerProto(&layer_message, layer) != SUCCESS) { |
|
|
|
if (ConvertLayerProto(layer_message, layer) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Invoke][ConvertLayerProto] Convert message to layer proto failed."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -1907,16 +1903,16 @@ Status CaffeWeightsParser::ParseLayerParameter(const google::protobuf::Descripto |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *message, |
|
|
|
Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message &message, |
|
|
|
google::protobuf::Message *layer) { |
|
|
|
const google::protobuf::Reflection *layer_reflection = message->GetReflection(); |
|
|
|
const google::protobuf::Reflection *layer_reflection = message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(layer_reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
vector<const google::protobuf::FieldDescriptor *> field_desc; |
|
|
|
layer_reflection->ListFields(*message, &field_desc); |
|
|
|
layer_reflection->ListFields(message, &field_desc); |
|
|
|
|
|
|
|
for (auto &field : field_desc) { |
|
|
|
GE_CHECK_NOTNULL(field); |
|
|
|
if (ParseLayerField(layer_reflection, message, field, layer) != SUCCESS) { |
|
|
|
if (ParseLayerField(*layer_reflection, message, *field, layer) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Invoke][ParseLayerField] Parse field %s failed.", field->name().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -1924,114 +1920,114 @@ Status CaffeWeightsParser::ConvertLayerProto(const google::protobuf::Message *me |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *reflection, |
|
|
|
const google::protobuf::Message *message, |
|
|
|
const google::protobuf::FieldDescriptor *field, |
|
|
|
Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection &reflection, |
|
|
|
const google::protobuf::Message &message, |
|
|
|
const google::protobuf::FieldDescriptor &field, |
|
|
|
google::protobuf::Message *layer) const { |
|
|
|
GELOGD("Start to parse field: %s.", field->name().c_str()); |
|
|
|
GELOGD("Start to parse field: %s.", field.name().c_str()); |
|
|
|
domi::caffe::LayerParameter *layer_proto = PtrToPtr<google::protobuf::Message, domi::caffe::LayerParameter>(layer); |
|
|
|
string filed_name = field->name(); |
|
|
|
#define CASE_FIELD_NAME(kName, method) \ |
|
|
|
string filed_name = field.name(); |
|
|
|
#define CASE_FIELD_NAME(kName, method, inner_message, field_ptr) \ |
|
|
|
if (filed_name == kField##kName) { \ |
|
|
|
string value = reflection->GetString(*message, field); \ |
|
|
|
string value = reflection.GetString(inner_message, field_ptr); \ |
|
|
|
GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ |
|
|
|
layer_proto->set_##method(value); \ |
|
|
|
return SUCCESS; \ |
|
|
|
} |
|
|
|
CASE_FIELD_NAME(Name, name); |
|
|
|
CASE_FIELD_NAME(Type, type); |
|
|
|
CASE_FIELD_NAME(Name, name, message, &field); |
|
|
|
CASE_FIELD_NAME(Type, type, message, &field); |
|
|
|
#undef CASE_FIELD_NAME |
|
|
|
#define CASE_FIELD_NAME_REPEATED(kName, method) \ |
|
|
|
if (filed_name == kField##kName) { \ |
|
|
|
int field_size = reflection->FieldSize(*message, field); \ |
|
|
|
for (int i = 0; i < field_size; ++i) { \ |
|
|
|
auto value = reflection->GetRepeatedString(*message, field, i); \ |
|
|
|
layer_proto->add_##method(value); \ |
|
|
|
} \ |
|
|
|
return SUCCESS; \ |
|
|
|
} |
|
|
|
CASE_FIELD_NAME_REPEATED(Bottom, bottom); |
|
|
|
CASE_FIELD_NAME_REPEATED(Top, top); |
|
|
|
#define CASE_FIELD_NAME_REPEATED(kName, method, inner_message, field_ptr) \ |
|
|
|
if (filed_name == kField##kName) { \ |
|
|
|
int field_size = reflection.FieldSize(inner_message, field_ptr); \ |
|
|
|
for (int i = 0; i < field_size; ++i) { \ |
|
|
|
auto value = reflection.GetRepeatedString(inner_message, field_ptr, i); \ |
|
|
|
layer_proto->add_##method(value); \ |
|
|
|
} \ |
|
|
|
return SUCCESS; \ |
|
|
|
} |
|
|
|
CASE_FIELD_NAME_REPEATED(Bottom, bottom, message, &field); |
|
|
|
CASE_FIELD_NAME_REPEATED(Top, top, message, &field); |
|
|
|
#undef CASE_FIELD_NAME_REPEATED |
|
|
|
if (filed_name == kFieldBlobs) { |
|
|
|
int field_size = reflection->FieldSize(*message, field); |
|
|
|
int field_size = reflection.FieldSize(message, &field); |
|
|
|
for (int i = 0; i < field_size; ++i) { |
|
|
|
domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); |
|
|
|
const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); |
|
|
|
if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str()); |
|
|
|
const google::protobuf::Message &sub_message = reflection.GetRepeatedMessage(message, &field, i); |
|
|
|
if (ConvertBlobsProto(sub_message, item_message) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field.name().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
if (filed_name == kFieldConvParam) { |
|
|
|
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); |
|
|
|
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); |
|
|
|
ConvolutionParameter *conv_param = layer_proto->mutable_convolution_param(); |
|
|
|
ConvertConvParamProto(&sub_message, conv_param); |
|
|
|
ConvertConvParamProto(sub_message, conv_param); |
|
|
|
} |
|
|
|
if (filed_name == kFieldInnerPro) { |
|
|
|
const google::protobuf::Message &sub_message = reflection->GetMessage(*message, field); |
|
|
|
const google::protobuf::Message &sub_message = reflection.GetMessage(message, &field); |
|
|
|
InnerProductParameter *inner_product = layer_proto->mutable_inner_product_param(); |
|
|
|
ConvertInnerProdcutProto(&sub_message, inner_product); |
|
|
|
ConvertInnerProdcutProto(sub_message, inner_product); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message *message, |
|
|
|
Status CaffeWeightsParser::ConvertBlobsProto(const google::protobuf::Message &message, |
|
|
|
google::protobuf::Message *blobs) const { |
|
|
|
const google::protobuf::Reflection *blobs_reflection = message->GetReflection(); |
|
|
|
const google::protobuf::Reflection *blobs_reflection = message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(blobs_reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
vector<const google::protobuf::FieldDescriptor *> field_desc; |
|
|
|
blobs_reflection->ListFields(*message, &field_desc); |
|
|
|
blobs_reflection->ListFields(message, &field_desc); |
|
|
|
domi::caffe::BlobProto *blobs_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobProto>(blobs); |
|
|
|
|
|
|
|
for (auto &field : field_desc) { |
|
|
|
GE_CHECK_NOTNULL(field); |
|
|
|
string feild_name = field->name(); |
|
|
|
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name) \ |
|
|
|
if (feild_name == #kName) { \ |
|
|
|
int field_size = blobs_reflection->FieldSize(*message, field); \ |
|
|
|
for (int i = 0; i < field_size; ++i) { \ |
|
|
|
valuetype value = blobs_reflection->GetRepeated##method(*message, field, i); \ |
|
|
|
blobs_proto->add_##name(value); \ |
|
|
|
} \ |
|
|
|
continue; \ |
|
|
|
} |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data); |
|
|
|
#define CASE_BLOBS_FIELD_NAME_REPEATED(kName, method, valuetype, name, inner_message, inner_field) \ |
|
|
|
if (feild_name == #kName) { \ |
|
|
|
int field_size = blobs_reflection->FieldSize(inner_message, inner_field); \ |
|
|
|
for (int i = 0; i < field_size; ++i) { \ |
|
|
|
valuetype value = blobs_reflection->GetRepeated##method(inner_message, inner_field, i); \ |
|
|
|
blobs_proto->add_##name(value); \ |
|
|
|
} \ |
|
|
|
continue; \ |
|
|
|
} |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(data, Float, float, data, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(diff, Float, float, diff, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(double_data, Double, double, double_data, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(double_diff, Double, double, double_diff, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(int32_data, Int32, int32_t, int32_data, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME_REPEATED(uint64_data, UInt64, uint64_t, uint64_data, message, field); |
|
|
|
#undef CASE_BLOBS_FIELD_NAME_REPEATED |
|
|
|
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name) \ |
|
|
|
if (feild_name == #kName) { \ |
|
|
|
valuetype value = blobs_reflection->Get##method(*message, field); \ |
|
|
|
blobs_proto->set_##name(value); \ |
|
|
|
continue; \ |
|
|
|
} |
|
|
|
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data); |
|
|
|
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num); |
|
|
|
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels); |
|
|
|
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height); |
|
|
|
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width); |
|
|
|
#define CASE_BLOBS_FIELD_NAME(kName, method, valuetype, name, inner_message, inner_field) \ |
|
|
|
if (feild_name == #kName) { \ |
|
|
|
valuetype value = blobs_reflection->Get##method(inner_message, inner_field); \ |
|
|
|
blobs_proto->set_##name(value); \ |
|
|
|
continue; \ |
|
|
|
} |
|
|
|
CASE_BLOBS_FIELD_NAME(int8_data, String, string, int8_data, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME(num, Int32, int32_t, num, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME(channels, Int32, int32_t, channels, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME(height, Int32, int32_t, height, message, field); |
|
|
|
CASE_BLOBS_FIELD_NAME(width, Int32, int32_t, width, message, field); |
|
|
|
#undef CASE_BLOBS_FIELD_NAME |
|
|
|
if (feild_name == kFieldShape) { |
|
|
|
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(*message, field); |
|
|
|
const google::protobuf::Message &sub_message = blobs_reflection->GetMessage(message, field); |
|
|
|
domi::caffe::BlobShape *blob_shape = blobs_proto->mutable_shape(); |
|
|
|
ConvertBlobShapeProto(&sub_message, blob_shape); |
|
|
|
ConvertBlobShapeProto(sub_message, blob_shape); |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message *message, |
|
|
|
Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message &message, |
|
|
|
google::protobuf::Message *dest_message) const { |
|
|
|
const google::protobuf::Reflection *reflection = message->GetReflection(); |
|
|
|
const google::protobuf::Reflection *reflection = message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
vector<const google::protobuf::FieldDescriptor *> field_desc; |
|
|
|
reflection->ListFields(*message, &field_desc); |
|
|
|
reflection->ListFields(message, &field_desc); |
|
|
|
|
|
|
|
domi::caffe::BlobShape *shape_proto = PtrToPtr<google::protobuf::Message, domi::caffe::BlobShape>(dest_message); |
|
|
|
|
|
|
|
@@ -2039,21 +2035,21 @@ Status CaffeWeightsParser::ConvertBlobShapeProto(const google::protobuf::Message |
|
|
|
if (field->name() != kFieldDim) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
int field_size = reflection->FieldSize(*message, field); |
|
|
|
int field_size = reflection->FieldSize(message, field); |
|
|
|
for (int i = 0; i < field_size; ++i) { |
|
|
|
int64_t value = reflection->GetRepeatedInt64(*message, field, i); |
|
|
|
int64_t value = reflection->GetRepeatedInt64(message, field, i); |
|
|
|
shape_proto->add_dim(value); |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message *message, |
|
|
|
Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message &message, |
|
|
|
google::protobuf::Message *dest_message) const { |
|
|
|
const google::protobuf::Reflection *reflection = message->GetReflection(); |
|
|
|
const google::protobuf::Reflection *reflection = message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
vector<const google::protobuf::FieldDescriptor *> field_desc; |
|
|
|
reflection->ListFields(*message, &field_desc); |
|
|
|
reflection->ListFields(message, &field_desc); |
|
|
|
|
|
|
|
domi::caffe::ConvolutionParameter *conv_param_proto = |
|
|
|
PtrToPtr<google::protobuf::Message, domi::caffe::ConvolutionParameter>(dest_message); |
|
|
|
@@ -2062,18 +2058,18 @@ Status CaffeWeightsParser::ConvertConvParamProto(const google::protobuf::Message |
|
|
|
if (field->name() != kFieldBiasTerm) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
bool value = reflection->GetBool(*message, field); |
|
|
|
bool value = reflection->GetBool(message, field); |
|
|
|
conv_param_proto->set_bias_term(value); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message *message, |
|
|
|
Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Message &message, |
|
|
|
google::protobuf::Message *dest_message) const { |
|
|
|
const google::protobuf::Reflection *reflection = message->GetReflection(); |
|
|
|
const google::protobuf::Reflection *reflection = message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
vector<const google::protobuf::FieldDescriptor *> field_desc; |
|
|
|
reflection->ListFields(*message, &field_desc); |
|
|
|
reflection->ListFields(message, &field_desc); |
|
|
|
|
|
|
|
domi::caffe::InnerProductParameter *inner_product_proto = |
|
|
|
PtrToPtr<google::protobuf::Message, domi::caffe::InnerProductParameter>(dest_message); |
|
|
|
@@ -2082,17 +2078,17 @@ Status CaffeWeightsParser::ConvertInnerProdcutProto(const google::protobuf::Mess |
|
|
|
if (field->name() != kFieldBiasTerm) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
bool value = reflection->GetBool(*message, field); |
|
|
|
bool value = reflection->GetBool(message, field); |
|
|
|
inner_product_proto->set_bias_term(value); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *message) const { |
|
|
|
const google::protobuf::Reflection *reflection = message->GetReflection(); |
|
|
|
Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message &message) const { |
|
|
|
const google::protobuf::Reflection *reflection = message.GetReflection(); |
|
|
|
CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(reflection, "Get Reflection failed in google::protobuf::Message"); |
|
|
|
vector<const google::protobuf::FieldDescriptor *> field_desc; |
|
|
|
reflection->ListFields(*message, &field_desc); |
|
|
|
reflection->ListFields(message, &field_desc); |
|
|
|
|
|
|
|
int num_layer = 0; |
|
|
|
int num_layers = 0; |
|
|
|
@@ -2110,7 +2106,7 @@ Status CaffeWeightsParser::CheckLayersSize(const google::protobuf::Message *mess |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
int field_size = reflection->FieldSize(*message, field); |
|
|
|
int field_size = reflection->FieldSize(message, field); |
|
|
|
if (field->name() == kLayerName) { |
|
|
|
num_layer = field_size; |
|
|
|
} else { |
|
|
|
|