You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

decode.go 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. // Copyright 2018 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package proto
  5. import (
  6. "google.golang.org/protobuf/encoding/protowire"
  7. "google.golang.org/protobuf/internal/encoding/messageset"
  8. "google.golang.org/protobuf/internal/errors"
  9. "google.golang.org/protobuf/internal/flags"
  10. "google.golang.org/protobuf/internal/pragma"
  11. "google.golang.org/protobuf/reflect/protoreflect"
  12. "google.golang.org/protobuf/reflect/protoregistry"
  13. "google.golang.org/protobuf/runtime/protoiface"
  14. )
  15. // UnmarshalOptions configures the unmarshaler.
  16. //
  17. // Example usage:
  18. // err := UnmarshalOptions{DiscardUnknown: true}.Unmarshal(b, m)
  19. type UnmarshalOptions struct {
  20. pragma.NoUnkeyedLiterals
  21. // Merge merges the input into the destination message.
  22. // The default behavior is to always reset the message before unmarshaling,
  23. // unless Merge is specified.
  24. Merge bool
  25. // AllowPartial accepts input for messages that will result in missing
  26. // required fields. If AllowPartial is false (the default), Unmarshal will
  27. // return an error if there are any missing required fields.
  28. AllowPartial bool
  29. // If DiscardUnknown is set, unknown fields are ignored.
  30. DiscardUnknown bool
  31. // Resolver is used for looking up types when unmarshaling extension fields.
  32. // If nil, this defaults to using protoregistry.GlobalTypes.
  33. Resolver interface {
  34. FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
  35. FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error)
  36. }
  37. }
  38. // Unmarshal parses the wire-format message in b and places the result in m.
  39. func Unmarshal(b []byte, m Message) error {
  40. _, err := UnmarshalOptions{}.unmarshal(b, m.ProtoReflect())
  41. return err
  42. }
  43. // Unmarshal parses the wire-format message in b and places the result in m.
  44. func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
  45. _, err := o.unmarshal(b, m.ProtoReflect())
  46. return err
  47. }
  48. // UnmarshalState parses a wire-format message and places the result in m.
  49. //
  50. // This method permits fine-grained control over the unmarshaler.
  51. // Most users should use Unmarshal instead.
  52. func (o UnmarshalOptions) UnmarshalState(in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
  53. return o.unmarshal(in.Buf, in.Message)
  54. }
  55. func (o UnmarshalOptions) unmarshal(b []byte, m protoreflect.Message) (out protoiface.UnmarshalOutput, err error) {
  56. if o.Resolver == nil {
  57. o.Resolver = protoregistry.GlobalTypes
  58. }
  59. if !o.Merge {
  60. Reset(m.Interface()) // TODO
  61. }
  62. allowPartial := o.AllowPartial
  63. o.Merge = true
  64. o.AllowPartial = true
  65. methods := protoMethods(m)
  66. if methods != nil && methods.Unmarshal != nil &&
  67. !(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
  68. in := protoiface.UnmarshalInput{
  69. Message: m,
  70. Buf: b,
  71. Resolver: o.Resolver,
  72. }
  73. if o.DiscardUnknown {
  74. in.Flags |= protoiface.UnmarshalDiscardUnknown
  75. }
  76. out, err = methods.Unmarshal(in)
  77. } else {
  78. err = o.unmarshalMessageSlow(b, m)
  79. }
  80. if err != nil {
  81. return out, err
  82. }
  83. if allowPartial || (out.Flags&protoiface.UnmarshalInitialized != 0) {
  84. return out, nil
  85. }
  86. return out, checkInitialized(m)
  87. }
  88. func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
  89. _, err := o.unmarshal(b, m)
  90. return err
  91. }
  92. func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
  93. md := m.Descriptor()
  94. if messageset.IsMessageSet(md) {
  95. return unmarshalMessageSet(b, m, o)
  96. }
  97. fields := md.Fields()
  98. for len(b) > 0 {
  99. // Parse the tag (field number and wire type).
  100. num, wtyp, tagLen := protowire.ConsumeTag(b)
  101. if tagLen < 0 {
  102. return protowire.ParseError(tagLen)
  103. }
  104. if num > protowire.MaxValidNumber {
  105. return errors.New("invalid field number")
  106. }
  107. // Find the field descriptor for this field number.
  108. fd := fields.ByNumber(num)
  109. if fd == nil && md.ExtensionRanges().Has(num) {
  110. extType, err := o.Resolver.FindExtensionByNumber(md.FullName(), num)
  111. if err != nil && err != protoregistry.NotFound {
  112. return errors.New("%v: unable to resolve extension %v: %v", md.FullName(), num, err)
  113. }
  114. if extType != nil {
  115. fd = extType.TypeDescriptor()
  116. }
  117. }
  118. var err error
  119. if fd == nil {
  120. err = errUnknown
  121. } else if flags.ProtoLegacy {
  122. if fd.IsWeak() && fd.Message().IsPlaceholder() {
  123. err = errUnknown // weak referent is not linked in
  124. }
  125. }
  126. // Parse the field value.
  127. var valLen int
  128. switch {
  129. case err != nil:
  130. case fd.IsList():
  131. valLen, err = o.unmarshalList(b[tagLen:], wtyp, m.Mutable(fd).List(), fd)
  132. case fd.IsMap():
  133. valLen, err = o.unmarshalMap(b[tagLen:], wtyp, m.Mutable(fd).Map(), fd)
  134. default:
  135. valLen, err = o.unmarshalSingular(b[tagLen:], wtyp, m, fd)
  136. }
  137. if err != nil {
  138. if err != errUnknown {
  139. return err
  140. }
  141. valLen = protowire.ConsumeFieldValue(num, wtyp, b[tagLen:])
  142. if valLen < 0 {
  143. return protowire.ParseError(valLen)
  144. }
  145. if !o.DiscardUnknown {
  146. m.SetUnknown(append(m.GetUnknown(), b[:tagLen+valLen]...))
  147. }
  148. }
  149. b = b[tagLen+valLen:]
  150. }
  151. return nil
  152. }
  153. func (o UnmarshalOptions) unmarshalSingular(b []byte, wtyp protowire.Type, m protoreflect.Message, fd protoreflect.FieldDescriptor) (n int, err error) {
  154. v, n, err := o.unmarshalScalar(b, wtyp, fd)
  155. if err != nil {
  156. return 0, err
  157. }
  158. switch fd.Kind() {
  159. case protoreflect.GroupKind, protoreflect.MessageKind:
  160. m2 := m.Mutable(fd).Message()
  161. if err := o.unmarshalMessage(v.Bytes(), m2); err != nil {
  162. return n, err
  163. }
  164. default:
  165. // Non-message scalars replace the previous value.
  166. m.Set(fd, v)
  167. }
  168. return n, nil
  169. }
  170. func (o UnmarshalOptions) unmarshalMap(b []byte, wtyp protowire.Type, mapv protoreflect.Map, fd protoreflect.FieldDescriptor) (n int, err error) {
  171. if wtyp != protowire.BytesType {
  172. return 0, errUnknown
  173. }
  174. b, n = protowire.ConsumeBytes(b)
  175. if n < 0 {
  176. return 0, protowire.ParseError(n)
  177. }
  178. var (
  179. keyField = fd.MapKey()
  180. valField = fd.MapValue()
  181. key protoreflect.Value
  182. val protoreflect.Value
  183. haveKey bool
  184. haveVal bool
  185. )
  186. switch valField.Kind() {
  187. case protoreflect.GroupKind, protoreflect.MessageKind:
  188. val = mapv.NewValue()
  189. }
  190. // Map entries are represented as a two-element message with fields
  191. // containing the key and value.
  192. for len(b) > 0 {
  193. num, wtyp, n := protowire.ConsumeTag(b)
  194. if n < 0 {
  195. return 0, protowire.ParseError(n)
  196. }
  197. if num > protowire.MaxValidNumber {
  198. return 0, errors.New("invalid field number")
  199. }
  200. b = b[n:]
  201. err = errUnknown
  202. switch num {
  203. case 1:
  204. key, n, err = o.unmarshalScalar(b, wtyp, keyField)
  205. if err != nil {
  206. break
  207. }
  208. haveKey = true
  209. case 2:
  210. var v protoreflect.Value
  211. v, n, err = o.unmarshalScalar(b, wtyp, valField)
  212. if err != nil {
  213. break
  214. }
  215. switch valField.Kind() {
  216. case protoreflect.GroupKind, protoreflect.MessageKind:
  217. if err := o.unmarshalMessage(v.Bytes(), val.Message()); err != nil {
  218. return 0, err
  219. }
  220. default:
  221. val = v
  222. }
  223. haveVal = true
  224. }
  225. if err == errUnknown {
  226. n = protowire.ConsumeFieldValue(num, wtyp, b)
  227. if n < 0 {
  228. return 0, protowire.ParseError(n)
  229. }
  230. } else if err != nil {
  231. return 0, err
  232. }
  233. b = b[n:]
  234. }
  235. // Every map entry should have entries for key and value, but this is not strictly required.
  236. if !haveKey {
  237. key = keyField.Default()
  238. }
  239. if !haveVal {
  240. switch valField.Kind() {
  241. case protoreflect.GroupKind, protoreflect.MessageKind:
  242. default:
  243. val = valField.Default()
  244. }
  245. }
  246. mapv.Set(key.MapKey(), val)
  247. return n, nil
  248. }
  249. // errUnknown is used internally to indicate fields which should be added
  250. // to the unknown field set of a message. It is never returned from an exported
  251. // function.
  252. var errUnknown = errors.New("BUG: internal error (unknown)")