| @@ -2,6 +2,7 @@ package mq | |||
| import ( | |||
| "fmt" | |||
| "gitlink.org.cn/cloudream/common/utils/sync2" | |||
| "net" | |||
| "time" | |||
| @@ -72,12 +73,18 @@ type RabbitMQServer struct { | |||
| connection *amqp.Connection | |||
| channel *amqp.Channel | |||
| closed chan any | |||
| config RabbitMQParam | |||
| OnMessage MessageHandlerFn | |||
| OnError func(err error) | |||
| } | |||
| func NewRabbitMQServer(url string, queueName string, onMessage MessageHandlerFn) (*RabbitMQServer, error) { | |||
| type RabbitMQParam struct { | |||
| RetryNum int `json:"retryNum"` | |||
| RetryInterval int `json:"retryInterval"` | |||
| } | |||
| func NewRabbitMQServer(url string, queueName string, onMessage MessageHandlerFn, cfg RabbitMQParam) (*RabbitMQServer, error) { | |||
| config := amqp.Config{ | |||
| Dial: func(network, addr string) (net.Conn, error) { | |||
| return net.DialTimeout(network, addr, 60*time.Second) // 设置连接超时时间为 60 秒 | |||
| @@ -102,12 +109,55 @@ func NewRabbitMQServer(url string, queueName string, onMessage MessageHandlerFn) | |||
| queueName: queueName, | |||
| closed: make(chan any), | |||
| OnMessage: onMessage, | |||
| config: cfg, | |||
| } | |||
| return srv, nil | |||
| } | |||
| func (s *RabbitMQServer) Serve() error { | |||
| type RabbitMQLogEvent interface{} | |||
| func (s *RabbitMQServer) Start() *sync2.UnboundChannel[RabbitMQLogEvent] { | |||
| ch := sync2.NewUnboundChannel[RabbitMQLogEvent]() | |||
| channel := s.openChannel(ch) | |||
| if channel == nil { | |||
| ch.Send(1) | |||
| return ch | |||
| } | |||
| retryNum := 0 | |||
| for { | |||
| select { | |||
| case rawReq, ok := <-channel: | |||
| if !ok { | |||
| if retryNum > s.config.RetryNum { | |||
| ch.Send(fmt.Errorf("maximum number of retries exceeded")) | |||
| ch.Send(1) | |||
| return ch | |||
| } | |||
| retryNum++ | |||
| time.Sleep(time.Duration(s.config.RetryInterval) * time.Millisecond) | |||
| channel = s.openChannel(ch) | |||
| } | |||
| reqMsg, err := Deserialize(rawReq.Body) | |||
| if err != nil { | |||
| ch.Send(NewDeserializeError(err)) | |||
| continue | |||
| } | |||
| go s.handleMessage(ch, reqMsg, rawReq) | |||
| case <-s.closed: | |||
| return nil | |||
| } | |||
| } | |||
| } | |||
| func (s *RabbitMQServer) openChannel(ch *sync2.UnboundChannel[RabbitMQLogEvent]) <-chan amqp.Delivery { | |||
| _, err := s.channel.QueueDeclare( | |||
| s.queueName, | |||
| false, | |||
| @@ -117,7 +167,8 @@ func (s *RabbitMQServer) Serve() error { | |||
| nil, | |||
| ) | |||
| if err != nil { | |||
| return fmt.Errorf("declare queue failed, err: %w", err) | |||
| ch.Send(fmt.Errorf("declare queue failed, err: %w", err)) | |||
| return nil | |||
| } | |||
| channel, err := s.channel.Consume( | |||
| @@ -131,32 +182,14 @@ func (s *RabbitMQServer) Serve() error { | |||
| ) | |||
| if err != nil { | |||
| return fmt.Errorf("open consume channel failed, err: %w", err) | |||
| ch.Send(fmt.Errorf("get rabbitmq channel failed, err: %w", err)) | |||
| return nil | |||
| } | |||
| for { | |||
| select { | |||
| case rawReq, ok := <-channel: | |||
| if !ok { | |||
| s.onError(NewDeserializeError(fmt.Errorf("channel is closed"))) | |||
| return NewReceiveMessageError(fmt.Errorf("channel is closed")) | |||
| } | |||
| reqMsg, err := Deserialize(rawReq.Body) | |||
| if err != nil { | |||
| s.onError(NewDeserializeError(err)) | |||
| continue | |||
| } | |||
| go s.handleMessage(reqMsg, rawReq) | |||
| case <-s.closed: | |||
| return nil | |||
| } | |||
| } | |||
| return channel | |||
| } | |||
| func (s *RabbitMQServer) handleMessage(reqMsg *Message, rawReq amqp.Delivery) { | |||
| func (s *RabbitMQServer) handleMessage(ch *sync2.UnboundChannel[RabbitMQLogEvent], reqMsg *Message, rawReq amqp.Delivery) { | |||
| replyed := make(chan bool) | |||
| defer close(replyed) | |||
| @@ -167,7 +200,7 @@ func (s *RabbitMQServer) handleMessage(reqMsg *Message, rawReq amqp.Delivery) { | |||
| reply, err := s.OnMessage(reqMsg) | |||
| if err != nil { | |||
| s.onError(NewDispatchError(err)) | |||
| ch.Send(NewDispatchError(err)) | |||
| return | |||
| } | |||
| @@ -175,7 +208,7 @@ func (s *RabbitMQServer) handleMessage(reqMsg *Message, rawReq amqp.Delivery) { | |||
| reply.SetRequestID(reqMsg.GetRequestID()) | |||
| err := s.replyToClient(*reply, &rawReq) | |||
| if err != nil { | |||
| s.onError(NewReplyError(err)) | |||
| ch.Send(NewReplyError(err)) | |||
| } | |||
| } | |||
| } | |||