|
- diff -Npur tvm/src/pass/make_api.cc tvm_new/src/pass/make_api.cc
- --- tvm/src/pass/make_api.cc 2019-12-14 15:11:37.626419432 +0800
- +++ tvm_new/src/pass/make_api.cc 2019-12-14 14:58:46.562493287 +0800
- @@ -20,6 +20,11 @@
- /*!
- * \file make_api.cc Build API function.
- */
- +
- +/*
- + * 2019.12.30 - Define new function to push buffer node from api_args to args_real.
- + */
- +
- #include <tvm/ir_pass.h>
- #include <tvm/ir.h>
- #include <tvm/ir_visitor.h>
- @@ -40,6 +45,17 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr
- return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
- }
-
- +Array<Var> Param ( Array<NodeRef> api_args,Array<Var> args_real) {
- + int num_args = static_cast<int>(api_args.size());
- + for (int i = 0; i < num_args; i++) {
- + const BufferNode *v = api_args[i].as<BufferNode>();
- + if(v) {
- + args_real.push_back(v->data);
- + }
- + }
- + return args_real;
- +}
- +
- LoweredFunc MakeAPI(Stmt body,
- std::string name,
- Array<NodeRef> api_args,
- @@ -47,6 +63,8 @@ LoweredFunc MakeAPI(Stmt body,
- bool is_restricted) {
- const Stmt nop = Evaluate::make(0);
- int num_args = static_cast<int>(api_args.size());
- + Array<Var> args_real;
- + args_real = Param (api_args, args_real);
- CHECK_LE(num_unpacked_args, num_args);
- int num_packed_args = num_args - num_unpacked_args;
- // Data field definitions
- @@ -170,6 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
- NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
- n->name = name;
- n->args = args;
- + n->args_real = args_real;
- n->handle_data_type = binder.def_handle_dtype();
- n->is_packed_func = num_unpacked_args == 0;
- n->is_restricted = is_restricted;
- diff -Npur tvm/src/pass/split_host_device.cc tvm_new/src/pass/split_host_device.cc
- --- tvm/src/pass/split_host_device.cc 2019-12-14 15:11:37.626419432 +0800
- +++ tvm_new/src/pass/split_host_device.cc 2019-12-14 11:28:49.293979656 +0800
- @@ -21,6 +21,11 @@
- * \file split_host_device.cc
- * \brief Split device function from host.
- */
- +
- +/*
- + * 2019.12.30 - Add new implements for host device splitter.
- + */
- +
- #include <tvm/ir.h>
- #include <tvm/lowered_func.h>
- #include <tvm/channel.h>
- @@ -38,6 +43,7 @@ class IRUseDefAnalysis : public IRMutato
- Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
- if (op->attr_key == attr::thread_extent) {
- IterVar iv = Downcast<IterVar>(op->node);
- + iv = IterVarNode::make(Range(0, op->value), iv->var, iv->iter_type, iv->thread_tag);
- CHECK_NE(iv->thread_tag.length(), 0U);
- // thread_extent can appear multiple times
- // use the first appearance as def.
- @@ -186,6 +192,7 @@ class HostDeviceSplitter : public IRMuta
- name_ = f->name;
- NodePtr<LoweredFuncNode> n =
- make_node<LoweredFuncNode>(*f.operator->());
- + args_real = n->args_real;
- n->body = this->Mutate(f->body);
- n->func_type = kHostFunc;
- Array<LoweredFunc> ret{LoweredFunc(n)};
- @@ -196,6 +203,7 @@ class HostDeviceSplitter : public IRMuta
- }
-
- private:
- + Array<Var> args_real;
- Stmt SplitDeviceFunc(Stmt body) {
- std::ostringstream os;
- os << name_ << "_kernel" << device_funcs_.size();
- @@ -223,6 +231,30 @@ class HostDeviceSplitter : public IRMuta
- n->args.push_back(v);
- }
- }
- +std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>();
- + for (unsigned i = 0; i < (unsigned)args_real.size(); i++) {
- + bool match = false;
- + for (unsigned j = 0; j < (unsigned)n->args.size(); j++) {
- + if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) {
- + na->args.push_back(n->args[j]);
- + match = true;
- + break;
- + } else {
- + continue;
- + }
- + }
- +
- + if (!match) {
- + na->args.push_back(args_real[i]);
- + // mark handle data type.
- + for (auto kv : handle_data_type_) {
- + if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) {
- + n->handle_data_type.Set(args_real[i], kv.second);
- + }
- + }
- + }
- + }
- + n->args = na->args;
- LoweredFunc f_device(n);
- Array<Expr> call_args;
- call_args.push_back(StringImm::make(f_device->name));
|