diff --git a/pkgs/mq/message.go b/pkgs/mq/message.go index 8177982..36a7ad3 100644 --- a/pkgs/mq/message.go +++ b/pkgs/mq/message.go @@ -1,16 +1,9 @@ package mq import ( - "bytes" - "fmt" - "reflect" - "strings" - "unsafe" - - jsoniter "github.com/json-iterator/go" - "github.com/modern-go/reflect2" "gitlink.org.cn/cloudream/common/pkgs/types" myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/serder" ) const ( @@ -86,298 +79,20 @@ func MakeHeartbeatMessage() Message { return msg } -type TypeUnionWithTypeName struct { - Union types.TypeUnion - TypeNameToType map[string]myreflect.Type -} - -func (u *TypeUnionWithTypeName) Register(typ myreflect.Type) { - u.Union.ElementTypes = append(msgBodyTypeUnion.Union.ElementTypes, typ) - u.TypeNameToType[makeFullTypeName(typ)] = typ -} - -var msgBodyTypeUnion *TypeUnionWithTypeName +var msgBodyTypeUnion = serder.UseTypeUnionExternallyTagged(types.Ref(types.NewTypeUnion[MessageBody]())) // 所有新定义的Message都需要在init中调用此函数 func RegisterMessage[T MessageBody]() { - msgBodyTypeUnion.Register(myreflect.TypeOf[T]()) -} - -// 在序列化结构体中包含的UnionType类型字段时,会将字段值的实际类型保存在序列化后的结果中。 -// 在反序列化时,会根据类型信息重建原本的字段值。 -// 注:TypeUnion.UnionType必须是一个interface -func RegisterUnionType(union types.TypeUnion) *TypeUnionWithTypeName { - myUnion := &TypeUnionWithTypeName{ - Union: union, - TypeNameToType: make(map[string]reflect.Type), - } - - for _, typ := range union.ElementTypes { - myUnion.TypeNameToType[makeFullTypeName(typ)] = typ - } - - if union.UnionType.NumMethod() == 0 { - registerForEFace(myUnion) - } else { - registerForIFace(myUnion) - } - - return myUnion -} - -// 无方法的interface类型 -func registerForEFace(myUnion *TypeUnionWithTypeName) { - jsoniter.RegisterTypeEncoderFunc(myUnion.Union.UnionType.String(), - func(ptr unsafe.Pointer, stream *jsoniter.Stream) { - // 无方法的interface底层数据结构都是eface类型,所以可以直接转*any - val := *(*any)(ptr) - if val != nil { - stream.WriteArrayStart() - - valType := myreflect.TypeOfValue(val) - if !myUnion.Union.Include(valType) { - stream.Error = fmt.Errorf("type %v is not in union %v", valType, myUnion.Union.UnionType) - return - } - - stream.WriteString(makeFullTypeName(valType)) - stream.WriteRaw(",") - stream.WriteVal(val) - stream.WriteArrayEnd() - } else { - stream.WriteNil() - } - }, - func(p unsafe.Pointer) bool { - return false - }) - - jsoniter.RegisterTypeDecoderFunc(myUnion.Union.UnionType.String(), - func(ptr unsafe.Pointer, iter *jsoniter.Iterator) { - // 无方法的interface底层都是eface结构体,所以可以直接转*any - vp := (*any)(ptr) - - nextTkType := iter.WhatIsNext() - if nextTkType == jsoniter.NilValue { - iter.ReadNil() - *vp = nil - - } else if nextTkType == jsoniter.ArrayValue { - iter.ReadArray() - typeStr := iter.ReadString() - iter.ReadArray() - - typ, ok := myUnion.TypeNameToType[typeStr] - if !ok { - iter.ReportError("decode UnionType", fmt.Sprintf("unknow type string %s under %v", typeStr, myUnion.Union.UnionType)) - return - } - - // 如果目标类型已经是个指针类型*T,那么在New的时候就需要使用T, - // 否则New出来的是会是**T,这将导致后续的反序列化出问题 - if typ.Kind() == reflect.Pointer { - val := reflect.New(typ.Elem()) - iter.ReadVal(val.Interface()) - *vp = val.Interface() - - } else { - val := reflect.New(typ) - iter.ReadVal(val.Interface()) - *vp = val.Elem().Interface() - } - - iter.ReadArray() - } else { - iter.ReportError("decode UnionType", fmt.Sprintf("unknow next token type %v", nextTkType)) - return - } - }) -} - -// 有方法的interface类型 -func registerForIFace(myUnion *TypeUnionWithTypeName) { - jsoniter.RegisterTypeEncoderFunc(myUnion.Union.UnionType.String(), - func(ptr unsafe.Pointer, stream *jsoniter.Stream) { - // 有方法的interface底层都是iface结构体,可以将其转成eface,转换后不损失类型信息 - val := reflect2.IFaceToEFace(ptr) - if val != nil { - stream.WriteArrayStart() - - valType := myreflect.TypeOfValue(val) - if !myUnion.Union.Include(valType) { - stream.Error = fmt.Errorf("type %v is not in union %v", valType, myUnion.Union.UnionType) - return - } - - stream.WriteString(makeFullTypeName(valType)) - stream.WriteRaw(",") - stream.WriteVal(val) - stream.WriteArrayEnd() - } else { - stream.WriteNil() - } - }, - func(p unsafe.Pointer) bool { - return false - }) - - jsoniter.RegisterTypeDecoderFunc(myUnion.Union.UnionType.String(), - func(ptr unsafe.Pointer, iter *jsoniter.Iterator) { - - nextTkType := iter.WhatIsNext() - if nextTkType == jsoniter.NilValue { - iter.ReadNil() - - } else if nextTkType == jsoniter.ArrayValue { - iter.ReadArray() - typeStr := iter.ReadString() - iter.ReadArray() - - typ, ok := myUnion.TypeNameToType[typeStr] - if !ok { - iter.ReportError("decode UnionType", fmt.Sprintf("unknow type string %s under %v", typeStr, myUnion.Union.UnionType)) - return - } - - // 如果目标类型已经是个指针类型*T,那么在New的时候就需要使用T, - // 否则New出来的是会是**T,这将导致后续的反序列化出问题 - if typ.Kind() == reflect.Pointer { - val := reflect.New(typ.Elem()) - iter.ReadVal(val.Interface()) - - retVal := reflect.NewAt(myUnion.Union.UnionType, ptr) - retVal.Elem().Set(val) - - } else { - val := reflect.New(typ) - iter.ReadVal(val.Interface()) - - retVal := reflect.NewAt(myUnion.Union.UnionType, ptr) - retVal.Elem().Set(val.Elem()) - } - - iter.ReadArray() - } else { - iter.ReportError("decode UnionType", fmt.Sprintf("unknow next token type %v", nextTkType)) - return - } - }) -} - -func makeFullTypeName(typ myreflect.Type) string { - refs := 0 - - realType := typ - for realType.Kind() == reflect.Pointer { - refs++ - realType = realType.Elem() - } - - return fmt.Sprintf("%s%s.%s", strings.Repeat("*", refs), realType.PkgPath(), realType.Name()) -} - -/* -// 如果对一个类型T调用了此函数,那么在序列化结构体中包含的T类型字段时, -// 会将字段值的实际类型保存在序列化后的结果中 -// 在反序列化时,会根据类型信息重建原本的字段值。 -// -// 只会处理types指定的类型。 -func RegisterTypeSet[T any](types ...myreflect.Type) *serder.UnionTypeInfo { - eleTypes := serder.NewTypeNameResolver(true) - set := serder.UnionTypeInfo{ - UnionType: myreflect.TypeOf[T](), - ElementTypes: eleTypes, - } - - for _, t := range types { - eleTypes.Register(t) + err := msgBodyTypeUnion.Add(myreflect.TypeOf[T]()) + if err != nil { + panic(err) } - - TODO 暂时保留这一段代码,如果RegisterUnionType中的非泛型版本出了问题,则重新使用这一部分的代码 - unionTypes[set.UnionType] = set - - jsoniter.RegisterTypeEncoderFunc(myreflect.TypeOf[T]().String(), - func(ptr unsafe.Pointer, stream *jsoniter.Stream) { - val := *((*T)(ptr)) - var ifVal any = val - - if ifVal != nil { - stream.WriteArrayStart() - typeStr, err := set.ElementTypes.TypeToString(myreflect.TypeOfValue(val)) - if err != nil { - stream.Error = err - return - } - stream.WriteString(typeStr) - stream.WriteRaw(",") - stream.WriteVal(val) - stream.WriteArrayEnd() - } else { - stream.WriteNil() - } - }, - func(p unsafe.Pointer) bool { - return false - }) - - jsoniter.RegisterTypeDecoderFunc(myreflect.TypeOf[T]().String(), - func(ptr unsafe.Pointer, iter *jsoniter.Iterator) { - vp := (*T)(ptr) - - nextTkType := iter.WhatIsNext() - if nextTkType == jsoniter.NilValue { - iter.ReadNil() - var zero T - *vp = zero - } else if nextTkType == jsoniter.ArrayValue { - iter.ReadArray() - typeStr := iter.ReadString() - iter.ReadArray() - - typ, err := set.ElementTypes.StringToType(typeStr) - if err != nil { - iter.ReportError("get type from string", err.Error()) - return - } - - val := reflect.New(typ) - iter.ReadVal(val.Interface()) - *vp = val.Elem().Interface().(T) - - iter.ReadArray() - } else { - iter.ReportError("parse TypeSet field", fmt.Sprintf("unknow next token type %v", nextTkType)) - return - } - }) - RegisterUnionType(serder.NewTypeUnion[T]("", serder.NewTypeNameResolver(true))) - return &set } -*/ func Serialize(msg Message) ([]byte, error) { - buf := bytes.NewBuffer(nil) - enc := jsoniter.NewEncoder(buf) - err := enc.Encode(msg) - if err != nil { - return nil, err - } - - return buf.Bytes(), nil + return serder.ObjectToJSONEx(msg) } func Deserialize(data []byte) (*Message, error) { - dec := jsoniter.NewDecoder(bytes.NewBuffer(data)) - - var msg Message - err := dec.Decode(&msg) - if err != nil { - return nil, err - } - - return &msg, nil -} - -func init() { - msgBodyTypeUnion = RegisterUnionType(types.NewTypeUnion[MessageBody]()) + return serder.JSONToObjectEx[*Message](data) } diff --git a/pkgs/mq/message_test.go b/pkgs/mq/message_test.go index 631d9c5..b2482fa 100644 --- a/pkgs/mq/message_test.go +++ b/pkgs/mq/message_test.go @@ -3,7 +3,6 @@ package mq import ( "bytes" "fmt" - "reflect" "testing" "unsafe" @@ -11,6 +10,7 @@ import ( . "github.com/smartystreets/goconvey/convey" "gitlink.org.cn/cloudream/common/pkgs/types" myreflect "gitlink.org.cn/cloudream/common/utils/reflect" + "gitlink.org.cn/cloudream/common/utils/serder" ) func TestMessage(t *testing.T) { @@ -146,7 +146,7 @@ func TestMessage(t *testing.T) { Value MyTypeUnion } RegisterMessage[*Body]() - RegisterUnionType(types.NewTypeUnion[MyTypeUnion]((*EleType1)(nil))) + serder.UseTypeUnionExternallyTagged(types.Ref(types.NewTypeUnion[MyTypeUnion]((*EleType1)(nil)))) msg := MakeAppDataMessage(&Body{Value: &EleType1{ Value: 1, @@ -174,7 +174,7 @@ func TestMessage(t *testing.T) { Value MyTypeUnion } RegisterMessage[*Body]() - RegisterUnionType(types.NewTypeUnion[MyTypeUnion]((*EleType1)(nil))) + serder.UseTypeUnionExternallyTagged(types.Ref(types.NewTypeUnion[MyTypeUnion]((*EleType1)(nil)))) msg := MakeAppDataMessage(&Body{Value: &EleType1{ Value: 1, @@ -200,7 +200,7 @@ func TestMessage(t *testing.T) { Value MyTypeUnion } RegisterMessage[*Body]() - RegisterUnionType(types.NewTypeUnion[MyTypeUnion]()) + serder.UseTypeUnionExternallyTagged(types.Ref(types.NewTypeUnion[MyTypeUnion]())) msg := MakeAppDataMessage(&Body{Value: nil}) data, err := Serialize(msg) @@ -220,26 +220,10 @@ func TestMessage(t *testing.T) { Value MyTypeUnion } RegisterMessage[*Body]() - RegisterUnionType(types.NewTypeUnion[MyTypeUnion]()) + serder.UseTypeUnionExternallyTagged(types.Ref(types.NewTypeUnion[MyTypeUnion]())) msg := MakeAppDataMessage(&Body{Value: &struct{}{}}) _, err := Serialize(msg) So(err, ShouldNotBeNil) }) - -} - -func TestMakeFullTypeName(t *testing.T) { - Convey("指针类型", t, func() { - type St struct{} - - typeName := makeFullTypeName(myreflect.TypeOf[St]()) - So(typeName, ShouldEqual, "gitlink.org.cn/cloudream/common/pkgs/mq.St") - - typeName = makeFullTypeName(reflect.PointerTo(myreflect.TypeOf[St]())) - So(typeName, ShouldEqual, "*gitlink.org.cn/cloudream/common/pkgs/mq.St") - - typeName = makeFullTypeName(reflect.PointerTo(reflect.PointerTo(myreflect.TypeOf[St]()))) - So(typeName, ShouldEqual, "**gitlink.org.cn/cloudream/common/pkgs/mq.St") - }) } diff --git a/pkgs/types/types.go b/pkgs/types/types.go new file mode 100644 index 0000000..d86e228 --- /dev/null +++ b/pkgs/types/types.go @@ -0,0 +1,5 @@ +package types + +func Ref[T any](val T) *T { + return &val +} diff --git a/pkgs/types/union.go b/pkgs/types/union.go index 10eeffd..ec4ecce 100644 --- a/pkgs/types/union.go +++ b/pkgs/types/union.go @@ -1,6 +1,7 @@ package types import ( + "fmt" "reflect" myreflect "gitlink.org.cn/cloudream/common/utils/reflect" @@ -37,6 +38,11 @@ func (u *TypeUnion) Include(typ myreflect.Type) bool { return false } -func (u *TypeUnion) Add(typ myreflect.Type) { +func (u *TypeUnion) Add(typ myreflect.Type) error { + if !typ.AssignableTo(u.UnionType) { + return fmt.Errorf("type is not assignable to union type") + } + u.ElementTypes = append(u.ElementTypes, typ) + return nil } diff --git a/sdks/scheduler/models.go b/sdks/scheduler/models.go index 65259ba..29a2594 100644 --- a/sdks/scheduler/models.go +++ b/sdks/scheduler/models.go @@ -1,7 +1,6 @@ package schsdk import ( - "gitlink.org.cn/cloudream/common/pkgs/mq" "gitlink.org.cn/cloudream/common/pkgs/types" cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" "gitlink.org.cn/cloudream/common/utils/serder" @@ -38,8 +37,7 @@ var JobInfoTypeUnion = types.NewTypeUnion[JobInfo]( (*NormalJobInfo)(nil), (*ResourceJobInfo)(nil), ) -var _ = serder.RegisterNewTaggedTypeUnion(JobInfoTypeUnion, "Type", "type") -var _ = mq.RegisterUnionType(JobInfoTypeUnion) +var _ = serder.UseTypeUnionInternallyTagged(&JobInfoTypeUnion, "type") type JobInfoBase struct { LocalJobID string `json:"localJobID"` @@ -50,16 +48,18 @@ func (i *JobInfoBase) GetLocalJobID() string { } type NormalJobInfo struct { + serder.Metadata `union:"Normal"` JobInfoBase - Type string `json:"type" union:"Normal"` + Type string `json:"type"` Files JobFilesInfo `json:"files"` Runtime JobRuntimeInfo `json:"runtime"` Resources JobResourcesInfo `json:"resources"` } type ResourceJobInfo struct { + serder.Metadata `union:"Resource"` JobInfoBase - Type string `json:"type" union:"Resource"` + Type string `json:"type"` BucketID int64 `json:"bucketID"` Redundancy cdssdk.TypedRedundancyInfo `json:"redundancy"` TargetLocalJobID string `json:"targetLocalJobID"` @@ -81,34 +81,37 @@ var FileInfoTypeUnion = types.NewTypeUnion[JobFileInfo]( (*ResourceJobFileInfo)(nil), (*ImageJobFileInfo)(nil), ) -var _ = serder.RegisterNewTaggedTypeUnion(FileInfoTypeUnion, "Type", "type") -var _ = mq.RegisterUnionType(FileInfoTypeUnion) +var _ = serder.UseTypeUnionInternallyTagged(&FileInfoTypeUnion, "type") type JobFileInfoBase struct{} func (i *JobFileInfoBase) Noop() {} type PackageJobFileInfo struct { + serder.Metadata `union:"Package"` JobFileInfoBase - Type string `json:"type" union:"Package"` + Type string `json:"type"` PackageID int64 `json:"packageID"` } type LocalJobFileInfo struct { + serder.Metadata `union:"LocalFile"` JobFileInfoBase - Type string `json:"type" union:"LocalFile"` + Type string `json:"type"` LocalPath string `json:"localPath"` } type ResourceJobFileInfo struct { + serder.Metadata `union:"Resource"` JobFileInfoBase - Type string `json:"type" union:"Resource"` + Type string `json:"type"` ResourceLocalJobID string `json:"resourceLocalJobID"` } type ImageJobFileInfo struct { + serder.Metadata `union:"Image"` JobFileInfoBase - Type string `json:"type" union:"Image"` + Type string `json:"type"` ImageID ImageID `json:"imageID"` } @@ -133,21 +136,6 @@ type JobResourcesInfo struct { Memory int64 `json:"memory"` } -func JobSetInfoFromJSON(data []byte) (*JobSetInfo, error) { - mp := make(map[string]any) - if err := serder.JSONToObject(data, &mp); err != nil { - return nil, err - } - - var ret JobSetInfo - err := serder.MapToObject(mp, &ret) - if err != nil { - return nil, err - } - - return &ret, nil -} - type JobSetFilesUploadScheme struct { LocalFileSchemes []LocalFileUploadScheme `json:"localFileUploadSchemes"` } diff --git a/sdks/unifyops/models.go b/sdks/unifyops/models.go index 655738c..6fe3d7b 100644 --- a/sdks/unifyops/models.go +++ b/sdks/unifyops/models.go @@ -1,7 +1,6 @@ package uopsdk import ( - "gitlink.org.cn/cloudream/common/pkgs/mq" "gitlink.org.cn/cloudream/common/pkgs/types" "gitlink.org.cn/cloudream/common/utils/serder" ) @@ -37,8 +36,7 @@ var ResourceDataTypeUnion = types.NewTypeUnion[ResourceData]( (*StorageResourceData)(nil), (*MemoryResourceData)(nil), ) -var _ = serder.RegisterNewTaggedTypeUnion(ResourceDataTypeUnion, "Name", "name") -var _ = mq.RegisterUnionType(ResourceDataTypeUnion) +var _ = serder.UseTypeUnionInternallyTagged(&ResourceDataTypeUnion, "name") type ResourceDataBase struct{} @@ -50,8 +48,9 @@ type UnitValue[T any] struct { } type CPUResourceData struct { + serder.Metadata `union:"CPU"` ResourceDataBase - Name ResourceType `json:"name" union:"CPU"` + Name ResourceType `json:"name"` Total UnitValue[int64] `json:"total"` Available UnitValue[int64] `json:"available"` } @@ -65,8 +64,9 @@ func NewCPUResourceData(total UnitValue[int64], available UnitValue[int64]) *CPU } type NPUResourceData struct { + serder.Metadata `union:"NPU"` ResourceDataBase - Name ResourceType `json:"name" union:"NPU"` + Name ResourceType `json:"name"` Total UnitValue[int64] `json:"total"` Available UnitValue[int64] `json:"available"` } @@ -80,8 +80,9 @@ func NewNPUResourceData(total UnitValue[int64], available UnitValue[int64]) *NPU } type GPUResourceData struct { + serder.Metadata `union:"GPU"` ResourceDataBase - Name ResourceType `json:"name" union:"GPU"` + Name ResourceType `json:"name"` Total UnitValue[int64] `json:"total"` Available UnitValue[int64] `json:"available"` } @@ -95,8 +96,9 @@ func NewGPUResourceData(total UnitValue[int64], available UnitValue[int64]) *GPU } type MLUResourceData struct { + serder.Metadata `union:"MLU"` ResourceDataBase - Name ResourceType `json:"name" union:"MLU"` + Name ResourceType `json:"name"` Total UnitValue[int64] `json:"total"` Available UnitValue[int64] `json:"available"` } @@ -110,8 +112,9 @@ func NewMLUResourceData(total UnitValue[int64], available UnitValue[int64]) *MLU } type StorageResourceData struct { + serder.Metadata `union:"STORAGE"` ResourceDataBase - Name ResourceType `json:"name" union:"STORAGE"` + Name ResourceType `json:"name"` Total UnitValue[float64] `json:"total"` Available UnitValue[float64] `json:"available"` } @@ -125,8 +128,9 @@ func NewStorageResourceData(total UnitValue[float64], available UnitValue[float6 } type MemoryResourceData struct { + serder.Metadata `union:"MEMORY"` ResourceDataBase - Name ResourceType `json:"name" union:"MEMORY"` + Name ResourceType `json:"name"` Total UnitValue[float64] `json:"total"` Available UnitValue[float64] `json:"available"` } diff --git a/utils/serder/serder.go b/utils/serder/serder.go index 1d46d0e..81b36b5 100644 --- a/utils/serder/serder.go +++ b/utils/serder/serder.go @@ -1,19 +1,64 @@ package serder import ( + "bytes" "encoding/json" "fmt" "io" "reflect" "strings" - "gitlink.org.cn/cloudream/common/pkgs/types" + jsoniter "github.com/json-iterator/go" ) +var unionHandler = UnionHandler{ + internallyTagged: make(map[reflect.Type]*TypeUnionInternallyTagged), + externallyTagged: make(map[reflect.Type]*TypeUnionExternallyTagged), +} + +var defaultAPI = func() jsoniter.API { + api := jsoniter.Config{ + EscapeHTML: true, + }.Froze() + + api.RegisterExtension(&unionHandler) + return api +}() + +// 将对象转为JSON字符串。支持TypeUnion。 +func ObjectToJSONEx[T any](obj T) ([]byte, error) { + buf := new(bytes.Buffer) + enc := defaultAPI.NewEncoder(buf) + // 这里使用&obj而直接不使用obj的原因是,Encode的形参类型为any, + // 如果T是一个interface类型,将obj传递进去后,内部拿到的类型将会是obj的实际类型, + // 使用&obj,那么内部拿到的将会是*T类型,通过一层一层解引用查找Encoder时,能找到T对应的TypeUnion + err := enc.Encode(&obj) + + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// 将JSON字符串转为对象。支持TypeUnion。 +func JSONToObjectEx[T any](data []byte) (T, error) { + var ret T + dec := defaultAPI.NewDecoder(bytes.NewReader(data)) + err := dec.Decode(&ret) + if err != nil { + return ret, err + } + + return ret, nil +} + +// 将对象转为JSON字符串。如果需要支持解析TypeUnion类型,则使用"Ex"结尾的同名函数。 func ObjectToJSON(obj any) ([]byte, error) { return json.Marshal(obj) } +// 将对象转为JSON字符串。如果需要支持解析TypeUnion类型,则使用"Ex"结尾的同名函数。 func ObjectToJSONStream(obj any) io.ReadCloser { pr, pw := io.Pipe() enc := json.NewEncoder(pw) @@ -30,10 +75,12 @@ func ObjectToJSONStream(obj any) io.ReadCloser { return pr } +// 将JSON字符串转为对象。如果需要支持解析TypeUnion类型,则使用"Ex"结尾的同名函数。 func JSONToObject(data []byte, obj any) error { return json.Unmarshal(data, obj) } +// 将JSON字符串转为对象。如果需要支持解析TypeUnion类型,则使用"Ex"结尾的同名函数。 func JSONToObjectStream(str io.Reader, obj any) error { dec := json.NewDecoder(str) err := dec.Decode(obj) @@ -49,86 +96,21 @@ type TypeResolver interface { StringToType(typeStr string) (reflect.Type, error) } -var registeredTaggedTypeUnions []*TaggedUnionType - -type TaggedUnionType struct { - Union types.TypeUnion - StrcutTagField string - JSONTagField string - TagToType map[string]reflect.Type -} - -// 根据指定的字段的值来区分不同的类型。值可以通过在字段增加“union”Tag来指定。如果没有指定,则使用类型名。 -func NewTaggedTypeUnion(union types.TypeUnion, structTagField string, jsonTagField string) *TaggedUnionType { - tagToType := make(map[string]reflect.Type) - - for _, typ := range union.ElementTypes { - if structTagField == "" { - tagToType[typ.Name()] = typ - continue - } - - // 如果ElementType是一个指向结构体的指针,那么就遍历结构体的字段(解引用) - structType := typ - for structType.Kind() == reflect.Pointer { - structType = structType.Elem() - } - - field, ok := structType.FieldByName(structTagField) - if !ok { - tagToType[typ.Name()] = typ - continue - } - - tag := field.Tag.Get("union") - if tag == "" { - tagToType[typ.Name()] = typ - continue - } - - tagToType[tag] = typ - } - - return &TaggedUnionType{ - Union: union, - StrcutTagField: structTagField, - JSONTagField: jsonTagField, - TagToType: tagToType, - } -} - -// 注册一个TaggedTypeUnion -func RegisterTaggedTypeUnion(union *TaggedUnionType) *TaggedUnionType { - registeredTaggedTypeUnions = append(registeredTaggedTypeUnions, union) - return union -} - -// 创建并注册一个TaggedTypeUnion -func RegisterNewTaggedTypeUnion(union types.TypeUnion, structTagField string, jsonTagField string) *TaggedUnionType { - taggedUnion := NewTaggedTypeUnion(union, structTagField, jsonTagField) - RegisterTaggedTypeUnion(taggedUnion) - return taggedUnion -} - type MapToObjectOption struct { - UnionTypes []*TaggedUnionType // 转换过程中遇到这些类型时,会依据指定的字段的值,来决定转换后的实际类型 - NoRegisteredUnionTypes bool // 是否不使用全局注册的UnionType + NoRegisteredUnionTypes bool // 是否不使用全局注册的UnionType } +// TODO 使用这个函数来处理TypeUnion的地方都可以直接使用Ex系列的函数 func MapToObject(m map[string]any, obj any, opt ...MapToObjectOption) error { var op MapToObjectOption if len(opt) > 0 { op = opt[0] } - unionTypeMapping := make(map[reflect.Type]*TaggedUnionType) - - for _, u := range op.UnionTypes { - unionTypeMapping[u.Union.UnionType] = u - } + unionTypeMapping := make(map[reflect.Type]*TypeUnionInternallyTagged) if !op.NoRegisteredUnionTypes { - for _, u := range registeredTaggedTypeUnions { + for _, u := range unionHandler.internallyTagged { unionTypeMapping[u.Union.UnionType] = u } } @@ -142,14 +124,14 @@ func MapToObject(m map[string]any, obj any, opt ...MapToObjectOption) error { } mp := from.Interface().(map[string]any) - tag, ok := mp[info.JSONTagField] + tag, ok := mp[info.TagField] if !ok { - return nil, fmt.Errorf("converting to %v: no tag field %s in map", toType, info.JSONTagField) + return nil, fmt.Errorf("converting to %v: no tag field %s in map", toType, info.TagField) } tagStr, ok := tag.(string) if !ok { - return nil, fmt.Errorf("converting to %v: tag field %s value is %v, which is not a string", toType, info.JSONTagField, tag) + return nil, fmt.Errorf("converting to %v: tag field %s value is %v, which is not a string", toType, info.TagField, tag) } eleType, ok := info.TagToType[tagStr] diff --git a/utils/serder/serder_test.go b/utils/serder/serder_test.go index ba3bf6d..e2c4b8c 100644 --- a/utils/serder/serder_test.go +++ b/utils/serder/serder_test.go @@ -363,13 +363,15 @@ func Test_MapToObject(t *testing.T) { type UnionType interface{} type EleType1 struct { - Type EleType `json:"type" union:"1"` - Value1 string `json:"value1"` + Metadata `union:"1"` + Type EleType `json:"type"` + Value1 string `json:"value1"` } type EleType2 struct { - Type EleType `json:"type" union:"2"` - Value2 int `json:"value2"` + Metadata `union:"2"` + Type EleType `json:"type"` + Value2 int `json:"value2"` } type St struct { @@ -390,17 +392,12 @@ func Test_MapToObject(t *testing.T) { } var ret St - err := MapToObject(mp, &ret, MapToObjectOption{ - UnionTypes: []*TaggedUnionType{ - NewTaggedTypeUnion(types.NewTypeUnion[UnionType]( - (*EleType1)(nil), - (*EleType2)(nil), - ), - "Type", - "type", - ), - }, - }) + union := types.NewTypeUnion[UnionType]( + (*EleType1)(nil), + (*EleType2)(nil), + ) + UseTypeUnionInternallyTagged(&union, "type") + err := MapToObject(mp, &ret) So(err, ShouldBeNil) @@ -414,13 +411,15 @@ func Test_MapToObject(t *testing.T) { type UnionType interface{} type EleType1 struct { - Type string `json:"type" union:"1"` - Value1 string `json:"value1"` + Metadata `union:"1"` + Type string `json:"type"` + Value1 string `json:"value1"` } type EleType2 struct { - Type string `json:"type" union:"2"` - Value2 int `json:"value2"` + Metadata `union:"2"` + Type string `json:"type"` + Value2 int `json:"value2"` } mp := map[string]any{ @@ -429,17 +428,12 @@ func Test_MapToObject(t *testing.T) { } var ret UnionType - err := MapToObject(mp, &ret, MapToObjectOption{ - UnionTypes: []*TaggedUnionType{ - NewTaggedTypeUnion(types.NewTypeUnion[UnionType]( - (*EleType1)(nil), - (*EleType2)(nil), - ), - "Type", - "type", - ), - }, - }) + union := types.NewTypeUnion[UnionType]( + (*EleType1)(nil), + (*EleType2)(nil), + ) + UseTypeUnionInternallyTagged(&union, "type") + err := MapToObject(mp, &ret) So(err, ShouldBeNil) @@ -464,7 +458,27 @@ func Test_MapToObject(t *testing.T) { So(string(ret.Str), ShouldEqual, "1") }) } -func Test_JSON(t *testing.T) { + +type Base interface { + Noop() +} +type St1 struct { + Metadata `union:"St1"` + Type string + Val string +} + +func (*St1) Noop() {} + +type St2 struct { + Metadata `union:"St2"` + Type string + Val int +} + +func (St2) Noop() {} + +func Test_ObjectToJSON2(t *testing.T) { Convey("NewType", t, func() { type Str string @@ -484,4 +498,154 @@ func Test_JSON(t *testing.T) { So(err, ShouldBeNil) So(string(ret.Str), ShouldEqual, "1") }) + + Convey("UnionType ExternallyTagged", t, func() { + type Base interface{} + type St1 struct { + Val string + } + type St2 struct { + Val int64 + } + type Outter struct { + B []Base + } + + union := types.NewTypeUnion[Base](St1{}, &St2{}) + UseTypeUnionExternallyTagged(&union) + + val := Outter{B: []Base{St1{Val: "asd"}, &St2{Val: 123}}} + data, err := ObjectToJSONEx(val) + So(err, ShouldBeNil) + + ret, err := JSONToObjectEx[Outter](data) + So(err, ShouldBeNil) + So(ret, ShouldResemble, val) + }) + + Convey("UnionType InternallyTagged", t, func() { + type Base interface{} + type St1 struct { + Metadata `union:"St1"` + Type string + Val string + } + type St2 struct { + Metadata `union:"St2"` + Type string + Val int64 + } + type Outter struct { + B []Base + } + + union := types.NewTypeUnion[Base](St1{}, &St2{}) + UseTypeUnionInternallyTagged(&union, "Type") + + val := Outter{B: []Base{St1{Val: "asd", Type: "St1"}, &St2{Val: 123, Type: "St2"}}} + data, err := ObjectToJSONEx(val) + So(err, ShouldBeNil) + + ret, err := JSONToObjectEx[Outter](data) + So(err, ShouldBeNil) + So(ret, ShouldResemble, val) + }) + + Convey("实参类型和目标类型本身就是UnionType ExternallyTagged", t, func() { + type Base interface{} + type St1 struct { + Val string + } + union := types.NewTypeUnion[Base](St1{}) + UseTypeUnionExternallyTagged(&union) + + var val Base = St1{Val: "asd"} + data, err := ObjectToJSONEx(val) + So(err, ShouldBeNil) + + ret, err := JSONToObjectEx[Base](data) + So(err, ShouldBeNil) + So(ret, ShouldResemble, val) + }) + + Convey("实参类型和目标类型本身就是UnionType InternallyTagged", t, func() { + type Base interface{} + type St1 struct { + Metadata `union:"St1"` + Type string + Val string + } + union := types.NewTypeUnion[Base](St1{}) + UseTypeUnionInternallyTagged(&union, "Type") + + var val Base = St1{Val: "asd", Type: "St1"} + data, err := ObjectToJSONEx(val) + So(err, ShouldBeNil) + + ret, err := JSONToObjectEx[Base](data) + So(err, ShouldBeNil) + So(ret, ShouldResemble, val) + }) + + Convey("UnionType带有函数 ExternallyTagged", t, func() { + union := types.NewTypeUnion[Base](&St1{}, St2{}) + UseTypeUnionExternallyTagged(&union) + + var val []Base = []Base{ + &St1{Val: "asd", Type: "St1"}, + St2{Val: 123, Type: "St2"}, + } + data, err := ObjectToJSONEx(val) + So(err, ShouldBeNil) + + ret, err := JSONToObjectEx[[]Base](data) + So(err, ShouldBeNil) + So(ret, ShouldResemble, val) + }) + + Convey("UnionType带有函数 InternallyTagged", t, func() { + union := types.NewTypeUnion[Base](&St1{}, St2{}) + UseTypeUnionInternallyTagged(&union, "Type") + + var val []Base = []Base{ + &St1{Val: "asd", Type: "St1"}, + St2{Val: 123, Type: "St2"}, + } + data, err := ObjectToJSONEx(val) + So(err, ShouldBeNil) + + ret, err := JSONToObjectEx[[]Base](data) + So(err, ShouldBeNil) + So(ret, ShouldResemble, val) + }) + + Convey("UnionType,但实际值为nil ExternallyTagged", t, func() { + union := types.NewTypeUnion[Base](&St1{}, St2{}) + UseTypeUnionExternallyTagged(&union) + + var val []Base = []Base{ + nil, + } + data, err := ObjectToJSONEx(val) + So(err, ShouldBeNil) + + ret, err := JSONToObjectEx[[]Base](data) + So(err, ShouldBeNil) + So(ret, ShouldResemble, val) + }) + + Convey("UnionType,但实际值为nil InternallyTagged", t, func() { + union := types.NewTypeUnion[Base](&St1{}, St2{}) + UseTypeUnionInternallyTagged(&union, "Type") + + var val []Base = []Base{ + nil, + } + data, err := ObjectToJSONEx(val) + So(err, ShouldBeNil) + + ret, err := JSONToObjectEx[[]Base](data) + So(err, ShouldBeNil) + So(ret, ShouldResemble, val) + }) } diff --git a/utils/serder/types.go b/utils/serder/types.go index e6517a9..6c93b71 100644 --- a/utils/serder/types.go +++ b/utils/serder/types.go @@ -6,6 +6,8 @@ import ( "time" ) +type Metadata struct{} + type TimestampSecond time.Time func (t *TimestampSecond) MarshalJSON() ([]byte, error) { diff --git a/utils/serder/union_handler.go b/utils/serder/union_handler.go new file mode 100644 index 0000000..bd4df5b --- /dev/null +++ b/utils/serder/union_handler.go @@ -0,0 +1,336 @@ +package serder + +import ( + "fmt" + "reflect" + "unsafe" + + jsoniter "github.com/json-iterator/go" + "github.com/modern-go/reflect2" + "gitlink.org.cn/cloudream/common/pkgs/types" + + myreflect "gitlink.org.cn/cloudream/common/utils/reflect" +) + +type TypeUnionExternallyTagged struct { + Union *types.TypeUnion + TypeNameToType map[string]reflect.Type +} + +// 遇到TypeUnion的基类(UnionType)的字段时,将其实际值的类型信息也编码到JSON中,反序列化时也会解析出类型信息,还原出真实的类型。 +// Externally Tagged的格式是:{ "类型名": {...对象内容...} } +// +// 可以通过内嵌Metadata结构体,并在它身上增加"union"Tag来指定类型名称,如果没有指定,则默认使用系统类型名(包括包路径)。 +func UseTypeUnionExternallyTagged(union *types.TypeUnion) *TypeUnionExternallyTagged { + eu := &TypeUnionExternallyTagged{ + Union: union, + TypeNameToType: make(map[string]reflect.Type), + } + + for _, eleType := range union.ElementTypes { + eu.Add(eleType) + } + + unionHandler.externallyTagged[union.UnionType] = eu + + return eu +} + +func (u *TypeUnionExternallyTagged) Add(typ reflect.Type) error { + err := u.Union.Add(typ) + if err != nil { + return nil + } + + u.TypeNameToType[makeDerefFullTypeName(typ)] = typ + return nil +} + +type TypeUnionInternallyTagged struct { + Union *types.TypeUnion + TagField string + TagToType map[string]reflect.Type +} + +// 遇到TypeUnion的基类(UnionType)的字段时,将其实际值的类型信息也编码到JSON中,反序列化时也会解析出类型信息,还原出真实的类型。 +// Internally Tagged的格式是:{ "类型字段": "类型名", ...对象内容...},JSON中的类型字段名需要指定。 +// 注:对象定义需要包含类型字段,而且在序列化之前需要手动赋值,目前不支持自动设置。 +// +// 可以通过内嵌Metadata结构体,并在它身上增加"union"Tag来指定类型名称,如果没有指定,则默认使用系统类型名(包括包路径)。 +func UseTypeUnionInternallyTagged(union *types.TypeUnion, tagField string) *TypeUnionInternallyTagged { + iu := &TypeUnionInternallyTagged{ + Union: union, + TagField: tagField, + TagToType: make(map[string]reflect.Type), + } + + for _, eleType := range union.ElementTypes { + iu.Add(eleType) + } + + unionHandler.internallyTagged[union.UnionType] = iu + return iu +} + +func (u *TypeUnionInternallyTagged) Add(typ reflect.Type) error { + err := u.Union.Add(typ) + if err != nil { + return nil + } + + // 解引用直到得到结构体类型 + structType := typ + for structType.Kind() == reflect.Pointer { + structType = structType.Elem() + } + + // 要求内嵌Metadata结构体,那么结构体中的字段名就会是Metadata, + field, ok := structType.FieldByName(myreflect.TypeNameOf[Metadata]()) + if !ok { + u.TagToType[makeDerefFullTypeName(structType)] = typ + return nil + } + + // 为防同名,检查类型是不是也是Metadata + if field.Type != myreflect.TypeOf[Metadata]() { + u.TagToType[makeDerefFullTypeName(structType)] = typ + return nil + } + + tag := field.Tag.Get("union") + if tag == "" { + u.TagToType[makeDerefFullTypeName(structType)] = typ + return nil + } + + u.TagToType[tag] = typ + return nil +} + +type UnionHandler struct { + internallyTagged map[reflect.Type]*TypeUnionInternallyTagged + externallyTagged map[reflect.Type]*TypeUnionExternallyTagged +} + +func (h *UnionHandler) UpdateStructDescriptor(structDescriptor *jsoniter.StructDescriptor) { + +} + +func (h *UnionHandler) CreateMapKeyDecoder(typ reflect2.Type) jsoniter.ValDecoder { + return nil +} + +func (h *UnionHandler) CreateMapKeyEncoder(typ reflect2.Type) jsoniter.ValEncoder { + return nil +} + +func (h *UnionHandler) CreateDecoder(typ reflect2.Type) jsoniter.ValDecoder { + typ1 := typ.Type1() + if it, ok := h.internallyTagged[typ1]; ok { + return &InternallyTaggedDecoder{ + union: it, + } + } + + if et, ok := h.externallyTagged[typ1]; ok { + return &ExternallyTaggedDecoder{ + union: et, + } + } + + return nil +} + +func (h *UnionHandler) CreateEncoder(typ reflect2.Type) jsoniter.ValEncoder { + typ1 := typ.Type1() + if it, ok := h.internallyTagged[typ1]; ok { + return &InternallyTaggedEncoder{ + union: it, + } + } + + if et, ok := h.externallyTagged[typ1]; ok { + return &ExternallyTaggedEncoder{ + union: et, + } + } + return nil +} + +func (h *UnionHandler) DecorateDecoder(typ reflect2.Type, decoder jsoniter.ValDecoder) jsoniter.ValDecoder { + return decoder +} + +func (h *UnionHandler) DecorateEncoder(typ reflect2.Type, encoder jsoniter.ValEncoder) jsoniter.ValEncoder { + return encoder +} + +// 以下Encoder/Decoder都是在传入类型/目标类型是TypeUnion的基类(UnionType)时使用 +type InternallyTaggedEncoder struct { + union *TypeUnionInternallyTagged +} + +func (e *InternallyTaggedEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + +func (e *InternallyTaggedEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { + var val any + + if e.union.Union.UnionType.NumMethod() == 0 { + // 无方法的interface底层都是eface结构体,所以可以直接转*any + val = *(*any)(ptr) + } else { + // 有方法的interface底层都是iface结构体,可以将其转成eface,转换后不损失类型信息 + val = reflect2.IFaceToEFace(ptr) + } + + // 可以考虑检查一下Type字段有没有赋值,没有赋值则将其赋值为union Tag指定的值 + stream.WriteVal(val) +} + +type InternallyTaggedDecoder struct { + union *TypeUnionInternallyTagged +} + +func (e *InternallyTaggedDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { + nextTokenKind := iter.WhatIsNext() + if nextTokenKind == jsoniter.NilValue { + iter.Skip() + return + } + + raw := iter.ReadAny() + if raw.LastError() != nil { + iter.ReportError("decode TaggedUnionType", "getting object raw:"+raw.LastError().Error()) + return + } + + tagField := raw.Get(e.union.TagField) + if tagField.LastError() != nil { + iter.ReportError("decode TaggedUnionType", "getting type tag field:"+tagField.LastError().Error()) + return + } + + typeTag := tagField.ToString() + if typeTag == "" { + iter.ReportError("decode TaggedUnionType", "type tag is empty") + return + } + + typ, ok := e.union.TagToType[typeTag] + if !ok { + iter.ReportError("decode TaggedUnionType", fmt.Sprintf("unknow type tag %s in union %s", typeTag, e.union.Union.UnionType.Name())) + return + } + + // 如果目标类型已经是个指针类型*T,那么在New的时候就需要使用T, + // 否则New出来的是会是**T,这将导致后续的反序列化出问题 + if typ.Kind() == reflect.Pointer { + val := reflect.New(typ.Elem()) + raw.ToVal(val.Interface()) + + retVal := reflect.NewAt(e.union.Union.UnionType, ptr) + retVal.Elem().Set(val) + + } else { + val := reflect.New(typ) + raw.ToVal(val.Interface()) + + retVal := reflect.NewAt(e.union.Union.UnionType, ptr) + retVal.Elem().Set(val.Elem()) + } +} + +type ExternallyTaggedEncoder struct { + union *TypeUnionExternallyTagged +} + +func (e *ExternallyTaggedEncoder) IsEmpty(ptr unsafe.Pointer) bool { + return false +} + +func (e *ExternallyTaggedEncoder) Encode(ptr unsafe.Pointer, stream *jsoniter.Stream) { + var val any + + if e.union.Union.UnionType.NumMethod() == 0 { + // 无方法的interface底层都是eface结构体,所以可以直接转*any + val = *(*any)(ptr) + } else { + // 有方法的interface底层都是iface结构体,可以将其转成eface,转换后不损失类型信息 + val = reflect2.IFaceToEFace(ptr) + } + + if val == nil { + stream.WriteNil() + return + } + + stream.WriteObjectStart() + valType := myreflect.TypeOfValue(val) + if !e.union.Union.Include(valType) { + stream.Error = fmt.Errorf("type %v is not in union %v", valType, e.union.Union.UnionType) + return + } + stream.WriteObjectField(makeDerefFullTypeName(valType)) + stream.WriteVal(val) + stream.WriteObjectEnd() +} + +type ExternallyTaggedDecoder struct { + union *TypeUnionExternallyTagged +} + +func (e *ExternallyTaggedDecoder) Decode(ptr unsafe.Pointer, iter *jsoniter.Iterator) { + nextTkType := iter.WhatIsNext() + + if nextTkType == jsoniter.NilValue { + iter.Skip() + return + } + + if nextTkType != jsoniter.ObjectValue { + iter.ReportError("decode UnionType", fmt.Sprintf("unknow next token type %v", nextTkType)) + return + } + + typeStr := iter.ReadObject() + if typeStr == "" { + iter.ReportError("decode UnionType", "type string is empty") + } + + typ, ok := e.union.TypeNameToType[typeStr] + if !ok { + iter.ReportError("decode UnionType", fmt.Sprintf("unknow type string %s in union %v", typeStr, e.union.Union.UnionType)) + return + } + + // 如果目标类型已经是个指针类型*T,那么在New的时候就需要使用T, + // 否则New出来的是会是**T,这将导致后续的反序列化出问题 + if typ.Kind() == reflect.Pointer { + val := reflect.New(typ.Elem()) + iter.ReadVal(val.Interface()) + + retVal := reflect.NewAt(e.union.Union.UnionType, ptr) + retVal.Elem().Set(val) + + } else { + val := reflect.New(typ) + iter.ReadVal(val.Interface()) + + retVal := reflect.NewAt(e.union.Union.UnionType, ptr) + retVal.Elem().Set(val.Elem()) + } + + if iter.ReadObject() != "" { + iter.ReportError("decode UnionType", "there should be only one fields in the json object") + } +} + +func makeDerefFullTypeName(typ reflect.Type) string { + realType := typ + for realType.Kind() == reflect.Pointer { + realType = realType.Elem() + } + return fmt.Sprintf("%s.%s", realType.PkgPath(), realType.Name()) +}