|
- package mq
-
- import (
- "database/sql"
- "fmt"
-
- "github.com/jmoiron/sqlx"
- "gitlink.org.cn/cloudream/common/consts/errorcode"
- "gitlink.org.cn/cloudream/common/pkgs/logger"
- "gitlink.org.cn/cloudream/common/pkgs/mq"
- cdssdk "gitlink.org.cn/cloudream/common/sdks/storage"
- coormq "gitlink.org.cn/cloudream/storage/common/pkgs/mq/coordinator"
- )
-
- func (svc *Service) GetUserNodes(msg *coormq.GetUserNodes) (*coormq.GetUserNodesResp, *mq.CodeMessage) {
- nodes, err := svc.db.Node().GetUserNodes(svc.db.SQLCtx(), msg.UserID)
- if err != nil {
- logger.WithField("UserID", msg.UserID).
- Warnf("query user nodes failed, err: %s", err.Error())
- return nil, mq.Failed(errorcode.OperationFailed, "query user nodes failed")
- }
-
- return mq.ReplyOK(coormq.NewGetUserNodesResp(nodes))
- }
-
- func (svc *Service) GetNodes(msg *coormq.GetNodes) (*coormq.GetNodesResp, *mq.CodeMessage) {
- var nodes []cdssdk.Node
-
- if msg.NodeIDs == nil {
- var err error
- nodes, err = svc.db2.Node().GetAllNodes(svc.db2.DefCtx())
- if err != nil {
- logger.Warnf("getting all nodes: %s", err.Error())
- return nil, mq.Failed(errorcode.OperationFailed, "get all node failed")
- }
-
- } else {
- // 可以不用事务
- for _, id := range msg.NodeIDs {
- node, err := svc.db2.Node().GetByID(svc.db2.DefCtx(), id)
- if err != nil {
- logger.WithField("NodeID", id).
- Warnf("query node failed, err: %s", err.Error())
- return nil, mq.Failed(errorcode.OperationFailed, "query node failed")
- }
-
- nodes = append(nodes, node)
- }
- }
-
- return mq.ReplyOK(coormq.NewGetNodesResp(nodes))
- }
-
- func (svc *Service) GetNodeConnectivities(msg *coormq.GetNodeConnectivities) (*coormq.GetNodeConnectivitiesResp, *mq.CodeMessage) {
- cons, err := svc.db.NodeConnectivity().BatchGetByFromNode(svc.db.SQLCtx(), msg.NodeIDs)
- if err != nil {
- logger.Warnf("batch get node connectivities by from node: %s", err.Error())
- return nil, mq.Failed(errorcode.OperationFailed, "batch get node connectivities by from node failed")
- }
-
- return mq.ReplyOK(coormq.RespGetNodeConnectivities(cons))
- }
-
- func (svc *Service) UpdateNodeConnectivities(msg *coormq.UpdateNodeConnectivities) (*coormq.UpdateNodeConnectivitiesResp, *mq.CodeMessage) {
- err := svc.db.DoTx(sql.LevelSerializable, func(tx *sqlx.Tx) error {
- // 只有发起节点和目的节点都存在,才能插入这条记录到数据库
- allNodes, err := svc.db.Node().GetAllNodes(tx)
- if err != nil {
- return fmt.Errorf("getting all nodes: %w", err)
- }
-
- allNodeID := make(map[cdssdk.NodeID]bool)
- for _, node := range allNodes {
- allNodeID[node.NodeID] = true
- }
-
- var avaiCons []cdssdk.NodeConnectivity
- for _, con := range msg.Connectivities {
- if allNodeID[con.FromNodeID] && allNodeID[con.ToNodeID] {
- avaiCons = append(avaiCons, con)
- }
- }
-
- err = svc.db.NodeConnectivity().BatchUpdateOrCreate(tx, avaiCons)
- if err != nil {
- return fmt.Errorf("batch update or create node connectivities: %s", err)
- }
-
- return nil
- })
- if err != nil {
- logger.Warn(err.Error())
- return nil, mq.Failed(errorcode.OperationFailed, err.Error())
- }
-
- return mq.ReplyOK(coormq.RespUpdateNodeConnectivities())
- }
|