From a59532e284884f57c871cd02fe8a6fde8f5d2a19 Mon Sep 17 00:00:00 2001 From: Sydonian <794346190@qq.com> Date: Tue, 27 May 2025 10:06:21 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkgs/ioswitch/dag/node.go | 30 +++++++++++++++++++++++++++--- pkgs/ioswitch/dag/var.go | 5 +++++ pkgs/ioswitch/plan/compile.go | 4 ++-- pkgs/ioswitch/plan/ops/store.go | 2 +- pkgs/ioswitch/plan/ops/sync.go | 2 +- 5 files changed, 36 insertions(+), 7 deletions(-) diff --git a/pkgs/ioswitch/dag/node.go b/pkgs/ioswitch/dag/node.go index d513267..df6e801 100644 --- a/pkgs/ioswitch/dag/node.go +++ b/pkgs/ioswitch/dag/node.go @@ -20,24 +20,28 @@ type NodeEnv struct { Pinned bool // 如果为true,则不应该改变这个节点的执行环境 } -func (e *NodeEnv) ToEnvUnknown() { +func (e *NodeEnv) ToEnvUnknown(pinned bool) { e.Type = EnvUnknown e.Worker = nil + e.Pinned = pinned } -func (e *NodeEnv) ToEnvDriver() { +func (e *NodeEnv) ToEnvDriver(pinned bool) { e.Type = EnvDriver e.Worker = nil + e.Pinned = pinned } -func (e *NodeEnv) ToEnvWorker(worker exec.WorkerInfo) { +func (e *NodeEnv) ToEnvWorker(worker exec.WorkerInfo, pinned bool) { e.Type = EnvWorker e.Worker = worker + e.Pinned = pinned } func (e *NodeEnv) CopyFrom(other *NodeEnv) { e.Type = other.Type e.Worker = other.Worker + e.Pinned = other.Pinned } func (e *NodeEnv) Equals(other *NodeEnv) bool { @@ -461,6 +465,16 @@ func (s StreamOutputSlot) ToSlot(slot StreamInputSlot) { s.Var().To(slot.Node, slot.Index) } +// 查询所有输出的连接的输入槽位 +func (s StreamOutputSlot) ListDstSlots() []StreamInputSlot { + slots := make([]StreamInputSlot, s.Var().Dst.Len()) + myVar := s.Var() + for i, dst := range s.Var().Dst { + slots[i] = StreamInputSlot{Node: dst, Index: dst.InputStreams().IndexOf(myVar)} + } + return slots +} + type StreamInputSlot struct { Node Node Index int @@ -483,6 +497,16 @@ func (s ValueOutputSlot) ToSlot(slot ValueInputSlot) { s.Var().To(slot.Node, slot.Index) } +// 查询所有输出的连接的输入槽位 +func (s ValueOutputSlot) ListDstSlots() []ValueInputSlot { + slots := make([]ValueInputSlot, s.Var().Dst.Len()) + myVar := s.Var() + for i, dst := range s.Var().Dst { + slots[i] = ValueInputSlot{Node: dst, Index: dst.InputValues().IndexOf(myVar)} + } + return slots +} + type ValueInputSlot struct { Node Node Index int diff --git a/pkgs/ioswitch/dag/var.go b/pkgs/ioswitch/dag/var.go index 5d5989e..217def7 100644 --- a/pkgs/ioswitch/dag/var.go +++ b/pkgs/ioswitch/dag/var.go @@ -64,6 +64,11 @@ func (v *ValueVar) To(to Node, slotIdx int) { to.InputValues().Slots.Set(slotIdx, v) } +func (v *ValueVar) ToSlot(slot ValueInputSlot) { + v.Dst.Add(slot.Node) + slot.Node.InputValues().Slots.Set(slot.Index, v) +} + func (v *ValueVar) NotTo(node Node) { v.Dst.Remove(node) node.InputValues().Slots.Clear(v) diff --git a/pkgs/ioswitch/plan/compile.go b/pkgs/ioswitch/plan/compile.go index 6e75abe..bcdf20c 100644 --- a/pkgs/ioswitch/plan/compile.go +++ b/pkgs/ioswitch/plan/compile.go @@ -42,7 +42,7 @@ func generateSend(graph *ops.GraphNodeBuilder) { dstNode := out.Dst.Get(0) getNode := graph.NewGetStream(node.Env().Worker) - getNode.Env().ToEnvDriver() + getNode.Env().ToEnvDriver(true) // // 同时需要对此变量生成HoldUntil指令,避免Plan结束时Get指令还未到达 holdNode := graph.NewHoldUntil() @@ -86,7 +86,7 @@ func generateSend(graph *ops.GraphNodeBuilder) { // // 如果是要送到Driver,则只能由Driver主动去拉取 dstNode := out.Dst.Get(0) getNode := graph.NewGetValue(node.Env().Worker) - getNode.Env().ToEnvDriver() + getNode.Env().ToEnvDriver(true) // // 同时需要对此变量生成HoldUntil指令,避免Plan结束时Get指令还未到达 holdNode := graph.NewHoldUntil() diff --git a/pkgs/ioswitch/plan/ops/store.go b/pkgs/ioswitch/plan/ops/store.go index 4353dc7..caf45e2 100644 --- a/pkgs/ioswitch/plan/ops/store.go +++ b/pkgs/ioswitch/plan/ops/store.go @@ -23,7 +23,7 @@ func (o *Store) Execute(ctx *exec.ExecContext, e *exec.Executor) error { } func (o *Store) String() string { - return fmt.Sprintf("Store %v: %v", o.Key, o.Var) + return fmt.Sprintf("Store %v as \"%v\"", o.Var, o.Key) } type StoreConst struct { diff --git a/pkgs/ioswitch/plan/ops/sync.go b/pkgs/ioswitch/plan/ops/sync.go index c4ec53c..3e096cb 100644 --- a/pkgs/ioswitch/plan/ops/sync.go +++ b/pkgs/ioswitch/plan/ops/sync.go @@ -113,7 +113,7 @@ func (w *HoldUntil) Execute(ctx *exec.ExecContext, e *exec.Executor) error { } func (w *HoldUntil) String() string { - return fmt.Sprintf("HoldUntil Waits: %v, (%v) -> (%v)", utils.FormatVarIDs(w.Waits), utils.FormatVarIDs(w.Holds), utils.FormatVarIDs(w.Emits)) + return fmt.Sprintf("HoldUntil(waits=%v): %v -> %v", utils.FormatVarIDs(w.Waits), utils.FormatVarIDs(w.Holds), utils.FormatVarIDs(w.Emits)) } type HangUntil struct {