|
- package rpc
-
- import (
- "fmt"
- "io"
-
- "gitlink.org.cn/cloudream/common/utils/io2"
- )
-
- type GRPCChunkedWriter interface {
- Send(*ChunkedData) error
- }
-
- type GRPCChunkedReader interface {
- Recv() (*ChunkedData, error)
- }
-
- type ChunkedWriter struct {
- gw GRPCChunkedWriter
- }
-
- func NewChunkedWriter(stream GRPCChunkedWriter) *ChunkedWriter {
- return &ChunkedWriter{gw: stream}
- }
-
- // 开始写入一个新Part。每次只能有一个Part在写入。
- func (w *ChunkedWriter) BeginPart(name string) io.Writer {
- err := w.gw.Send(&ChunkedData{Type: ChunkedDataType_NewPart, Data: []byte(name)})
- if err != nil {
- return io2.ErrorWriter(fmt.Errorf("write part name: %w", err))
- }
- return &PartWriter{cw: w}
- }
-
- func (w *ChunkedWriter) WriteDataPart(name string, data []byte) error {
- pw := w.BeginPart(name)
- return io2.WriteAll(pw, data)
- }
-
- func (w *ChunkedWriter) WriteStreamPart(name string, stream io.Reader) (int64, error) {
- pw := w.BeginPart(name)
- n, err := io.Copy(pw, stream)
- return n, err
- }
-
- // 发送ErrorPart,不会关闭连接。
- func (w *ChunkedWriter) Abort(msg string) error {
- return w.gw.Send(&ChunkedData{Type: ChunkedDataType_Error, Data: []byte(msg)})
- }
-
- // 发送EOFPart,不会关闭连接。
- func (w *ChunkedWriter) Finish() error {
- err := w.gw.Send(&ChunkedData{Type: ChunkedDataType_EOF, Data: []byte{}})
- if err != nil {
- return fmt.Errorf("write EOF: %w", err)
- }
-
- return nil
- }
-
- type PartWriter struct {
- cw *ChunkedWriter
- }
-
- func (w *PartWriter) Write(data []byte) (int, error) {
- err := w.cw.gw.Send(&ChunkedData{Type: ChunkedDataType_Data, Data: data})
- if err != nil {
- return 0, fmt.Errorf("write data: %w", err)
- }
-
- return len(data), nil
- }
-
- type ChunkedAbortError struct {
- Message string
- }
-
- func (e *ChunkedAbortError) Error() string {
- return e.Message
- }
-
- type ChunkedReader struct {
- gr GRPCChunkedReader
- chunk *ChunkedData
- err error
- }
-
- func NewChunkedReader(gr GRPCChunkedReader) *ChunkedReader {
- return &ChunkedReader{gr: gr}
- }
-
- // 读取下一个Part。每次只能读取一个Part,且必须将其全部读取完毕才能读取下一个
- func (r *ChunkedReader) NextPart() (string, io.Reader, error) {
- if r.err != nil {
- return "", nil, r.err
- }
-
- if r.chunk == nil {
- var err error
- r.chunk, err = r.gr.Recv()
- if err != nil {
- r.err = fmt.Errorf("receive chunk: %w", err)
- return "", nil, r.err
- }
- }
-
- switch r.chunk.Type {
- case ChunkedDataType_NewPart:
- return string(r.chunk.Data), &PartReader{creader: r}, nil
-
- case ChunkedDataType_Data:
- r.err = fmt.Errorf("unexpected data part")
- return "", nil, r.err
-
- case ChunkedDataType_EOF:
- r.err = io.EOF
- return "", nil, r.err
-
- case ChunkedDataType_Error:
- r.err = &ChunkedAbortError{Message: string(r.chunk.Data)}
- return "", nil, r.err
-
- default:
- r.err = fmt.Errorf("unknown part type: %d", r.chunk.Type)
- return "", nil, r.err
- }
- }
-
- func (r *ChunkedReader) NextDataPart() (string, []byte, error) {
- partName, partReader, err := r.NextPart()
- if err != nil {
- return "", nil, err
- }
-
- data, err := io.ReadAll(partReader)
- if err != nil {
- return "", nil, err
- }
-
- return partName, data, nil
- }
-
- type PartReader struct {
- creader *ChunkedReader
- data []byte
- }
-
- func (r *PartReader) Read(p []byte) (int, error) {
- if len(r.data) > 0 {
- n := copy(p, r.data)
- r.data = r.data[n:]
- return n, nil
- }
-
- chunk, err := r.creader.gr.Recv()
- if err == io.EOF {
- r.creader.err = io.ErrUnexpectedEOF
- return 0, io.ErrUnexpectedEOF
- }
- if err != nil {
- r.creader.err = fmt.Errorf("receive chunk: %w", err)
- return 0, r.creader.err
- }
-
- switch chunk.Type {
- case ChunkedDataType_NewPart:
- r.creader.chunk = chunk
- return 0, io.EOF
-
- case ChunkedDataType_Data:
- r.data = chunk.Data
- return r.Read(p)
-
- case ChunkedDataType_EOF:
- r.creader.err = io.EOF
- return 0, io.EOF
-
- case ChunkedDataType_Error:
- r.creader.err = &ChunkedAbortError{Message: string(chunk.Data)}
- return 0, r.creader.err
-
- default:
- r.creader.err = fmt.Errorf("unknown part type: %d", chunk.Type)
- return 0, r.creader.err
- }
- }
|