| @@ -1,6 +1,7 @@ | |||
| package dag | |||
| import ( | |||
| "github.com/samber/lo" | |||
| "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" | |||
| "gitlink.org.cn/cloudream/common/utils/lo2" | |||
| ) | |||
| @@ -74,6 +75,10 @@ func (s *VarSlots) Set(idx int, val *Var) *Var { | |||
| return old | |||
| } | |||
| func (s *VarSlots) IndexOf(v *Var) int { | |||
| return lo.IndexOf(*s, v) | |||
| } | |||
| func (s *VarSlots) Append(val *Var) int { | |||
| *s = append(*s, val) | |||
| return s.Len() - 1 | |||
| @@ -139,6 +144,14 @@ func (s *InputSlots) GetVarIDsRanged(start, end int) []exec.VarID { | |||
| return ids | |||
| } | |||
| func (s *InputSlots) ClearInput(v *Var) { | |||
| for i, v2 := range s.RawArray() { | |||
| if v2 == v { | |||
| s.Set(i, nil) | |||
| } | |||
| } | |||
| } | |||
| type OutputSlots struct { | |||
| VarSlots | |||
| } | |||
| @@ -149,18 +162,12 @@ func (s *OutputSlots) Setup(my Node, v *Var, slotIdx int) { | |||
| } | |||
| s.Set(slotIdx, v) | |||
| *v.From() = EndPoint{ | |||
| Node: my, | |||
| SlotIndex: slotIdx, | |||
| } | |||
| v.src = my | |||
| } | |||
| func (s *OutputSlots) SetupNew(my Node, v *Var) { | |||
| s.Append(v) | |||
| *v.From() = EndPoint{ | |||
| Node: my, | |||
| SlotIndex: s.Len() - 1, | |||
| } | |||
| v.src = my | |||
| } | |||
| func (s *OutputSlots) GetVarIDs() []exec.VarID { | |||
| @@ -1,6 +1,7 @@ | |||
| package dag | |||
| import ( | |||
| "github.com/samber/lo" | |||
| "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" | |||
| "gitlink.org.cn/cloudream/common/utils/lo2" | |||
| ) | |||
| @@ -10,90 +11,98 @@ type EndPoint struct { | |||
| SlotIndex int // 所连接的Node的Output或Input数组的索引 | |||
| } | |||
| type EndPointSlots []EndPoint | |||
| type DstList []Node | |||
| func (s *EndPointSlots) Len() int { | |||
| func (s *DstList) Len() int { | |||
| return len(*s) | |||
| } | |||
| func (s *EndPointSlots) Get(idx int) *EndPoint { | |||
| return &(*s)[idx] | |||
| func (s *DstList) Get(idx int) Node { | |||
| return (*s)[idx] | |||
| } | |||
| func (s *EndPointSlots) Add(ed EndPoint) int { | |||
| (*s) = append((*s), ed) | |||
| func (s *DstList) Add(n Node) int { | |||
| (*s) = append((*s), n) | |||
| return len(*s) - 1 | |||
| } | |||
| func (s *EndPointSlots) Remove(ed EndPoint) { | |||
| func (s *DstList) Remove(n Node) { | |||
| for i, e := range *s { | |||
| if e == ed { | |||
| if e == n { | |||
| (*s) = lo2.RemoveAt((*s), i) | |||
| return | |||
| } | |||
| } | |||
| } | |||
| func (s *EndPointSlots) RemoveAt(idx int) { | |||
| func (s *DstList) RemoveAt(idx int) { | |||
| lo2.RemoveAt((*s), idx) | |||
| } | |||
| func (s *EndPointSlots) Resize(size int) { | |||
| func (s *DstList) Resize(size int) { | |||
| if s.Len() < size { | |||
| (*s) = append((*s), make([]EndPoint, size-s.Len())...) | |||
| (*s) = append((*s), make([]Node, size-s.Len())...) | |||
| } else if s.Len() > size { | |||
| (*s) = (*s)[:size] | |||
| } | |||
| } | |||
| func (s *EndPointSlots) RawArray() []EndPoint { | |||
| func (s *DstList) RawArray() []Node { | |||
| return *s | |||
| } | |||
| type Var struct { | |||
| VarID exec.VarID | |||
| from EndPoint | |||
| to EndPointSlots | |||
| src Node | |||
| dst DstList | |||
| } | |||
| func (v *Var) From() *EndPoint { | |||
| return &v.from | |||
| func (v *Var) From() Node { | |||
| return v.src | |||
| } | |||
| func (v *Var) To() *EndPointSlots { | |||
| return &v.to | |||
| func (v *Var) To() *DstList { | |||
| return &v.dst | |||
| } | |||
| func (v *Var) StreamIndexOfFrom() int { | |||
| return lo.IndexOf(v.src.OutputStreams().RawArray(), v) | |||
| } | |||
| func (v *Var) ValueIndexOfFrom() int { | |||
| return lo.IndexOf(v.src.InputValues().RawArray(), v) | |||
| } | |||
| func (v *Var) ValueTo(to Node, slotIdx int) { | |||
| v.To().Add(EndPoint{Node: to, SlotIndex: slotIdx}) | |||
| v.To().Add(to) | |||
| to.InputValues().Set(slotIdx, v) | |||
| } | |||
| func (v *Var) ValueNotTo(node Node, slotIdx int) { | |||
| v.to.Remove(EndPoint{Node: node, SlotIndex: slotIdx}) | |||
| v.dst.Remove(node) | |||
| node.InputValues().Set(slotIdx, nil) | |||
| } | |||
| func (v *Var) StreamTo(to Node, slotIdx int) { | |||
| v.To().Add(EndPoint{Node: to, SlotIndex: slotIdx}) | |||
| v.To().Add(to) | |||
| to.InputStreams().Set(slotIdx, v) | |||
| } | |||
| func (v *Var) StreamNotTo(node Node, slotIdx int) { | |||
| v.to.Remove(EndPoint{Node: node, SlotIndex: slotIdx}) | |||
| v.dst.Remove(node) | |||
| node.InputStreams().Set(slotIdx, nil) | |||
| } | |||
| func (v *Var) NoInputAllValue() { | |||
| for _, ed := range v.to { | |||
| ed.Node.InputValues().Set(ed.SlotIndex, nil) | |||
| for _, n := range v.dst { | |||
| n.InputValues().ClearInput(v) | |||
| } | |||
| v.to = nil | |||
| v.dst = nil | |||
| } | |||
| func (v *Var) NoInputAllStream() { | |||
| for _, ed := range v.to { | |||
| ed.Node.InputStreams().Set(ed.SlotIndex, nil) | |||
| for _, n := range v.dst { | |||
| n.InputStreams().ClearInput(v) | |||
| } | |||
| v.to = nil | |||
| v.dst = nil | |||
| } | |||
| @@ -31,38 +31,41 @@ func generateSend(graph *ops.GraphNodeBuilder) { | |||
| for i := 0; i < node.OutputStreams().Len(); i++ { | |||
| out := node.OutputStreams().Get(i) | |||
| to := out.To().Get(0) | |||
| if to.Node.Env().Equals(node.Env()) { | |||
| if to.Env().Equals(node.Env()) { | |||
| continue | |||
| } | |||
| switch to.Node.Env().Type { | |||
| switch to.Env().Type { | |||
| case dag.EnvDriver: | |||
| // // 如果是要送到Driver,则只能由Driver主动去拉取 | |||
| dstNode := out.To().Get(0) | |||
| getNode := graph.NewGetStream(node.Env().Worker) | |||
| getNode.Env().ToEnvDriver() | |||
| // // 同时需要对此变量生成HoldUntil指令,避免Plan结束时Get指令还未到达 | |||
| holdType := graph.NewHoldUntil() //dag.NewNode(graph, &ops.HoldUntilNode{}, nil) | |||
| *holdType.Env() = *node.Env() | |||
| holdNode := graph.NewHoldUntil() | |||
| *holdNode.Env() = *node.Env() | |||
| // 将Get指令的信号送到Hold指令 | |||
| holdType.SetSignal(getNode.SignalVar()) | |||
| holdNode.SetSignal(getNode.SignalVar()) | |||
| out.To().RemoveAt(0) | |||
| // 将源节点的输出送到Hold指令,将Hold指令的输出送到Get指令 | |||
| getNode.Get(holdType.HoldStream(out)). | |||
| getNode.Get(holdNode.HoldStream(out)). | |||
| // 将Get指令的输出送到目的地 | |||
| StreamTo(to.Node, to.SlotIndex) | |||
| StreamTo(to, dstNode.InputStreams().IndexOf(out)) | |||
| case dag.EnvWorker: | |||
| // 如果是要送到Agent,则可以直接发送 | |||
| n := graph.NewSendStream(to.Node.Env().Worker) | |||
| dstNode := out.To().Get(0) | |||
| n := graph.NewSendStream(to.Env().Worker) | |||
| *n.Env() = *node.Env() | |||
| out.To().RemoveAt(0) | |||
| n.Send(out).StreamTo(to.Node, to.SlotIndex) | |||
| n.Send(out).StreamTo(to, dstNode.InputStreams().IndexOf(out)) | |||
| } | |||
| } | |||
| @@ -74,13 +77,14 @@ func generateSend(graph *ops.GraphNodeBuilder) { | |||
| } | |||
| to := out.To().Get(0) | |||
| if to.Node.Env().Equals(node.Env()) { | |||
| if to.Env().Equals(node.Env()) { | |||
| continue | |||
| } | |||
| switch to.Node.Env().Type { | |||
| switch to.Env().Type { | |||
| case dag.EnvDriver: | |||
| // // 如果是要送到Driver,则只能由Driver主动去拉取 | |||
| dstNode := out.To().Get(0) | |||
| getNode := graph.NewGetValue(node.Env().Worker) | |||
| getNode.Env().ToEnvDriver() | |||
| @@ -96,16 +100,17 @@ func generateSend(graph *ops.GraphNodeBuilder) { | |||
| // 将源节点的输出送到Hold指令,将Hold指令的输出送到Get指令 | |||
| getNode.Get(holdNode.HoldVar(out)). | |||
| // 将Get指令的输出送到目的地 | |||
| ValueTo(to.Node, to.SlotIndex) | |||
| ValueTo(to, dstNode.InputValues().IndexOf(out)) | |||
| case dag.EnvWorker: | |||
| // 如果是要送到Agent,则可以直接发送 | |||
| t := graph.NewSendValue(to.Node.Env().Worker) | |||
| dstNode := out.To().Get(0) | |||
| t := graph.NewSendValue(to.Env().Worker) | |||
| *t.Env() = *node.Env() | |||
| out.To().RemoveAt(0) | |||
| t.Send(out).ValueTo(to.Node, to.SlotIndex) | |||
| t.Send(out).ValueTo(to, dstNode.InputValues().IndexOf(out)) | |||
| } | |||
| } | |||