diff --git a/common/pkgs/db2/storage.go b/common/pkgs/db2/storage.go index 2ede6b8..187d886 100644 --- a/common/pkgs/db2/storage.go +++ b/common/pkgs/db2/storage.go @@ -3,8 +3,11 @@ package db2 import ( "fmt" + "gitlink.org.cn/cloudream/common/pkgs/logger" cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" + stgmod "gitlink.org.cn/cloudream/storage/common/models" "gitlink.org.cn/cloudream/storage/common/pkgs/db2/model" + "gorm.io/gorm" ) type StorageDB struct { @@ -82,3 +85,67 @@ func (db *StorageDB) GetHubStorages(ctx SQLContext, hubID cdssdk.NodeID) ([]mode err := ctx.Table("Storage").Select("Storage.*").Find(&stgs, "MasterHub = ?", hubID).Error return stgs, err } + +func (db *StorageDB) FillDetails(ctx SQLContext, details []stgmod.StorageDetail) error { + stgsMp := make(map[cdssdk.StorageID]*stgmod.StorageDetail) + stgIDs := make([]cdssdk.StorageID, 0, len(details)) + var masterHubIDs []cdssdk.NodeID + for _, d := range details { + d2 := d + stgsMp[d.Storage.StorageID] = &d2 + stgIDs = append(stgIDs, d.Storage.StorageID) + masterHubIDs = append(masterHubIDs, d.Storage.MasterHub) + } + + // 获取监护Hub信息 + masterHubs, err := db.Node().BatchGetByID(ctx, masterHubIDs) + if err != nil && err != gorm.ErrRecordNotFound { + return fmt.Errorf("getting master hub: %w", err) + } + masterHubMap := make(map[cdssdk.NodeID]cdssdk.Node) + for _, hub := range masterHubs { + masterHubMap[hub.NodeID] = hub + } + for _, stg := range stgsMp { + if stg.Storage.MasterHub != 0 { + hub, ok := masterHubMap[stg.Storage.MasterHub] + if !ok { + logger.Warnf("master hub %v of storage %v not found, this storage will not be add to result", stg.Storage.MasterHub, stg.Storage) + delete(stgsMp, stg.Storage.StorageID) + continue + } + + stg.MasterHub = &hub + } + } + + // 获取分片存储 + shards, err := db.ShardStorage().BatchGetByStorageIDs(ctx, stgIDs) + if err != nil && err != gorm.ErrRecordNotFound { + return fmt.Errorf("getting shard storage: %w", err) + } + for _, shard := range shards { + stg := stgsMp[shard.StorageID] + if stg == nil { + continue + } + + stg.Shard = &shard + } + + // 获取共享存储的相关信息 + shareds, err := db.SharedStorage().BatchGetByStorageIDs(ctx, stgIDs) + if err != nil && err != gorm.ErrRecordNotFound { + return fmt.Errorf("getting shared storage: %w", err) + } + for _, shared := range shareds { + stg := stgsMp[shared.StorageID] + if stg == nil { + continue + } + + stg.Shared = &shared + } + + return nil +} diff --git a/common/pkgs/mq/coordinator/storage.go b/common/pkgs/mq/coordinator/storage.go index c2cbd50..e1a3a97 100644 --- a/common/pkgs/mq/coordinator/storage.go +++ b/common/pkgs/mq/coordinator/storage.go @@ -69,6 +69,19 @@ func RespGetStorageDetails(stgs []*stgmod.StorageDetail) *GetStorageDetailsResp Storages: stgs, } } + +func (r *GetStorageDetailsResp) ToMap() map[cdssdk.StorageID]stgmod.StorageDetail { + m := make(map[cdssdk.StorageID]stgmod.StorageDetail) + for _, stg := range r.Storages { + if stg == nil { + continue + } + + m[stg.Storage.StorageID] = *stg + } + return m +} + func (client *Client) GetStorageDetails(msg *GetStorageDetails) (*GetStorageDetailsResp, error) { return mq.Request(Service.GetStorageDetails, client.rabbitCli, msg) } diff --git a/coordinator/internal/mq/storage.go b/coordinator/internal/mq/storage.go index 7499f03..f903e5d 100644 --- a/coordinator/internal/mq/storage.go +++ b/coordinator/internal/mq/storage.go @@ -4,7 +4,6 @@ import ( "database/sql" "fmt" - "github.com/samber/lo" "gitlink.org.cn/cloudream/common/consts/errorcode" "gitlink.org.cn/cloudream/common/pkgs/logger" cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" @@ -34,62 +33,17 @@ func (svc *Service) GetStorageDetails(msg *coormq.GetStorageDetails) (*coormq.Ge if err != nil && err != gorm.ErrRecordNotFound { return fmt.Errorf("getting storage: %w", err) } - var masterHubIDs []cdssdk.NodeID - for _, stg := range stgs { - stgsMp[stg.StorageID] = &stgmod.StorageDetail{ - Storage: stg, - } - masterHubIDs = append(masterHubIDs, stg.MasterHub) - } - - // 获取监护Hub信息 - masterHubs, err := svc.db2.Node().BatchGetByID(tx, masterHubIDs) - if err != nil && err != gorm.ErrRecordNotFound { - return fmt.Errorf("getting master hub: %w", err) - } - masterHubMap := make(map[cdssdk.NodeID]cdssdk.Node) - for _, hub := range masterHubs { - masterHubMap[hub.NodeID] = hub - } - for _, stg := range stgsMp { - if stg.Storage.MasterHub != 0 { - hub, ok := masterHubMap[stg.Storage.MasterHub] - if !ok { - logger.Warnf("master hub %v of storage %v not found, this storage will not be add to result", stg.Storage.MasterHub, stg.Storage) - delete(stgsMp, stg.Storage.StorageID) - continue - } - - stg.MasterHub = &hub - } - } - // 获取分片存储 - shards, err := svc.db2.ShardStorage().BatchGetByStorageIDs(tx, msg.StorageIDs) - if err != nil && err != gorm.ErrRecordNotFound { - return fmt.Errorf("getting shard storage: %w", err) - } - for _, shard := range shards { - stg := stgsMp[shard.StorageID] - if stg == nil { - continue + details := make([]stgmod.StorageDetail, len(stgs)) + for i, stg := range stgs { + details[i] = stgmod.StorageDetail{ + Storage: stg, } - - stg.Shard = &shard - } - - // 获取共享存储的相关信息 - shareds, err := svc.db2.SharedStorage().BatchGetByStorageIDs(tx, msg.StorageIDs) - if err != nil && err != gorm.ErrRecordNotFound { - return fmt.Errorf("getting shared storage: %w", err) + stgsMp[stg.StorageID] = &details[i] } - for _, shared := range shareds { - stg := stgsMp[shared.StorageID] - if stg == nil { - continue - } - - stg.Shared = &shared + err = svc.db2.Storage().FillDetails(tx, details) + if err != nil { + return err } return nil @@ -109,81 +63,27 @@ func (svc *Service) GetStorageDetails(msg *coormq.GetStorageDetails) (*coormq.Ge } func (svc *Service) GetUserStorageDetails(msg *coormq.GetUserStorageDetails) (*coormq.GetUserStorageDetailsResp, *mq.CodeMessage) { - stgsMp := make(map[cdssdk.StorageID]*stgmod.StorageDetail) + var ret []stgmod.StorageDetail svc.db2.DoTx(func(tx db2.SQLContext) error { stgs, err := svc.db2.Storage().GetUserStorages(tx, msg.UserID) if err != nil && err != gorm.ErrRecordNotFound { return fmt.Errorf("getting user storages: %w", err) } - var masterHubIDs []cdssdk.NodeID + for _, stg := range stgs { - stgsMp[stg.StorageID] = &stgmod.StorageDetail{ + ret = append(ret, stgmod.StorageDetail{ Storage: stg, - } - masterHubIDs = append(masterHubIDs, stg.MasterHub) - } - - // 监护Hub的信息 - masterHubs, err := svc.db2.Node().BatchGetByID(tx, masterHubIDs) - if err != nil && err != gorm.ErrRecordNotFound { - return fmt.Errorf("getting master hub: %w", err) - } - masterHubMap := make(map[cdssdk.NodeID]cdssdk.Node) - for _, hub := range masterHubs { - masterHubMap[hub.NodeID] = hub - } - for _, stg := range stgsMp { - if stg.Storage.MasterHub != 0 { - hub, ok := masterHubMap[stg.Storage.MasterHub] - if !ok { - logger.Warnf("master hub %v of storage %v not found, this storage will not be add to result", stg.Storage.MasterHub, stg.Storage) - delete(stgsMp, stg.Storage.StorageID) - continue - } - - stg.MasterHub = &hub - } - } - - stgIDs := lo.Map(stgs, func(stg cdssdk.Storage, i int) cdssdk.StorageID { return stg.StorageID }) - - // 获取分片存储信息 - shards, err := svc.db2.ShardStorage().BatchGetByStorageIDs(tx, stgIDs) - if err != nil && err != gorm.ErrRecordNotFound { - return fmt.Errorf("getting shard storage: %w", err) + }) } - for _, shard := range shards { - stg := stgsMp[shard.StorageID] - if stg == nil { - continue - } - - stg.Shard = &shard - } - - // 获取共享存储的相关信息 - shareds, err := svc.db2.SharedStorage().BatchGetByStorageIDs(tx, stgIDs) - if err != nil && err != gorm.ErrRecordNotFound { - return fmt.Errorf("getting shared storage: %w", err) - } - for _, shared := range shareds { - stg := stgsMp[shared.StorageID] - if stg == nil { - continue - } - - stg.Shared = &shared + err = svc.db2.Storage().FillDetails(tx, ret) + if err != nil { + return err } return nil }) - var ret []stgmod.StorageDetail - for _, id := range stgsMp { - ret = append(ret, *id) - } - return mq.ReplyOK(coormq.RespGetUserStorageDetails(ret)) }