|
|
|
@@ -103,6 +103,9 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec |
|
|
|
// simulating benchmark: session_->CompileGraph() -> scheduler.Schedule() -> BuildKernels() |
|
|
|
MS_LOG(DEBUG) << "create OpenCLKernel"; |
|
|
|
kernel::KernelKey key{kernel::kGPU, kernel_inputs.front()->data_type(), primitive_type}; |
|
|
|
if (key.data_type == kNumberTypeFloat32 && fp16_enable) { |
|
|
|
key.data_type = kNumberTypeFloat16; |
|
|
|
} |
|
|
|
auto creator = KernelRegistry::GetInstance()->GetCreator(key); |
|
|
|
if (creator == nullptr) { |
|
|
|
std::cerr << "can't get registry function for: " << schema::EnumNamePrimitiveType(primitive_type) |
|
|
|
@@ -117,7 +120,7 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, const std::vec |
|
|
|
FAIL(); |
|
|
|
} |
|
|
|
kernel->set_name(schema::EnumNamesPrimitiveType()[primitive_type]); |
|
|
|
|
|
|
|
kernel->set_desc(key); |
|
|
|
// simulating benchmark: session_->CompileGraph() -> scheduler.Schedule() -> ConstructSubGraphs() |
|
|
|
MS_LOG(DEBUG) << "create SubGraph"; |
|
|
|
std::vector<LiteKernel *> kernels{kernel}; |
|
|
|
@@ -246,6 +249,9 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std |
|
|
|
// simulating benchmark: session_->CompileGraph() -> scheduler.Schedule() -> BuildKernels() |
|
|
|
MS_LOG(DEBUG) << "create OpenCLKernel"; |
|
|
|
kernel::KernelKey key{kernel::kGPU, kernel_inputs.front()->data_type(), primitive_type}; |
|
|
|
if (key.data_type == kNumberTypeFloat32 && fp16_enable) { |
|
|
|
key.data_type = kNumberTypeFloat16; |
|
|
|
} |
|
|
|
auto creator = KernelRegistry::GetInstance()->GetCreator(key); |
|
|
|
if (creator == nullptr) { |
|
|
|
std::cerr << "can't get registry function for: " << schema::EnumNamePrimitiveType(primitive_type) |
|
|
|
@@ -260,7 +266,7 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std |
|
|
|
FAIL(); |
|
|
|
} |
|
|
|
kernel->set_name(schema::EnumNamesPrimitiveType()[primitive_type]); |
|
|
|
|
|
|
|
kernel->set_desc(key); |
|
|
|
// simulating benchmark: session_->CompileGraph() -> scheduler.Schedule() -> ConstructSubGraphs() |
|
|
|
MS_LOG(DEBUG) << "create SubGraph"; |
|
|
|
std::vector<LiteKernel *> kernels{kernel}; |
|
|
|
|