diff --git a/consts/errorcode/error_code.go b/consts/errorcode/error_code.go index fc9e00c..984ca9c 100644 --- a/consts/errorcode/error_code.go +++ b/consts/errorcode/error_code.go @@ -4,6 +4,7 @@ const ( OK = "OK" OperationFailed = "OperationFailed" DataNotFound = "DataNotFound" + DataExists = "DataExists" BadArgument = "BadArgument" TaskNotFound = "TaskNotFound" ) diff --git a/main.go b/main.go index c509ea1..e154996 100644 --- a/main.go +++ b/main.go @@ -32,19 +32,12 @@ func test1(url string) { return } - partLen, err := strconv.ParseInt(os.Args[3], 10, 64) - if err != nil { - fmt.Println(err) - return - } - startTime := time.Now() obj, err := cli.Object().Download(cdsapi.ObjectDownload{ UserID: 1, ObjectID: 470790, Offset: 0, Length: &openLen, - PartSize: partLen, }) if err != nil { fmt.Println(err) @@ -75,7 +68,6 @@ func test2(url string) { UserID: 1, ObjectID: 27151, Offset: 0, - PartSize: 100000000, // Length: &openLen, }) diff --git a/pkgs/ioswitch/dag/node.go b/pkgs/ioswitch/dag/node.go index 45f7f25..60f8ea9 100644 --- a/pkgs/ioswitch/dag/node.go +++ b/pkgs/ioswitch/dag/node.go @@ -35,6 +35,11 @@ func (e *NodeEnv) ToEnvWorker(worker exec.WorkerInfo) { e.Worker = worker } +func (e *NodeEnv) CopyFrom(other *NodeEnv) { + e.Type = other.Type + e.Worker = other.Worker +} + func (e *NodeEnv) Equals(other *NodeEnv) bool { if e.Type != other.Type { return false @@ -239,6 +244,10 @@ func (s *ValueInputSlots) GetVarIDs() []exec.VarID { return ids } +func (s *ValueInputSlots) GetVarIDsStart(start int) []exec.VarID { + return s.GetVarIDsRanged(start, s.Len()) +} + func (s *ValueInputSlots) GetVarIDsRanged(start, end int) []exec.VarID { var ids []exec.VarID for i := start; i < end; i++ { @@ -280,11 +289,11 @@ func (s *StreamOutputSlots) Init(my Node, size int) { } // 在Slots末尾增加一个StreamVar,并返回它的索引 -func (s *StreamOutputSlots) AppendNew(my Node) StreamSlot { +func (s *StreamOutputSlots) AppendNew(my Node) StreamOutputSlot { v := my.Graph().NewStreamVar() v.Src = my s.Slots.Append(v) - return StreamSlot{Var: v, Index: s.Len() - 1} + return StreamOutputSlot{Node: my, Index: s.Len() - 1} } // 断开指定位置的输出流到指定节点的连接 @@ -355,11 +364,11 @@ func (s *ValueOutputSlots) Init(my Node, size int) { } // 在Slots末尾增加一个StreamVar,并返回它的索引 -func (s *ValueOutputSlots) AppendNew(my Node) ValueSlot { +func (s *ValueOutputSlots) AppendNew(my Node) ValueOutputSlot { v := my.Graph().NewValueVar() v.Src = my s.Slots.Append(v) - return ValueSlot{Var: v, Index: s.Len() - 1} + return ValueOutputSlot{Node: my, Index: s.Len() - 1} } // 断开指定位置的输出流到指定节点的连接 @@ -402,16 +411,6 @@ func (s *ValueOutputSlots) GetVarIDsRanged(start, end int) []exec.VarID { return ids } -type StreamSlot struct { - Var *StreamVar - Index int -} - -type ValueSlot struct { - Var *ValueVar - Index int -} - type NodeBase struct { env NodeEnv inputStreams StreamInputSlots @@ -448,3 +447,39 @@ func (n *NodeBase) InputValues() *ValueInputSlots { func (n *NodeBase) OutputValues() *ValueOutputSlots { return &n.outputValues } + +type StreamOutputSlot struct { + Node Node + Index int +} + +func (s StreamOutputSlot) Var() *StreamVar { + return s.Node.OutputStreams().Get(s.Index) +} + +type StreamInputSlot struct { + Node Node + Index int +} + +func (s StreamInputSlot) Var() *StreamVar { + return s.Node.InputStreams().Get(s.Index) +} + +type ValueOutputSlot struct { + Node Node + Index int +} + +func (s ValueOutputSlot) Var() *ValueVar { + return s.Node.OutputValues().Get(s.Index) +} + +type ValueInputSlot struct { + Node Node + Index int +} + +func (s ValueInputSlot) Var() *ValueVar { + return s.Node.InputValues().Get(s.Index) +} diff --git a/pkgs/ioswitch/dag/var.go b/pkgs/ioswitch/dag/var.go index 75e99f2..1a2f49c 100644 --- a/pkgs/ioswitch/dag/var.go +++ b/pkgs/ioswitch/dag/var.go @@ -5,7 +5,7 @@ import ( "gitlink.org.cn/cloudream/common/utils/lo2" ) -type Var2 interface { +type Var interface { GetVarID() exec.VarID } @@ -28,6 +28,11 @@ func (v *StreamVar) To(to Node, slotIdx int) { to.InputStreams().Slots.Set(slotIdx, v) } +func (v *StreamVar) ToSlot(slot StreamInputSlot) { + v.Dst.Add(slot.Node) + slot.Node.InputStreams().Slots.Set(slot.Index, v) +} + func (v *StreamVar) NotTo(node Node) { v.Dst.Remove(node) node.InputStreams().Slots.Clear(v) @@ -59,6 +64,11 @@ func (v *ValueVar) To(to Node, slotIdx int) { to.InputValues().Slots.Set(slotIdx, v) } +func (v *ValueVar) ToSlot(slot ValueInputSlot) { + v.Dst.Add(slot.Node) + slot.Node.InputValues().Slots.Set(slot.Index, v) +} + func (v *ValueVar) NotTo(node Node) { v.Dst.Remove(node) node.InputValues().Slots.Clear(v) diff --git a/pkgs/ioswitch/exec/utils.go b/pkgs/ioswitch/exec/utils.go index 09c48d3..64b00c9 100644 --- a/pkgs/ioswitch/exec/utils.go +++ b/pkgs/ioswitch/exec/utils.go @@ -79,3 +79,19 @@ func (r *Range) ClampLength(maxLen int64) { *r.Length = math2.Min(*r.Length, maxLen-r.Offset) } + +func (r *Range) Equals(other Range) bool { + if r.Offset != other.Offset { + return false + } + + if r.Length == nil && other.Length == nil { + return true + } + + if r.Length == nil || other.Length == nil { + return false + } + + return *r.Length == *other.Length +} diff --git a/pkgs/ioswitch/plan/ops/driver.go b/pkgs/ioswitch/plan/ops/driver.go index a15fbee..da3c62f 100644 --- a/pkgs/ioswitch/plan/ops/driver.go +++ b/pkgs/ioswitch/plan/ops/driver.go @@ -21,9 +21,9 @@ func (b *GraphNodeBuilder) NewFromDriver(handle *exec.DriverWriteStream) *FromDr return node } -func (t *FromDriverNode) Output() dag.StreamSlot { - return dag.StreamSlot{ - Var: t.OutputStreams().Get(0), +func (t *FromDriverNode) Output() dag.StreamOutputSlot { + return dag.StreamOutputSlot{ + Node: t, Index: 0, } } @@ -57,9 +57,9 @@ func (t *ToDriverNode) SetInput(v *dag.StreamVar) { v.To(t, 0) } -func (t *ToDriverNode) Input() dag.StreamSlot { - return dag.StreamSlot{ - Var: t.InputStreams().Get(0), +func (t *ToDriverNode) Input() dag.StreamOutputSlot { + return dag.StreamOutputSlot{ + Node: t, Index: 0, } } diff --git a/pkgs/ioswitch/plan/ops/sync.go b/pkgs/ioswitch/plan/ops/sync.go index a356bb1..c4ec53c 100644 --- a/pkgs/ioswitch/plan/ops/sync.go +++ b/pkgs/ioswitch/plan/ops/sync.go @@ -175,12 +175,12 @@ func (t *HoldUntilNode) SetSignal(s *dag.ValueVar) { func (t *HoldUntilNode) HoldStream(str *dag.StreamVar) *dag.StreamVar { str.To(t, t.InputStreams().EnlargeOne()) - return t.OutputStreams().AppendNew(t).Var + return t.OutputStreams().AppendNew(t).Var() } func (t *HoldUntilNode) HoldVar(v *dag.ValueVar) *dag.ValueVar { v.To(t, t.InputValues().EnlargeOne()) - return t.OutputValues().AppendNew(t).Var + return t.OutputValues().AppendNew(t).Var() } func (t *HoldUntilNode) GenerateOp() (exec.Op, error) { diff --git a/pkgs/trie/trie.go b/pkgs/trie/trie.go index 9fd84ba..829a49f 100644 --- a/pkgs/trie/trie.go +++ b/pkgs/trie/trie.go @@ -4,7 +4,17 @@ const ( WORD_ANY = 0 ) +type VisitCtrl int + +const ( + VisitContinue = 0 + VisitBreak = 1 + VisitSkip = 2 +) + type Node[T any] struct { + Word any + Parent *Node[T] WordNexts map[string]*Node[T] AnyNext *Node[T] Value T @@ -43,7 +53,10 @@ func (n *Node[T]) Create(word string) *Node[T] { node, ok := n.WordNexts[word] if !ok { - node = &Node[T]{} + node = &Node[T]{ + Word: word, + Parent: n, + } n.WordNexts[word] = node } @@ -52,16 +65,81 @@ func (n *Node[T]) Create(word string) *Node[T] { func (n *Node[T]) CreateAny() *Node[T] { if n.AnyNext == nil { - n.AnyNext = &Node[T]{} + n.AnyNext = &Node[T]{ + Word: WORD_ANY, + Parent: n, + } } return n.AnyNext } +func (n *Node[T]) IsEmpty() bool { + return len(n.WordNexts) == 0 && n.AnyNext == nil +} + +// 将自己从树中移除。如果cleanParent为true,则会一直向上清除所有没有子节点的节点 +func (n *Node[T]) RemoveSelf(cleanParent bool) { + if n.Parent == nil { + return + } + + if n.Word == WORD_ANY { + if n.Parent.AnyNext == n { + n.Parent.AnyNext = nil + } + } else if n.Parent.WordNexts != nil && n.Parent.WordNexts[n.Word.(string)] == n { + delete(n.Parent.WordNexts, n.Word.(string)) + } + + if cleanParent { + if n.Parent.IsEmpty() { + n.Parent.RemoveSelf(true) + } + } + + n.Parent = nil +} + +// 修改时需要注意允许在visitorFn中删除当前节点 +func (n *Node[T]) Iterate(visitorFn func(word string, node *Node[T], isWordNode bool) VisitCtrl) { + if n.WordNexts != nil { + for word, node := range n.WordNexts { + ret := visitorFn(word, node, true) + if ret == VisitBreak { + return + } + + if ret == VisitSkip { + continue + } + + node.Iterate(visitorFn) + } + } + + if n.AnyNext != nil { + ret := visitorFn("", n.AnyNext, false) + if ret == VisitBreak { + return + } + + if ret == VisitSkip { + return + } + + n.AnyNext.Iterate(visitorFn) + } +} + type Trie[T any] struct { Root Node[T] } +func NewTrie[T any]() *Trie[T] { + return &Trie[T]{} +} + func (t *Trie[T]) Walk(words []string, visitorFn func(word string, wordIndex int, node *Node[T], isWordNode bool)) bool { ptr := &t.Root @@ -109,3 +187,17 @@ func (t *Trie[T]) Create(words []any) *Node[T] { return ptr } + +func (t *Trie[T]) CreateWords(words []string) *Node[T] { + ptr := &t.Root + + for _, word := range words { + ptr = ptr.Create(word) + } + + return ptr +} + +func (n *Trie[T]) Iterate(visitorFn func(word string, node *Node[T], isWordNode bool) VisitCtrl) { + n.Root.Iterate(visitorFn) +} diff --git a/sdks/storage/cdsapi/object.go b/sdks/storage/cdsapi/object.go index f0743f6..32d7243 100644 --- a/sdks/storage/cdsapi/object.go +++ b/sdks/storage/cdsapi/object.go @@ -25,6 +25,43 @@ func (c *Client) Object() *ObjectService { } } +const ObjectListPath = "/object/list" + +type ObjectList struct { + UserID cdssdk.UserID `form:"userID" binding:"required"` + PackageID cdssdk.PackageID `form:"packageID" binding:"required"` + Path string `form:"path"` // 允许为空字符串 + IsPrefix bool `form:"isPrefix"` +} +type ObjectListResp struct { + Objects []cdssdk.Object `json:"objects"` +} + +func (c *ObjectService) List(req ObjectList) (*ObjectListResp, error) { + url, err := url.JoinPath(c.baseURL, ObjectListPath) + if err != nil { + return nil, err + } + + resp, err := http2.GetForm(url, http2.RequestParam{ + Query: req, + }) + if err != nil { + return nil, err + } + + jsonResp, err := ParseJSONResponse[response[ObjectListResp]](resp) + if err != nil { + return nil, err + } + + if jsonResp.Code == errorcode.OK { + return &jsonResp.Data, nil + } + + return nil, jsonResp.ToError() +} + const ObjectUploadPath = "/object/upload" type ObjectUpload struct { @@ -101,7 +138,6 @@ type ObjectDownload struct { ObjectID cdssdk.ObjectID `form:"objectID" json:"objectID" binding:"required"` Offset int64 `form:"offset" json:"offset,omitempty"` Length *int64 `form:"length" json:"length,omitempty"` - PartSize int64 `form:"partSize" json:"partSize,omitempty"` } type DownloadingObject struct { Path string @@ -143,6 +179,51 @@ func (c *ObjectService) Download(req ObjectDownload) (*DownloadingObject, error) }, nil } +const ObjectDownloadByPathPath = "/object/downloadByPath" + +type ObjectDownloadByPath struct { + UserID cdssdk.UserID `form:"userID" json:"userID" binding:"required"` + PackageID cdssdk.PackageID `form:"packageID" json:"packageID" binding:"required"` + Path string `form:"path" json:"path" binding:"required"` + Offset int64 `form:"offset" json:"offset,omitempty"` + Length *int64 `form:"length" json:"length,omitempty"` +} + +func (c *ObjectService) DownloadByPath(req ObjectDownloadByPath) (*DownloadingObject, error) { + url, err := url.JoinPath(c.baseURL, ObjectDownloadByPathPath) + if err != nil { + return nil, err + } + + resp, err := http2.GetJSON(url, http2.RequestParam{ + Query: req, + }) + if err != nil { + return nil, err + } + + contType := resp.Header.Get("Content-Type") + + if strings.Contains(contType, http2.ContentTypeJSON) { + var codeResp response[any] + if err := serder.JSONToObjectStream(resp.Body, &codeResp); err != nil { + return nil, fmt.Errorf("parsing response: %w", err) + } + + return nil, codeResp.ToError() + } + + _, params, err := mime.ParseMediaType(resp.Header.Get("Content-Disposition")) + if err != nil { + return nil, fmt.Errorf("parsing content disposition: %w", err) + } + + return &DownloadingObject{ + Path: params["filename"], + File: resp.Body, + }, nil +} + const ObjectUpdateInfoPath = "/object/updateInfo" type UpdatingObject struct { @@ -188,6 +269,42 @@ func (c *ObjectService) UpdateInfo(req ObjectUpdateInfo) (*ObjectUpdateInfoResp, return nil, jsonResp.ToError() } +const ObjectUpdateInfoByPathPath = "/object/updateInfoByPath" + +type ObjectUpdateInfoByPath struct { + UserID cdssdk.UserID `json:"userID" binding:"required"` + PackageID cdssdk.PackageID `json:"packageID" binding:"required"` + Path string `json:"path" binding:"required"` + UpdateTime time.Time `json:"updateTime" binding:"required"` +} + +type ObjectUpdateInfoByPathResp struct{} + +func (c *ObjectService) UpdateInfoByPath(req ObjectUpdateInfoByPath) (*ObjectUpdateInfoByPathResp, error) { + url, err := url.JoinPath(c.baseURL, ObjectUpdateInfoByPathPath) + if err != nil { + return nil, err + } + + resp, err := http2.PostJSON(url, http2.RequestParam{ + Body: req, + }) + if err != nil { + return nil, err + } + + jsonResp, err := ParseJSONResponse[response[ObjectUpdateInfoByPathResp]](resp) + if err != nil { + return nil, err + } + + if jsonResp.Code == errorcode.OK { + return &jsonResp.Data, nil + } + + return nil, jsonResp.ToError() +} + const ObjectMovePath = "/object/move" type MovingObject struct { @@ -269,6 +386,40 @@ func (c *ObjectService) Delete(req ObjectDelete) error { return jsonResp.ToError() } +const ObjectDeleteByPathPath = "/object/deleteByPath" + +type ObjectDeleteByPath struct { + UserID cdssdk.UserID `json:"userID" binding:"required"` + PackageID cdssdk.PackageID `json:"packageID" binding:"required"` + Path string `json:"path" binding:"required"` +} +type ObjectDeleteByPathResp struct{} + +func (c *ObjectService) DeleteByPath(req ObjectDeleteByPath) error { + url, err := url.JoinPath(c.baseURL, ObjectDeleteByPathPath) + if err != nil { + return err + } + + resp, err := http2.PostJSON(url, http2.RequestParam{ + Body: req, + }) + if err != nil { + return err + } + + jsonResp, err := ParseJSONResponse[response[ObjectDeleteByPathResp]](resp) + if err != nil { + return err + } + + if jsonResp.Code == errorcode.OK { + return nil + } + + return jsonResp.ToError() +} + const ObjectGetPackageObjectsPath = "/object/getPackageObjects" type ObjectGetPackageObjects struct { diff --git a/sdks/storage/cdsapi/storage_test.go b/sdks/storage/cdsapi/storage_test.go index cec219e..e12f984 100644 --- a/sdks/storage/cdsapi/storage_test.go +++ b/sdks/storage/cdsapi/storage_test.go @@ -125,6 +125,23 @@ func Test_Object(t *testing.T) { }) } +func Test_ObjectList(t *testing.T) { + Convey("路径查询", t, func() { + cli := NewClient(&Config{ + URL: "http://localhost:7890", + }) + + resp, err := cli.Object().List(ObjectList{ + UserID: 1, + PackageID: 10, + Path: "100x100K/zexema", + }) + So(err, ShouldBeNil) + fmt.Printf("\n") + fmt.Printf("%+v\n", resp.Objects[0]) + }) +} + func Test_Storage(t *testing.T) { Convey("上传后调度文件", t, func() { cli := NewClient(&Config{ diff --git a/sdks/storage/filehash.go b/sdks/storage/filehash.go new file mode 100644 index 0000000..9fecbaf --- /dev/null +++ b/sdks/storage/filehash.go @@ -0,0 +1,80 @@ +package cdssdk + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" +) + +// 文件的哈希值,格式:[前缀: 4个字符][哈希值: 64个字符] +// 前缀用于区分哈希值的类型: +// +// - "Full":完整文件的哈希值 +// +// - "Comp":将文件拆分成多个分片,每一个分片计算Hash之后再合并的哈希值 +// +// 哈希值:SHA256哈希值,全大写的16进制字符串格式 +type FileHash string + +const ( + FullHashPrefix = "Full" + CompositeHashPrefix = "Comp" +) + +func (h *FileHash) GetPrefix() string { + return string((*h)[:4]) +} + +func (h *FileHash) GetHash() string { + return string((*h)[4:]) +} + +func (h *FileHash) GetHashPrefix(len int) string { + return string((*h)[4 : 4+len]) +} + +func (h *FileHash) IsFullHash() bool { + return (*h)[:4] == FullHashPrefix +} + +func (h *FileHash) IsCompositeHash() bool { + return (*h)[:4] == CompositeHashPrefix +} + +func ParseHash(hashStr string) (FileHash, error) { + if len(hashStr) != 4+64 { + return "", fmt.Errorf("hash string length should be 4+64, but got %d", len(hashStr)) + } + + prefix := hashStr[:4] + hash := hashStr[4:] + if prefix != FullHashPrefix && prefix != CompositeHashPrefix { + return "", fmt.Errorf("invalid hash prefix: %s", prefix) + } + + if len(hash) != 64 { + return "", fmt.Errorf("invalid hash length: %d", len(hash)) + } + + for _, c := range hash { + if (c < '0' || c > '9') && (c < 'A' || c > 'F') { + return "", fmt.Errorf("invalid hash character: %c", c) + } + } + + return FileHash(hashStr), nil +} + +func NewFullHash(hash []byte) FileHash { + return FileHash(FullHashPrefix + strings.ToUpper(hex.EncodeToString(hash))) +} + +func CalculateCompositeHash(segmentHashes [][]byte) FileHash { + data := make([]byte, len(segmentHashes)*32) + for i, segmentHash := range segmentHashes { + copy(data[i*32:], segmentHash) + } + hash := sha256.Sum256(data) + return FileHash(CompositeHashPrefix + strings.ToUpper(hex.EncodeToString(hash[:]))) +} diff --git a/sdks/storage/models.go b/sdks/storage/models.go index 2b48c9c..fc340b8 100644 --- a/sdks/storage/models.go +++ b/sdks/storage/models.go @@ -7,6 +7,7 @@ import ( "github.com/samber/lo" "gitlink.org.cn/cloudream/common/pkgs/types" + "gitlink.org.cn/cloudream/common/utils/math2" "gitlink.org.cn/cloudream/common/utils/serder" ) @@ -28,9 +29,6 @@ type StorageID int64 type LocationID int64 -// 文件的SHA256哈希值,全大写的16进制字符串格式 -type FileHash string - /// TODO 将分散在各处的公共结构体定义集中到这里来 type Redundancy interface { @@ -170,18 +168,9 @@ type SegmentRedundancy struct { } func NewSegmentRedundancy(totalSize int64, segmentCount int) *SegmentRedundancy { - var segs []int64 - segLen := int64(0) - // 计算每一段的大小。大小不一定都相同,但总和应该等于总大小。 - for i := 0; i < segmentCount; i++ { - curLen := totalSize*int64(i+1)/int64(segmentCount) - segLen - segs = append(segs, curLen) - segLen += curLen - } - return &SegmentRedundancy{ Type: "segment", - Segments: segs, + Segments: math2.SplitN(totalSize, segmentCount), } } @@ -261,7 +250,7 @@ type Object struct { PackageID PackageID `json:"packageID" gorm:"column:PackageID; type:bigint; not null"` Path string `json:"path" gorm:"column:Path; type:varchar(1024); not null"` Size int64 `json:"size,string" gorm:"column:Size; type:bigint; not null"` - FileHash FileHash `json:"fileHash" gorm:"column:FileHash; type:char(64); not null"` + FileHash FileHash `json:"fileHash" gorm:"column:FileHash; type:char(68); not null"` Redundancy Redundancy `json:"redundancy" gorm:"column:Redundancy; type: json; serializer:union"` CreateTime time.Time `json:"createTime" gorm:"column:CreateTime; type:datetime; not null"` UpdateTime time.Time `json:"updateTime" gorm:"column:UpdateTime; type:datetime; not null"` diff --git a/sdks/storage/shard_storage.go b/sdks/storage/shard_storage.go index 1aeba6e..1285532 100644 --- a/sdks/storage/shard_storage.go +++ b/sdks/storage/shard_storage.go @@ -9,13 +9,14 @@ import ( // 分片存储服务的配置数据 type ShardStoreConfig interface { - GetType() string + GetShardStoreType() string // 输出调试用的字符串,不要包含敏感信息 String() string } var _ = serder.UseTypeUnionInternallyTagged(types.Ref(types.NewTypeUnion[ShardStoreConfig]( (*LocalShardStorage)(nil), + (*S3ShardStorage)(nil), )), "type") type LocalShardStorage struct { @@ -25,10 +26,24 @@ type LocalShardStorage struct { MaxSize int64 `json:"maxSize"` } -func (s *LocalShardStorage) GetType() string { +func (s *LocalShardStorage) GetShardStoreType() string { return "Local" } func (s *LocalShardStorage) String() string { return fmt.Sprintf("Local[root=%s, maxSize=%d]", s.Root, s.MaxSize) } + +type S3ShardStorage struct { + serder.Metadata `union:"S3"` + Type string `json:"type"` + Root string `json:"root"` +} + +func (s *S3ShardStorage) GetShardStoreType() string { + return "S3" +} + +func (s *S3ShardStorage) String() string { + return fmt.Sprintf("S3[root=%s]", s.Root) +} diff --git a/sdks/storage/storage.go b/sdks/storage/storage.go index cd20b86..0d1fcd8 100644 --- a/sdks/storage/storage.go +++ b/sdks/storage/storage.go @@ -12,8 +12,8 @@ type Storage struct { Name string `json:"name" gorm:"column:Name; type:varchar(256); not null"` // 完全管理此存储服务的Hub的ID MasterHub HubID `json:"masterHub" gorm:"column:MasterHub; type:bigint; not null"` - // 存储服务的地址,包含鉴权所需数据 - Address StorageAddress `json:"address" gorm:"column:Address; type:json; not null; serializer:union"` + // 存储服务的类型,包含地址信息以及鉴权所需数据 + Type StorageType `json:"type" gorm:"column:Type; type:json; not null; serializer:union"` // 分片存储服务的配置数据 ShardStore ShardStoreConfig `json:"shardStore" gorm:"column:ShardStore; type:json; serializer:union"` // 共享存储服务的配置数据 @@ -32,31 +32,35 @@ func (s *Storage) String() string { } // 存储服务地址 -type StorageAddress interface { - GetType() string +type StorageType interface { + GetStorageType() string // 输出调试用的字符串,不要包含敏感信息 String() string } -var _ = serder.UseTypeUnionInternallyTagged(types.Ref(types.NewTypeUnion[StorageAddress]( - (*LocalStorageAddress)(nil), +var _ = serder.UseTypeUnionInternallyTagged(types.Ref(types.NewTypeUnion[StorageType]( + (*LocalStorageType)(nil), + (*OBSType)(nil), + (*OSSType)(nil), + (*COSType)(nil), )), "type") -type LocalStorageAddress struct { +type LocalStorageType struct { serder.Metadata `union:"Local"` Type string `json:"type"` } -func (a *LocalStorageAddress) GetType() string { +func (a *LocalStorageType) GetStorageType() string { return "Local" } -func (a *LocalStorageAddress) String() string { +func (a *LocalStorageType) String() string { return "Local" } -type OSSAddress struct { - serder.Metadata `union:"Local"` +type OSSType struct { + serder.Metadata `union:"OSS"` + Type string `json:"type"` Region string `json:"region"` AK string `json:"accessKeyId"` SK string `json:"secretAccessKey"` @@ -64,16 +68,17 @@ type OSSAddress struct { Bucket string `json:"bucket"` } -func (a *OSSAddress) GetType() string { - return "OSSAddress" +func (a *OSSType) GetStorageType() string { + return "OSS" } -func (a *OSSAddress) String() string { - return "OSSAddress" +func (a *OSSType) String() string { + return "OSS" } -type OBSAddress struct { - serder.Metadata `union:"Local"` +type OBSType struct { + serder.Metadata `union:"OBS"` + Type string `json:"type"` Region string `json:"region"` AK string `json:"accessKeyId"` SK string `json:"secretAccessKey"` @@ -81,16 +86,17 @@ type OBSAddress struct { Bucket string `json:"bucket"` } -func (a *OBSAddress) GetType() string { - return "OBSAddress" +func (a *OBSType) GetStorageType() string { + return "OBS" } -func (a *OBSAddress) String() string { - return "OBSAddress" +func (a *OBSType) String() string { + return "OBS" } -type COSAddress struct { - serder.Metadata `union:"Local"` +type COSType struct { + serder.Metadata `union:"COS"` + Type string `json:"type"` Region string `json:"region"` AK string `json:"accessKeyId"` SK string `json:"secretAccessKey"` @@ -98,10 +104,10 @@ type COSAddress struct { Bucket string `json:"bucket"` } -func (a *COSAddress) GetType() string { - return "COSAddress" +func (a *COSType) GetStorageType() string { + return "COS" } -func (a *COSAddress) String() string { - return "COSAddress" +func (a *COSType) String() string { + return "COS" } diff --git a/sdks/storage/storage_feature.go b/sdks/storage/storage_feature.go index 3f66a90..0e150cb 100644 --- a/sdks/storage/storage_feature.go +++ b/sdks/storage/storage_feature.go @@ -7,26 +7,39 @@ import ( // 存储服务特性 type StorageFeature interface { - GetType() string + GetFeatureType() string // 输出调试用的字符串,不要包含敏感信息 String() string } var _ = serder.UseTypeUnionInternallyTagged(types.Ref(types.NewTypeUnion[StorageFeature]( + (*TempStore)(nil), (*BypassWriteFeature)(nil), (*MultipartUploadFeature)(nil), (*InternalServerlessCallFeature)(nil), )), "type") +type TempStore struct { + serder.Metadata `union:"TempStore"` + Type string `json:"type"` + TempRoot string `json:"tempRoot"` // 临时文件存放目录 +} + +func (f *TempStore) GetFeatureType() string { + return "TempStore" +} + +func (f *TempStore) String() string { + return "TempStore" +} + // 存储服务支持被非MasterHub直接上传文件 type BypassWriteFeature struct { serder.Metadata `union:"BypassWrite"` Type string `json:"type"` - // 存放上传文件的临时目录 - TempRoot string `json:"tempRoot"` } -func (f *BypassWriteFeature) GetType() string { +func (f *BypassWriteFeature) GetFeatureType() string { return "BypassWrite" } @@ -38,9 +51,12 @@ func (f *BypassWriteFeature) String() string { type MultipartUploadFeature struct { serder.Metadata `union:"MultipartUpload"` Type string `json:"type"` + TempDir string `json:"tempDir"` // 临时文件存放目录 + MinPartSize int64 `json:"minPartSize"` // 最小分段大小 + MaxPartSize int64 `json:"maxPartSize"` // 最大分段大小 } -func (f *MultipartUploadFeature) GetType() string { +func (f *MultipartUploadFeature) GetFeatureType() string { return "MultipartUpload" } @@ -55,7 +71,7 @@ type InternalServerlessCallFeature struct { CommandDir string `json:"commandDir"` // 存放命令文件的目录 } -func (f *InternalServerlessCallFeature) GetType() string { +func (f *InternalServerlessCallFeature) GetFeatureType() string { return "InternalServerlessCall" } diff --git a/utils/http2/http.go b/utils/http2/http.go index 338f95c..ab67e21 100644 --- a/utils/http2/http.go +++ b/utils/http2/http.go @@ -385,7 +385,7 @@ func PostMultiPart(url string, param MultiPartRequestParam) (*http.Response, err defer muWriter.Close() if param.Form != nil { - mp, err := objectToStringMap(param.Form) + mp, err := objectToStringMap(param.Form, "json") if err != nil { return fmt.Errorf("formValues object to map failed, err: %w", err) } @@ -477,7 +477,7 @@ func prepareQuery(req *http.Request, query any) error { mp, ok := query.(map[string]string) if !ok { var err error - if mp, err = objectToStringMap(query); err != nil { + if mp, err = objectToStringMap(query, "form"); err != nil { return fmt.Errorf("query object to map: %w", err) } } @@ -499,7 +499,7 @@ func prepareHeader(req *http.Request, header any) error { mp, ok := header.(map[string]string) if !ok { var err error - if mp, err = objectToStringMap(header); err != nil { + if mp, err = objectToStringMap(header, "json"); err != nil { return fmt.Errorf("header object to map: %w", err) } } @@ -543,7 +543,7 @@ func prepareFormBody(req *http.Request, body any) error { mp, ok := body.(map[string]string) if !ok { var err error - if mp, err = objectToStringMap(body); err != nil { + if mp, err = objectToStringMap(body, "json"); err != nil { return fmt.Errorf("body object to map: %w", err) } } @@ -577,10 +577,10 @@ func setValue(values ul.Values, key, value string) ul.Values { return values } -func objectToStringMap(obj any) (map[string]string, error) { +func objectToStringMap(obj any, tag string) (map[string]string, error) { anyMap := make(map[string]any) dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - TagName: "json", + TagName: tag, Result: &anyMap, WeaklyTypedInput: true, }) diff --git a/utils/http2/http_test.go b/utils/http2/http_test.go index 85cef5f..d5922eb 100644 --- a/utils/http2/http_test.go +++ b/utils/http2/http_test.go @@ -21,7 +21,7 @@ func Test_objectToStringMap(t *testing.T) { Omit: nil, } - mp, err := objectToStringMap(a) + mp, err := objectToStringMap(a, "json") So(err, ShouldBeNil) So(mp, ShouldResemble, map[string]string{ diff --git a/utils/io2/hash.go b/utils/io2/hash.go new file mode 100644 index 0000000..b3ae829 --- /dev/null +++ b/utils/io2/hash.go @@ -0,0 +1,54 @@ +package io2 + +import ( + "hash" + "io" +) + +type ReadHasher struct { + hasher hash.Hash + inner io.Reader +} + +func NewReadHasher(h hash.Hash, r io.Reader) *ReadHasher { + return &ReadHasher{ + hasher: h, + inner: r, + } +} + +func (h *ReadHasher) Read(p []byte) (n int, err error) { + n, err = h.inner.Read(p) + if n > 0 { + h.hasher.Write(p[:n]) + } + return +} + +func (h *ReadHasher) Sum() []byte { + return h.hasher.Sum(nil) +} + +type WriteHasher struct { + hasher hash.Hash + inner io.Writer +} + +func NewWriteHasher(h hash.Hash, w io.Writer) *WriteHasher { + return &WriteHasher{ + hasher: h, + inner: w, + } +} + +func (h *WriteHasher) Write(p []byte) (n int, err error) { + n, err = h.inner.Write(p) + if n > 0 { + h.hasher.Write(p[:n]) + } + return +} + +func (h *WriteHasher) Sum() []byte { + return h.hasher.Sum(nil) +} diff --git a/utils/io2/stats.go b/utils/io2/stats.go new file mode 100644 index 0000000..b6c7956 --- /dev/null +++ b/utils/io2/stats.go @@ -0,0 +1,22 @@ +package io2 + +import "io" + +type Counter struct { + inner io.Reader + count int64 +} + +func (c *Counter) Read(buf []byte) (n int, err error) { + n, err = c.inner.Read(buf) + c.count += int64(n) + return +} + +func (c *Counter) Count() int64 { + return c.count +} + +func NewCounter(inner io.Reader) *Counter { + return &Counter{inner: inner, count: 0} +} diff --git a/utils/math2/math.go b/utils/math2/math.go index ee65faa..7a8508b 100644 --- a/utils/math2/math.go +++ b/utils/math2/math.go @@ -45,3 +45,30 @@ func Clamp[T constraints.Integer](v, min, max T) T { return v } + +// 将一个整数切分成小于maxValue的整数列表,尽量均匀 +func SplitLessThan[T constraints.Integer](v T, maxValue T) []T { + cnt := int(CeilDiv(v, maxValue)) + result := make([]T, cnt) + last := int64(0) + for i := 0; i < cnt; i++ { + cur := int64(v) * int64(i+1) / int64(cnt) + result[i] = T(cur - last) + last = cur + } + + return result +} + +// 将一个整数切分成n个整数,尽量均匀 +func SplitN[T constraints.Integer](v T, n int) []T { + result := make([]T, n) + last := int64(0) + for i := 0; i < n; i++ { + cur := int64(v) * int64(i+1) / int64(n) + result[i] = T(cur - last) + last = cur + } + + return result +} diff --git a/utils/math2/math_test.go b/utils/math2/math_test.go new file mode 100644 index 0000000..7e75164 --- /dev/null +++ b/utils/math2/math_test.go @@ -0,0 +1,57 @@ +package math2 + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func Test_SplitLessThan(t *testing.T) { + checker := func(t *testing.T, arr []int, total int, maxValue int) { + t.Logf("arr: %v, total: %d, maxValue: %d", arr, total, maxValue) + + sum := 0 + for _, v := range arr { + sum += v + + if v > maxValue { + t.Errorf("value should be less than %d", maxValue) + } + } + + if sum != total { + t.Errorf("sum should be %d", total) + } + } + + Convey("测试", t, func() { + checker(t, SplitLessThan(9, 9), 9, 9) + checker(t, SplitLessThan(9, 3), 9, 3) + checker(t, SplitLessThan(10, 3), 10, 3) + checker(t, SplitLessThan(11, 3), 11, 3) + checker(t, SplitLessThan(12, 3), 12, 3) + }) +} + +func Test_SplitN(t *testing.T) { + checker := func(t *testing.T, arr []int, total int) { + t.Logf("arr: %v, total: %d", arr, total) + + sum := 0 + for _, v := range arr { + sum += v + } + + if sum != total { + t.Errorf("sum should be %d", total) + } + } + + Convey("测试", t, func() { + checker(t, SplitN(9, 9), 9) + checker(t, SplitN(9, 3), 9) + checker(t, SplitN(10, 3), 10) + checker(t, SplitN(11, 3), 11) + checker(t, SplitN(12, 3), 12) + }) +}