diff --git a/pkgs/ioswitch/dag/node.go b/pkgs/ioswitch/dag/node.go index 6298b46..6f24072 100644 --- a/pkgs/ioswitch/dag/node.go +++ b/pkgs/ioswitch/dag/node.go @@ -23,6 +23,7 @@ const ( type NodeEnv struct { Type NodeEnvType Worker exec.WorkerInfo + Pinned bool // 如果为true,则不应该改变这个节点的执行环境 } func (e *NodeEnv) ToEnvUnknown() { diff --git a/pkgs/ioswitch/dag/var.go b/pkgs/ioswitch/dag/var.go index 48d890b..f10ea82 100644 --- a/pkgs/ioswitch/dag/var.go +++ b/pkgs/ioswitch/dag/var.go @@ -84,7 +84,8 @@ func NodeDeclareInputStream(node *Node, cnt int) { type ValueVarType int const ( - StringValueVar ValueVarType = iota + UnknownValueVar ValueVarType = iota + StringValueVar SignalValueVar ) @@ -103,9 +104,10 @@ func (v *ValueVar) To(to *Node, slotIdx int) int { return len(v.Toes) - 1 } -func NodeNewOutputValue(node *Node, props any) *ValueVar { +func NodeNewOutputValue(node *Node, typ ValueVarType, props any) *ValueVar { val := &ValueVar{ ID: node.Graph.genVarID(), + Type: typ, From: EndPoint{Node: node, SlotIndex: len(node.OutputStreams)}, Props: props, } diff --git a/pkgs/ioswitch/exec/exec.go b/pkgs/ioswitch/exec/exec.go index ec16299..873ad59 100644 --- a/pkgs/ioswitch/exec/exec.go +++ b/pkgs/ioswitch/exec/exec.go @@ -19,6 +19,7 @@ var opUnion = serder.UseTypeUnionExternallyTagged(types.Ref(types.NewTypeUnion[O type Op interface { Execute(ctx context.Context, e *Executor) error + String() string } func UseOp[T Op]() { diff --git a/pkgs/ioswitch/exec/plan_builder.go b/pkgs/ioswitch/exec/plan_builder.go index 24b0525..0e1e44a 100644 --- a/pkgs/ioswitch/exec/plan_builder.go +++ b/pkgs/ioswitch/exec/plan_builder.go @@ -2,6 +2,7 @@ package exec import ( "context" + "strings" "gitlink.org.cn/cloudream/common/pkgs/future" "gitlink.org.cn/cloudream/common/utils/lo2" @@ -97,6 +98,29 @@ func (b *PlanBuilder) Execute() *Driver { return &exec } +func (b *PlanBuilder) String() string { + sb := strings.Builder{} + sb.WriteString("Driver:\n") + for _, op := range b.DriverPlan.Ops { + sb.WriteString(op.String()) + sb.WriteRune('\n') + } + sb.WriteRune('\n') + + for _, w := range b.WorkerPlans { + sb.WriteString("Worker(") + sb.WriteString(w.Worker.String()) + sb.WriteString("):\n") + for _, op := range w.Ops { + sb.WriteString(op.String()) + sb.WriteRune('\n') + } + sb.WriteRune('\n') + } + + return sb.String() +} + type WorkerPlanBuilder struct { Worker WorkerInfo Ops []Op diff --git a/pkgs/ioswitch/plan/generate.go b/pkgs/ioswitch/plan/generate.go index 6a7362b..2e94477 100644 --- a/pkgs/ioswitch/plan/generate.go +++ b/pkgs/ioswitch/plan/generate.go @@ -1,10 +1,11 @@ package plan import ( + "fmt" + "gitlink.org.cn/cloudream/common/pkgs/ioswitch/dag" "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" "gitlink.org.cn/cloudream/common/pkgs/ioswitch/plan/ops" - "gitlink.org.cn/cloudream/storage/common/pkgs/ioswitch" ) func Generate(graph *dag.Graph, planBld *exec.PlanBuilder) error { @@ -15,6 +16,19 @@ func Generate(graph *dag.Graph, planBld *exec.PlanBuilder) error { // 生成Send指令 func generateSend(graph *dag.Graph) { graph.Walk(func(node *dag.Node) bool { + switch node.Type.(type) { + case *ops.SendStreamType: + return true + case *ops.SendVarType: + return true + case *ops.GetStreamType: + return true + case *ops.GetVaType: + return true + case *ops.HoldUntilType: + return true + } + for _, out := range node.OutputStreams { to := out.Toes[0] if to.Node.Env.Equals(node.Env) { @@ -24,30 +38,34 @@ func generateSend(graph *dag.Graph) { switch to.Node.Env.Type { case dag.EnvDriver: // // 如果是要送到Driver,则只能由Driver主动去拉取 - getNode := graph.NewNode(&ops.GetStreamType{}, ioswitch.NodeProps{}) + getNode, getType := dag.NewNode(graph, &ops.GetStreamType{ + FromWorker: node.Env.Worker, + }, nil) getNode.Env.ToEnvDriver() // // 同时需要对此变量生成HoldUntil指令,避免Plan结束时Get指令还未到达 - holdNode := graph.NewNode(&ops.HoldUntilType{}, ioswitch.NodeProps{}) + holdNode, holdType := dag.NewNode(graph, &ops.HoldUntilType{}, nil) holdNode.Env = node.Env // 将Get指令的信号送到Hold指令 - getNode.OutputValues[0].To(holdNode, 0) - // 将Get指令的输出送到目的地 - getNode.OutputStreams[0].To(to.Node, to.SlotIndex) + holdType.Signal(holdNode, getType.SignalVar(getNode)) + out.Toes = nil - // 将源节点的输出送到Hold指令 - out.To(holdNode, 0) - // 将Hold指令的输出送到Get指令 - holdNode.OutputStreams[0].To(getNode, 0) + + // 将源节点的输出送到Hold指令,将Hold指令的输出送到Get指令 + getType.Get(getNode, holdType.HoldStream(holdNode, out)). + // 将Get指令的输出送到目的地 + To(to.Node, to.SlotIndex) case dag.EnvWorker: // 如果是要送到Agent,则可以直接发送 - n := graph.NewNode(&ops.SendStreamType{}, ioswitch.NodeProps{}) + n, t := dag.NewNode(graph, &ops.SendStreamType{ + ToWorker: to.Node.Env.Worker, + }, nil) n.Env = node.Env - n.OutputStreams[0].To(to.Node, to.SlotIndex) + out.Toes = nil - out.To(n, 0) + t.Send(n, out).To(to.Node, to.SlotIndex) } } @@ -60,30 +78,34 @@ func generateSend(graph *dag.Graph) { switch to.Node.Env.Type { case dag.EnvDriver: // // 如果是要送到Driver,则只能由Driver主动去拉取 - getNode := graph.NewNode(&ops.GetVaType{}, ioswitch.NodeProps{}) + getNode, getType := dag.NewNode(graph, &ops.GetVaType{ + FromWorker: node.Env.Worker, + }, nil) getNode.Env.ToEnvDriver() // // 同时需要对此变量生成HoldUntil指令,避免Plan结束时Get指令还未到达 - holdNode := graph.NewNode(&ops.HoldUntilType{}, ioswitch.NodeProps{}) + holdNode, holdType := dag.NewNode(graph, &ops.HoldUntilType{}, nil) holdNode.Env = node.Env // 将Get指令的信号送到Hold指令 - getNode.OutputValues[0].To(holdNode, 0) - // 将Get指令的输出送到目的地 - getNode.OutputValues[1].To(to.Node, to.SlotIndex) + holdType.Signal(holdNode, getType.SignalVar(getNode)) + out.Toes = nil - // 将源节点的输出送到Hold指令 - out.To(holdNode, 0) - // 将Hold指令的输出送到Get指令 - holdNode.OutputValues[0].To(getNode, 0) + + // 将源节点的输出送到Hold指令,将Hold指令的输出送到Get指令 + getType.Get(getNode, holdType.HoldVar(holdNode, out)). + // 将Get指令的输出送到目的地 + To(to.Node, to.SlotIndex) case dag.EnvWorker: // 如果是要送到Agent,则可以直接发送 - n := graph.NewNode(&ops.SendVarType{}, ioswitch.NodeProps{}) + n, t := dag.NewNode(graph, &ops.SendVarType{ + ToWorker: to.Node.Env.Worker, + }, nil) n.Env = node.Env - n.OutputValues[0].To(to.Node, to.SlotIndex) + out.Toes = nil - out.To(n, 0) + t.Send(n, out).To(to.Node, to.SlotIndex) } } @@ -121,6 +143,9 @@ func buildPlan(graph *dag.Graph, blder *exec.PlanBuilder) error { out.Var = blder.NewStringVar() case dag.SignalValueVar: out.Var = blder.NewSignalVar() + default: + retErr = fmt.Errorf("unsupported value var type: %v", out.Type) + return false } } @@ -134,6 +159,9 @@ func buildPlan(graph *dag.Graph, blder *exec.PlanBuilder) error { in.Var = blder.NewStringVar() case dag.SignalValueVar: in.Var = blder.NewSignalVar() + default: + retErr = fmt.Errorf("unsupported value var type: %v", in.Type) + return false } } @@ -143,6 +171,11 @@ func buildPlan(graph *dag.Graph, blder *exec.PlanBuilder) error { return false } + // TODO 当前ToDriver,FromDriver不会生成Op,所以这里需要判断一下 + if op == nil { + return true + } + switch node.Env.Type { case dag.EnvDriver: blder.AtDriver().AddOp(op) diff --git a/pkgs/ioswitch/plan/ops/driver.go b/pkgs/ioswitch/plan/ops/driver.go index 214f333..7de285e 100644 --- a/pkgs/ioswitch/plan/ops/driver.go +++ b/pkgs/ioswitch/plan/ops/driver.go @@ -5,6 +5,7 @@ import ( "gitlink.org.cn/cloudream/common/pkgs/ioswitch/dag" "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" + "gitlink.org.cn/cloudream/storage/common/pkgs/ioswitch2" ) type FromDriverType struct { @@ -12,7 +13,7 @@ type FromDriverType struct { } func (t *FromDriverType) InitNode(node *dag.Node) { - dag.NodeNewOutputStream(node, nil) + dag.NodeNewOutputStream(node, &ioswitch2.VarProps{}) } func (t *FromDriverType) GenerateOp(op *dag.Node) (exec.Op, error) { diff --git a/pkgs/ioswitch/plan/ops/drop.go b/pkgs/ioswitch/plan/ops/drop.go index 22f2d5d..b1cbf2e 100644 --- a/pkgs/ioswitch/plan/ops/drop.go +++ b/pkgs/ioswitch/plan/ops/drop.go @@ -22,6 +22,7 @@ func (o *DropStream) Execute(ctx context.Context, e *exec.Executor) error { if err != nil { return err } + defer o.Input.Stream.Close() for { buf := make([]byte, 1024*8) @@ -35,6 +36,10 @@ func (o *DropStream) Execute(ctx context.Context, e *exec.Executor) error { } } +func (o *DropStream) String() string { + return fmt.Sprintf("DropStream %v", o.Input.ID) +} + type DropType struct{} func (t *DropType) InitNode(node *dag.Node) { diff --git a/pkgs/ioswitch/plan/ops/send.go b/pkgs/ioswitch/plan/ops/send.go index 148938a..08f44fa 100644 --- a/pkgs/ioswitch/plan/ops/send.go +++ b/pkgs/ioswitch/plan/ops/send.go @@ -49,6 +49,10 @@ func (o *SendStream) Execute(ctx context.Context, e *exec.Executor) error { return nil } +func (o *SendStream) String() string { + return fmt.Sprintf("SendStream %v->%v@%v", o.Input.ID, o.Send.ID, o.Worker) +} + type GetStream struct { Signal *exec.SignalVar `json:"signal"` Target *exec.StreamVar `json:"target"` @@ -80,6 +84,10 @@ func (o *GetStream) Execute(ctx context.Context, e *exec.Executor) error { return fut.Wait(ctx) } +func (o *GetStream) String() string { + return fmt.Sprintf("GetStream %v(S:%v)<-%v@%v", o.Output.ID, o.Signal.ID, o.Target.ID, o.Worker) +} + type SendVar struct { Input exec.Var `json:"input"` Send exec.Var `json:"send"` @@ -109,6 +117,10 @@ func (o *SendVar) Execute(ctx context.Context, e *exec.Executor) error { return nil } +func (o *SendVar) String() string { + return fmt.Sprintf("SendVar %v->%v@%v", o.Input.GetID(), o.Send.GetID(), o.Worker) +} + type GetVar struct { Signal *exec.SignalVar `json:"signal"` Target exec.Var `json:"target"` @@ -135,10 +147,19 @@ func (o *GetVar) Execute(ctx context.Context, e *exec.Executor) error { return nil } +func (o *GetVar) String() string { + return fmt.Sprintf("GetVar %v(S:%v)<-%v@%v", o.Output.GetID(), o.Signal.ID, o.Target.GetID(), o.Worker) +} + type SendStreamType struct { ToWorker exec.WorkerInfo } +func (t *SendStreamType) Send(n *dag.Node, v *dag.StreamVar) *dag.StreamVar { + v.To(n, 0) + return n.OutputStreams[0] +} + func (t *SendStreamType) InitNode(node *dag.Node) { dag.NodeDeclareInputStream(node, 1) dag.NodeNewOutputStream(node, nil) @@ -160,9 +181,15 @@ type SendVarType struct { ToWorker exec.WorkerInfo } +func (t *SendVarType) Send(n *dag.Node, v *dag.ValueVar) *dag.ValueVar { + v.To(n, 0) + n.OutputValues[0].Type = v.Type + return n.OutputValues[0] +} + func (t *SendVarType) InitNode(node *dag.Node) { dag.NodeDeclareInputValue(node, 1) - dag.NodeNewOutputValue(node, nil) + dag.NodeNewOutputValue(node, 0, nil) } func (t *SendVarType) GenerateOp(op *dag.Node) (exec.Op, error) { @@ -181,9 +208,18 @@ type GetStreamType struct { FromWorker exec.WorkerInfo } +func (t *GetStreamType) Get(n *dag.Node, v *dag.StreamVar) *dag.StreamVar { + v.To(n, 0) + return n.OutputStreams[0] +} + +func (t *GetStreamType) SignalVar(n *dag.Node) *dag.ValueVar { + return n.OutputValues[0] +} + func (t *GetStreamType) InitNode(node *dag.Node) { dag.NodeDeclareInputStream(node, 1) - dag.NodeNewOutputValue(node, nil) + dag.NodeNewOutputValue(node, dag.SignalValueVar, nil) dag.NodeNewOutputStream(node, nil) } @@ -204,10 +240,20 @@ type GetVaType struct { FromWorker exec.WorkerInfo } +func (t *GetVaType) Get(n *dag.Node, v *dag.ValueVar) *dag.ValueVar { + v.To(n, 0) + n.OutputValues[1].Type = v.Type + return n.OutputValues[1] +} + +func (t *GetVaType) SignalVar(n *dag.Node) *dag.ValueVar { + return n.OutputValues[0] +} + func (t *GetVaType) InitNode(node *dag.Node) { dag.NodeDeclareInputValue(node, 1) - dag.NodeNewOutputValue(node, nil) - dag.NodeNewOutputValue(node, nil) + dag.NodeNewOutputValue(node, dag.SignalValueVar, nil) + dag.NodeNewOutputValue(node, 0, nil) } func (t *GetVaType) GenerateOp(op *dag.Node) (exec.Op, error) { diff --git a/pkgs/ioswitch/plan/ops/store.go b/pkgs/ioswitch/plan/ops/store.go index 08886bd..ad48ab5 100644 --- a/pkgs/ioswitch/plan/ops/store.go +++ b/pkgs/ioswitch/plan/ops/store.go @@ -29,10 +29,18 @@ func (o *Store) Execute(ctx context.Context, e *exec.Executor) error { return nil } +func (o *Store) String() string { + return fmt.Sprintf("Store %v: %v", o.Key, o.Var.GetID()) +} + type StoreType struct { StoreKey string } +func (t *StoreType) Store(node *dag.Node, v *dag.ValueVar) { + v.To(node, 0) +} + func (t *StoreType) InitNode(node *dag.Node) { dag.NodeDeclareInputValue(node, 1) } diff --git a/pkgs/ioswitch/plan/ops/sync.go b/pkgs/ioswitch/plan/ops/sync.go index 26a98ab..4689e70 100644 --- a/pkgs/ioswitch/plan/ops/sync.go +++ b/pkgs/ioswitch/plan/ops/sync.go @@ -8,6 +8,7 @@ import ( "gitlink.org.cn/cloudream/common/pkgs/future" "gitlink.org.cn/cloudream/common/pkgs/ioswitch/dag" "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" + "gitlink.org.cn/cloudream/common/pkgs/ioswitch/utils" ) func init() { @@ -36,6 +37,10 @@ func (o *OnStreamBegin) Execute(ctx context.Context, e *exec.Executor) error { return nil } +func (o *OnStreamBegin) String() string { + return fmt.Sprintf("OnStreamBegin %v->%v S:%v", o.Raw.ID, o.New.ID, o.Signal.ID) +} + type OnStreamEnd struct { Raw *exec.StreamVar `json:"raw"` New *exec.StreamVar `json:"new"` @@ -85,6 +90,10 @@ func (o *OnStreamEnd) Execute(ctx context.Context, e *exec.Executor) error { return nil } +func (o *OnStreamEnd) String() string { + return fmt.Sprintf("OnStreamEnd %v->%v S:%v", o.Raw.ID, o.New.ID, o.Signal.ID) +} + type HoldUntil struct { Waits []*exec.SignalVar `json:"waits"` Holds []exec.Var `json:"holds"` @@ -113,6 +122,10 @@ func (w *HoldUntil) Execute(ctx context.Context, e *exec.Executor) error { return nil } +func (w *HoldUntil) String() string { + return fmt.Sprintf("HoldUntil Waits: %v, (%v) -> (%v)", utils.FormatVarIDs(w.Waits), utils.FormatVarIDs(w.Holds), utils.FormatVarIDs(w.Emits)) +} + type HangUntil struct { Waits []*exec.SignalVar `json:"waits"` Op exec.Op `json:"op"` @@ -127,6 +140,10 @@ func (h *HangUntil) Execute(ctx context.Context, e *exec.Executor) error { return h.Op.Execute(ctx, e) } +func (h *HangUntil) String() string { + return "HangUntil" +} + type Broadcast struct { Source *exec.SignalVar `json:"source"` Targets []*exec.SignalVar `json:"targets"` @@ -142,6 +159,10 @@ func (b *Broadcast) Execute(ctx context.Context, e *exec.Executor) error { return nil } +func (b *Broadcast) String() string { + return "Broadcast" +} + type HoldUntilType struct { } @@ -167,6 +188,24 @@ func (t *HoldUntilType) GenerateOp(op *dag.Node) (exec.Op, error) { return o, nil } +func (t *HoldUntilType) Signal(n *dag.Node, s *dag.ValueVar) { + s.To(n, 0) +} + +func (t *HoldUntilType) HoldStream(n *dag.Node, str *dag.StreamVar) *dag.StreamVar { + n.InputStreams = append(n.InputStreams, nil) + str.To(n, len(n.InputStreams)-1) + + return dag.NodeNewOutputStream(n, nil) +} + +func (t *HoldUntilType) HoldVar(n *dag.Node, v *dag.ValueVar) *dag.ValueVar { + n.InputValues = append(n.InputValues, nil) + v.To(n, len(n.InputValues)-1) + + return dag.NodeNewOutputValue(n, v.Type, nil) +} + func (t *HoldUntilType) String(node *dag.Node) string { return fmt.Sprintf("HoldUntil[]%v%v", formatStreamIO(node), formatValueIO(node)) } diff --git a/pkgs/ioswitch/plan/ops/var.go b/pkgs/ioswitch/plan/ops/var.go index d2a7fe6..cb6aec0 100644 --- a/pkgs/ioswitch/plan/ops/var.go +++ b/pkgs/ioswitch/plan/ops/var.go @@ -18,3 +18,7 @@ func (o *ConstVar) Execute(ctx context.Context, e *exec.Executor) error { e.PutVars(o.Var) return nil } + +func (o *ConstVar) String() string { + return "ConstVar" +} diff --git a/pkgs/ioswitch/utils/utils.go b/pkgs/ioswitch/utils/utils.go new file mode 100644 index 0000000..2db1ff6 --- /dev/null +++ b/pkgs/ioswitch/utils/utils.go @@ -0,0 +1,19 @@ +package utils + +import ( + "fmt" + "strings" + + "gitlink.org.cn/cloudream/common/pkgs/ioswitch/exec" +) + +func FormatVarIDs[T exec.Var](arr []T) string { + sb := strings.Builder{} + for i, v := range arr { + sb.WriteString(fmt.Sprintf("%v", v.GetID())) + if i < len(arr)-1 { + sb.WriteString(",") + } + } + return sb.String() +} diff --git a/sdks/storage/models.go b/sdks/storage/models.go index 9a79cd6..7645a59 100644 --- a/sdks/storage/models.go +++ b/sdks/storage/models.go @@ -37,6 +37,7 @@ var RedundancyUnion = serder.UseTypeUnionInternallyTagged(types.Ref(types.NewTyp (*NoneRedundancy)(nil), (*RepRedundancy)(nil), (*ECRedundancy)(nil), + (*LRCRedundancy)(nil), )), "type") type NoneRedundancy struct { @@ -93,6 +94,67 @@ func (b *ECRedundancy) Value() (driver.Value, error) { return serder.ObjectToJSONEx[Redundancy](b) } +var DefaultLRCRedundancy = *NewLRCRedundancy(2, 4, []int{2}, 1024*1024*5) + +type LRCRedundancy struct { + serder.Metadata `union:"lrc"` + Type string `json:"type"` + K int `json:"k"` + N int `json:"n"` + Groups []int `json:"groups"` + ChunkSize int `json:"chunkSize"` +} + +func NewLRCRedundancy(k int, n int, groups []int, chunkSize int) *LRCRedundancy { + return &LRCRedundancy{ + Type: "lrc", + K: k, + N: n, + Groups: groups, + ChunkSize: chunkSize, + } +} +func (b *LRCRedundancy) Value() (driver.Value, error) { + return serder.ObjectToJSONEx[Redundancy](b) +} + +// 判断指定块属于哪个组。如果都不属于,则返回-1。 +func (b *LRCRedundancy) FindGroup(idx int) int { + if idx >= b.N-len(b.Groups) { + return idx - (b.N - len(b.Groups)) + } + + for i, group := range b.Groups { + if idx < group { + return i + } + idx -= group + } + + return -1 +} + +// M = N - len(Groups),即数据块+校验块的总数,不包括组校验块。 +func (b *LRCRedundancy) M() int { + return b.N - len(b.Groups) +} + +func (b *LRCRedundancy) GetGroupElements(grp int) []int { + var idxes []int + + grpStart := 0 + for i := 0; i < grp; i++ { + grpStart += b.Groups[i] + } + + for i := 0; i < b.Groups[grp]; i++ { + idxes = append(idxes, grpStart+i) + } + + idxes = append(idxes, b.N-len(b.Groups)+grp) + return idxes +} + const ( PackageStateNormal = "Normal" PackageStateDeleted = "Deleted" diff --git a/utils/lo2/lo.go b/utils/lo2/lo.go index 69e6150..056f7e7 100644 --- a/utils/lo2/lo.go +++ b/utils/lo2/lo.go @@ -11,6 +11,12 @@ func Remove[T comparable](arr []T, item T) []T { return RemoveAt(arr, index) } +func RemoveAll[T comparable](arr []T, item T) []T { + return lo.Filter(arr, func(i T, idx int) bool { + return i != item + }) +} + func RemoveAt[T any](arr []T, index int) []T { if index >= len(arr) { return arr diff --git a/utils/serder/json/config.go b/utils/serder/json/config.go new file mode 100644 index 0000000..f1479fe --- /dev/null +++ b/utils/serder/json/config.go @@ -0,0 +1,72 @@ +package json + +import ( + "reflect" + + jsoniter "github.com/json-iterator/go" + "gitlink.org.cn/cloudream/common/pkgs/types" +) + +type Config struct { + unionHandler *UnionHandler + exts []jsoniter.Extension +} + +func New() *Config { + return &Config{ + unionHandler: &UnionHandler{ + internallyTagged: make(map[reflect.Type]*anyTypeUnionInternallyTagged), + externallyTagged: make(map[reflect.Type]*anyTypeUnionExternallyTagged), + }, + } +} + +func (c *Config) UseUnionInternallyTagged(u *types.AnyTypeUnion, tagField string) *Config { + iu := &anyTypeUnionInternallyTagged{ + Union: u, + TagField: tagField, + TagToType: make(map[string]reflect.Type), + } + + for _, eleType := range u.ElementTypes { + iu.Add(eleType) + } + + c.unionHandler.internallyTagged[u.UnionType] = iu + return c +} + +func (c *Config) UseUnionExternallyTagged(u *types.AnyTypeUnion) *Config { + eu := &anyTypeUnionExternallyTagged{ + Union: u, + TypeNameToType: make(map[string]reflect.Type), + } + + for _, eleType := range u.ElementTypes { + eu.Add(eleType) + } + + c.unionHandler.externallyTagged[u.UnionType] = eu + return c +} + +func (c *Config) UseExtension(ext jsoniter.Extension) *Config { + c.exts = append(c.exts, ext) + return c +} + +func (c *Config) Build() Serder { + cfg := jsoniter.Config{} + api := cfg.Froze() + + api.RegisterExtension(c.unionHandler) + + for _, ext := range c.exts { + api.RegisterExtension(ext) + } + + return Serder{ + cfg: *c, + api: api, + } +} diff --git a/utils/serder/json/json.go b/utils/serder/json/json.go new file mode 100644 index 0000000..efb8b1c --- /dev/null +++ b/utils/serder/json/json.go @@ -0,0 +1,29 @@ +package json + +import ( + "bytes" + + jsoniter "github.com/json-iterator/go" +) + +type Serder struct { + cfg Config + api jsoniter.API +} + +func (s *Serder) Encode(obj any) ([]byte, error) { + buf := new(bytes.Buffer) + + enc := s.api.NewEncoder(buf) + err := enc.Encode(obj) + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +func (s *Serder) Decode(data []byte, obj any) error { + dec := s.api.NewDecoder(bytes.NewReader(data)) + return dec.Decode(&obj) +} diff --git a/utils/serder/json/union_handler.go b/utils/serder/json/union_handler.go new file mode 100644 index 0000000..44bca21 --- /dev/null +++ b/utils/serder/json/union_handler.go @@ -0,0 +1,318 @@ +package json + +import ( + "fmt" + "reflect" + "unsafe" + + jsoniter "github.com/json-iterator/go" + "github.com/modern-go/reflect2" + "gitlink.org.cn/cloudream/common/pkgs/types" + stypes "gitlink.org.cn/cloudream/common/utils/serder/types" + + ref2 "gitlink.org.cn/cloudream/common/utils/reflect2" +) + +type anyTypeUnionExternallyTagged struct { + Union *types.AnyTypeUnion + TypeNameToType map[string]reflect.Type +} + +func (u *anyTypeUnionExternallyTagged) Add(typ reflect.Type) error { + err := u.Union.Add(typ) + if err != nil { + return nil + } + + u.TypeNameToType[makeDerefFullTypeName(typ)] = typ + return nil +} + +type TypeUnionExternallyTagged[T any] struct { + anyTypeUnionExternallyTagged + TUnion *types.TypeUnion[T] +} + +func (u *TypeUnionExternallyTagged[T]) AddT(nilValue T) error { + u.Add(reflect.TypeOf(nilValue)) + return nil +} + +type anyTypeUnionInternallyTagged struct { + Union *types.AnyTypeUnion + TagField string + TagToType map[string]reflect.Type +} + +func (u *anyTypeUnionInternallyTagged) Add(typ reflect.Type) error { + err := u.Union.Add(typ) + if err != nil { + return nil + } + + // 解引用直到得到结构体类型 + structType := typ + for structType.Kind() == reflect.Pointer { + structType = structType.Elem() + } + + // 要求内嵌Metadata结构体,那么结构体中的字段名就会是Metadata, + field, ok := structType.FieldByName(ref2.TypeNameOf[stypes.Metadata]()) + if !ok { + u.TagToType[makeDerefFullTypeName(structType)] = typ + return nil + } + + // 为防同名,检查类型是不是也是Metadata + if field.Type != ref2.TypeOf[stypes.Metadata]() { + u.TagToType[makeDerefFullTypeName(structType)] = typ + return nil + } + + tag := field.Tag.Get("union") + if tag == "" { + u.TagToType[makeDerefFullTypeName(structType)] = typ + return nil + } + + u.TagToType[tag] = typ + return nil +} + +type TypeUnionInternallyTagged[T any] struct { + anyTypeUnionInternallyTagged + TUnion *types.TypeUnion[T] +} + +func (u *TypeUnionInternallyTagged[T]) AddT(nilValue T) error { + u.Add(reflect.TypeOf(nilValue)) + return nil +} + +type UnionHandler struct { + internallyTagged map[reflect.Type]*anyTypeUnionInternallyTagged + externallyTagged map[reflect.Type]*anyTypeUnionExternallyTagged +} + +func (h *UnionHandler) UpdateStructDescriptor(structDescriptor *jsoniter.StructDescriptor) { + +} + +func (h *UnionHandler) CreateMapKeyDecoder(typ reflect2.Type) jsoniter.ValDecoder { + return nil +} + +func (h *UnionHandler) CreateMapKeyEncoder(typ reflect2.Type) jsoniter.ValEncoder { + return nil +} + +func (h *UnionHandler) CreateDecoder(typ reflect2.Type) jsoniter.ValDecoder { + typ1 := typ.Type1() + if it, ok := h.internallyTagged[typ1]; ok { + return &InternallyTaggedDecoder{ + union: it, + } + } + + if et, ok := h.externallyTagged[typ1]; ok { + return &ExternallyTaggedDecoder{ + union: et, + } + } + + return nil +} + +func (h *UnionHandler) CreateEncoder(typ reflect2.Type) jsoniter.ValEncoder { + typ1 := typ.Type1() + if it, ok := h.internallyTagged[typ1]; ok { + return &InternallyTaggedEncoder{ + union: it, + } + } + + if et, ok := h.externallyTagged[typ1]; ok { + return &ExternallyTaggedEncoder{ + union: et, + } + } + return nil +} + +func (h *UnionHandler) DecorateDecoder(typ reflect2.Type, decoder jsoniter.ValDecoder) jsoniter.ValDecoder { + return decoder +} + +func (h *UnionHandler) DecorateEncoder(typ reflect2.Type, encoder jsoniter.ValEncoder) jsoniter.ValEncoder { + return encoder +} + +// 以下Encoder/Decoder都是在传入类型/目标类型是TypeUnion的基类(UnionType)时使用 +type InternallyTaggedEncoder struct { + union *anyTypeUnionInternallyTagged +} + +func (e *InternallyTaggedEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + +func (e *InternallyTaggedEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { + var val any + + if e.union.Union.UnionType.NumMethod() == 0 { + // 无方法的interface底层都是eface结构体,所以可以直接转*any + val = *(*any)(ptr) + } else { + // 有方法的interface底层都是iface结构体,可以将其转成eface,转换后不损失类型信息 + val = reflect2.IFaceToEFace(ptr) + } + + // 可以考虑检查一下Type字段有没有赋值,没有赋值则将其赋值为union Tag指定的值 + stream.WriteVal(val) +} + +type InternallyTaggedDecoder struct { + union *anyTypeUnionInternallyTagged +} + +func (e *InternallyTaggedDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { + nextTokenKind := iter.WhatIsNext() + if nextTokenKind == jsoniter.NilValue { + iter.Skip() + return + } + + raw := iter.ReadAny() + if raw.LastError() != nil { + iter.ReportError("decode TaggedUnionType", "getting object raw:"+raw.LastError().Error()) + return + } + + tagField := raw.Get(e.union.TagField) + if tagField.LastError() != nil { + iter.ReportError("decode TaggedUnionType", "getting type tag field:"+tagField.LastError().Error()) + return + } + + typeTag := tagField.ToString() + if typeTag == "" { + iter.ReportError("decode TaggedUnionType", "type tag is empty") + return + } + + typ, ok := e.union.TagToType[typeTag] + if !ok { + iter.ReportError("decode TaggedUnionType", fmt.Sprintf("unknow type tag %s in union %s", typeTag, e.union.Union.UnionType.Name())) + return + } + + // 如果目标类型已经是个指针类型*T,那么在New的时候就需要使用T, + // 否则New出来的是会是**T,这将导致后续的反序列化出问题 + if typ.Kind() == reflect.Pointer { + val := reflect.New(typ.Elem()) + raw.ToVal(val.Interface()) + + retVal := reflect.NewAt(e.union.Union.UnionType, ptr) + retVal.Elem().Set(val) + + } else { + val := reflect.New(typ) + raw.ToVal(val.Interface()) + + retVal := reflect.NewAt(e.union.Union.UnionType, ptr) + retVal.Elem().Set(val.Elem()) + } +} + +type ExternallyTaggedEncoder struct { + union *anyTypeUnionExternallyTagged +} + +func (e *ExternallyTaggedEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + +func (e *ExternallyTaggedEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { + var val any + + if e.union.Union.UnionType.NumMethod() == 0 { + // 无方法的interface底层都是eface结构体,所以可以直接转*any + val = *(*any)(ptr) + } else { + // 有方法的interface底层都是iface结构体,可以将其转成eface,转换后不损失类型信息 + val = reflect2.IFaceToEFace(ptr) + } + + if val == nil { + stream.WriteNil() + return + } + + stream.WriteObjectStart() + valType := ref2.TypeOfValue(val) + if !e.union.Union.Include(valType) { + stream.Error = fmt.Errorf("type %v is not in union %v", valType, e.union.Union.UnionType) + return + } + stream.WriteObjectField(makeDerefFullTypeName(valType)) + stream.WriteVal(val) + stream.WriteObjectEnd() +} + +type ExternallyTaggedDecoder struct { + union *anyTypeUnionExternallyTagged +} + +func (e *ExternallyTaggedDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { + nextTkType := iter.WhatIsNext() + + if nextTkType == jsoniter.NilValue { + iter.Skip() + return + } + + if nextTkType != jsoniter.ObjectValue { + iter.ReportError("decode UnionType", fmt.Sprintf("unknow next token type %v", nextTkType)) + return + } + + typeStr := iter.ReadObject() + if typeStr == "" { + iter.ReportError("decode UnionType", "type string is empty") + } + + typ, ok := e.union.TypeNameToType[typeStr] + if !ok { + iter.ReportError("decode UnionType", fmt.Sprintf("unknow type string %s in union %v", typeStr, e.union.Union.UnionType)) + return + } + + // 如果目标类型已经是个指针类型*T,那么在New的时候就需要使用T, + // 否则New出来的是会是**T,这将导致后续的反序列化出问题 + if typ.Kind() == reflect.Pointer { + val := reflect.New(typ.Elem()) + iter.ReadVal(val.Interface()) + + retVal := reflect.NewAt(e.union.Union.UnionType, ptr) + retVal.Elem().Set(val) + + } else { + val := reflect.New(typ) + iter.ReadVal(val.Interface()) + + retVal := reflect.NewAt(e.union.Union.UnionType, ptr) + retVal.Elem().Set(val.Elem()) + } + + if iter.ReadObject() != "" { + iter.ReportError("decode UnionType", "there should be only one fields in the json object") + } +} + +func makeDerefFullTypeName(typ reflect.Type) string { + realType := typ + for realType.Kind() == reflect.Pointer { + realType = realType.Elem() + } + return fmt.Sprintf("%s.%s", realType.PkgPath(), realType.Name()) +} diff --git a/utils/serder/types.go b/utils/serder/types.go index 6c93b71..c411250 100644 --- a/utils/serder/types.go +++ b/utils/serder/types.go @@ -4,9 +4,11 @@ import ( "fmt" "strconv" "time" + + "gitlink.org.cn/cloudream/common/utils/serder/types" ) -type Metadata struct{} +type Metadata = types.Metadata type TimestampSecond time.Time diff --git a/utils/serder/types/types.go b/utils/serder/types/types.go new file mode 100644 index 0000000..0046926 --- /dev/null +++ b/utils/serder/types/types.go @@ -0,0 +1,3 @@ +package types + +type Metadata struct{}