Browse Source

消息队列公共代码移动到common库

pull/1/head
Sydonian 2 years ago
parent
commit
5a23e63394
7 changed files with 858 additions and 4 deletions
  1. +7
    -4
      go.mod
  2. +15
    -0
      go.sum
  3. +266
    -0
      pkg/mq/client.go
  4. +167
    -0
      pkg/mq/message.go
  5. +165
    -0
      pkg/mq/message_test.go
  6. +46
    -0
      pkg/mq/response.go
  7. +192
    -0
      pkg/mq/server.go

+ 7
- 4
go.mod View File

@@ -6,14 +6,18 @@ require (
github.com/antonfisher/nested-logrus-formatter v1.3.1
github.com/beevik/etree v1.2.0
github.com/go-ping/ping v1.1.0
github.com/google/uuid v1.3.0
github.com/hashicorp/go-multierror v1.1.1
github.com/imdario/mergo v0.3.15
github.com/ipfs/go-ipfs-api v0.6.0
github.com/json-iterator/go v1.1.12
github.com/magefile/mage v1.15.0
github.com/mitchellh/mapstructure v1.5.0
github.com/otiai10/copy v1.12.0
github.com/samber/lo v1.36.0
github.com/sirupsen/logrus v1.9.2
github.com/smartystreets/goconvey v1.8.0
github.com/streadway/amqp v1.1.0
github.com/zyedidia/generic v1.2.1
go.etcd.io/etcd/client/v3 v3.5.9
golang.org/x/exp v0.0.0-20230519143937-03e91628a987
@@ -27,8 +31,8 @@ require (
github.com/decred/dcrd/dcrec/secp256k1/v4 v4.1.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/ipfs/boxo v0.8.0 // indirect
github.com/ipfs/go-cid v0.4.0 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
@@ -38,6 +42,8 @@ require (
github.com/libp2p/go-libp2p v0.26.3 // indirect
github.com/minio/sha256-simd v1.0.0 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/mr-tron/base58 v1.2.0 // indirect
github.com/multiformats/go-base32 v0.1.0 // indirect
github.com/multiformats/go-base36 v0.2.0 // indirect
@@ -65,6 +71,3 @@ require (
google.golang.org/protobuf v1.30.0 // indirect
lukechampine.com/blake3 v1.1.7 // indirect
)

// 运行go mod tidy时需要将下面几行取消注释
//replace gitlink.org.cn/cloudream/proto => ../proto

+ 15
- 0
go.sum View File

@@ -27,11 +27,17 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/imdario/mergo v0.3.15 h1:M8XP7IuFNsqUx6VPK2P9OSmsYsI/YFaGil0uD21V3dM=
github.com/imdario/mergo v0.3.15/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY=
github.com/ipfs/boxo v0.8.0 h1:UdjAJmHzQHo/j3g3b1bAcAXCj/GM6iTwvSlBDvPBNBs=
@@ -40,6 +46,8 @@ github.com/ipfs/go-cid v0.4.0 h1:a4pdZq0sx6ZSxbCizebnKiMCx/xI/aBBFlB73IgH4rA=
github.com/ipfs/go-cid v0.4.0/go.mod h1:uQHwDeX4c6CtyrFwdqyhpNcxVewur1M7l7fNU7LKwZk=
github.com/ipfs/go-ipfs-api v0.6.0 h1:JARgG0VTbjyVhO5ZfesnbXv9wTcMvoKRBLF1SzJqzmg=
github.com/ipfs/go-ipfs-api v0.6.0/go.mod h1:iDC2VMwN9LUpQV/GzEeZ2zNqd8NUdRmWcFM+K/6odf0=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
@@ -62,6 +70,10 @@ github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/mr-tron/base58 v1.2.0 h1:T/HDJBh4ZCPbU39/+c3rRvE0uKBQlU27+QI8LJ4t64o=
github.com/mr-tron/base58 v1.2.0/go.mod h1:BinMc/sQntlIE1frQmRFPUoPA1Zkr8VRgBdjWI2mNwc=
github.com/multiformats/go-base32 v0.1.0 h1:pVx9xoSPqEIQG8o+UbAe7DNi51oej1NtK+aGkbLYxPE=
@@ -96,7 +108,10 @@ github.com/smartystreets/goconvey v1.8.0 h1:Oi49ha/2MURE0WexF052Z0m+BNSGirfjg5RL
github.com/smartystreets/goconvey v1.8.0/go.mod h1:EdX8jtrTIj26jmjCOVNMVSIYAtgexqXKHOXW2Dx9JLg=
github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI=
github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/streadway/amqp v1.1.0 h1:py12iX8XSyI7aN/3dUT8DFIDJazNJsVJdxNVEpnQTZM=
github.com/streadway/amqp v1.1.0/go.mod h1:WYSrTEYHOXHd0nwFeUXAe2G2hRnQT+deZJJf88uS9Bg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/thoas/go-funk v0.9.1 h1:O549iLZqPpTUQ10ykd26sZhzD+rmR5pWhuElrhbC20M=


+ 266
- 0
pkg/mq/client.go View File

@@ -0,0 +1,266 @@
package mq

import (
"fmt"
"sync"
"time"

"github.com/hashicorp/go-multierror"
"github.com/streadway/amqp"
"gitlink.org.cn/cloudream/common/consts/errorcode"
"gitlink.org.cn/cloudream/common/pkg/future"
"gitlink.org.cn/cloudream/common/pkg/logger"
myreflect "gitlink.org.cn/cloudream/common/utils/reflect"
)

const (
DIRECT_REPLY_TO = "amq.rabbitmq.reply-to"
)

type CodeMessageError struct {
code string
message string
}

func (e *CodeMessageError) Error() string {
return fmt.Sprintf("code: %s, message: %s", e.code, e.message)
}

type SendOption struct {
// 等待响应的超时时间,为0代表不设置超时时间
Timeout time.Duration
}

type RequestOption struct {
// 等待响应的超时时间,为0代表不设置超时时间
Timeout time.Duration
}

type RabbitMQClient struct {
connection *amqp.Connection
channel *amqp.Channel
exchange string
key string

requests map[string]*future.SetValueFuture[*Message]
requestsLock sync.Mutex

closed chan any
}

func NewRabbitMQClient(url string, key string, exchange string) (*RabbitMQClient, error) {
connection, err := amqp.Dial(url)
if err != nil {
return nil, fmt.Errorf("connecting to %s: %w", url, err)
}

channel, err := connection.Channel()
if err != nil {
connection.Close()
return nil, fmt.Errorf("openning channel on connection: %w", err)
}

cli := &RabbitMQClient{
connection: connection,
channel: channel,
exchange: exchange,
key: key,
requests: make(map[string]*future.SetValueFuture[*Message]),
closed: make(chan any),
}

// NOTE! 经测试发现,必须在Publish之前调用Consume进行消息接收,否则Consume会返回错误
// 因此这段代码不能移动到serve函数中,必须放在这里,保证顺序
recvChan, err := channel.Consume(
// 一个特殊队列,服务端的回复消息都会发送到这个队列里
DIRECT_REPLY_TO,
"",
true,
true,
false,
false,
nil,
)
if err != nil {
channel.Close()
connection.Close()
return nil, fmt.Errorf("openning consume channel: %w", err)
}

go func() {
err := cli.serve(recvChan)
if err != nil {
// TODO 错误处理
logger.Std.Warnf("rabbitmq client serving: %s", err.Error())
}
}()

return cli, nil
}

func (cli *RabbitMQClient) Request(req Message, opts ...RequestOption) (*Message, error) {
opt := RequestOption{Timeout: time.Second * 15}
if len(opts) > 0 {
opt = opts[0]
}

reqID := req.MakeRequestID()
fut := future.NewSetValue[*Message]()

cli.requestsLock.Lock()
cli.requests[reqID] = fut
cli.requestsLock.Unlock()

// 启动超时定时器
if opt.Timeout != 0 {
go func() {
<-time.After(opt.Timeout)
cli.requestsLock.Lock()
// 由于只会在requestsLock.Lock()之后修改fut的状态,所以Complete的判断是可信的
if !fut.IsComplete() {
fut.SetError(fmt.Errorf("wait response timeout"))
}
delete(cli.requests, reqID)
cli.requestsLock.Unlock()
}()
}

err := cli.Send(req, SendOption{
Timeout: opt.Timeout,
})
if err != nil {
cli.requestsLock.Lock()
delete(cli.requests, reqID)
cli.requestsLock.Unlock()

return nil, fmt.Errorf("sending message: %w", err)
}

resp, err := fut.WaitValue()
if err != nil {
return nil, fmt.Errorf("requesting: %w", err)
}

return resp, nil
}

func (c *RabbitMQClient) Send(msg Message, opts ...SendOption) error {
opt := SendOption{}
if len(opts) > 0 {
opt = opts[0]
}

data, err := Serialize(msg)
if err != nil {
return fmt.Errorf("serialize message failed: %w", err)
}

expiration := ""
if opt.Timeout > 0 {
if opt.Timeout < time.Millisecond {
expiration = "1"
} else {
expiration = fmt.Sprintf("%d", opt.Timeout.Milliseconds()+1)
}
}

err = c.channel.Publish(c.exchange, c.key, false, false, amqp.Publishing{
ContentType: "text/plain",
Body: data,
// 设置了此字段后rabbitmq会建立一个临时且私有的队列,服务端的回复消息都是送到此队列中
ReplyTo: DIRECT_REPLY_TO,
Expiration: expiration,
})

if err != nil {
return fmt.Errorf("publishing data: %w", err)
}

return nil
}

func (c *RabbitMQClient) serve(recvChan <-chan amqp.Delivery) error {
for {
select {
case rawMsg, ok := <-recvChan:
if !ok {
return fmt.Errorf("receive channel closed")
}

msg, err := Deserialize(rawMsg.Body)
if err != nil {
// TODO 记录日志
logger.Std.Warnf("deserializing message body: %s", err.Error())
continue
}

reqID := msg.GetRequestID()
if reqID != "" {
c.requestsLock.Lock()
if req, ok := c.requests[reqID]; ok {
req.SetValue(msg)
delete(c.requests, reqID)
}
c.requestsLock.Unlock()
}

case <-c.closed:
return nil
}
}
}

func (c *RabbitMQClient) Close() error {
var retErr error

close(c.closed)

err := c.channel.Close()
if err != nil {
multierror.Append(retErr, fmt.Errorf("closing channel: %w", err))
}

err = c.connection.Close()
if err != nil {
multierror.Append(retErr, fmt.Errorf("closing connection: %w", err))
}

return retErr
}

// 发送消息并等待回应。因为无法自动推断出TResp的类型,所以将其放在第一个手工填写,之后的TBody可以自动推断出来
func Request[TResp any, TReq any](cli *RabbitMQClient, req TReq, opts ...RequestOption) (*TResp, error) {
resp, err := cli.Request(MakeMessage(req), opts...)
if err != nil {
return nil, fmt.Errorf("requesting: %w", err)
}

errCode, errMsg := resp.GetCodeMessage()
if errCode != errorcode.OK {
return nil, &CodeMessageError{
code: errCode,
message: errMsg,
}
}

respBody, ok := resp.Body.(TResp)
if !ok {
return nil, fmt.Errorf("expect a %s body, but got %s",
myreflect.ElemTypeOf[TResp]().Name(),
myreflect.TypeOfValue(resp.Body).Name())
}

return &respBody, nil
}

// 发送消息,不等待回应
func Send[TReq any](cli *RabbitMQClient, msg TReq, opts ...SendOption) error {
req := MakeMessage(msg)

err := cli.Send(req, opts...)
if err != nil {
return fmt.Errorf("sending: %w", err)
}

return nil
}

+ 167
- 0
pkg/mq/message.go View File

@@ -0,0 +1,167 @@
package mq

import (
"bytes"
"fmt"
"reflect"
"unsafe"

"github.com/google/uuid"
jsoniter "github.com/json-iterator/go"
myreflect "gitlink.org.cn/cloudream/common/utils/reflect"
"gitlink.org.cn/cloudream/common/utils/serder"
)

type Message struct {
Headers map[string]string `json:"headers"`
Body MessageBodyTypes `json:"body"`
}

type MessageBodyTypes interface{}

func (m *Message) GetRequestID() string {
return m.Headers["requestID"]
}

func (m *Message) SetRequestID(id string) {
m.Headers["requestID"] = id
}

func (m *Message) MakeRequestID() string {
id := uuid.NewString()
m.Headers["requestID"] = id
return id
}

func (m *Message) SetCodeMessage(code string, msg string) {
m.Headers["responseCode"] = code
m.Headers["responseMessage"] = msg
}

func (m *Message) GetCodeMessage() (string, string) {
return m.Headers["responseCode"], m.Headers["responseMessage"]
}

func MakeMessage(body MessageBodyTypes) Message {
msg := Message{
Headers: make(map[string]string),
Body: body,
}

return msg
}

type typeSet struct {
TopType myreflect.Type
ElementTypes serder.TypeNameResolver
}

var typeSets map[myreflect.Type]typeSet = make(map[reflect.Type]typeSet)
var messageTypeSet *typeSet

// 所有新定义的Message都需要在init中调用此函数
func RegisterMessage[T any]() {
messageTypeSet.ElementTypes.Register(myreflect.TypeOf[T]())
}

// 如果对一个类型T调用了此函数,那么在序列化结构体中包含的T类型字段时,
// 会将字段值的实际类型保存在序列化后的结果中(作为一个字段@type),
// 在反序列化时,会根据类型信息重建原本的字段值。
//
// 只会处理types指定的类型。
func RegisterTypeSet[T any](types ...myreflect.Type) *typeSet {
set := typeSet{
TopType: myreflect.TypeOf[T](),
ElementTypes: serder.NewTypeNameResolver(true),
}

for _, t := range types {
set.ElementTypes.Register(t)
}

typeSets[set.TopType] = set

jsoniter.RegisterTypeEncoderFunc(myreflect.TypeOf[T]().String(),
func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
val := *((*T)(ptr))
var ifVal any = val

if ifVal != nil {
stream.WriteArrayStart()
typeStr, err := set.ElementTypes.TypeToString(myreflect.TypeOfValue(val))
if err != nil {
stream.Error = err
return
}
stream.WriteString(typeStr)
stream.WriteRaw(",")
stream.WriteVal(val)
stream.WriteArrayEnd()
} else {
stream.WriteNil()
}
},
func(p unsafe.Pointer) bool {
return false
})

jsoniter.RegisterTypeDecoderFunc(myreflect.TypeOf[T]().String(),
func(ptr unsafe.Pointer, iter *jsoniter.Iterator) {
vp := (*T)(ptr)

nextTkType := iter.WhatIsNext()
if nextTkType == jsoniter.NilValue {
iter.ReadNil()
var zero T
*vp = zero
} else if nextTkType == jsoniter.ArrayValue {
iter.ReadArray()
typeStr := iter.ReadString()
iter.ReadArray()

typ, err := set.ElementTypes.StringToType(typeStr)
if err != nil {
iter.ReportError("get type from string", err.Error())
return
}

val := reflect.New(typ)
iter.ReadVal(val.Interface())
*vp = val.Elem().Interface().(T)

iter.ReadArray()
} else {
iter.ReportError("parse TypeSet field", fmt.Sprintf("unknow next token type %v", nextTkType))
return
}
})

return &set
}

func Serialize(msg Message) ([]byte, error) {
buf := bytes.NewBuffer(nil)
enc := jsoniter.NewEncoder(buf)
err := enc.Encode(msg)
if err != nil {
return nil, err
}

return buf.Bytes(), nil
}

func Deserialize(data []byte) (*Message, error) {
dec := jsoniter.NewDecoder(bytes.NewBuffer(data))

var msg Message
err := dec.Decode(&msg)
if err != nil {
return nil, err
}

return &msg, nil
}

func init() {
messageTypeSet = RegisterTypeSet[MessageBodyTypes]()
}

+ 165
- 0
pkg/mq/message_test.go View File

@@ -0,0 +1,165 @@
package mq

import (
"bytes"
"fmt"
"testing"
"unsafe"

jsoniter "github.com/json-iterator/go"
. "github.com/smartystreets/goconvey/convey"
myreflect "gitlink.org.cn/cloudream/common/utils/reflect"
)

func TestMessage(t *testing.T) {
Convey("测试jsoniter", t, func() {

type MyAny interface{}

type Struct1 struct {
Value string
}

type Struct2 struct {
Value string
}

type Top struct {
A1 Struct1
A2 MyAny
Nil MyAny
}

top := Top{
A1: Struct1{
Value: "s1",
},
A2: Struct2{
Value: "s2",
},
Nil: nil,
}

jsoniter.RegisterTypeEncoderFunc(myreflect.TypeOf[MyAny]().String(),
func(ptr unsafe.Pointer, stream *jsoniter.Stream) {
val := *((*MyAny)(ptr))

stream.WriteArrayStart()
if val != nil {
stream.WriteString(myreflect.TypeOfValue(val).String())
stream.WriteRaw(",")
stream.WriteVal(val)
}
stream.WriteArrayEnd()
},
func(p unsafe.Pointer) bool {
return false
})

jsoniter.RegisterTypeDecoderFunc(myreflect.TypeOf[MyAny]().String(),
func(ptr unsafe.Pointer, iter *jsoniter.Iterator) {
vp := (*MyAny)(ptr)

nextTkType := iter.WhatIsNext()
if nextTkType == jsoniter.NilValue {
*vp = nil
} else if nextTkType == jsoniter.ArrayValue {
iter.ReadArray()
typ := iter.ReadString()
iter.ReadArray()

if typ == "message.Struct1" {
var st Struct1
iter.ReadVal(&st)
*vp = st
} else if typ == "message.Struct2" {
var st Struct2
iter.ReadVal(&st)
*vp = st
}

iter.ReadArray()
}
})

buf := bytes.NewBuffer(nil)
enc := jsoniter.NewEncoder(buf)
err := enc.Encode(top)
So(err, ShouldBeNil)

dec := jsoniter.NewDecoder(buf)
var newTop Top
dec.Decode(&newTop)

fmt.Printf("%s\n", buf.String())
fmt.Printf("%#+v", newTop)
})

Convey("body中包含nil数组", t, func() {
type Body struct {
NilArr []string
}
RegisterMessage[Body]()

msg := MakeMessage(Body{})
data, err := Serialize(msg)
So(err, ShouldBeNil)

retMsg, err := Deserialize(data)
So(err, ShouldBeNil)

So(retMsg.Body.(Body).NilArr, ShouldBeNil)
})

Convey("body中包含匿名结构体", t, func() {
type Emb struct {
Value string `json:"value"`
}
type Body struct {
Emb
}
RegisterMessage[Body]()

msg := MakeMessage(Body{Emb: Emb{Value: "test"}})
data, err := Serialize(msg)
So(err, ShouldBeNil)

retMsg, err := Deserialize(data)
So(err, ShouldBeNil)

So(retMsg.Body.(Body).Value, ShouldEqual, "test")
})

Convey("使用TypeSet类型,但字段值为nil", t, func() {
type MyTypeSet interface{}

type Body struct {
Value MyTypeSet
}
RegisterMessage[Body]()
RegisterTypeSet[MyTypeSet]()

msg := MakeMessage(Body{Value: nil})
data, err := Serialize(msg)
So(err, ShouldBeNil)

retMsg, err := Deserialize(data)
So(err, ShouldBeNil)

So(retMsg.Body.(Body).Value, ShouldBeNil)
})

Convey("字段实际类型不在TypeSet范围内", t, func() {
type MyTypeSet interface{}

type Body struct {
Value MyTypeSet
}
RegisterMessage[Body]()
RegisterTypeSet[MyTypeSet]()

msg := MakeMessage(Body{Value: struct{}{}})
_, err := Serialize(msg)
So(err, ShouldNotBeNil)
})
}

+ 46
- 0
pkg/mq/response.go View File

@@ -0,0 +1,46 @@
package mq

import (
"gitlink.org.cn/cloudream/common/consts/errorcode"
)

type CodeMessage struct {
Code string `json:"code"`
Message string `json:"message"`
}

func (msg *CodeMessage) IsOK() bool {
return msg.Code == errorcode.OK
}

func (msg *CodeMessage) IsFailed() bool {
return !msg.IsOK()
}

func OK() *CodeMessage {
return &CodeMessage{
Code: errorcode.OK,
Message: "",
}
}

func Failed(errCode string, msg string) *CodeMessage {
return &CodeMessage{
Code: errCode,
Message: msg,
}
}

func ReplyFailed[T any](errCode string, msg string) (*T, *CodeMessage) {
return nil, &CodeMessage{
Code: errCode,
Message: msg,
}
}

func ReplyOK[T any](val T) (*T, *CodeMessage) {
return &val, &CodeMessage{
Code: errorcode.OK,
Message: "",
}
}

+ 192
- 0
pkg/mq/server.go View File

@@ -0,0 +1,192 @@
package mq

import (
"fmt"

"github.com/streadway/amqp"
)

type ReceiveMessageError struct {
err error
}

func (err ReceiveMessageError) Error() string {
return fmt.Sprintf("receive message error: %s", err.err.Error())
}

func NewReceiveMessageError(err error) ReceiveMessageError {
return ReceiveMessageError{
err: err,
}
}

type DeserializeError struct {
err error
}

func (err DeserializeError) Error() string {
return fmt.Sprintf("deserialize error: %s", err.err.Error())
}

func NewDeserializeError(err error) DeserializeError {
return DeserializeError{
err: err,
}
}

type DispatchError struct {
err error
}

func (err DispatchError) Error() string {
return fmt.Sprintf("dispatch error: %s", err.err.Error())
}

func NewDispatchError(err error) DispatchError {
return DispatchError{
err: err,
}
}

type ReplyError struct {
err error
}

func (err ReplyError) Error() string {
return fmt.Sprintf("replay to client : %s", err.err.Error())
}

func NewReplyError(err error) ReplyError {
return ReplyError{
err: err,
}
}

// 处理消息。会将第一个返回值作为响应回复给客户端,如果为nil,则不回复。
type MessageHandlerFn func(msg *Message) (*Message, error)

type RabbitMQServer struct {
queueName string
connection *amqp.Connection
channel *amqp.Channel
closed chan any

OnMessage MessageHandlerFn
OnError func(err error)
}

func NewRabbitMQServer(url string, queueName string, onMessage MessageHandlerFn) (*RabbitMQServer, error) {
connection, err := amqp.Dial(url)
if err != nil {
return nil, fmt.Errorf("connecting to %s: %w", url, err)
}

channel, err := connection.Channel()
if err != nil {
connection.Close()
return nil, fmt.Errorf("openning channel on connection: %w", err)
}

srv := &RabbitMQServer{
connection: connection,
channel: channel,
queueName: queueName,
closed: make(chan any),
OnMessage: onMessage,
}

return srv, nil
}

func (s *RabbitMQServer) Serve() error {
_, err := s.channel.QueueDeclare(
s.queueName,
false,
true,
false,
false,
nil,
)
if err != nil {
return fmt.Errorf("declare queue failed, err: %w", err)
}

channel, err := s.channel.Consume(
s.queueName,
"",
true,
false,
true,
false,
nil,
)

if err != nil {
return fmt.Errorf("open consume channel failed, err: %w", err)
}

for {
select {
case rawReq, ok := <-channel:
if !ok {
if s.OnError != nil {
s.OnError(NewReceiveMessageError(fmt.Errorf("channel is closed")))
}
return NewReceiveMessageError(fmt.Errorf("channel is closed"))
}

reqMsg, err := Deserialize(rawReq.Body)
if err != nil {
if s.OnError != nil {
s.OnError(NewDeserializeError(err))
}
break
}

reply, err := s.OnMessage(reqMsg)
if err != nil {
if s.OnError != nil {
s.OnError(NewDispatchError(err))
}
continue
}

if reply != nil {
reply.SetRequestID(reqMsg.GetRequestID())
err := s.replyClientMessage(*reply, &rawReq)
if err != nil {
if s.OnError != nil {
s.OnError(NewReplyError(err))
}
}
}

case <-s.closed:
return nil
}
}
}

func (s *RabbitMQServer) Close() {
close(s.closed)
}

// replyClientMessage 回复客户端的消息,需要用到客户端发来的消息中的字段来判断回到哪个队列
func (s *RabbitMQServer) replyClientMessage(reply Message, reqMsg *amqp.Delivery) error {
msgData, err := Serialize(reply)
if err != nil {
return fmt.Errorf("serialize message failed: %w", err)
}

return s.channel.Publish(
"",
reqMsg.ReplyTo,
false,
false,
amqp.Publishing{
ContentType: "text/plain",
Body: msgData,
Expiration: "30000", // 响应消息的超时时间默认30秒
},
)
}

Loading…
Cancel
Save