diff --git a/utils/io2/ring.go b/utils/io2/ring.go new file mode 100644 index 0000000..cb0c1c8 --- /dev/null +++ b/utils/io2/ring.go @@ -0,0 +1,117 @@ +package io2 + +import ( + "io" + "sync" +) + +type RingBuffer2 struct { + buf []byte + src io.ReadCloser + err error + isReading bool + writePos int // 指向下一次写入的位置,应该是一个空位 + readPos int // 执行下一次读取的位置,应该是有效数据 + waitReading *sync.Cond + waitComsuming *sync.Cond + UpstreamName string + DownstreamName string +} + +func RingBuffer(src io.ReadCloser, size int) io.ReadCloser { + lk := &sync.Mutex{} + return &RingBuffer2{ + buf: make([]byte, size), + src: src, + waitReading: sync.NewCond(lk), + waitComsuming: sync.NewCond(lk), + } +} + +func (r *RingBuffer2) Read(p []byte) (n int, err error) { + r.waitReading.L.Lock() + if !r.isReading { + go r.reading() + r.isReading = true + } + + for r.writePos == r.readPos { + if r.err != nil { + r.waitReading.L.Unlock() + return 0, r.err + } + + // startTime := time.Now() + r.waitReading.Wait() + // fmt.Printf("%s wait data for %v\n", r.DownstreamName, time.Since(startTime)) + } + writePos := r.writePos + readPos := r.readPos + r.waitReading.L.Unlock() + + if readPos < writePos { + n = copy(p, r.buf[readPos:writePos]) + } else { + n = copy(p, r.buf[readPos:]) + } + + r.waitComsuming.L.Lock() + r.readPos = (r.readPos + n) % len(r.buf) + r.waitComsuming.L.Unlock() + r.waitComsuming.Broadcast() + + err = nil + return +} + +func (r *RingBuffer2) Close() error { + r.src.Close() + r.waitComsuming.Broadcast() + r.waitReading.Broadcast() + return nil +} + +func (r *RingBuffer2) reading() { + defer r.src.Close() + + for { + r.waitComsuming.L.Lock() + // writePos不能和readPos重合,因为无法区分缓冲区是已经满了,还是完全是空的 + // 所以writePos最多能到readPos的前一格 + for r.writePos+1 == r.readPos { + r.waitComsuming.Wait() + + if r.err != nil { + return + } + } + writePos := r.writePos + readPos := r.readPos + r.waitComsuming.L.Unlock() + + var n int + var err error + if readPos <= writePos { + // 同上理,写入数据的时候如果readPos为0,则它的前一格是底层缓冲区的最后一格 + // 那就不能写入到这一格 + if readPos == 0 { + n, err = r.src.Read(r.buf[writePos : len(r.buf)-1]) + } else { + n, err = r.src.Read(r.buf[writePos:]) + } + } else if readPos > writePos { + n, err = r.src.Read(r.buf[writePos:readPos]) + } + + // 无论成功还是失败,都发送一下信号通知读取端 + r.waitReading.L.Lock() + r.err = err + r.writePos = (r.writePos + n) % len(r.buf) + r.waitReading.L.Unlock() + r.waitReading.Broadcast() + + if err != nil { + break + } + } +} diff --git a/utils/io2/ring_test.go b/utils/io2/ring_test.go new file mode 100644 index 0000000..3c50fd4 --- /dev/null +++ b/utils/io2/ring_test.go @@ -0,0 +1,110 @@ +package io2 + +import ( + "bytes" + "io" + "testing" + + . "github.com/smartystreets/goconvey/convey" + "gitlink.org.cn/cloudream/common/utils/sync2" +) + +type syncReader struct { + data [][]byte + curDataIndex int + nextData int + counter *sync2.CounterCond +} + +func (r *syncReader) Read(p []byte) (n int, err error) { + if r.nextData >= len(r.data) { + return 0, io.EOF + } + + if r.data[r.nextData] == nil { + r.counter.Wait() + r.nextData++ + } + + n = copy(p, r.data[r.nextData][r.curDataIndex:]) + r.curDataIndex += n + if r.curDataIndex == len(r.data[r.nextData]) { + r.curDataIndex = 0 + r.nextData++ + } + return n, nil +} + +func (r *syncReader) Close() error { + return nil +} + +func Test_RingBuffer(t *testing.T) { + Convey("写满读满", t, func() { + b := RingBuffer(io.NopCloser(bytes.NewBuffer([]byte{1, 2, 3})), 4) + + ret := make([]byte, 3) + n, err := b.Read(ret) + So(err, ShouldEqual, nil) + So(n, ShouldEqual, 3) + So(ret, ShouldResemble, []byte{1, 2, 3}) + }) + + Convey("1+3+1", t, func() { + sy := sync2.NewCounterCond(0) + + b := RingBuffer(&syncReader{ + data: [][]byte{ + {1}, + nil, + {2, 3, 4, 5}, + }, + counter: sy, + }, 4) + + ret := make([]byte, 3) + n, err := b.Read(ret) + So(err, ShouldEqual, nil) + So(n, ShouldEqual, 1) + So(ret[:n], ShouldResemble, []byte{1}) + + sy.Release() + + n, err = b.Read(ret) + So(err, ShouldEqual, nil) + So(n, ShouldEqual, 3) + So(ret[:n], ShouldResemble, []byte{2, 3, 4}) + + n, err = b.Read(ret) + So(err, ShouldEqual, nil) + So(n, ShouldEqual, 1) + So(ret[:n], ShouldResemble, []byte{5}) + }) + + Convey("3+1+2", t, func() { + sy := sync2.NewCounterCond(0) + + b := RingBuffer(&syncReader{ + data: [][]byte{ + {1, 2, 3, 4, 5, 6}, + }, + counter: sy, + }, 4) + + ret := make([]byte, 3) + n, err := b.Read(ret) + So(err, ShouldEqual, nil) + So(n, ShouldEqual, 3) + So(ret[:n], ShouldResemble, []byte{1, 2, 3}) + + n, err = b.Read(ret) + So(err, ShouldEqual, nil) + So(n, ShouldEqual, 1) + So(ret[:n], ShouldResemble, []byte{4}) + + n, err = b.Read(ret) + So(err, ShouldEqual, nil) + So(n, ShouldEqual, 2) + So(ret[:n], ShouldResemble, []byte{5, 6}) + }) +} diff --git a/utils/sync2/sync2.go b/utils/sync2/sync2.go index 8555a3b..bb346fe 100644 --- a/utils/sync2/sync2.go +++ b/utils/sync2/sync2.go @@ -27,3 +27,27 @@ func ParallelDo[T any](args []T, fn func(val T, index int) error) error { return err } + +func ParallelDoMap[K comparable, V any](args map[K]V, fn func(k K, v V) error) error { + lock := sync.Mutex{} + var err error + + var wg sync.WaitGroup + wg.Add(len(args)) + for k, v := range args { + go func(k K, v V) { + defer wg.Done() + + if e := fn(k, v); e != nil { + lock.Lock() + if err == nil { + err = e + } + lock.Unlock() + } + }(k, v) + } + wg.Wait() + + return err +}