You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

chunked2.go 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. package http2
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "gitlink.org.cn/cloudream/common/utils/io2"
  7. "gitlink.org.cn/cloudream/common/utils/math2"
  8. )
  9. const (
  10. PartTypeError = 0xff
  11. PartTypeEOF = 0x00
  12. PartTypeNewPart = 0x01
  13. PartTypeData = 0x02
  14. )
  15. type ChunkedWriter struct {
  16. stream io.WriteCloser
  17. }
  18. func NewChunkedWriter(stream io.WriteCloser) *ChunkedWriter {
  19. return &ChunkedWriter{stream: stream}
  20. }
  21. // 开始写入一个新Part。每次只能有一个Part在写入。
  22. func (w *ChunkedWriter) BeginPart(name string) io.Writer {
  23. header := []byte{PartTypeNewPart, 0, 0}
  24. binary.LittleEndian.PutUint16(header[1:], uint16(len(name)))
  25. err := io2.WriteAll(w.stream, header)
  26. if err != nil {
  27. return io2.ErrorWriter(fmt.Errorf("write header: %w", err))
  28. }
  29. err = io2.WriteAll(w.stream, []byte(name))
  30. if err != nil {
  31. return io2.ErrorWriter(fmt.Errorf("write part name: %w", err))
  32. }
  33. return &PartWriter{stream: w.stream}
  34. }
  35. func (w *ChunkedWriter) WriteDataPart(name string, data []byte) error {
  36. pw := w.BeginPart(name)
  37. return io2.WriteAll(pw, data)
  38. }
  39. func (w *ChunkedWriter) WriteStreamPart(name string, stream io.Reader) (int64, error) {
  40. pw := w.BeginPart(name)
  41. n, err := io.Copy(pw, stream)
  42. return n, err
  43. }
  44. // 发送ErrorPart并关闭连接。无论是否返回错误,连接都会关闭
  45. func (w *ChunkedWriter) Abort(msg string) error {
  46. defer w.stream.Close()
  47. header := []byte{PartTypeError, 0, 0}
  48. binary.LittleEndian.PutUint16(header[1:], uint16(len(msg)))
  49. err := io2.WriteAll(w.stream, header)
  50. if err != nil {
  51. return fmt.Errorf("write header: %w", err)
  52. }
  53. err = io2.WriteAll(w.stream, []byte(msg))
  54. if err != nil {
  55. return fmt.Errorf("write error message: %w", err)
  56. }
  57. return nil
  58. }
  59. // 发送EOFPart并关闭连接。无论是否返回错误,连接都会关闭
  60. func (w *ChunkedWriter) Finish() error {
  61. defer w.stream.Close()
  62. header := []byte{PartTypeEOF, 0, 0}
  63. err := io2.WriteAll(w.stream, header)
  64. if err != nil {
  65. return fmt.Errorf("write header: %w", err)
  66. }
  67. return nil
  68. }
  69. // 直接关闭连接,不发送EOFPart也不发送ErrorPart。接收端会产生一个io.UnexpectedEOF错误
  70. func (w *ChunkedWriter) Close() {
  71. w.stream.Close()
  72. }
  73. type PartWriter struct {
  74. stream io.WriteCloser
  75. }
  76. func (w *PartWriter) Write(data []byte) (int, error) {
  77. sendLen := math2.Min(len(data), 0xffff)
  78. header := []byte{PartTypeData, 0, 0}
  79. binary.LittleEndian.PutUint16(header[1:], uint16(sendLen))
  80. err := io2.WriteAll(w.stream, header)
  81. if err != nil {
  82. return 0, fmt.Errorf("write header: %w", err)
  83. }
  84. err = io2.WriteAll(w.stream, data[:sendLen])
  85. if err != nil {
  86. return 0, fmt.Errorf("write data: %w", err)
  87. }
  88. return sendLen, nil
  89. }
  90. type ChunkedAbortError struct {
  91. Message string
  92. }
  93. func (e *ChunkedAbortError) Error() string {
  94. return e.Message
  95. }
  96. type ChunkedReader struct {
  97. stream io.ReadCloser
  98. partHeader []byte
  99. err error
  100. }
  101. func NewChunkedReader(stream io.ReadCloser) *ChunkedReader {
  102. return &ChunkedReader{stream: stream}
  103. }
  104. // 读取下一个Part。每次只能读取一个Part,且必须将其全部读取完毕才能读取下一个
  105. func (r *ChunkedReader) NextPart() (string, io.Reader, error) {
  106. if r.err != nil {
  107. return "", nil, r.err
  108. }
  109. if r.partHeader == nil {
  110. r.partHeader = make([]byte, 3)
  111. _, err := io.ReadFull(r.stream, r.partHeader)
  112. if err != nil {
  113. r.err = fmt.Errorf("read header: %w", err)
  114. return "", nil, r.err
  115. }
  116. }
  117. partType := r.partHeader[0]
  118. switch partType {
  119. case PartTypeNewPart:
  120. partNameLen := int(binary.LittleEndian.Uint16(r.partHeader[1:]))
  121. partName := make([]byte, partNameLen)
  122. _, err := io.ReadFull(r.stream, partName)
  123. if err != nil {
  124. r.err = fmt.Errorf("read part name: %w", err)
  125. return "", nil, r.err
  126. }
  127. return string(partName), &PartReader{creader: r}, nil
  128. case PartTypeData:
  129. r.err = fmt.Errorf("unexpected data part")
  130. return "", nil, r.err
  131. case PartTypeEOF:
  132. r.err = io.EOF
  133. return "", nil, r.err
  134. case PartTypeError:
  135. msgLen := int(binary.LittleEndian.Uint16(r.partHeader[1:]))
  136. msg := make([]byte, msgLen)
  137. _, err := io.ReadFull(r.stream, msg)
  138. if err != nil {
  139. r.err = fmt.Errorf("read error message: %w", err)
  140. return "", nil, r.err
  141. }
  142. r.err = &ChunkedAbortError{Message: string(msg)}
  143. return "", nil, r.err
  144. default:
  145. r.err = fmt.Errorf("unknown part type: %d", partType)
  146. return "", nil, r.err
  147. }
  148. }
  149. func (r *ChunkedReader) NextDataPart() (string, []byte, error) {
  150. partName, partReader, err := r.NextPart()
  151. if err != nil {
  152. return "", nil, err
  153. }
  154. data, err := io.ReadAll(partReader)
  155. if err != nil {
  156. return "", nil, err
  157. }
  158. return partName, data, nil
  159. }
  160. func (r *ChunkedReader) Close() {
  161. r.stream.Close()
  162. }
  163. type PartReader struct {
  164. creader *ChunkedReader
  165. partLen int
  166. partReadLen int
  167. }
  168. func (r *PartReader) Read(p []byte) (int, error) {
  169. // 允许有空的DataPart,因此用循环来跳过空的Part
  170. for r.partLen-r.partReadLen == 0 {
  171. header := make([]byte, 3)
  172. _, err := io.ReadFull(r.creader.stream, header)
  173. if err != nil {
  174. r.creader.err = err
  175. return 0, err
  176. }
  177. partType := header[0]
  178. switch partType {
  179. case PartTypeNewPart:
  180. r.creader.partHeader = header
  181. return 0, io.EOF
  182. case PartTypeData:
  183. r.partLen = int(binary.LittleEndian.Uint16(header[1:]))
  184. r.partReadLen = 0
  185. case PartTypeEOF:
  186. r.creader.err = io.EOF
  187. return 0, io.EOF
  188. case PartTypeError:
  189. msgLen := int(binary.LittleEndian.Uint16(header[1:]))
  190. msg := make([]byte, msgLen)
  191. _, err := io.ReadFull(r.creader.stream, msg)
  192. if err != nil {
  193. r.creader.err = fmt.Errorf("read error message: %w", err)
  194. return 0, fmt.Errorf("read error message: %w", err)
  195. }
  196. r.creader.err = &ChunkedAbortError{Message: string(msg)}
  197. return 0, r.creader.err
  198. }
  199. }
  200. readLen := math2.Min(len(p), r.partLen-r.partReadLen)
  201. n, err := r.creader.stream.Read(p[:readLen])
  202. if err == io.EOF {
  203. r.creader.err = io.ErrUnexpectedEOF
  204. return 0, io.ErrUnexpectedEOF
  205. }
  206. if err != nil {
  207. r.creader.err = err
  208. return 0, err
  209. }
  210. r.partReadLen += n
  211. return n, nil
  212. }