|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268 |
- package http2
-
- import (
- "encoding/binary"
- "fmt"
- "io"
-
- "gitlink.org.cn/cloudream/common/utils/io2"
- "gitlink.org.cn/cloudream/common/utils/math2"
- )
-
- const (
- PartTypeError = 0xff
- PartTypeEOF = 0x00
- PartTypeNewPart = 0x01
- PartTypeData = 0x02
- )
-
- type ChunkedWriter struct {
- stream io.WriteCloser
- }
-
- func NewChunkedWriter(stream io.WriteCloser) *ChunkedWriter {
- return &ChunkedWriter{stream: stream}
- }
-
- // 开始写入一个新Part。每次只能有一个Part在写入。
- func (w *ChunkedWriter) BeginPart(name string) io.Writer {
- header := []byte{PartTypeNewPart, 0, 0}
- binary.LittleEndian.PutUint16(header[1:], uint16(len(name)))
-
- err := io2.WriteAll(w.stream, header)
- if err != nil {
- return io2.ErrorWriter(fmt.Errorf("write header: %w", err))
- }
-
- err = io2.WriteAll(w.stream, []byte(name))
- if err != nil {
- return io2.ErrorWriter(fmt.Errorf("write part name: %w", err))
- }
-
- return &PartWriter{stream: w.stream}
- }
-
- func (w *ChunkedWriter) WriteDataPart(name string, data []byte) error {
- pw := w.BeginPart(name)
- return io2.WriteAll(pw, data)
- }
-
- func (w *ChunkedWriter) WriteStreamPart(name string, stream io.Reader) (int64, error) {
- pw := w.BeginPart(name)
- n, err := io.Copy(pw, stream)
- return n, err
- }
-
- // 发送ErrorPart并关闭连接。无论是否返回错误,连接都会关闭
- func (w *ChunkedWriter) Abort(msg string) error {
- defer w.stream.Close()
-
- header := []byte{PartTypeError, 0, 0}
- binary.LittleEndian.PutUint16(header[1:], uint16(len(msg)))
-
- err := io2.WriteAll(w.stream, header)
- if err != nil {
- return fmt.Errorf("write header: %w", err)
- }
-
- err = io2.WriteAll(w.stream, []byte(msg))
- if err != nil {
- return fmt.Errorf("write error message: %w", err)
- }
-
- return nil
- }
-
- // 发送EOFPart并关闭连接。无论是否返回错误,连接都会关闭
- func (w *ChunkedWriter) Finish() error {
- defer w.stream.Close()
-
- header := []byte{PartTypeEOF, 0, 0}
- err := io2.WriteAll(w.stream, header)
- if err != nil {
- return fmt.Errorf("write header: %w", err)
- }
-
- return nil
- }
-
- // 直接关闭连接,不发送EOFPart也不发送ErrorPart。接收端会产生一个io.UnexpectedEOF错误
- func (w *ChunkedWriter) Close() {
- w.stream.Close()
- }
-
- type PartWriter struct {
- stream io.WriteCloser
- }
-
- func (w *PartWriter) Write(data []byte) (int, error) {
- sendLen := math2.Min(len(data), 0xffff)
-
- header := []byte{PartTypeData, 0, 0}
- binary.LittleEndian.PutUint16(header[1:], uint16(sendLen))
-
- err := io2.WriteAll(w.stream, header)
- if err != nil {
- return 0, fmt.Errorf("write header: %w", err)
- }
-
- err = io2.WriteAll(w.stream, data[:sendLen])
- if err != nil {
- return 0, fmt.Errorf("write data: %w", err)
- }
-
- return sendLen, nil
- }
-
- type ChunkedAbortError struct {
- Message string
- }
-
- func (e *ChunkedAbortError) Error() string {
- return e.Message
- }
-
- type ChunkedReader struct {
- stream io.ReadCloser
- partHeader []byte
- err error
- }
-
- func NewChunkedReader(stream io.ReadCloser) *ChunkedReader {
- return &ChunkedReader{stream: stream}
- }
-
- // 读取下一个Part。每次只能读取一个Part,且必须将其全部读取完毕才能读取下一个。
- // 返回EOF代表没有更多Part了。
- func (r *ChunkedReader) NextPart() (string, io.Reader, error) {
- if r.err != nil {
- return "", nil, r.err
- }
-
- if r.partHeader == nil {
- r.partHeader = make([]byte, 3)
- _, err := io.ReadFull(r.stream, r.partHeader)
- if err != nil {
- r.err = fmt.Errorf("read header: %w", err)
- return "", nil, r.err
- }
- }
-
- partType := r.partHeader[0]
- switch partType {
- case PartTypeNewPart:
- partNameLen := int(binary.LittleEndian.Uint16(r.partHeader[1:]))
- partName := make([]byte, partNameLen)
-
- _, err := io.ReadFull(r.stream, partName)
- if err != nil {
- r.err = fmt.Errorf("read part name: %w", err)
- return "", nil, r.err
- }
-
- return string(partName), &PartReader{creader: r}, nil
-
- case PartTypeData:
- r.err = fmt.Errorf("unexpected data part")
- return "", nil, r.err
-
- case PartTypeEOF:
- r.err = io.EOF
- return "", nil, r.err
-
- case PartTypeError:
- msgLen := int(binary.LittleEndian.Uint16(r.partHeader[1:]))
- msg := make([]byte, msgLen)
-
- _, err := io.ReadFull(r.stream, msg)
- if err != nil {
- r.err = fmt.Errorf("read error message: %w", err)
- return "", nil, r.err
- }
-
- r.err = &ChunkedAbortError{Message: string(msg)}
- return "", nil, r.err
-
- default:
- r.err = fmt.Errorf("unknown part type: %d", partType)
- return "", nil, r.err
- }
- }
-
- func (r *ChunkedReader) NextDataPart() (string, []byte, error) {
- partName, partReader, err := r.NextPart()
- if err != nil {
- return "", nil, err
- }
-
- data, err := io.ReadAll(partReader)
- if err != nil {
- return "", nil, err
- }
-
- return partName, data, nil
- }
-
- func (r *ChunkedReader) Close() {
- r.stream.Close()
- }
-
- type PartReader struct {
- creader *ChunkedReader
- partLen int
- partReadLen int
- }
-
- func (r *PartReader) Read(p []byte) (int, error) {
- // 允许有空的DataPart,因此用循环来跳过空的Part
- for r.partLen-r.partReadLen == 0 {
- header := make([]byte, 3)
- _, err := io.ReadFull(r.creader.stream, header)
- if err != nil {
- r.creader.err = err
- return 0, err
- }
-
- partType := header[0]
- switch partType {
- case PartTypeNewPart:
- r.creader.partHeader = header
- return 0, io.EOF
-
- case PartTypeData:
- r.partLen = int(binary.LittleEndian.Uint16(header[1:]))
- r.partReadLen = 0
-
- case PartTypeEOF:
- r.creader.err = io.EOF
- return 0, io.EOF
-
- case PartTypeError:
- msgLen := int(binary.LittleEndian.Uint16(header[1:]))
- msg := make([]byte, msgLen)
-
- _, err := io.ReadFull(r.creader.stream, msg)
- if err != nil {
- r.creader.err = fmt.Errorf("read error message: %w", err)
- return 0, fmt.Errorf("read error message: %w", err)
- }
-
- r.creader.err = &ChunkedAbortError{Message: string(msg)}
- return 0, r.creader.err
- }
- }
-
- readLen := math2.Min(len(p), r.partLen-r.partReadLen)
- n, err := r.creader.stream.Read(p[:readLen])
- if err == io.EOF {
- r.creader.err = io.ErrUnexpectedEOF
- return 0, io.ErrUnexpectedEOF
- }
- if err != nil {
- r.creader.err = err
- return 0, err
- }
-
- r.partReadLen += n
- return n, nil
- }
|