/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef TENSOR_ASSIGN_H_ #define TENSOR_ASSIGN_H_ #include "graph/ge_tensor.h" #include "proto/tensorflow/tensor.pb.h" namespace domi { using GeTensorPtr = std::shared_ptr; using Status = uint32_t; using domi::tensorflow::TensorProto; using google::protobuf::int32; using google::protobuf::int64; class TensorAssign { public: static Status SetGeTensor(const TensorProto &tensor, GeTensorPtr &weight); static Status SetGeTensorDataType(int64_t dataType, GeTensorPtr &weight); static ge::DataType ConvertTensorflowDataType(uint32_t tf_data_type); private: static bool CheckBoolVal(tensorflow::DataType data_type); static bool CheckHalfVal(tensorflow::DataType data_type); static bool CheckFloatVal(tensorflow::DataType data_type); static bool CheckDoubleVal(tensorflow::DataType data_type); static bool CheckComplex64Val(tensorflow::DataType data_type); static bool CheckComplex128Val(tensorflow::DataType data_type); static bool CheckStringVal(tensorflow::DataType data_type); static bool CheckByte(tensorflow::DataType data_type); static bool CheckDoubleByte(tensorflow::DataType data_type); static bool CheckSignedFourByte(tensorflow::DataType data_type); static bool CheckUnsignedFourByte(tensorflow::DataType data_type); static bool CheckSignedEightByte(tensorflow::DataType data_type); static bool CheckUnsignedEightByte(tensorflow::DataType data_type); static Status GetDoubleByteVal(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, GeTensorPtr &weight); static Status GetByteVal(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, GeTensorPtr &weight); static Status GetStringVal(int32_t val_size, const google::protobuf::RepeatedPtrField &val_vector, int count, GeTensorPtr &weight); static void SetGeTensorWeightData(const TensorProto &tensor, int32_t val_size, int count, GeTensorPtr &weight); static void SetWeightData(tensorflow::DataType data_type, int count, const std::string &tensor_content, GeTensorPtr &weight); template static Status GetVal(int32_t val_size, const google::protobuf::RepeatedField &val_vector, int count, GeTensorPtr &weight) { bool zerosLike = (count != val_size && val_size == 1); T *addr = new (std::nothrow) T[count](); GE_CHECK_NOTNULL(addr); int minCount = (count > val_size) ? val_size : count; if (!zerosLike) { for (int32_t i = 0; i < minCount; i++) { *(addr + i) = val_vector.Get(i); } for (int32_t i = minCount; i < count; i++) { *(addr + i) = val_vector.Get(minCount - 1); } } else { for (int32_t i = 0; i < count; i++) { *(addr + i) = val_vector.Get(0); } } (void)weight->SetData(reinterpret_cast(addr), count * sizeof(T)); GE_DELETE_NEW_ARRAY(addr); return SUCCESS; } }; } // namespace domi #endif // TENSOR_ASSIGN_H_