You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

src_pass.patch 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. diff -Npur tvm/src/pass/make_api.cc tvm_new/src/pass/make_api.cc
  2. --- tvm/src/pass/make_api.cc 2019-12-14 15:11:37.626419432 +0800
  3. +++ tvm_new/src/pass/make_api.cc 2019-12-14 14:58:46.562493287 +0800
  4. @@ -20,6 +20,11 @@
  5. /*!
  6. * \file make_api.cc Build API function.
  7. */
  8. +
  9. +/*
  10. + * 2019.12.30 - Define new function to push buffer node from api_args to args_real.
  11. + */
  12. +
  13. #include <tvm/ir_pass.h>
  14. #include <tvm/ir.h>
  15. #include <tvm/ir_visitor.h>
  16. @@ -40,6 +45,17 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr
  17. return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
  18. }
  19. +Array<Var> Param ( Array<NodeRef> api_args,Array<Var> args_real) {
  20. + int num_args = static_cast<int>(api_args.size());
  21. + for (int i = 0; i < num_args; i++) {
  22. + const BufferNode *v = api_args[i].as<BufferNode>();
  23. + if(v) {
  24. + args_real.push_back(v->data);
  25. + }
  26. + }
  27. + return args_real;
  28. +}
  29. +
  30. LoweredFunc MakeAPI(Stmt body,
  31. std::string name,
  32. Array<NodeRef> api_args,
  33. @@ -47,6 +63,8 @@ LoweredFunc MakeAPI(Stmt body,
  34. bool is_restricted) {
  35. const Stmt nop = Evaluate::make(0);
  36. int num_args = static_cast<int>(api_args.size());
  37. + Array<Var> args_real;
  38. + args_real = Param (api_args, args_real);
  39. CHECK_LE(num_unpacked_args, num_args);
  40. int num_packed_args = num_args - num_unpacked_args;
  41. // Data field definitions
  42. @@ -170,6 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
  43. NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
  44. n->name = name;
  45. n->args = args;
  46. + n->args_real = args_real;
  47. n->handle_data_type = binder.def_handle_dtype();
  48. n->is_packed_func = num_unpacked_args == 0;
  49. n->is_restricted = is_restricted;
  50. diff -Npur tvm/src/pass/split_host_device.cc tvm_new/src/pass/split_host_device.cc
  51. --- tvm/src/pass/split_host_device.cc 2019-12-14 15:11:37.626419432 +0800
  52. +++ tvm_new/src/pass/split_host_device.cc 2019-12-14 11:28:49.293979656 +0800
  53. @@ -21,6 +21,11 @@
  54. * \file split_host_device.cc
  55. * \brief Split device function from host.
  56. */
  57. +
  58. +/*
  59. + * 2019.12.30 - Add new implements for host device splitter.
  60. + */
  61. +
  62. #include <tvm/ir.h>
  63. #include <tvm/lowered_func.h>
  64. #include <tvm/channel.h>
  65. @@ -38,6 +43,7 @@ class IRUseDefAnalysis : public IRMutato
  66. Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
  67. if (op->attr_key == attr::thread_extent) {
  68. IterVar iv = Downcast<IterVar>(op->node);
  69. + iv = IterVarNode::make(Range(0, op->value), iv->var, iv->iter_type, iv->thread_tag);
  70. CHECK_NE(iv->thread_tag.length(), 0U);
  71. // thread_extent can appear multiple times
  72. // use the first appearance as def.
  73. @@ -186,6 +192,7 @@ class HostDeviceSplitter : public IRMuta
  74. name_ = f->name;
  75. NodePtr<LoweredFuncNode> n =
  76. make_node<LoweredFuncNode>(*f.operator->());
  77. + args_real = n->args_real;
  78. n->body = this->Mutate(f->body);
  79. n->func_type = kHostFunc;
  80. Array<LoweredFunc> ret{LoweredFunc(n)};
  81. @@ -196,6 +203,7 @@ class HostDeviceSplitter : public IRMuta
  82. }
  83. private:
  84. + Array<Var> args_real;
  85. Stmt SplitDeviceFunc(Stmt body) {
  86. std::ostringstream os;
  87. os << name_ << "_kernel" << device_funcs_.size();
  88. @@ -223,6 +231,30 @@ class HostDeviceSplitter : public IRMuta
  89. n->args.push_back(v);
  90. }
  91. }
  92. +std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>();
  93. + for (unsigned i = 0; i < (unsigned)args_real.size(); i++) {
  94. + bool match = false;
  95. + for (unsigned j = 0; j < (unsigned)n->args.size(); j++) {
  96. + if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) {
  97. + na->args.push_back(n->args[j]);
  98. + match = true;
  99. + break;
  100. + } else {
  101. + continue;
  102. + }
  103. + }
  104. +
  105. + if (!match) {
  106. + na->args.push_back(args_real[i]);
  107. + // mark handle data type.
  108. + for (auto kv : handle_data_type_) {
  109. + if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) {
  110. + n->handle_data_type.Set(args_real[i], kv.second);
  111. + }
  112. + }
  113. + }
  114. + }
  115. + n->args = na->args;
  116. LoweredFunc f_device(n);
  117. Array<Expr> call_args;
  118. call_args.push_back(StringImm::make(f_device->name));