| @@ -259,8 +259,16 @@ ShapeFuture::ShapeFuture(NodeState *src_node, | |||||
| NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | ||||
| : node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { | : node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { | ||||
| this->op_desc_ = node_item.node->GetOpDesc(); | this->op_desc_ = node_item.node->GetOpDesc(); | ||||
| } | |||||
| Status NodeState::Init(int group, const shared_ptr<FrameState> &frame_state) { | |||||
| GE_CHECK_NOTNULL(frame_state); | |||||
| group_ = group; | |||||
| frame_state_ = frame_state; | |||||
| auto unique_task_context = TaskContext::Create(this, subgraph_context_); | auto unique_task_context = TaskContext::Create(this, subgraph_context_); | ||||
| GE_CHECK_NOTNULL(unique_task_context); | |||||
| task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | ||||
| return SUCCESS; | |||||
| } | } | ||||
| Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { | ||||
| @@ -350,6 +358,7 @@ void NodeState::ResetContext(uint64_t iteration) { | |||||
| switch_index_ = -1; | switch_index_ = -1; | ||||
| subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
| auto unique_task_context = TaskContext::Create(this, subgraph_context_); | auto unique_task_context = TaskContext::Create(this, subgraph_context_); | ||||
| GE_CHECK_NOTNULL_JUST_RETURN(unique_task_context); | |||||
| task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | task_context_ = std::shared_ptr<TaskContext>(unique_task_context.release()); | ||||
| data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | data_scheduled_ = static_cast<uint32_t>(node_item_->root_data_.size()); | ||||
| @@ -100,6 +100,8 @@ struct NodeState { | |||||
| NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); | NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); | ||||
| ~NodeState() = default; | ~NodeState() = default; | ||||
| Status Init(int group, const shared_ptr<FrameState> &frame_state); | |||||
| OpDesc *GetOpDesc() const { | OpDesc *GetOpDesc() const { | ||||
| return op_desc_.get(); | return op_desc_.get(); | ||||
| } | } | ||||
| @@ -152,18 +154,10 @@ struct NodeState { | |||||
| return merge_index_; | return merge_index_; | ||||
| } | } | ||||
| void SetGroup(int group) { | |||||
| group_ = group; | |||||
| } | |||||
| int GetGroup() const { | int GetGroup() const { | ||||
| return group_; | return group_; | ||||
| } | } | ||||
| void SetFrameState(const shared_ptr<FrameState> &frame_state) { | |||||
| frame_state_ = frame_state; | |||||
| } | |||||
| const shared_ptr<NodeTask> &GetKernelTask() const { | const shared_ptr<NodeTask> &GetKernelTask() const { | ||||
| return kernel_task_; | return kernel_task_; | ||||
| } | } | ||||
| @@ -79,20 +79,31 @@ NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return CreateNodeState(node_item); | |||||
| } | |||||
| NodeStatePtr SubgraphContext::CreateNodeState(const NodeItem *node_item) { | |||||
| GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | ||||
| if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | ||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | ||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| auto &node_state = node_states_[node_item]; | auto &node_state = node_states_[node_item]; | ||||
| if (node_state == nullptr) { | |||||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||||
| node_state->SetFrameState(GetOrCreateFrameState(*node_item)); | |||||
| node_state->SetGroup(group_); | |||||
| (void)guard; | |||||
| } | |||||
| do { | |||||
| if (node_state == nullptr) { | |||||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | |||||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||||
| if (node_state == nullptr || node_state->Init(group_, GetOrCreateFrameState(*node_item)) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "[Create][NodeState] failed for[%s].", node_item->NodeName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Create NodeState failed for %s.", node_item->NodeName().c_str()); | |||||
| break; | |||||
| } | |||||
| (void)guard; | |||||
| } | |||||
| } while (0); | |||||
| GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | ||||
| if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | ||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | ||||
| @@ -51,6 +51,7 @@ class SubgraphContext { | |||||
| void NodeDone(const NodePtr &node); | void NodeDone(const NodePtr &node); | ||||
| private: | private: | ||||
| NodeStatePtr CreateNodeState(const NodeItem *node_item); | |||||
| FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | FrameStatePtr GetOrCreateFrameState(const NodeItem &node_item); // no lock | ||||
| friend class TaskContext; | friend class TaskContext; | ||||
| const GraphItem *graph_item_; | const GraphItem *graph_item_; | ||||