You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

chunked2.go 6.1 kB

2 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. package http2
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "gitlink.org.cn/cloudream/common/utils/io2"
  7. "gitlink.org.cn/cloudream/common/utils/math2"
  8. )
  9. const (
  10. PartTypeError = 0xff
  11. PartTypeEOF = 0x00
  12. PartTypeNewPart = 0x01
  13. PartTypeData = 0x02
  14. )
  15. type ChunkedWriter struct {
  16. stream io.WriteCloser
  17. }
  18. func NewChunkedWriter(stream io.WriteCloser) *ChunkedWriter {
  19. return &ChunkedWriter{stream: stream}
  20. }
  21. // 开始写入一个新Part。每次只能有一个Part在写入。
  22. func (w *ChunkedWriter) BeginPart(name string) io.Writer {
  23. header := []byte{PartTypeNewPart, 0, 0}
  24. binary.LittleEndian.PutUint16(header[1:], uint16(len(name)))
  25. err := io2.WriteAll(w.stream, header)
  26. if err != nil {
  27. return io2.ErrorWriter(fmt.Errorf("write header: %w", err))
  28. }
  29. err = io2.WriteAll(w.stream, []byte(name))
  30. if err != nil {
  31. return io2.ErrorWriter(fmt.Errorf("write part name: %w", err))
  32. }
  33. return &PartWriter{stream: w.stream}
  34. }
  35. func (w *ChunkedWriter) WriteDataPart(name string, data []byte) error {
  36. pw := w.BeginPart(name)
  37. return io2.WriteAll(pw, data)
  38. }
  39. func (w *ChunkedWriter) WriteStreamPart(name string, stream io.Reader) (int64, error) {
  40. pw := w.BeginPart(name)
  41. n, err := io.Copy(pw, stream)
  42. return n, err
  43. }
  44. // 发送ErrorPart并关闭连接。无论是否返回错误,连接都会关闭
  45. func (w *ChunkedWriter) Abort(msg string) error {
  46. defer w.stream.Close()
  47. header := []byte{PartTypeError, 0, 0}
  48. binary.LittleEndian.PutUint16(header[1:], uint16(len(msg)))
  49. err := io2.WriteAll(w.stream, header)
  50. if err != nil {
  51. return fmt.Errorf("write header: %w", err)
  52. }
  53. err = io2.WriteAll(w.stream, []byte(msg))
  54. if err != nil {
  55. return fmt.Errorf("write error message: %w", err)
  56. }
  57. return nil
  58. }
  59. // 发送EOFPart并关闭连接。无论是否返回错误,连接都会关闭
  60. func (w *ChunkedWriter) Finish() error {
  61. defer w.stream.Close()
  62. header := []byte{PartTypeEOF, 0, 0}
  63. err := io2.WriteAll(w.stream, header)
  64. if err != nil {
  65. return fmt.Errorf("write header: %w", err)
  66. }
  67. return nil
  68. }
  69. // 直接关闭连接,不发送EOFPart也不发送ErrorPart。接收端会产生一个io.UnexpectedEOF错误
  70. func (w *ChunkedWriter) Close() {
  71. w.stream.Close()
  72. }
  73. type PartWriter struct {
  74. stream io.WriteCloser
  75. }
  76. func (w *PartWriter) Write(data []byte) (int, error) {
  77. sendLen := math2.Min(len(data), 0xffff)
  78. header := []byte{PartTypeData, 0, 0}
  79. binary.LittleEndian.PutUint16(header[1:], uint16(sendLen))
  80. err := io2.WriteAll(w.stream, header)
  81. if err != nil {
  82. return 0, fmt.Errorf("write header: %w", err)
  83. }
  84. err = io2.WriteAll(w.stream, data[:sendLen])
  85. if err != nil {
  86. return 0, fmt.Errorf("write data: %w", err)
  87. }
  88. return sendLen, nil
  89. }
  90. type ChunkedAbortError struct {
  91. Message string
  92. }
  93. func (e *ChunkedAbortError) Error() string {
  94. return e.Message
  95. }
  96. type ChunkedReader struct {
  97. stream io.ReadCloser
  98. partHeader []byte
  99. err error
  100. }
  101. func NewChunkedReader(stream io.ReadCloser) *ChunkedReader {
  102. return &ChunkedReader{stream: stream}
  103. }
  104. // 读取下一个Part。每次只能读取一个Part,且必须将其全部读取完毕才能读取下一个。
  105. // 返回EOF代表没有更多Part了。
  106. func (r *ChunkedReader) NextPart() (string, io.Reader, error) {
  107. if r.err != nil {
  108. return "", nil, r.err
  109. }
  110. if r.partHeader == nil {
  111. r.partHeader = make([]byte, 3)
  112. _, err := io.ReadFull(r.stream, r.partHeader)
  113. if err != nil {
  114. r.err = fmt.Errorf("read header: %w", err)
  115. return "", nil, r.err
  116. }
  117. }
  118. partType := r.partHeader[0]
  119. switch partType {
  120. case PartTypeNewPart:
  121. partNameLen := int(binary.LittleEndian.Uint16(r.partHeader[1:]))
  122. partName := make([]byte, partNameLen)
  123. _, err := io.ReadFull(r.stream, partName)
  124. if err != nil {
  125. r.err = fmt.Errorf("read part name: %w", err)
  126. return "", nil, r.err
  127. }
  128. return string(partName), &PartReader{creader: r}, nil
  129. case PartTypeData:
  130. r.err = fmt.Errorf("unexpected data part")
  131. return "", nil, r.err
  132. case PartTypeEOF:
  133. r.err = io.EOF
  134. return "", nil, r.err
  135. case PartTypeError:
  136. msgLen := int(binary.LittleEndian.Uint16(r.partHeader[1:]))
  137. msg := make([]byte, msgLen)
  138. _, err := io.ReadFull(r.stream, msg)
  139. if err != nil {
  140. r.err = fmt.Errorf("read error message: %w", err)
  141. return "", nil, r.err
  142. }
  143. r.err = &ChunkedAbortError{Message: string(msg)}
  144. return "", nil, r.err
  145. default:
  146. r.err = fmt.Errorf("unknown part type: %d", partType)
  147. return "", nil, r.err
  148. }
  149. }
  150. func (r *ChunkedReader) NextDataPart() (string, []byte, error) {
  151. partName, partReader, err := r.NextPart()
  152. if err != nil {
  153. return "", nil, err
  154. }
  155. data, err := io.ReadAll(partReader)
  156. if err != nil {
  157. return "", nil, err
  158. }
  159. return partName, data, nil
  160. }
  161. func (r *ChunkedReader) Close() {
  162. r.stream.Close()
  163. }
  164. type PartReader struct {
  165. creader *ChunkedReader
  166. partLen int
  167. partReadLen int
  168. }
  169. func (r *PartReader) Read(p []byte) (int, error) {
  170. // 允许有空的DataPart,因此用循环来跳过空的Part
  171. for r.partLen-r.partReadLen == 0 {
  172. header := make([]byte, 3)
  173. _, err := io.ReadFull(r.creader.stream, header)
  174. if err != nil {
  175. r.creader.err = err
  176. return 0, err
  177. }
  178. partType := header[0]
  179. switch partType {
  180. case PartTypeNewPart:
  181. r.creader.partHeader = header
  182. return 0, io.EOF
  183. case PartTypeData:
  184. r.partLen = int(binary.LittleEndian.Uint16(header[1:]))
  185. r.partReadLen = 0
  186. case PartTypeEOF:
  187. r.creader.err = io.EOF
  188. return 0, io.EOF
  189. case PartTypeError:
  190. msgLen := int(binary.LittleEndian.Uint16(header[1:]))
  191. msg := make([]byte, msgLen)
  192. _, err := io.ReadFull(r.creader.stream, msg)
  193. if err != nil {
  194. r.creader.err = fmt.Errorf("read error message: %w", err)
  195. return 0, fmt.Errorf("read error message: %w", err)
  196. }
  197. r.creader.err = &ChunkedAbortError{Message: string(msg)}
  198. return 0, r.creader.err
  199. }
  200. }
  201. readLen := math2.Min(len(p), r.partLen-r.partReadLen)
  202. n, err := r.creader.stream.Read(p[:readLen])
  203. if err == io.EOF {
  204. r.creader.err = io.ErrUnexpectedEOF
  205. return 0, io.ErrUnexpectedEOF
  206. }
  207. if err != nil {
  208. r.creader.err = err
  209. return 0, err
  210. }
  211. r.partReadLen += n
  212. return n, nil
  213. }