From 1f716ea452bb684eca999c4fd287b9b2f8b0df3a Mon Sep 17 00:00:00 2001 From: Sydonian <794346190@qq.com> Date: Fri, 24 Nov 2023 16:43:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0io=E7=9B=B8=E5=85=B3=E7=9A=84?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkgs/distlock/internal/acquire_actor.go | 34 +-- pkgs/distlock/internal/release_actor.go | 29 +-- sdks/storage/models.go | 16 +- utils/io/chunked_split.go | 79 +++++++ utils/io/chunked_split_test.go | 298 ++++++++++++++++++++++++ utils/io/io.go | 13 ++ utils/io/io_test.go | 117 ++++++++++ utils/io/join.go | 107 +++++++++ utils/io/length.go | 65 ++++++ utils/io/zero.go | 25 ++ 10 files changed, 748 insertions(+), 35 deletions(-) create mode 100644 utils/io/chunked_split.go create mode 100644 utils/io/chunked_split_test.go create mode 100644 utils/io/io_test.go create mode 100644 utils/io/join.go create mode 100644 utils/io/length.go create mode 100644 utils/io/zero.go diff --git a/pkgs/distlock/internal/acquire_actor.go b/pkgs/distlock/internal/acquire_actor.go index a26a62a..78589b8 100644 --- a/pkgs/distlock/internal/acquire_actor.go +++ b/pkgs/distlock/internal/acquire_actor.go @@ -41,7 +41,7 @@ func NewAcquireActor(cfg *Config, etcdCli *clientv3.Client) *AcquireActor { cfg: cfg, etcdCli: etcdCli, isMaintenance: true, - doAcquiringChan: make(chan any), + doAcquiringChan: make(chan any, 1), } } @@ -139,27 +139,34 @@ func (a *AcquireActor) ResetState(serviceID string) { func (a *AcquireActor) Serve() { for { + // 离开了select块之后doAcquiringChan的buf就会空出来, + // 如果之后成功提交了一个锁请求,那么WatchEtcd会收到事件,然后调用此Actor的回调再次设置doAcquiringChan。 + // 因此无论多少个锁请求同时提交,或者是在doAcquiring期间提交,都不会因为某一个成功了,其他的就连试都不试就开始等待。 + // 如果没有一个锁请求提交成功,那自然是已经尝试过所有锁请求了,此时等待新事件到来后F再来尝试也是合理的。 select { case <-a.doAcquiringChan: - err := a.doAcquiring() - if err != nil { - logger.Std.Debugf("doing acquiring: %s", err.Error()) - } + } + + // 如果没有锁请求,那么就不需要进行加锁操作 + a.lock.Lock() + if len(a.acquirings) == 0 { + a.lock.Unlock() + continue + } + a.lock.Unlock() + + err := a.doAcquiring() + if err != nil { + logger.Std.Debugf("doing acquiring: %s", err.Error()) } } } +// 返回true代表成功提交了一个锁 func (a *AcquireActor) doAcquiring() error { + // TODO 配置等待时间 ctx := context.Background() - // 先看一眼,如果没有需要请求的锁,就不用走后面的流程了 - a.lock.Lock() - if len(a.acquirings) == 0 { - a.lock.Unlock() - return nil - } - a.lock.Unlock() - // 在获取全局锁的时候不用锁Actor,只有获取成功了,才加锁 // TODO 根据不同的错误设置不同的错误类型,方便上层进行后续处理 unlock, err := acquireEtcdRequestDataLock(ctx, a.etcdCli, a.cfg.EtcdLockLeaseTimeSec) @@ -174,7 +181,6 @@ func (a *AcquireActor) doAcquiring() error { } // 等待本地状态同步到最新 - // TODO 配置等待时间 err = a.providersActor.WaitLocalIndexTo(ctx, index) if err != nil { return err diff --git a/pkgs/distlock/internal/release_actor.go b/pkgs/distlock/internal/release_actor.go index 72e9145..511fa20 100644 --- a/pkgs/distlock/internal/release_actor.go +++ b/pkgs/distlock/internal/release_actor.go @@ -36,7 +36,7 @@ func NewReleaseActor(cfg *Config, etcdCli *clientv3.Client) *ReleaseActor { etcdCli: etcdCli, isMaintenance: true, releasingLockRequestIDs: make(map[string]bool), - doReleasingChan: make(chan any), + doReleasingChan: make(chan any, 1), } } @@ -120,12 +120,23 @@ func (a *ReleaseActor) OnLockRequestEvent(event LockRequestEvent) { func (a *ReleaseActor) Serve() { for { + // 与Acquire不同,解锁操作不需要进行互斥判断,而且能一次性解锁多个, + // 所以此处也能保证新提交的解锁请求都会被尝试后再进入等待。 select { case <-a.doReleasingChan: - err := a.doReleasing() - if err != nil { - logger.Std.Debugf("doing releasing: %s", err.Error()) - } + } + + // 先看一眼,如果没有需要释放的锁,就重新进入等待状态 + a.lock.Lock() + if len(a.releasingLockRequestIDs) == 0 { + a.lock.Unlock() + continue + } + a.lock.Unlock() + + err := a.doReleasing() + if err != nil { + logger.Std.Debugf("doing releasing: %s", err.Error()) } } } @@ -134,14 +145,6 @@ func (a *ReleaseActor) doReleasing() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - // 先看一眼,如果没有需要释放的锁,就不用走后面的流程了 - a.lock.Lock() - if len(a.releasingLockRequestIDs) == 0 { - a.lock.Unlock() - return nil - } - a.lock.Unlock() - // 在获取全局锁的时候不用锁Actor,只有获取成功了,才加锁 // TODO 根据不同的错误设置不同的错误类型,方便上层进行后续处理 unlock, err := acquireEtcdRequestDataLock(ctx, a.etcdCli, a.cfg.EtcdLockLeaseTimeSec) diff --git a/sdks/storage/models.go b/sdks/storage/models.go index 5e35809..b0600a9 100644 --- a/sdks/storage/models.go +++ b/sdks/storage/models.go @@ -35,14 +35,14 @@ func NewRepRedundancyInfo(repCount int) RepRedundancyInfo { } type ECRedundancyInfo struct { - ECName string `json:"ecName"` - PacketSize int64 `json:"packetSize"` + ECName string `json:"ecName"` + ChunkSize int64 `json:"chunkSize"` } -func NewECRedundancyInfo(ecName string, packetSize int64) ECRedundancyInfo { +func NewECRedundancyInfo(ecName string, chunkSize int64) ECRedundancyInfo { return ECRedundancyInfo{ - ECName: ecName, - PacketSize: packetSize, + ECName: ecName, + ChunkSize: chunkSize, } } @@ -74,12 +74,12 @@ func NewTypedRepRedundancyInfo(repCount int) TypedRedundancyInfo { } } -func NewTypedECRedundancyInfo(ecName string, packetSize int64) TypedRedundancyInfo { +func NewTypedECRedundancyInfo(ecName string, chunkSize int64) TypedRedundancyInfo { return TypedRedundancyInfo{ Type: RedundancyRep, Info: ECRedundancyInfo{ - ECName: ecName, - PacketSize: packetSize, + ECName: ecName, + ChunkSize: chunkSize, }, } } diff --git a/utils/io/chunked_split.go b/utils/io/chunked_split.go new file mode 100644 index 0000000..97c68d4 --- /dev/null +++ b/utils/io/chunked_split.go @@ -0,0 +1,79 @@ +package io + +import ( + "fmt" + "io" +) + +type ChunkedSplitOption struct { + // 如果流的长度不是chunkSize * streamCount的整数倍,启用此参数后,会在输出流里填充0直到满足长度 + FillZeros bool +} + +// 每次读取一个chunkSize大小的数据,然后轮流写入到返回的流中。注:读取不同流的操作必须在不同的goroutine中进行,或者按顺序读取,每次精确读取一个chunkSize大小 +func ChunkedSplit(stream io.Reader, chunkSize int64, streamCount int, opts ...ChunkedSplitOption) []io.ReadCloser { + var opt ChunkedSplitOption + if len(opts) > 0 { + opt = opts[0] + } + + buf := make([]byte, chunkSize) + prs := make([]io.ReadCloser, streamCount) + pws := make([]*io.PipeWriter, streamCount) + for i := 0; i < streamCount; i++ { + pr, pw := io.Pipe() + prs[i] = pr + pws[i] = pw + } + + go func() { + var closeErr error + eof := false + loop: + for { + for i := 0; i < streamCount; i++ { + var rd int = 0 + if !eof { + var err error + rd, err = io.ReadFull(stream, buf) + if err == io.ErrUnexpectedEOF || err == io.EOF { + eof = true + } else if err != nil { + closeErr = err + break loop + } + } + + // 如果rd为0,那么肯定是eof,如果此时正好是在一轮读取的第一次,那么就直接退出整个读取,避免填充不必要的0 + if rd == 0 && i == 0 { + break + } + + if opt.FillZeros { + Zero(buf[rd:]) + err := WriteAll(pws[i], buf) + if err != nil { + closeErr = fmt.Errorf("writing to one of the output streams: %w", err) + break loop + } + } else { + err := WriteAll(pws[i], buf[:rd]) + if err != nil { + closeErr = fmt.Errorf("writing to one of the output streams: %w", err) + break loop + } + } + } + + if eof { + break + } + } + + for _, pw := range pws { + pw.CloseWithError(closeErr) + } + }() + + return prs +} diff --git a/utils/io/chunked_split_test.go b/utils/io/chunked_split_test.go new file mode 100644 index 0000000..0e35574 --- /dev/null +++ b/utils/io/chunked_split_test.go @@ -0,0 +1,298 @@ +package io + +import ( + "bytes" + "io" + "sync" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func Test_RoundRobin(t *testing.T) { + Convey("数据长度为chunkSize的整数倍", t, func() { + input := []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, 6, 7, 8, 9, + } + + outputs := ChunkedSplit(bytes.NewReader(input), 3, 3) + + wg := sync.WaitGroup{} + wg.Add(3) + + o1 := make([]byte, 10) + var e1 error + var rd1 int + go func() { + rd1, e1 = io.ReadFull(outputs[0], o1) + wg.Done() + }() + + o2 := make([]byte, 10) + var e2 error + var rd2 int + go func() { + rd2, e2 = io.ReadFull(outputs[1], o2) + wg.Done() + }() + + o3 := make([]byte, 10) + var e3 error + var rd3 int + go func() { + rd3, e3 = io.ReadFull(outputs[2], o3) + wg.Done() + }() + + wg.Wait() + + So(e1, ShouldEqual, io.ErrUnexpectedEOF) + So(o1[:rd1], ShouldResemble, []byte{1, 2, 3, 1, 2, 3}) + + So(e2, ShouldEqual, io.ErrUnexpectedEOF) + So(o2[:rd2], ShouldResemble, []byte{4, 5, 6, 4, 5, 6}) + + So(e3, ShouldEqual, io.ErrUnexpectedEOF) + So(o3[:rd3], ShouldResemble, []byte{7, 8, 9, 7, 8, 9}) + }) + + Convey("数据长度比chunkSize的整数倍少小于chunkSize的数量,且不填充0", t, func() { + input := []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, 6, 7, + } + + outputs := ChunkedSplit(bytes.NewReader(input), 3, 3) + + wg := sync.WaitGroup{} + wg.Add(3) + + o1 := make([]byte, 10) + var e1 error + var rd1 int + go func() { + rd1, e1 = io.ReadFull(outputs[0], o1) + wg.Done() + }() + + o2 := make([]byte, 10) + var e2 error + var rd2 int + go func() { + rd2, e2 = io.ReadFull(outputs[1], o2) + wg.Done() + }() + + o3 := make([]byte, 10) + var e3 error + var rd3 int + go func() { + rd3, e3 = io.ReadFull(outputs[2], o3) + wg.Done() + }() + + wg.Wait() + + So(e1, ShouldEqual, io.ErrUnexpectedEOF) + So(o1[:rd1], ShouldResemble, []byte{1, 2, 3, 1, 2, 3}) + + So(e2, ShouldEqual, io.ErrUnexpectedEOF) + So(o2[:rd2], ShouldResemble, []byte{4, 5, 6, 4, 5, 6}) + + So(e3, ShouldEqual, io.ErrUnexpectedEOF) + So(o3[:rd3], ShouldResemble, []byte{7, 8, 9, 7}) + }) + + Convey("数据长度比chunkSize的整数倍少多于chunkSize的数量,且不填充0", t, func() { + input := []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, + } + + outputs := ChunkedSplit(bytes.NewReader(input), 3, 3) + + wg := sync.WaitGroup{} + wg.Add(3) + + o1 := make([]byte, 10) + var e1 error + var rd1 int + go func() { + rd1, e1 = io.ReadFull(outputs[0], o1) + wg.Done() + }() + + o2 := make([]byte, 10) + var e2 error + var rd2 int + go func() { + rd2, e2 = io.ReadFull(outputs[1], o2) + wg.Done() + }() + + o3 := make([]byte, 10) + var e3 error + var rd3 int + go func() { + rd3, e3 = io.ReadFull(outputs[2], o3) + wg.Done() + }() + + wg.Wait() + + So(e1, ShouldEqual, io.ErrUnexpectedEOF) + So(o1[:rd1], ShouldResemble, []byte{1, 2, 3, 1, 2, 3}) + + So(e2, ShouldEqual, io.ErrUnexpectedEOF) + So(o2[:rd2], ShouldResemble, []byte{4, 5, 6, 4, 5}) + + So(e3, ShouldEqual, io.ErrUnexpectedEOF) + So(o3[:rd3], ShouldResemble, []byte{7, 8, 9}) + }) + + Convey("数据长度是chunkSize的整数倍,且填充0", t, func() { + input := []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, + } + + outputs := ChunkedSplit(bytes.NewReader(input), 3, 3, ChunkedSplitOption{ + FillZeros: true, + }) + + wg := sync.WaitGroup{} + wg.Add(3) + + o1 := make([]byte, 10) + var e1 error + var rd1 int + go func() { + rd1, e1 = io.ReadFull(outputs[0], o1) + wg.Done() + }() + + o2 := make([]byte, 10) + var e2 error + var rd2 int + go func() { + rd2, e2 = io.ReadFull(outputs[1], o2) + wg.Done() + }() + + o3 := make([]byte, 10) + var e3 error + var rd3 int + go func() { + rd3, e3 = io.ReadFull(outputs[2], o3) + wg.Done() + }() + + wg.Wait() + + So(e1, ShouldEqual, io.ErrUnexpectedEOF) + So(o1[:rd1], ShouldResemble, []byte{1, 2, 3}) + + So(e2, ShouldEqual, io.ErrUnexpectedEOF) + So(o2[:rd2], ShouldResemble, []byte{4, 5, 6}) + + So(e3, ShouldEqual, io.ErrUnexpectedEOF) + So(o3[:rd3], ShouldResemble, []byte{7, 8, 9}) + }) + + Convey("数据长度比chunkSize的整数倍少小于chunkSize的数量,但是填充0", t, func() { + input := []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, 6, 7, + } + + outputs := ChunkedSplit(bytes.NewReader(input), 3, 3, ChunkedSplitOption{ + FillZeros: true, + }) + wg := sync.WaitGroup{} + wg.Add(3) + + o1 := make([]byte, 10) + var e1 error + var rd1 int + go func() { + rd1, e1 = io.ReadFull(outputs[0], o1) + wg.Done() + }() + + o2 := make([]byte, 10) + var e2 error + var rd2 int + go func() { + rd2, e2 = io.ReadFull(outputs[1], o2) + wg.Done() + }() + + o3 := make([]byte, 10) + var e3 error + var rd3 int + go func() { + rd3, e3 = io.ReadFull(outputs[2], o3) + wg.Done() + }() + + wg.Wait() + + So(e1, ShouldEqual, io.ErrUnexpectedEOF) + So(o1[:rd1], ShouldResemble, []byte{1, 2, 3, 1, 2, 3}) + + So(e2, ShouldEqual, io.ErrUnexpectedEOF) + So(o2[:rd2], ShouldResemble, []byte{4, 5, 6, 4, 5, 6}) + + So(e3, ShouldEqual, io.ErrUnexpectedEOF) + So(o3[:rd3], ShouldResemble, []byte{7, 8, 9, 7, 0, 0}) + }) + + Convey("数据长度比chunkSize的整数倍少多于chunkSize的数量,但是填充0", t, func() { + input := []byte{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, + } + + outputs := ChunkedSplit(bytes.NewReader(input), 3, 3, ChunkedSplitOption{ + FillZeros: true, + }) + + wg := sync.WaitGroup{} + wg.Add(3) + + o1 := make([]byte, 10) + var e1 error + var rd1 int + go func() { + rd1, e1 = io.ReadFull(outputs[0], o1) + wg.Done() + }() + + o2 := make([]byte, 10) + var e2 error + var rd2 int + go func() { + rd2, e2 = io.ReadFull(outputs[1], o2) + wg.Done() + }() + + o3 := make([]byte, 10) + var e3 error + var rd3 int + go func() { + rd3, e3 = io.ReadFull(outputs[2], o3) + wg.Done() + }() + + wg.Wait() + + So(e1, ShouldEqual, io.ErrUnexpectedEOF) + So(o1[:rd1], ShouldResemble, []byte{1, 2, 3, 1, 2, 0}) + + So(e2, ShouldEqual, io.ErrUnexpectedEOF) + So(o2[:rd2], ShouldResemble, []byte{4, 5, 6, 0, 0, 0}) + + So(e3, ShouldEqual, io.ErrUnexpectedEOF) + So(o3[:rd3], ShouldResemble, []byte{7, 8, 9, 0, 0, 0}) + }) +} diff --git a/utils/io/io.go b/utils/io/io.go index e5ee9c7..af8c639 100644 --- a/utils/io/io.go +++ b/utils/io/io.go @@ -112,3 +112,16 @@ func Lazy(open func() (io.ReadCloser, error)) *LazyReadCloser { open: open, } } + +func ToReaders(strs []io.ReadCloser) ([]io.Reader, func()) { + var readers []io.Reader + for _, s := range strs { + readers = append(readers, s) + } + + return readers, func() { + for _, s := range strs { + s.Close() + } + } +} diff --git a/utils/io/io_test.go b/utils/io/io_test.go new file mode 100644 index 0000000..3066aa9 --- /dev/null +++ b/utils/io/io_test.go @@ -0,0 +1,117 @@ +package io + +import ( + "bytes" + "io" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func Test_Join(t *testing.T) { + Convey("连接多个流", t, func() { + str := Join([]io.Reader{ + bytes.NewReader([]byte{1, 2, 3}), + bytes.NewReader([]byte{4}), + bytes.NewReader([]byte{5, 6, 7, 8}), + }) + + buf := make([]byte, 9) + rd, err := io.ReadFull(str, buf) + + So(err, ShouldEqual, io.ErrUnexpectedEOF) + So(buf[:rd], ShouldResemble, []byte{1, 2, 3, 4, 5, 6, 7, 8}) + }) + + Convey("分块式连接多个流,每个流长度相等", t, func() { + str := ChunkedJoin([]io.Reader{ + bytes.NewReader([]byte{1, 2, 3}), + bytes.NewReader([]byte{4, 5, 6}), + bytes.NewReader([]byte{7, 8, 9}), + }, 3) + + buf := make([]byte, 10) + rd, err := io.ReadFull(str, buf) + + So(err, ShouldEqual, io.ErrUnexpectedEOF) + So(buf[:rd], ShouldResemble, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}) + }) + + Convey("分块式连接多个流,流长度不相等,但都是chunkSize的整数倍", t, func() { + str := ChunkedJoin([]io.Reader{ + bytes.NewReader([]byte{1, 2, 3}), + bytes.NewReader([]byte{4, 5, 6, 7, 8, 9, 10, 11, 12}), + bytes.NewReader([]byte{}), + bytes.NewReader([]byte{13, 14, 15}), + }, 3) + + buf := make([]byte, 100) + rd, err := io.ReadFull(str, buf) + + So(err, ShouldEqual, io.ErrUnexpectedEOF) + So(buf[:rd], ShouldResemble, []byte{1, 2, 3, 4, 5, 6, 13, 14, 15, 7, 8, 9, 10, 11, 12}) + }) + + Convey("分块式连接多个流,流长度不相等,且不一定是chunkSize的整数倍", t, func() { + str := ChunkedJoin([]io.Reader{ + bytes.NewReader([]byte{1, 2, 3}), + bytes.NewReader([]byte{4, 5, 6, 7, 8}), + bytes.NewReader([]byte{9}), + }, 3) + + buf := make([]byte, 10) + rd, err := io.ReadFull(str, buf) + + So(err, ShouldEqual, io.ErrUnexpectedEOF) + So(buf[:rd], ShouldResemble, []byte{1, 2, 3, 4, 5, 6, 9, 7, 8}) + }) +} + +func Test_Length(t *testing.T) { + Convey("非强制,长度刚好", t, func() { + str := Length(bytes.NewReader([]byte{1, 2, 3}), 3) + buf := make([]byte, 9) + rd, err := io.ReadFull(str, buf) + So(err, ShouldEqual, io.ErrUnexpectedEOF) + So(buf[:rd], ShouldResemble, []byte{1, 2, 3}) + }) + + Convey("非强制,长度小于设定", t, func() { + str := Length(bytes.NewReader([]byte{1, 2}), 3) + + buf := make([]byte, 2) + rd, err := io.ReadFull(str, buf) + if err == nil { + var rd2 int + rd2, err = io.ReadFull(str, buf) + So(rd2, ShouldEqual, 0) + } + So(err, ShouldEqual, io.EOF) + So(buf[:rd], ShouldResemble, []byte{1, 2}) + }) + + Convey("非强制,长度大于设定", t, func() { + str := Length(bytes.NewReader([]byte{1, 2, 3, 4}), 3) + + buf := make([]byte, 3) + rd, err := io.ReadFull(str, buf) + if err == nil { + var rd2 int + rd2, err = io.ReadFull(str, buf) + So(rd2, ShouldEqual, 0) + } + So(err, ShouldEqual, io.EOF) + So(buf[:rd], ShouldResemble, []byte{1, 2, 3}) + }) + + Convey("强制,长度小于设定", t, func() { + str := MustLength(bytes.NewReader([]byte{1, 2}), 3) + + buf := make([]byte, 2) + _, err := io.ReadFull(str, buf) + if err == nil { + _, err = io.ReadFull(str, buf) + } + So(err, ShouldEqual, io.ErrUnexpectedEOF) + }) +} diff --git a/utils/io/join.go b/utils/io/join.go new file mode 100644 index 0000000..09c294e --- /dev/null +++ b/utils/io/join.go @@ -0,0 +1,107 @@ +package io + +import ( + "io" + + "gitlink.org.cn/cloudream/common/utils/lo" + "gitlink.org.cn/cloudream/common/utils/math" +) + +func Join(strs []io.Reader) io.ReadCloser { + + pr, pw := io.Pipe() + + go func() { + var closeErr error + + buf := make([]byte, 4096) + outer: + for _, str := range strs { + for { + bufLen := len(buf) + if bufLen == 0 { + break outer + } + + rd, err := str.Read(buf[:bufLen]) + if err != nil { + if err != io.EOF { + closeErr = err + break outer + } + + err = WriteAll(pw, buf[:rd]) + if err != nil { + closeErr = err + break outer + } + + break + } + + err = WriteAll(pw, buf[:rd]) + if err != nil { + closeErr = err + break outer + } + } + } + + pw.CloseWithError(closeErr) + }() + + return pr +} + +type chunkedJoin struct { + inputs []io.Reader + chunkSize int + currentInput int + currentRead int + err error +} + +func (s *chunkedJoin) Read(buf []byte) (int, error) { + if s.err != nil { + return 0, s.err + } + if len(s.inputs) == 0 { + return 0, io.EOF + } + + bufLen := math.Min(math.Min(s.chunkSize, len(buf)), s.chunkSize-s.currentRead) + rd, err := s.inputs[s.currentInput].Read(buf[:bufLen]) + if err == nil { + s.currentRead += rd + if s.currentRead == s.chunkSize { + s.currentInput = (s.currentInput + 1) % len(s.inputs) + s.currentRead = 0 + } + return rd, nil + } + + if err == io.EOF { + s.inputs = lo.RemoveAt(s.inputs, s.currentInput) + // 此处不需要+1 + if len(s.inputs) > 0 { + s.currentInput = s.currentInput % len(s.inputs) + s.currentRead = 0 + } + return rd, nil + } + + s.err = err + return rd, err +} + +func (s *chunkedJoin) Close() error { + s.err = io.ErrClosedPipe + return nil +} + +func ChunkedJoin(inputs []io.Reader, chunkSize int) io.ReadCloser { + return &chunkedJoin{ + inputs: inputs, + chunkSize: chunkSize, + } +} diff --git a/utils/io/length.go b/utils/io/length.go new file mode 100644 index 0000000..d86781c --- /dev/null +++ b/utils/io/length.go @@ -0,0 +1,65 @@ +package io + +import ( + "io" + + "gitlink.org.cn/cloudream/common/utils/math" +) + +type lengthStream struct { + src io.Reader + length int64 + readLength int64 + must bool + err error +} + +func (s *lengthStream) Read(buf []byte) (int, error) { + if s.err != nil { + return 0, s.err + } + + bufLen := math.Min(s.length-s.readLength, int64(len(buf))) + rd, err := s.src.Read(buf[:bufLen]) + if err == nil { + s.readLength += int64(rd) + if s.readLength == s.length { + s.err = io.EOF + } + return rd, nil + } + + if err == io.EOF { + s.readLength += int64(rd) + if s.readLength < s.length && s.must { + s.err = io.ErrUnexpectedEOF + return rd, io.ErrUnexpectedEOF + } + + s.err = io.EOF + return rd, io.EOF + } + + s.err = err + return 0, err +} + +func (s *lengthStream) Close() error { + s.err = io.ErrClosedPipe + return nil +} + +func Length(str io.Reader, length int64) io.ReadCloser { + return &lengthStream{ + src: str, + length: length, + } +} + +func MustLength(str io.Reader, length int64) io.ReadCloser { + return &lengthStream{ + src: str, + length: length, + must: true, + } +} diff --git a/utils/io/zero.go b/utils/io/zero.go new file mode 100644 index 0000000..ddd3a3f --- /dev/null +++ b/utils/io/zero.go @@ -0,0 +1,25 @@ +package io + +import "io" + +var zeros zeroStream + +type zeroStream struct{} + +func (s *zeroStream) Read(buf []byte) (int, error) { + for i := range buf { + buf[i] = 0 + } + + return len(buf), nil +} + +func Zeros() io.Reader { + return &zeros +} + +func Zero(buf []byte) { + for i := range buf { + buf[i] = 0 + } +}