| @@ -32,7 +32,7 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const size_t kConcatV2InputNum = 3; | const size_t kConcatV2InputNum = 3; | ||||
| const int kSupportEmptyTensorRank = 1; | const int kSupportEmptyTensorRank = 1; | ||||
| const std::set<DataType> concatv2_supported_type = {DT_INT32, DT_FLOAT}; | |||||
| const std::set<DataType> concatv2_supported_type = {DT_INT32, DT_FLOAT, DT_INT64}; | |||||
| template <typename T> | template <typename T> | ||||
| void GetOutputData(std::vector<T> &y_data, int64_t loop, size_t &input_size, | void GetOutputData(std::vector<T> &y_data, int64_t loop, size_t &input_size, | ||||
| @@ -88,6 +88,7 @@ Status ConcatV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const vector<ge: | |||||
| std::vector<int32_t> y_data_int32_t; | std::vector<int32_t> y_data_int32_t; | ||||
| std::vector<float> y_data_float; | std::vector<float> y_data_float; | ||||
| std::vector<int64_t> y_data_int64_t; | |||||
| // Index 0 can always gets a GeTensorDesc object from any OpDescPtr. | // Index 0 can always gets a GeTensorDesc object from any OpDescPtr. | ||||
| auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); | auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); | ||||
| @@ -106,6 +107,7 @@ Status ConcatV2Kernel::Compute(const ge::OpDescPtr op_desc_ptr, const vector<ge: | |||||
| switch (data_type) { | switch (data_type) { | ||||
| SET_OUTPUT(DT_INT32, int32_t) | SET_OUTPUT(DT_INT32, int32_t) | ||||
| SET_OUTPUT(DT_FLOAT, float) | SET_OUTPUT(DT_FLOAT, float) | ||||
| SET_OUTPUT(DT_INT64, int64_t) | |||||
| default: | default: | ||||
| break; | break; | ||||
| } | } | ||||