| @@ -2,6 +2,7 @@ package mq | |||||
| import ( | import ( | ||||
| "fmt" | "fmt" | ||||
| "gitlink.org.cn/cloudream/common/utils/sync2" | |||||
| "net" | "net" | ||||
| "time" | "time" | ||||
| @@ -72,12 +73,18 @@ type RabbitMQServer struct { | |||||
| connection *amqp.Connection | connection *amqp.Connection | ||||
| channel *amqp.Channel | channel *amqp.Channel | ||||
| closed chan any | closed chan any | ||||
| config RabbitMQParam | |||||
| OnMessage MessageHandlerFn | OnMessage MessageHandlerFn | ||||
| OnError func(err error) | OnError func(err error) | ||||
| } | } | ||||
| func NewRabbitMQServer(url string, queueName string, onMessage MessageHandlerFn) (*RabbitMQServer, error) { | |||||
| type RabbitMQParam struct { | |||||
| RetryNum int `json:"retryNum"` | |||||
| RetryInterval int `json:"retryInterval"` | |||||
| } | |||||
| func NewRabbitMQServer(url string, queueName string, onMessage MessageHandlerFn, cfg RabbitMQParam) (*RabbitMQServer, error) { | |||||
| config := amqp.Config{ | config := amqp.Config{ | ||||
| Dial: func(network, addr string) (net.Conn, error) { | Dial: func(network, addr string) (net.Conn, error) { | ||||
| return net.DialTimeout(network, addr, 60*time.Second) // 设置连接超时时间为 60 秒 | return net.DialTimeout(network, addr, 60*time.Second) // 设置连接超时时间为 60 秒 | ||||
| @@ -102,12 +109,55 @@ func NewRabbitMQServer(url string, queueName string, onMessage MessageHandlerFn) | |||||
| queueName: queueName, | queueName: queueName, | ||||
| closed: make(chan any), | closed: make(chan any), | ||||
| OnMessage: onMessage, | OnMessage: onMessage, | ||||
| config: cfg, | |||||
| } | } | ||||
| return srv, nil | return srv, nil | ||||
| } | } | ||||
| func (s *RabbitMQServer) Serve() error { | |||||
| type RabbitMQLogEvent interface{} | |||||
| func (s *RabbitMQServer) Start() *sync2.UnboundChannel[RabbitMQLogEvent] { | |||||
| ch := sync2.NewUnboundChannel[RabbitMQLogEvent]() | |||||
| channel := s.openChannel(ch) | |||||
| if channel == nil { | |||||
| ch.Send(1) | |||||
| return ch | |||||
| } | |||||
| retryNum := 0 | |||||
| for { | |||||
| select { | |||||
| case rawReq, ok := <-channel: | |||||
| if !ok { | |||||
| if retryNum > s.config.RetryNum { | |||||
| ch.Send(fmt.Errorf("maximum number of retries exceeded")) | |||||
| ch.Send(1) | |||||
| return ch | |||||
| } | |||||
| retryNum++ | |||||
| time.Sleep(time.Duration(s.config.RetryInterval) * time.Millisecond) | |||||
| channel = s.openChannel(ch) | |||||
| } | |||||
| reqMsg, err := Deserialize(rawReq.Body) | |||||
| if err != nil { | |||||
| ch.Send(NewDeserializeError(err)) | |||||
| continue | |||||
| } | |||||
| go s.handleMessage(ch, reqMsg, rawReq) | |||||
| case <-s.closed: | |||||
| return nil | |||||
| } | |||||
| } | |||||
| } | |||||
| func (s *RabbitMQServer) openChannel(ch *sync2.UnboundChannel[RabbitMQLogEvent]) <-chan amqp.Delivery { | |||||
| _, err := s.channel.QueueDeclare( | _, err := s.channel.QueueDeclare( | ||||
| s.queueName, | s.queueName, | ||||
| false, | false, | ||||
| @@ -117,7 +167,8 @@ func (s *RabbitMQServer) Serve() error { | |||||
| nil, | nil, | ||||
| ) | ) | ||||
| if err != nil { | if err != nil { | ||||
| return fmt.Errorf("declare queue failed, err: %w", err) | |||||
| ch.Send(fmt.Errorf("declare queue failed, err: %w", err)) | |||||
| return nil | |||||
| } | } | ||||
| channel, err := s.channel.Consume( | channel, err := s.channel.Consume( | ||||
| @@ -131,32 +182,14 @@ func (s *RabbitMQServer) Serve() error { | |||||
| ) | ) | ||||
| if err != nil { | if err != nil { | ||||
| return fmt.Errorf("open consume channel failed, err: %w", err) | |||||
| ch.Send(fmt.Errorf("get rabbitmq channel failed, err: %w", err)) | |||||
| return nil | |||||
| } | } | ||||
| for { | |||||
| select { | |||||
| case rawReq, ok := <-channel: | |||||
| if !ok { | |||||
| s.onError(NewDeserializeError(fmt.Errorf("channel is closed"))) | |||||
| return NewReceiveMessageError(fmt.Errorf("channel is closed")) | |||||
| } | |||||
| reqMsg, err := Deserialize(rawReq.Body) | |||||
| if err != nil { | |||||
| s.onError(NewDeserializeError(err)) | |||||
| continue | |||||
| } | |||||
| go s.handleMessage(reqMsg, rawReq) | |||||
| case <-s.closed: | |||||
| return nil | |||||
| } | |||||
| } | |||||
| return channel | |||||
| } | } | ||||
| func (s *RabbitMQServer) handleMessage(reqMsg *Message, rawReq amqp.Delivery) { | |||||
| func (s *RabbitMQServer) handleMessage(ch *sync2.UnboundChannel[RabbitMQLogEvent], reqMsg *Message, rawReq amqp.Delivery) { | |||||
| replyed := make(chan bool) | replyed := make(chan bool) | ||||
| defer close(replyed) | defer close(replyed) | ||||
| @@ -167,7 +200,7 @@ func (s *RabbitMQServer) handleMessage(reqMsg *Message, rawReq amqp.Delivery) { | |||||
| reply, err := s.OnMessage(reqMsg) | reply, err := s.OnMessage(reqMsg) | ||||
| if err != nil { | if err != nil { | ||||
| s.onError(NewDispatchError(err)) | |||||
| ch.Send(NewDispatchError(err)) | |||||
| return | return | ||||
| } | } | ||||
| @@ -175,7 +208,7 @@ func (s *RabbitMQServer) handleMessage(reqMsg *Message, rawReq amqp.Delivery) { | |||||
| reply.SetRequestID(reqMsg.GetRequestID()) | reply.SetRequestID(reqMsg.GetRequestID()) | ||||
| err := s.replyToClient(*reply, &rawReq) | err := s.replyToClient(*reply, &rawReq) | ||||
| if err != nil { | if err != nil { | ||||
| s.onError(NewReplyError(err)) | |||||
| ch.Send(NewReplyError(err)) | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||