diff --git a/common/models/models.go b/common/models/models.go index 0031093..a86afbe 100644 --- a/common/models/models.go +++ b/common/models/models.go @@ -10,7 +10,7 @@ type ObjectBlock struct { ObjectID cdssdk.ObjectID `gorm:"column:ObjectID; primaryKey; type:bigint" json:"objectID"` Index int `gorm:"column:Index; primaryKey; type:int" json:"index"` StorageID cdssdk.StorageID `gorm:"column:StorageID; primaryKey; type:bigint" json:"storageID"` // 这个块应该在哪个节点上 - FileHash cdssdk.FileHash `gorm:"column:FileHash; type:char(64); not null" json:"fileHash"` + FileHash cdssdk.FileHash `gorm:"column:FileHash; type:char(68); not null" json:"fileHash"` } func (ObjectBlock) TableName() string { diff --git a/common/pkgs/db2/model/model.go b/common/pkgs/db2/model/model.go index 187cb37..27badf9 100644 --- a/common/pkgs/db2/model/model.go +++ b/common/pkgs/db2/model/model.go @@ -57,7 +57,7 @@ type HubConnectivity = cdssdk.HubConnectivity type ObjectBlock = stgmod.ObjectBlock type Cache struct { - FileHash cdssdk.FileHash `gorm:"column:FileHash; primaryKey; type: char(64)" json:"fileHash"` + FileHash cdssdk.FileHash `gorm:"column:FileHash; primaryKey; type: char(68)" json:"fileHash"` StorageID cdssdk.StorageID `gorm:"column:StorageID; primaryKey; type: bigint" json:"storageID"` CreateTime time.Time `gorm:"column:CreateTime; type:datetime" json:"createTime"` Priority int `gorm:"column:Priority; type:int" json:"priority"` diff --git a/common/pkgs/storage/local/multipart_upload.go b/common/pkgs/storage/local/multipart_upload.go index 1f87041..50eeae7 100644 --- a/common/pkgs/storage/local/multipart_upload.go +++ b/common/pkgs/storage/local/multipart_upload.go @@ -3,13 +3,11 @@ package local import ( "context" "crypto/sha256" - "encoding/hex" "fmt" "hash" "io" "os" "path/filepath" - "strings" cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" "gitlink.org.cn/cloudream/common/utils/io2" @@ -68,7 +66,7 @@ func (i *MultipartInitiator) JoinParts(ctx context.Context, parts []types.Upload return types.BypassFileInfo{ TempFilePath: joined.Name(), Size: size, - FileHash: cdssdk.FileHash(strings.ToUpper(hex.EncodeToString(h))), + FileHash: cdssdk.NewFullHash(h), }, nil } diff --git a/common/pkgs/storage/local/shard_store.go b/common/pkgs/storage/local/shard_store.go index 4ec0059..6afa855 100644 --- a/common/pkgs/storage/local/shard_store.go +++ b/common/pkgs/storage/local/shard_store.go @@ -2,14 +2,12 @@ package local import ( "crypto/sha256" - "encoding/hex" "errors" "fmt" "io" "io/fs" "os" "path/filepath" - "strings" "sync" "time" @@ -177,7 +175,7 @@ func (s *ShardStore) writeTempFile(file *os.File, stream io.Reader) (int64, cdss } h := hasher.Sum(nil) - return size, cdssdk.FileHash(strings.ToUpper(hex.EncodeToString(h))), nil + return size, cdssdk.NewFullHash(h), nil } func (s *ShardStore) onCreateFinished(tempFilePath string, size int64, hash cdssdk.FileHash) (types.FileInfo, error) { @@ -243,12 +241,7 @@ func (s *ShardStore) Open(opt types.OpenOption) (io.ReadCloser, error) { s.lock.Lock() defer s.lock.Unlock() - fileName := string(opt.FileHash) - if len(fileName) < 2 { - return nil, fmt.Errorf("invalid file name") - } - - filePath := s.getFilePathFromHash(cdssdk.FileHash(fileName)) + filePath := s.getFilePathFromHash(opt.FileHash) file, err := os.Open(filePath) if err != nil { return nil, err @@ -306,10 +299,14 @@ func (s *ShardStore) ListAll() ([]types.FileInfo, error) { if err != nil { return err } - // TODO 简单检查一下文件名是否合法 + + fileHash, err := cdssdk.ParseHash(filepath.Base(info.Name())) + if err != nil { + return nil + } infos = append(infos, types.FileInfo{ - Hash: cdssdk.FileHash(filepath.Base(info.Name())), + Hash: fileHash, Size: info.Size(), Description: filepath.Join(blockDir, path), }) @@ -348,7 +345,11 @@ func (s *ShardStore) GC(avaiables []cdssdk.FileHash) error { return err } - fileHash := cdssdk.FileHash(filepath.Base(info.Name())) + fileHash, err := cdssdk.ParseHash(filepath.Base(info.Name())) + if err != nil { + return nil + } + if !avais[fileHash] { err = os.Remove(path) if err != nil { @@ -378,10 +379,6 @@ func (s *ShardStore) Stats() types.Stats { } func (s *ShardStore) BypassUploaded(info types.BypassFileInfo) error { - if info.FileHash == "" { - return fmt.Errorf("empty file hash is not allowed by this shard store") - } - s.lock.Lock() defer s.lock.Unlock() @@ -418,9 +415,9 @@ func (s *ShardStore) getLogger() logger.Logger { } func (s *ShardStore) getFileDirFromHash(hash cdssdk.FileHash) string { - return filepath.Join(s.absRoot, BlocksDir, string(hash)[:2]) + return filepath.Join(s.absRoot, BlocksDir, hash.GetHashPrefix(2)) } func (s *ShardStore) getFilePathFromHash(hash cdssdk.FileHash) string { - return filepath.Join(s.absRoot, BlocksDir, string(hash)[:2], string(hash)) + return filepath.Join(s.absRoot, BlocksDir, hash.GetHashPrefix(2), string(hash)) } diff --git a/common/pkgs/storage/s3/multipart_upload.go b/common/pkgs/storage/s3/multipart_upload.go index e4c99aa..e6e371d 100644 --- a/common/pkgs/storage/s3/multipart_upload.go +++ b/common/pkgs/storage/s3/multipart_upload.go @@ -2,6 +2,7 @@ package s3 import ( "context" + "crypto/sha256" "io" "path/filepath" @@ -9,7 +10,9 @@ import ( "github.com/aws/aws-sdk-go-v2/service/s3" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" + "gitlink.org.cn/cloudream/common/utils/io2" "gitlink.org.cn/cloudream/common/utils/os2" + "gitlink.org.cn/cloudream/common/utils/sort2" "gitlink.org.cn/cloudream/storage/common/pkgs/storage/types" ) @@ -45,6 +48,10 @@ func (i *MultipartInitiator) Initiate(ctx context.Context) (types.MultipartInitS } func (i *MultipartInitiator) JoinParts(ctx context.Context, parts []types.UploadedPartInfo) (types.BypassFileInfo, error) { + parts = sort2.Sort(parts, func(l, r types.UploadedPartInfo) int { + return l.PartNumber - r.PartNumber + }) + s3Parts := make([]s3types.CompletedPart, len(parts)) for i, part := range parts { s3Parts[i] = s3types.CompletedPart{ @@ -52,8 +59,12 @@ func (i *MultipartInitiator) JoinParts(ctx context.Context, parts []types.Upload PartNumber: aws.Int32(int32(part.PartNumber)), } } + partHashes := make([][]byte, len(parts)) + for i, part := range parts { + partHashes[i] = part.PartHash + } - compResp, err := i.cli.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ + _, err := i.cli.CompleteMultipartUpload(ctx, &s3.CompleteMultipartUploadInput{ Bucket: aws.String(i.bucket), Key: aws.String(i.tempFilePath), UploadId: aws.String(i.uploadID), @@ -73,17 +84,7 @@ func (i *MultipartInitiator) JoinParts(ctx context.Context, parts []types.Upload return types.BypassFileInfo{}, err } - var hash cdssdk.FileHash - // if compResp.ChecksumSHA256 == nil { - // hash = "4D142C458F2399175232D5636235B09A84664D60869E925EB20FFBE931045BDD" - // } else { - // } - // TODO2 这里其实是单独上传的每一个分片的SHA256按顺序组成一个新字符串后,再计算得到的SHA256,不是完整文件的SHA256。 - // 这种Hash考虑使用特殊的格式来区分 - hash, err = DecodeBase64Hash(*compResp.ChecksumSHA256) - if err != nil { - return types.BypassFileInfo{}, err - } + hash := cdssdk.CalculateCompositeHash(partHashes) return types.BypassFileInfo{ TempFilePath: i.tempFilePath, @@ -117,12 +118,13 @@ type MultipartUploader struct { } func (u *MultipartUploader) UploadPart(ctx context.Context, init types.MultipartInitState, partSize int64, partNumber int, stream io.Reader) (types.UploadedPartInfo, error) { + hashStr := io2.NewReadHasher(sha256.New(), stream) resp, err := u.cli.UploadPart(ctx, &s3.UploadPartInput{ Bucket: aws.String(init.Bucket), Key: aws.String(init.Key), UploadId: aws.String(init.UploadID), PartNumber: aws.Int32(int32(partNumber)), - Body: stream, + Body: hashStr, }) if err != nil { return types.UploadedPartInfo{}, err @@ -131,6 +133,7 @@ func (u *MultipartUploader) UploadPart(ctx context.Context, init types.Multipart return types.UploadedPartInfo{ ETag: *resp.ETag, PartNumber: partNumber, + PartHash: hashStr.Sum(), }, nil } diff --git a/common/pkgs/storage/s3/s3.go b/common/pkgs/storage/s3/s3.go index e7e8f13..aed8506 100644 --- a/common/pkgs/storage/s3/s3.go +++ b/common/pkgs/storage/s3/s3.go @@ -37,7 +37,10 @@ func createService(detail stgmod.StorageDetail) (types.StorageService, error) { return nil, err } - store, err := NewShardStore(svc, cli, bkt, *cfg) + store, err := NewShardStore(svc, cli, bkt, *cfg, ShardStoreOption{ + // 目前对接的存储服务都不支持从上传接口直接获取到Sha256 + UseAWSSha256: false, + }) if err != nil { return nil, err } diff --git a/common/pkgs/storage/s3/shard_store.go b/common/pkgs/storage/s3/shard_store.go index 38aa271..c802c65 100644 --- a/common/pkgs/storage/s3/shard_store.go +++ b/common/pkgs/storage/s3/shard_store.go @@ -2,6 +2,7 @@ package s3 import ( "context" + "crypto/sha256" "errors" "fmt" "io" @@ -24,22 +25,28 @@ const ( BlocksDir = "blocks" ) +type ShardStoreOption struct { + UseAWSSha256 bool // 能否直接使用AWS提供的SHA256校验,如果不行,则使用本地计算。默认使用本地计算。 +} + type ShardStore struct { svc *Service cli *s3.Client bucket string cfg cdssdk.S3ShardStorage + opt ShardStoreOption lock sync.Mutex workingTempFiles map[string]bool done chan any } -func NewShardStore(svc *Service, cli *s3.Client, bkt string, cfg cdssdk.S3ShardStorage) (*ShardStore, error) { +func NewShardStore(svc *Service, cli *s3.Client, bkt string, cfg cdssdk.S3ShardStorage, opt ShardStoreOption) (*ShardStore, error) { return &ShardStore{ svc: svc, cli: cli, bucket: bkt, cfg: cfg, + opt: opt, workingTempFiles: make(map[string]bool), done: make(chan any, 1), }, nil @@ -135,6 +142,14 @@ func (s *ShardStore) Stop() { } func (s *ShardStore) Create(stream io.Reader) (types.FileInfo, error) { + if s.opt.UseAWSSha256 { + return s.createWithAwsSha256(stream) + } else { + return s.createWithCalcSha256(stream) + } +} + +func (s *ShardStore) createWithAwsSha256(stream io.Reader) (types.FileInfo, error) { log := s.getLogger() key, fileName := s.createTempFile() @@ -170,7 +185,34 @@ func (s *ShardStore) Create(stream io.Reader) (types.FileInfo, error) { return types.FileInfo{}, fmt.Errorf("decode SHA256 checksum: %v", err) } - return s.onCreateFinished(key, counter.Count(), hash) + return s.onCreateFinished(key, counter.Count(), cdssdk.NewFullHash(hash)) +} + +func (s *ShardStore) createWithCalcSha256(stream io.Reader) (types.FileInfo, error) { + log := s.getLogger() + + key, fileName := s.createTempFile() + + hashStr := io2.NewReadHasher(sha256.New(), stream) + counter := io2.NewCounter(hashStr) + + _, err := s.cli.PutObject(context.TODO(), &s3.PutObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(key), + Body: counter, + ChecksumAlgorithm: s3types.ChecksumAlgorithmSha256, + }) + if err != nil { + log.Warnf("uploading file %v: %v", key, err) + + s.lock.Lock() + defer s.lock.Unlock() + + delete(s.workingTempFiles, fileName) + return types.FileInfo{}, err + } + + return s.onCreateFinished(key, counter.Count(), cdssdk.NewFullHash(hashStr.Sum())) } func (s *ShardStore) createTempFile() (string, string) { @@ -238,12 +280,7 @@ func (s *ShardStore) Open(opt types.OpenOption) (io.ReadCloser, error) { s.lock.Lock() defer s.lock.Unlock() - fileName := string(opt.FileHash) - if len(fileName) < 2 { - return nil, fmt.Errorf("invalid file name") - } - - filePath := s.getFilePathFromHash(cdssdk.FileHash(fileName)) + filePath := s.getFilePathFromHash(opt.FileHash) rngStr := fmt.Sprintf("bytes=%d-", opt.Offset) if opt.Length >= 0 { @@ -307,12 +344,14 @@ func (s *ShardStore) ListAll() ([]types.FileInfo, error) { for _, obj := range resp.Contents { key := BaseKey(*obj.Key) - if len(key) != 64 { + + fileHash, err := cdssdk.ParseHash(key) + if err != nil { continue } infos = append(infos, types.FileInfo{ - Hash: cdssdk.FileHash(key), + Hash: fileHash, Size: *obj.Size, Description: *obj.Key, }) @@ -355,11 +394,12 @@ func (s *ShardStore) GC(avaiables []cdssdk.FileHash) error { for _, obj := range resp.Contents { key := BaseKey(*obj.Key) - if len(key) != 64 { + fileHash, err := cdssdk.ParseHash(key) + if err != nil { continue } - if !avais[cdssdk.FileHash(key)] { + if !avais[fileHash] { deletes = append(deletes, s3types.ObjectIdentifier{ Key: obj.Key, }) @@ -441,9 +481,9 @@ func (s *ShardStore) getLogger() logger.Logger { } func (s *ShardStore) getFileDirFromHash(hash cdssdk.FileHash) string { - return JoinKey(s.cfg.Root, BlocksDir, string(hash)[:2]) + return JoinKey(s.cfg.Root, BlocksDir, hash.GetHashPrefix(2)) } func (s *ShardStore) getFilePathFromHash(hash cdssdk.FileHash) string { - return JoinKey(s.cfg.Root, BlocksDir, string(hash)[:2], string(hash)) + return JoinKey(s.cfg.Root, BlocksDir, hash.GetHashPrefix(2), string(hash)) } diff --git a/common/pkgs/storage/s3/utils.go b/common/pkgs/storage/s3/utils.go index 8d21dc9..f17b2ac 100644 --- a/common/pkgs/storage/s3/utils.go +++ b/common/pkgs/storage/s3/utils.go @@ -4,8 +4,6 @@ import ( "encoding/base64" "fmt" "strings" - - cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" ) func JoinKey(comps ...string) string { @@ -27,15 +25,15 @@ func BaseKey(key string) string { return key[strings.LastIndex(key, "/")+1:] } -func DecodeBase64Hash(hash string) (cdssdk.FileHash, error) { +func DecodeBase64Hash(hash string) ([]byte, error) { hashBytes := make([]byte, 32) n, err := base64.RawStdEncoding.Decode(hashBytes, []byte(hash)) if err != nil { - return "", err + return nil, err } if n != 32 { - return "", fmt.Errorf("invalid hash length: %d", n) + return nil, fmt.Errorf("invalid hash length: %d", n) } - return cdssdk.FileHash(strings.ToUpper(string(hashBytes))), nil + return hashBytes, nil } diff --git a/common/pkgs/storage/types/s3_client.go b/common/pkgs/storage/types/s3_client.go index 03fb622..9514d24 100644 --- a/common/pkgs/storage/types/s3_client.go +++ b/common/pkgs/storage/types/s3_client.go @@ -21,14 +21,15 @@ type MultipartUploader interface { Close() } -// TODO 重构成一个接口,支持不同的类型的分片有不同内容的实现 +// TODO 可以考虑重构成一个接口,支持不同的类型的分片有不同内容的实现 type MultipartInitState struct { UploadID string - Bucket string // TODO 临时使用 - Key string // TODO 临时使用 + Bucket string + Key string } type UploadedPartInfo struct { - PartNumber int ETag string + PartNumber int + PartHash []byte }