diff --git a/pkgs/mq/client.go b/pkgs/mq/client.go index cc1ffbb..0fb02e8 100644 --- a/pkgs/mq/client.go +++ b/pkgs/mq/client.go @@ -302,15 +302,17 @@ func (c *RabbitMQClient) Close() error { } // 发送消息并等待回应。因为无法自动推断出TResp的类型,所以将其放在第一个手工填写,之后的TBody可以自动推断出来 -func Request[TResp any, TReq any](cli *RabbitMQClient, req TReq, opts ...RequestOption) (*TResp, error) { +func Request[TSvc any, TReq MessageBody, TResp MessageBody](_ func(svc TSvc, msg TReq) (TResp, *CodeMessage), cli *RabbitMQClient, req TReq, opts ...RequestOption) (TResp, error) { + var defRet TResp + resp, err := cli.Request(MakeAppDataMessage(req), opts...) if err != nil { - return nil, fmt.Errorf("requesting: %w", err) + return defRet, fmt.Errorf("requesting: %w", err) } errCode, errMsg := resp.GetCodeMessage() if errCode != errorcode.OK { - return nil, &CodeMessageError{ + return defRet, &CodeMessageError{ code: errCode, message: errMsg, } @@ -318,16 +320,16 @@ func Request[TResp any, TReq any](cli *RabbitMQClient, req TReq, opts ...Request respBody, ok := resp.Body.(TResp) if !ok { - return nil, fmt.Errorf("expect a %s body, but got %s", + return defRet, fmt.Errorf("expect a %s body, but got %s", myreflect.ElemTypeOf[TResp]().Name(), myreflect.TypeOfValue(resp.Body).Name()) } - return &respBody, nil + return respBody, nil } // 发送消息,不等待回应 -func Send[TReq any](cli *RabbitMQClient, msg TReq, opts ...SendOption) error { +func Send[TSvc any, TReq MessageBody](_ func(svc TSvc, msg TReq), cli *RabbitMQClient, msg TReq, opts ...SendOption) error { req := MakeAppDataMessage(msg) err := cli.Send(req, opts...) diff --git a/pkgs/mq/message.go b/pkgs/mq/message.go index 35d9a16..9ac16f6 100644 --- a/pkgs/mq/message.go +++ b/pkgs/mq/message.go @@ -7,6 +7,7 @@ import ( "unsafe" jsoniter "github.com/json-iterator/go" + "github.com/modern-go/reflect2" "gitlink.org.cn/cloudream/common/pkgs/types" myreflect "gitlink.org.cn/cloudream/common/utils/reflect" ) @@ -16,14 +17,26 @@ const ( MessageTypeHeartbeat = "Heartbeat" ) +type MessageBody interface { + // 此方法无任何作用,仅用于避免MessageBody是一个空interface,从而导致任何类型的值都可以赋值给它 + // 与下方的MessageBodyBase配合使用: + // IsMessageBody只让实现了此接口的类型能赋值给它,内嵌MessageBodyBase让类型必须是个指针类型, + // 这就确保了Message.Body必是某个类型的指针类型,避免序列化、反序列化过程出错 + IsMessageBody() +} + +// 这个结构体无任何字段,但实现了IsMessageBody,每种MessageBody都要内嵌这个结构体 +type MessageBodyBase struct{} + +// 此处的receiver是指针 +func (b *MessageBodyBase) IsMessageBody() {} + type Message struct { Type string `json:"type"` Headers map[string]any `json:"headers"` Body MessageBody `json:"body"` } -type MessageBody interface{} - func (m *Message) GetRequestID() string { reqID, _ := m.Headers["requestID"].(string) return reqID @@ -91,6 +104,7 @@ func RegisterMessage[T any]() { // 在序列化结构体中包含的UnionType类型字段时,会将字段值的实际类型保存在序列化后的结果中。 // 在反序列化时,会根据类型信息重建原本的字段值。 +// 注:TypeUnion.UnionType必须是一个interface func RegisterUnionType(union types.TypeUnion) *TypeUnionWithTypeName { myUnion := &TypeUnionWithTypeName{ Union: union, @@ -101,16 +115,27 @@ func RegisterUnionType(union types.TypeUnion) *TypeUnionWithTypeName { myUnion.TypeNameToType[makeFullTypeName(typ)] = typ } - jsoniter.RegisterTypeEncoderFunc(union.UnionType.String(), + if union.UnionType.NumMethod() == 0 { + registerForEFace(myUnion) + } else { + registerForIFace(myUnion) + } + + return myUnion +} + +// 无方法的interface类型 +func registerForEFace(myUnion *TypeUnionWithTypeName) { + jsoniter.RegisterTypeEncoderFunc(myUnion.Union.UnionType.String(), func(ptr unsafe.Pointer, stream *jsoniter.Stream) { - // 此处无法变成*UnionType,只能强转为*any + // 无方法的interface底层数据结构都是eface类型,所以可以直接转*any val := *(*any)(ptr) if val != nil { stream.WriteArrayStart() - valType := myreflect.TypeOfValue(val) + valType := myreflect.TypeOfValue(val).Elem() if !myUnion.Union.Include(valType) { - stream.Error = fmt.Errorf("type %v is not in union %v", valType, union.UnionType) + stream.Error = fmt.Errorf("type %v is not in union %v", valType, myUnion.Union.UnionType) return } @@ -126,9 +151,9 @@ func RegisterUnionType(union types.TypeUnion) *TypeUnionWithTypeName { return false }) - jsoniter.RegisterTypeDecoderFunc(union.UnionType.String(), + jsoniter.RegisterTypeDecoderFunc(myUnion.Union.UnionType.String(), func(ptr unsafe.Pointer, iter *jsoniter.Iterator) { - // 此处无法变成*UnionType,只能强转为*any + // 无方法的interface底层都是eface结构体,所以可以直接转*any vp := (*any)(ptr) nextTkType := iter.WhatIsNext() @@ -143,13 +168,13 @@ func RegisterUnionType(union types.TypeUnion) *TypeUnionWithTypeName { typ, ok := myUnion.TypeNameToType[typeStr] if !ok { - iter.ReportError("decode UnionType", fmt.Sprintf("unknow type string %s under %v", typeStr, union.UnionType)) + iter.ReportError("decode UnionType", fmt.Sprintf("unknow type string %s under %v", typeStr, myUnion.Union.UnionType)) return } val := reflect.New(typ) iter.ReadVal(val.Interface()) - *vp = val.Elem().Interface() + *vp = val.Interface() iter.ReadArray() } else { @@ -157,8 +182,66 @@ func RegisterUnionType(union types.TypeUnion) *TypeUnionWithTypeName { return } }) +} - return myUnion +// 有方法的interface类型 +func registerForIFace(myUnion *TypeUnionWithTypeName) { + jsoniter.RegisterTypeEncoderFunc(myUnion.Union.UnionType.String(), + func(ptr unsafe.Pointer, stream *jsoniter.Stream) { + // 有方法的interface底层都是iface结构体,可以将其转成eface,转换后不损失类型信息 + val := reflect2.IFaceToEFace(ptr) + if val != nil { + stream.WriteArrayStart() + + // 此处肯定是指针类型,见MessageBody上的注释的分析 + valType := myreflect.TypeOfValue(val).Elem() + if !myUnion.Union.Include(valType) { + stream.Error = fmt.Errorf("type %v is not in union %v", valType, myUnion.Union.UnionType) + return + } + + stream.WriteString(makeFullTypeName(valType)) + stream.WriteRaw(",") + stream.WriteVal(val) + stream.WriteArrayEnd() + } else { + stream.WriteNil() + } + }, + func(p unsafe.Pointer) bool { + return false + }) + + jsoniter.RegisterTypeDecoderFunc(myUnion.Union.UnionType.String(), + func(ptr unsafe.Pointer, iter *jsoniter.Iterator) { + + nextTkType := iter.WhatIsNext() + if nextTkType == jsoniter.NilValue { + iter.ReadNil() + + } else if nextTkType == jsoniter.ArrayValue { + iter.ReadArray() + typeStr := iter.ReadString() + iter.ReadArray() + + typ, ok := myUnion.TypeNameToType[typeStr] + if !ok { + iter.ReportError("decode UnionType", fmt.Sprintf("unknow type string %s under %v", typeStr, myUnion.Union.UnionType)) + return + } + + val := reflect.New(typ) + iter.ReadVal(val.Interface()) + + retVal := reflect.NewAt(myUnion.Union.UnionType, ptr) + retVal.Elem().Set(val) + + iter.ReadArray() + } else { + iter.ReportError("decode UnionType", fmt.Sprintf("unknow next token type %v", nextTkType)) + return + } + }) } func makeFullTypeName(typ myreflect.Type) string { diff --git a/pkgs/mq/message_dispatcher.go b/pkgs/mq/message_dispatcher.go index 66a6bcf..c6f9b0d 100644 --- a/pkgs/mq/message_dispatcher.go +++ b/pkgs/mq/message_dispatcher.go @@ -33,18 +33,12 @@ func (h *MessageDispatcher) Handle(svcBase any, msg *Message) (*Message, error) } // 将Service中的一个接口函数作为指定类型消息的处理函数 -func AddServiceFn[TSvc any, TReq any, TResp any](dispatcher *MessageDispatcher, svcFn func(svc TSvc, msg *TReq) (*TResp, *CodeMessage)) { +func AddServiceFn[TSvc any, TReq MessageBody, TResp MessageBody](dispatcher *MessageDispatcher, svcFn func(svc TSvc, msg TReq) (TResp, *CodeMessage)) { dispatcher.Add(myreflect.TypeOf[TReq](), func(svcBase any, reqMsg *Message) (*Message, error) { reqMsgBody := reqMsg.Body.(TReq) - ret, codeMsg := svcFn(svcBase.(TSvc), &reqMsgBody) - - var body MessageBody - if ret != nil { - body = *ret - } - - respMsg := MakeAppDataMessage(body) + ret, codeMsg := svcFn(svcBase.(TSvc), reqMsgBody) + respMsg := MakeAppDataMessage(ret) respMsg.SetCodeMessage(codeMsg.Code, codeMsg.Message) return &respMsg, nil @@ -52,11 +46,11 @@ func AddServiceFn[TSvc any, TReq any, TResp any](dispatcher *MessageDispatcher, } // 将Service中的一个*没有返回值的*接口函数作为指定类型消息的处理函数 -func AddNoRespServiceFn[TSvc any, TReq any](dispatcher *MessageDispatcher, svcFn func(svc TSvc, msg *TReq)) { +func AddNoRespServiceFn[TSvc any, TReq MessageBody](dispatcher *MessageDispatcher, svcFn func(svc TSvc, msg TReq)) { dispatcher.Add(myreflect.TypeOf[TReq](), func(svcBase any, reqMsg *Message) (*Message, error) { reqMsgBody := reqMsg.Body.(TReq) - svcFn(svcBase.(TSvc), &reqMsgBody) + svcFn(svcBase.(TSvc), reqMsgBody) return nil, nil }) diff --git a/pkgs/mq/message_test.go b/pkgs/mq/message_test.go index f04ee77..78511ee 100644 --- a/pkgs/mq/message_test.go +++ b/pkgs/mq/message_test.go @@ -98,18 +98,19 @@ func TestMessage(t *testing.T) { Convey("body中包含nil数组", t, func() { type Body struct { + MessageBodyBase NilArr []string } RegisterMessage[Body]() - msg := MakeAppDataMessage(Body{}) + msg := MakeAppDataMessage(&Body{}) data, err := Serialize(msg) So(err, ShouldBeNil) retMsg, err := Deserialize(data) So(err, ShouldBeNil) - So(retMsg.Body.(Body).NilArr, ShouldBeNil) + So(retMsg.Body.(*Body).NilArr, ShouldBeNil) }) Convey("body中包含匿名结构体", t, func() { @@ -117,52 +118,112 @@ func TestMessage(t *testing.T) { Value string `json:"value"` } type Body struct { + MessageBodyBase Emb } RegisterMessage[Body]() - msg := MakeAppDataMessage(Body{Emb: Emb{Value: "test"}}) + msg := MakeAppDataMessage(&Body{Emb: Emb{Value: "test"}}) data, err := Serialize(msg) So(err, ShouldBeNil) retMsg, err := Deserialize(data) So(err, ShouldBeNil) + So(retMsg, ShouldNotBeNil) - So(retMsg.Body.(Body).Value, ShouldEqual, "test") + So(retMsg.Body.(*Body).Value, ShouldEqual, "test") }) - Convey("使用TypeSet类型,但字段值为nil", t, func() { - type MyTypeSet interface { + Convey("无方法的TypeUnino", t, func() { + type MyTypeUnion interface{} + type EleType1 struct { + Value int + } + + type Body struct { + MessageBodyBase + Value MyTypeUnion + } + RegisterMessage[Body]() + RegisterUnionType(types.NewTypeUnion[MyTypeUnion](myreflect.TypeOf[EleType1]())) + + msg := MakeAppDataMessage(&Body{Value: &EleType1{ + Value: 1, + }}) + data, err := Serialize(msg) + So(err, ShouldBeNil) + + retMsg, err := Deserialize(data) + So(err, ShouldBeNil) + + So(retMsg.Body.(*Body).Value, ShouldResemble, &EleType1{Value: 1}) + }) + + Convey("有方法的TypeUnino", t, func() { + type MyTypeUnion interface { + MessageBody + } + type EleType1 struct { + MessageBodyBase + Value int + } + + type Body struct { + MessageBodyBase + Value MyTypeUnion + } + RegisterMessage[Body]() + RegisterUnionType(types.NewTypeUnion[MyTypeUnion](myreflect.TypeOf[EleType1]())) + + msg := MakeAppDataMessage(&Body{Value: &EleType1{ + Value: 1, + }}) + data, err := Serialize(msg) + So(err, ShouldBeNil) + + retMsg, err := Deserialize(data) + So(err, ShouldBeNil) + + So(retMsg.Body.(*Body).Value, ShouldNotBeNil) + + So(retMsg.Body.(*Body).Value, ShouldResemble, &EleType1{Value: 1}) + }) + + Convey("使用TypeUnion类型,但字段值为nil", t, func() { + type MyTypeUnion interface { Test() } type Body struct { - Value MyTypeSet + MessageBodyBase + Value MyTypeUnion } RegisterMessage[Body]() - RegisterUnionType(types.NewTypeUnion[MyTypeSet]()) + RegisterUnionType(types.NewTypeUnion[MyTypeUnion]()) - msg := MakeAppDataMessage(Body{Value: nil}) + msg := MakeAppDataMessage(&Body{Value: nil}) data, err := Serialize(msg) So(err, ShouldBeNil) retMsg, err := Deserialize(data) So(err, ShouldBeNil) - So(retMsg.Body.(Body).Value, ShouldBeNil) + So(retMsg.Body.(*Body).Value, ShouldBeNil) }) - Convey("字段实际类型不在TypeSet范围内", t, func() { - type MyTypeSet interface{} + Convey("字段实际类型不在TypeUnion范围内", t, func() { + type MyTypeUnion interface{} type Body struct { - Value MyTypeSet + MessageBodyBase + Value MyTypeUnion } RegisterMessage[Body]() - RegisterUnionType(types.NewTypeUnion[MyTypeSet]()) + RegisterUnionType(types.NewTypeUnion[MyTypeUnion]()) - msg := MakeAppDataMessage(Body{Value: struct{}{}}) + msg := MakeAppDataMessage(&Body{Value: &struct{}{}}) _, err := Serialize(msg) So(err, ShouldNotBeNil) }) + } diff --git a/pkgs/mq/mq_test.go b/pkgs/mq/mq_test.go index 7c52c3d..f818601 100644 --- a/pkgs/mq/mq_test.go +++ b/pkgs/mq/mq_test.go @@ -11,6 +11,7 @@ import ( func Test_ServerClient(t *testing.T) { Convey("心跳", t, func() { type Msg struct { + MessageBodyBase Data int64 } RegisterMessage[Msg]() @@ -21,7 +22,7 @@ func Test_ServerClient(t *testing.T) { svr, err := NewRabbitMQServer(rabbitURL, testQueue, func(msg *Message) (*Message, error) { <-time.After(time.Second * 10) - reply := MakeAppDataMessage(Msg{Data: 1}) + reply := MakeAppDataMessage(&Msg{Data: 1}) return &reply, nil }) So(err, ShouldBeNil) @@ -31,18 +32,18 @@ func Test_ServerClient(t *testing.T) { cli, err := NewRabbitMQClient(rabbitURL, testQueue, "") So(err, ShouldBeNil) - _, err = cli.Request(MakeAppDataMessage(Msg{}), RequestOption{ + _, err = cli.Request(MakeAppDataMessage(&Msg{}), RequestOption{ Timeout: time.Second * 2, }) So(err, ShouldEqual, ErrWaitResponseTimeout) - reply, err := cli.Request(MakeAppDataMessage(Msg{}), RequestOption{ + reply, err := cli.Request(MakeAppDataMessage(&Msg{}), RequestOption{ Timeout: time.Second * 2, KeepAlive: true, }) So(err, ShouldBeNil) - msgReply, ok := reply.Body.(Msg) + msgReply, ok := reply.Body.(*Msg) So(ok, ShouldBeTrue) So(msgReply.Data, ShouldEqual, 1) diff --git a/pkgs/mq/response.go b/pkgs/mq/response.go index a50861d..e7b5b45 100644 --- a/pkgs/mq/response.go +++ b/pkgs/mq/response.go @@ -31,15 +31,19 @@ func Failed(errCode string, msg string) *CodeMessage { } } -func ReplyFailed[T any](errCode string, msg string) (*T, *CodeMessage) { - return nil, &CodeMessage{ +/* +// 在支持从调用上下文推导类型之前,不使用这个函数 +func ReplyFailed[T MessageBody](errCode string, msg string) (T, *CodeMessage) { + var defRet T + return defRet, &CodeMessage{ Code: errCode, Message: msg, } } +*/ -func ReplyOK[T any](val T) (*T, *CodeMessage) { - return &val, &CodeMessage{ +func ReplyOK[T MessageBody](val T) (T, *CodeMessage) { + return val, &CodeMessage{ Code: errorcode.OK, Message: "", }