From: @zhangxiaokun9 Reviewed-by: @xchu42 Signed-off-by:tags/v1.5.1
@@ -14,14 +14,14 @@ | |||||
* limitations under the License. | * limitations under the License. | ||||
*/ | */ | ||||
#include "ge_local_engine/engine/host_cpu_engine.h" | #include "ge_local_engine/engine/host_cpu_engine.h" | ||||
#include "graph/common/omg_util.h" | |||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/type_utils.h" | |||||
#include "register/op_kernel_registry.h" | #include "register/op_kernel_registry.h" | ||||
#include "register/host_cpu_context.h" | #include "register/host_cpu_context.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/ge/plugin_manager.h" | #include "common/ge/plugin_manager.h" | ||||
#include "graph/utils/type_utils.h" | |||||
#include "common/fp16_t.h" | #include "common/fp16_t.h" | ||||
#include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
@@ -123,10 +123,7 @@ bool HostCpuEngine::CheckSupported(const string &op_type) { | |||||
} | } | ||||
Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | Status HostCpuEngine::FindOpKernel(const ge::NodePtr &node, std::unique_ptr<HostCpuOp> &op_kernel) { | ||||
std::string op_type; | |||||
auto status = GetOriginalType(node, op_type); | |||||
GE_CHK_BOOL_EXEC_NOLOG(status == SUCCESS, return status); | |||||
const std::string op_type = NodeUtils::GetNodeType(node); | |||||
auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | auto kernel = OpKernelRegistry::GetInstance().CreateHostCpuOp(op_type); | ||||
if (kernel == nullptr) { | if (kernel == nullptr) { | ||||
GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | GELOGD("Op of type %s is not supported by host cpu engine", op_type.c_str()); | ||||
@@ -1378,7 +1378,9 @@ Status ModelManager::LoadCustAicpuSo(const OpDescPtr &op_desc, const string &so_ | |||||
Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | ||||
GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); | GELOGD("Aicpu kernel launch task in, kernel name %s.", kernel_name.c_str()); | ||||
std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | std::lock_guard<std::mutex> lock(cust_aicpu_mutex_); | ||||
if (cust_aicpu_so_.size() == 0) return SUCCESS; | |||||
if (cust_aicpu_so_.empty()) { | |||||
return SUCCESS; | |||||
} | |||||
// get current context | // get current context | ||||
rtContext_t rt_cur_ctx = nullptr; | rtContext_t rt_cur_ctx = nullptr; | ||||
auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | auto rt_error = rtCtxGetCurrent(&rt_cur_ctx); | ||||
@@ -1394,9 +1396,19 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
rtStream_t stream = nullptr; | |||||
vector<void *> allocated_mem; | vector<void *> allocated_mem; | ||||
std::function<void()> callback = [&]() { | |||||
for (auto mem : allocated_mem) { | |||||
GE_CHK_RT(rtFree(mem)); | |||||
} | |||||
if (stream != nullptr) { | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
} | |||||
}; | |||||
GE_MAKE_GUARD(release, callback); | |||||
rtError_t status; | rtError_t status; | ||||
rtStream_t stream = nullptr; | |||||
vector<CustAicpuSoBuf> v_cust_so; | vector<CustAicpuSoBuf> v_cust_so; | ||||
void *args = nullptr; | void *args = nullptr; | ||||
@@ -1471,13 +1483,6 @@ Status ModelManager::LaunchKernelCustAicpuSo(const string &kernel_name) { | |||||
GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status); | GELOGE(RT_FAILED, "[Call][RtStreamSynchronize] fail, ret = 0x%X", status); | ||||
return RT_ERROR_TO_GE_STATUS(status); | return RT_ERROR_TO_GE_STATUS(status); | ||||
} | } | ||||
std::function<void()> callback = [&]() { | |||||
for (auto mem : allocated_mem) { | |||||
GE_CHK_RT(rtFree(mem)); | |||||
} | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
}; | |||||
GE_MAKE_GUARD(release, callback); | |||||
GELOGI("Cpu kernel launch task success."); | GELOGI("Cpu kernel launch task success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
auto cluster = MakeShared<Cluster>(rank++, type, node, this); | auto cluster = MakeShared<Cluster>(rank++, type, node, this); | ||||
REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); | REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed."); | ||||
node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
if (cluster->IsUnknownShape()) { | |||||
ordered_cluster_.push_back(cluster); | |||||
} | |||||
int64_t group_index = -1; | int64_t group_index = -1; | ||||
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | ||||
@@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DynamicShapePartitioner::TopologicalSortClusters() { | |||||
Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) { | |||||
ordered_cluster_.clear(); | ordered_cluster_.clear(); | ||||
// BFS topological sort clusters for known shape cluster | // BFS topological sort clusters for known shape cluster | ||||
std::queue<ClusterPtr> ready_clusters; | std::queue<ClusterPtr> ready_clusters; | ||||
@@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { | |||||
auto cluster = ready_clusters.front(); | auto cluster = ready_clusters.front(); | ||||
ready_clusters.pop(); | ready_clusters.pop(); | ||||
cluster->UpdateRank(rank++); | cluster->UpdateRank(rank++); | ||||
if (cluster->IsKnownShape() || cluster->IsInputNode()) { | |||||
if (ordered_filter == nullptr || ordered_filter(cluster)) { | |||||
ordered_cluster_.push_back(cluster); | ordered_cluster_.push_back(cluster); | ||||
} | } | ||||
for (const auto &out_cluster : cluster->Outputs()) { | for (const auto &out_cluster : cluster->Outputs()) { | ||||
@@ -378,7 +375,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
continue; | continue; | ||||
} | } | ||||
bool is_unknown_cluster = cluster->IsUnknownShape(); | |||||
for (++rit; rit != control_cluster.rend(); ++rit) { | for (++rit; rit != control_cluster.rend(); ++rit) { | ||||
const auto &cluster_from = *rit; | const auto &cluster_from = *rit; | ||||
if (all_merged_clusters.count(cluster_from) > 0) { | if (all_merged_clusters.count(cluster_from) > 0) { | ||||
@@ -395,11 +391,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
if (!is_unknown_cluster && cluster->IsUnknownShape()) { | |||||
GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str()); | |||||
ordered_cluster_.push_back(cluster); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -475,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() { | |||||
} | } | ||||
Status DynamicShapePartitioner::MergeClusters() { | Status DynamicShapePartitioner::MergeClusters() { | ||||
const auto filter_known = [](const ClusterPtr &cluster) { | |||||
return cluster->IsKnownShape() || cluster->IsInputNode(); | |||||
}; | |||||
const auto filter_unknown = [](const ClusterPtr &cluster) { | |||||
return cluster->IsUnknownShape(); | |||||
}; | |||||
MergeClustersControlFlow(); | MergeClustersControlFlow(); | ||||
REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown), | |||||
"[TopologicalSort][Clusters] after merge control flow clusters failed."); | |||||
MergeClustersUnknownShape(); | MergeClustersUnknownShape(); | ||||
REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed."); | |||||
REQUIRE_SUCCESS(TopologicalSortClusters(filter_known), | |||||
"[TopologicalSort][Clusters] after merge unknown shape clusters failed."); | |||||
MergeClustersKnownShape(); | MergeClustersKnownShape(); | ||||
MergeClustersInputData(); | MergeClustersInputData(); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -111,6 +111,8 @@ class DynamicShapePartitioner { | |||||
Status Partition(); | Status Partition(); | ||||
using OrderedFilter = std::function<bool(const std::shared_ptr<Cluster> &cluster)>; | |||||
private: | private: | ||||
Status PartitionImpl(); | Status PartitionImpl(); | ||||
// Collect nodes that satisfy the unknowshape rules: | // Collect nodes that satisfy the unknowshape rules: | ||||
@@ -138,7 +140,7 @@ class DynamicShapePartitioner { | |||||
// Merge clusters step3 | // Merge clusters step3 | ||||
void MergeClustersInputData(); | void MergeClustersInputData(); | ||||
// Topological sort clusters after merge unknown shape clusters. | // Topological sort clusters after merge unknown shape clusters. | ||||
Status TopologicalSortClusters(); | |||||
Status TopologicalSortClusters(const OrderedFilter &ordered_filter); | |||||
// Deduplicate merged clusters | // Deduplicate merged clusters | ||||
void PruneUniqueClusters(); | void PruneUniqueClusters(); | ||||
// Establish the input-output anchors for each partition of the cluster and record links to other clusters | // Establish the input-output anchors for each partition of the cluster and record links to other clusters | ||||
@@ -265,10 +265,6 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. | |||||
node->GetOpDesc()->SetType(IDENTITY); | |||||
} | |||||
std::unique_ptr<NodeItem> new_node; | std::unique_ptr<NodeItem> new_node; | ||||
GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | ||||
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | ||||
@@ -15,9 +15,7 @@ | |||||
*/ | */ | ||||
#include "hybrid/model/node_item.h" | #include "hybrid/model/node_item.h" | ||||
#include <sstream> | |||||
#include "framework/common/debug/log.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
@@ -98,8 +96,7 @@ Status ParseFusedSubgraph(NodeItem &node_item) { | |||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type)); | |||||
const std::string node_type = NodeUtils::GetNodeType(node); | |||||
if (node_type == DATA) { | if (node_type == DATA) { | ||||
GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); | GE_CHK_GRAPH_STATUS_RET(ParseInputMapping(*node, *op_desc, *fused_subgraph)); | ||||
} else if (node_type == kNodeTypeRetVal) { | } else if (node_type == kNodeTypeRetVal) { | ||||
@@ -409,8 +406,8 @@ void NodeItem::SetDataSend(NodeItem *node_item, int anchor_index) { | |||||
void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | void NodeItem::SetCtrlSend(NodeItem *node_item, uint32_t switch_index) { | ||||
if (switch_index < switch_groups_.size()) { | if (switch_index < switch_groups_.size()) { | ||||
std::vector<const NodeItem *> &switch_group = switch_groups_[switch_index]; | |||||
switch_group.emplace_back(node_item); | |||||
auto &switch_group = switch_groups_[switch_index]; | |||||
switch_group.emplace(node_item); | |||||
} else { | } else { | ||||
ctrl_send_.insert(node_item); | ctrl_send_.insert(node_item); | ||||
} | } | ||||
@@ -433,8 +430,8 @@ void NodeItem::SetMergeCtrl(NodeItem *node_item, uint32_t merge_index) { | |||||
} | } | ||||
// this is StreamMerge node, node_item is StreamActive node. | // this is StreamMerge node, node_item is StreamActive node. | ||||
std::vector<const NodeItem *> &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace_back(node_item); | |||||
auto &switch_group = switch_groups_[merge_index]; | |||||
switch_group.emplace(node_item); | |||||
node_item->ctrl_send_.emplace(this); | node_item->ctrl_send_.emplace(this); | ||||
GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | GELOGI("Node[%s] will control node[%s]", node_item->NodeName().c_str(), NodeName().c_str()); | ||||
@@ -155,7 +155,7 @@ struct NodeItem { | |||||
std::map<const NodeItem *, int> data_recv_; // Recv data notify from | std::map<const NodeItem *, int> data_recv_; // Recv data notify from | ||||
std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | std::set<const NodeItem *> ctrl_send_; // Send ctrl notify to | ||||
std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | std::set<const NodeItem *> ctrl_recv_; // Recv ctrl notify from | ||||
std::vector<std::vector<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
std::vector<std::set<const NodeItem *>> switch_groups_; // Send ctrl notify to | |||||
std::shared_ptr<NodeTask> kernel_task; | std::shared_ptr<NodeTask> kernel_task; | ||||
std::unique_ptr<FusedSubgraph> fused_subgraph; | std::unique_ptr<FusedSubgraph> fused_subgraph; | ||||
@@ -342,6 +342,7 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
GE_CHK_RT_RET(rtEventDestroy(evt)); | GE_CHK_RT_RET(rtEventDestroy(evt)); | ||||
} | } | ||||
GELOGI("rdma callback success."); | GELOGI("rdma callback success."); | ||||
return SUCCESS; | |||||
}; | }; | ||||
HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | ||||
@@ -17,13 +17,9 @@ | |||||
#include "hybrid/node_executor/rts/rts_node_executor.h" | #include "hybrid/node_executor/rts/rts_node_executor.h" | ||||
#include "hybrid/node_executor/rts/rts_task_factory.h" | #include "hybrid/node_executor/rts/rts_task_factory.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "framework/common/types.h" | |||||
#include "graph/common/omg_util.h" | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#include "runtime/rt.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
@@ -33,6 +29,7 @@ REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask); | |||||
REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); | REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); | REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); | REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, IdentityNodeTask); | |||||
Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | ||||
auto input_desc = context.MutableInputDesc(index); | auto input_desc = context.MutableInputDesc(index); | ||||
@@ -133,8 +130,7 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< | |||||
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("[%s] Load for local task.", node->GetName().c_str()); | GELOGD("[%s] Load for local task.", node->GetName().c_str()); | ||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | |||||
const std::string node_type = NodeUtils::GetNodeType(node); | |||||
RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | ||||
if (rts_task == nullptr) { | if (rts_task == nullptr) { | ||||
GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | ||||
@@ -43,7 +43,6 @@ namespace hybrid { | |||||
REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); | REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask); | |||||
REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); | REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); | ||||
REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); | REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); | ||||
@@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | |||||
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | |||||
auto input_desc = task_context.MutableInputDesc(0); | |||||
GE_CHECK_NOTNULL(input_desc); | |||||
int64_t copy_size = 0; | |||||
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); | |||||
// copy_size would not be negative since GetTensorSizeInBytes returned successfully. | |||||
if (copy_size > 0) { | |||||
const auto in_v = task_context.MutableInput(0); | |||||
const auto out_v = task_context.MutableOutput(0); | |||||
GE_CHECK_NOTNULL(in_v); | |||||
GE_CHECK_NOTNULL(out_v); | |||||
GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(), | |||||
in_v->GetSize(), out_v->GetSize(), copy_size); | |||||
GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size, | |||||
RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); | |||||
} else { | |||||
GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size); | |||||
} | |||||
if (done_callback) { | |||||
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); | |||||
} | |||||
GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); | |||||
return SUCCESS; | |||||
} | |||||
Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { | ||||
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | GELOGD("[%s] Start to execute.", task_context.GetNodeName()); | ||||
const auto in_x = task_context.GetInput(0); // x | const auto in_x = task_context.GetInput(0); // x | ||||
@@ -60,11 +60,6 @@ class StreamMergeNodeTask : public RtsNodeTask { | |||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | ||||
}; | }; | ||||
class MemcpyAsyncNodeTask : public RtsNodeTask { | |||||
public: | |||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | |||||
}; | |||||
class PassThroughNodeTask : public RtsNodeTask { | class PassThroughNodeTask : public RtsNodeTask { | ||||
public: | public: | ||||
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override; | ||||
@@ -438,4 +438,22 @@ TEST_F(UtestModelManagerModelManager, test_data_input_tensor) { | |||||
auto ret = mm.DataInputTensor(model_id,inputs); | auto ret = mm.DataInputTensor(model_id,inputs); | ||||
EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. | EXPECT_EQ(PARAM_INVALID, ret); // HybridDavinciModel::impl_ is null. | ||||
} | } | ||||
TEST_F(UtestModelManagerModelManager, test_launch_kernel_cust_aicpu) { | |||||
ModelManager mm; | |||||
// cust_aicpu_so_ is empty. | |||||
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("empty_cust_aicpu"), SUCCESS); | |||||
// deleteCustOp after Launch will deleted. | |||||
uintptr_t resource_id = 1; // for rtCtxGetCurrent stub | |||||
std::vector<char> kernel_bin(256); | |||||
auto &cust_resource_001 = mm.cust_aicpu_so_[resource_id]; | |||||
auto tbe_kernel = std::shared_ptr<OpKernelBin>(new OpKernelBin("deleteCustOp", std::move(kernel_bin))); | |||||
auto &cust_opkernel_001 = cust_resource_001["deleteCustOp"] = tbe_kernel; | |||||
EXPECT_FALSE(mm.cust_aicpu_so_.empty()); | |||||
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("deleteCustOp"), SUCCESS); | |||||
EXPECT_TRUE(mm.cust_aicpu_so_.empty()); | |||||
} | |||||
} // namespace ge | } // namespace ge |