diff -Npur tvm/src/pass/make_api.cc tvm_new/src/pass/make_api.cc --- tvm/src/pass/make_api.cc 2019-12-14 15:11:37.626419432 +0800 +++ tvm_new/src/pass/make_api.cc 2019-12-14 14:58:46.562493287 +0800 @@ -20,6 +20,11 @@ /*! * \file make_api.cc Build API function. */ + +/* + * 2019.12.30 - Define new function to push buffer node from api_args to args_real. + */ + #include #include #include @@ -40,6 +45,17 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0)); } +Array Param ( Array api_args,Array args_real) { + int num_args = static_cast(api_args.size()); + for (int i = 0; i < num_args; i++) { + const BufferNode *v = api_args[i].as(); + if(v) { + args_real.push_back(v->data); + } + } + return args_real; +} + LoweredFunc MakeAPI(Stmt body, std::string name, Array api_args, @@ -47,6 +63,8 @@ LoweredFunc MakeAPI(Stmt body, bool is_restricted) { const Stmt nop = Evaluate::make(0); int num_args = static_cast(api_args.size()); + Array args_real; + args_real = Param (api_args, args_real); CHECK_LE(num_unpacked_args, num_args); int num_packed_args = num_args - num_unpacked_args; // Data field definitions @@ -170,6 +188,7 @@ LoweredFunc MakeAPI(Stmt body, NodePtr n = make_node(); n->name = name; n->args = args; + n->args_real = args_real; n->handle_data_type = binder.def_handle_dtype(); n->is_packed_func = num_unpacked_args == 0; n->is_restricted = is_restricted; diff -Npur tvm/src/pass/split_host_device.cc tvm_new/src/pass/split_host_device.cc --- tvm/src/pass/split_host_device.cc 2019-12-14 15:11:37.626419432 +0800 +++ tvm_new/src/pass/split_host_device.cc 2019-12-14 11:28:49.293979656 +0800 @@ -21,6 +21,11 @@ * \file split_host_device.cc * \brief Split device function from host. */ + +/* + * 2019.12.30 - Add new implements for host device splitter. + */ + #include #include #include @@ -38,6 +43,7 @@ class IRUseDefAnalysis : public IRMutato Stmt Mutate_(const AttrStmt *op, const Stmt& s) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); + iv = IterVarNode::make(Range(0, op->value), iv->var, iv->iter_type, iv->thread_tag); CHECK_NE(iv->thread_tag.length(), 0U); // thread_extent can appear multiple times // use the first appearance as def. @@ -186,6 +192,7 @@ class HostDeviceSplitter : public IRMuta name_ = f->name; NodePtr n = make_node(*f.operator->()); + args_real = n->args_real; n->body = this->Mutate(f->body); n->func_type = kHostFunc; Array ret{LoweredFunc(n)}; @@ -196,6 +203,7 @@ class HostDeviceSplitter : public IRMuta } private: + Array args_real; Stmt SplitDeviceFunc(Stmt body) { std::ostringstream os; os << name_ << "_kernel" << device_funcs_.size(); @@ -223,6 +231,30 @@ class HostDeviceSplitter : public IRMuta n->args.push_back(v); } } +std::shared_ptr na = std::make_shared(); + for (unsigned i = 0; i < (unsigned)args_real.size(); i++) { + bool match = false; + for (unsigned j = 0; j < (unsigned)n->args.size(); j++) { + if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) { + na->args.push_back(n->args[j]); + match = true; + break; + } else { + continue; + } + } + + if (!match) { + na->args.push_back(args_real[i]); + // mark handle data type. + for (auto kv : handle_data_type_) { + if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) { + n->handle_data_type.Set(args_real[i], kv.second); + } + } + } + } + n->args = na->args; LoweredFunc f_device(n); Array call_args; call_args.push_back(StringImm::make(f_device->name));