diff --git a/common/pkgs/ioswitch2/ops2/ec.go b/common/pkgs/ioswitch2/ops2/ec.go index 2e81063..5f3a1c4 100644 --- a/common/pkgs/ioswitch2/ops2/ec.go +++ b/common/pkgs/ioswitch2/ops2/ec.go @@ -10,6 +10,7 @@ import ( "gitlink.org.cn/cloudream/common/pkgs/ioswitch/utils" cdssdk "gitlink.org.cn/cloudream/common/sdks/storage" "gitlink.org.cn/cloudream/common/utils/io2" + "gitlink.org.cn/cloudream/common/utils/math2" "gitlink.org.cn/cloudream/common/utils/sync2" "gitlink.org.cn/cloudream/storage/common/pkgs/ec" ) @@ -45,20 +46,35 @@ func (o *ECMultiply) Execute(ctx *exec.ExecContext, e *exec.Executor) error { outputWrs[i] = wr } - fut := future.NewSetVoid() - go func() { - mul := ec.GaloisMultiplier().BuildGalois() + inputChunks := make([][]byte, len(o.Inputs)) + for i := range o.Inputs { + inputChunks[i] = make([]byte, math2.Min(o.ChunkSize, 64*1024)) + } - inputChunks := make([][]byte, len(o.Inputs)) - for i := range o.Inputs { - inputChunks[i] = make([]byte, o.ChunkSize) - } + // 输出用两个缓冲轮换 + outputBufPool := sync2.NewBucketPool[[][]byte]() + for i := 0; i < 2; i++ { outputChunks := make([][]byte, len(o.Outputs)) for i := range o.Outputs { - outputChunks[i] = make([]byte, o.ChunkSize) + outputChunks[i] = make([]byte, math2.Min(o.ChunkSize, 64*1024)) } + outputBufPool.PutEmpty(outputChunks) + } + + fut := future.NewSetVoid() + go func() { + mul := ec.GaloisMultiplier().BuildGalois() + defer outputBufPool.WakeUpAll() + + readLens := math2.SplitLessThan(o.ChunkSize, 64*1024) + readLenIdx := 0 for { + curReadLen := readLens[readLenIdx] + for i := range inputChunks { + inputChunks[i] = inputChunks[i][:curReadLen] + } + err := sync2.ParallelDo(inputs, func(s *exec.StreamValue, i int) error { _, err := io.ReadFull(s.Stream, inputChunks[i]) return err @@ -72,12 +88,34 @@ func (o *ECMultiply) Execute(ctx *exec.ExecContext, e *exec.Executor) error { return } - err = mul.Multiply(o.Coef, inputChunks, outputChunks) + outputBuf, ok := outputBufPool.GetEmpty() + if !ok { + return + } + for i := range outputBuf { + outputBuf[i] = outputBuf[i][:curReadLen] + } + + err = mul.Multiply(o.Coef, inputChunks, outputBuf) if err != nil { fut.SetError(err) return } + outputBufPool.PutFilled(outputBuf) + readLenIdx = (readLenIdx + 1) % len(readLens) + } + }() + + go func() { + defer outputBufPool.WakeUpAll() + + for { + outputChunks, ok := outputBufPool.GetFilled() + if !ok { + return + } + for i := range o.Outputs { err := io2.WriteAll(outputWrs[i], outputChunks[i]) if err != nil { @@ -85,6 +123,8 @@ func (o *ECMultiply) Execute(ctx *exec.ExecContext, e *exec.Executor) error { return } } + + outputBufPool.PutEmpty(outputChunks) } }()