decode.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665
  1. // Copyright 2019 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package protojson
  5. import (
  6. "encoding/base64"
  7. "fmt"
  8. "math"
  9. "strconv"
  10. "strings"
  11. "google.golang.org/protobuf/internal/encoding/json"
  12. "google.golang.org/protobuf/internal/encoding/messageset"
  13. "google.golang.org/protobuf/internal/errors"
  14. "google.golang.org/protobuf/internal/flags"
  15. "google.golang.org/protobuf/internal/genid"
  16. "google.golang.org/protobuf/internal/pragma"
  17. "google.golang.org/protobuf/internal/set"
  18. "google.golang.org/protobuf/proto"
  19. "google.golang.org/protobuf/reflect/protoreflect"
  20. "google.golang.org/protobuf/reflect/protoregistry"
  21. )
  22. // Unmarshal reads the given []byte into the given proto.Message.
  23. // The provided message must be mutable (e.g., a non-nil pointer to a message).
  24. func Unmarshal(b []byte, m proto.Message) error {
  25. return UnmarshalOptions{}.Unmarshal(b, m)
  26. }
  27. // UnmarshalOptions is a configurable JSON format parser.
  28. type UnmarshalOptions struct {
  29. pragma.NoUnkeyedLiterals
  30. // If AllowPartial is set, input for messages that will result in missing
  31. // required fields will not return an error.
  32. AllowPartial bool
  33. // If DiscardUnknown is set, unknown fields are ignored.
  34. DiscardUnknown bool
  35. // Resolver is used for looking up types when unmarshaling
  36. // google.protobuf.Any messages or extension fields.
  37. // If nil, this defaults to using protoregistry.GlobalTypes.
  38. Resolver interface {
  39. protoregistry.MessageTypeResolver
  40. protoregistry.ExtensionTypeResolver
  41. }
  42. }
  43. // Unmarshal reads the given []byte and populates the given proto.Message
  44. // using options in the UnmarshalOptions object.
  45. // It will clear the message first before setting the fields.
  46. // If it returns an error, the given message may be partially set.
  47. // The provided message must be mutable (e.g., a non-nil pointer to a message).
  48. func (o UnmarshalOptions) Unmarshal(b []byte, m proto.Message) error {
  49. return o.unmarshal(b, m)
  50. }
  51. // unmarshal is a centralized function that all unmarshal operations go through.
  52. // For profiling purposes, avoid changing the name of this function or
  53. // introducing other code paths for unmarshal that do not go through this.
  54. func (o UnmarshalOptions) unmarshal(b []byte, m proto.Message) error {
  55. proto.Reset(m)
  56. if o.Resolver == nil {
  57. o.Resolver = protoregistry.GlobalTypes
  58. }
  59. dec := decoder{json.NewDecoder(b), o}
  60. if err := dec.unmarshalMessage(m.ProtoReflect(), false); err != nil {
  61. return err
  62. }
  63. // Check for EOF.
  64. tok, err := dec.Read()
  65. if err != nil {
  66. return err
  67. }
  68. if tok.Kind() != json.EOF {
  69. return dec.unexpectedTokenError(tok)
  70. }
  71. if o.AllowPartial {
  72. return nil
  73. }
  74. return proto.CheckInitialized(m)
  75. }
  76. type decoder struct {
  77. *json.Decoder
  78. opts UnmarshalOptions
  79. }
  80. // newError returns an error object with position info.
  81. func (d decoder) newError(pos int, f string, x ...interface{}) error {
  82. line, column := d.Position(pos)
  83. head := fmt.Sprintf("(line %d:%d): ", line, column)
  84. return errors.New(head+f, x...)
  85. }
  86. // unexpectedTokenError returns a syntax error for the given unexpected token.
  87. func (d decoder) unexpectedTokenError(tok json.Token) error {
  88. return d.syntaxError(tok.Pos(), "unexpected token %s", tok.RawString())
  89. }
  90. // syntaxError returns a syntax error for given position.
  91. func (d decoder) syntaxError(pos int, f string, x ...interface{}) error {
  92. line, column := d.Position(pos)
  93. head := fmt.Sprintf("syntax error (line %d:%d): ", line, column)
  94. return errors.New(head+f, x...)
  95. }
  96. // unmarshalMessage unmarshals a message into the given protoreflect.Message.
  97. func (d decoder) unmarshalMessage(m protoreflect.Message, skipTypeURL bool) error {
  98. if unmarshal := wellKnownTypeUnmarshaler(m.Descriptor().FullName()); unmarshal != nil {
  99. return unmarshal(d, m)
  100. }
  101. tok, err := d.Read()
  102. if err != nil {
  103. return err
  104. }
  105. if tok.Kind() != json.ObjectOpen {
  106. return d.unexpectedTokenError(tok)
  107. }
  108. messageDesc := m.Descriptor()
  109. if !flags.ProtoLegacy && messageset.IsMessageSet(messageDesc) {
  110. return errors.New("no support for proto1 MessageSets")
  111. }
  112. var seenNums set.Ints
  113. var seenOneofs set.Ints
  114. fieldDescs := messageDesc.Fields()
  115. for {
  116. // Read field name.
  117. tok, err := d.Read()
  118. if err != nil {
  119. return err
  120. }
  121. switch tok.Kind() {
  122. default:
  123. return d.unexpectedTokenError(tok)
  124. case json.ObjectClose:
  125. return nil
  126. case json.Name:
  127. // Continue below.
  128. }
  129. name := tok.Name()
  130. // Unmarshaling a non-custom embedded message in Any will contain the
  131. // JSON field "@type" which should be skipped because it is not a field
  132. // of the embedded message, but simply an artifact of the Any format.
  133. if skipTypeURL && name == "@type" {
  134. d.Read()
  135. continue
  136. }
  137. // Get the FieldDescriptor.
  138. var fd protoreflect.FieldDescriptor
  139. if strings.HasPrefix(name, "[") && strings.HasSuffix(name, "]") {
  140. // Only extension names are in [name] format.
  141. extName := protoreflect.FullName(name[1 : len(name)-1])
  142. extType, err := d.opts.Resolver.FindExtensionByName(extName)
  143. if err != nil && err != protoregistry.NotFound {
  144. return d.newError(tok.Pos(), "unable to resolve %s: %v", tok.RawString(), err)
  145. }
  146. if extType != nil {
  147. fd = extType.TypeDescriptor()
  148. if !messageDesc.ExtensionRanges().Has(fd.Number()) || fd.ContainingMessage().FullName() != messageDesc.FullName() {
  149. return d.newError(tok.Pos(), "message %v cannot be extended by %v", messageDesc.FullName(), fd.FullName())
  150. }
  151. }
  152. } else {
  153. // The name can either be the JSON name or the proto field name.
  154. fd = fieldDescs.ByJSONName(name)
  155. if fd == nil {
  156. fd = fieldDescs.ByTextName(name)
  157. }
  158. }
  159. if flags.ProtoLegacy {
  160. if fd != nil && fd.IsWeak() && fd.Message().IsPlaceholder() {
  161. fd = nil // reset since the weak reference is not linked in
  162. }
  163. }
  164. if fd == nil {
  165. // Field is unknown.
  166. if d.opts.DiscardUnknown {
  167. if err := d.skipJSONValue(); err != nil {
  168. return err
  169. }
  170. continue
  171. }
  172. return d.newError(tok.Pos(), "unknown field %v", tok.RawString())
  173. }
  174. // Do not allow duplicate fields.
  175. num := uint64(fd.Number())
  176. if seenNums.Has(num) {
  177. return d.newError(tok.Pos(), "duplicate field %v", tok.RawString())
  178. }
  179. seenNums.Set(num)
  180. // No need to set values for JSON null unless the field type is
  181. // google.protobuf.Value or google.protobuf.NullValue.
  182. if tok, _ := d.Peek(); tok.Kind() == json.Null && !isKnownValue(fd) && !isNullValue(fd) {
  183. d.Read()
  184. continue
  185. }
  186. switch {
  187. case fd.IsList():
  188. list := m.Mutable(fd).List()
  189. if err := d.unmarshalList(list, fd); err != nil {
  190. return err
  191. }
  192. case fd.IsMap():
  193. mmap := m.Mutable(fd).Map()
  194. if err := d.unmarshalMap(mmap, fd); err != nil {
  195. return err
  196. }
  197. default:
  198. // If field is a oneof, check if it has already been set.
  199. if od := fd.ContainingOneof(); od != nil {
  200. idx := uint64(od.Index())
  201. if seenOneofs.Has(idx) {
  202. return d.newError(tok.Pos(), "error parsing %s, oneof %v is already set", tok.RawString(), od.FullName())
  203. }
  204. seenOneofs.Set(idx)
  205. }
  206. // Required or optional fields.
  207. if err := d.unmarshalSingular(m, fd); err != nil {
  208. return err
  209. }
  210. }
  211. }
  212. }
  213. func isKnownValue(fd protoreflect.FieldDescriptor) bool {
  214. md := fd.Message()
  215. return md != nil && md.FullName() == genid.Value_message_fullname
  216. }
  217. func isNullValue(fd protoreflect.FieldDescriptor) bool {
  218. ed := fd.Enum()
  219. return ed != nil && ed.FullName() == genid.NullValue_enum_fullname
  220. }
  221. // unmarshalSingular unmarshals to the non-repeated field specified
  222. // by the given FieldDescriptor.
  223. func (d decoder) unmarshalSingular(m protoreflect.Message, fd protoreflect.FieldDescriptor) error {
  224. var val protoreflect.Value
  225. var err error
  226. switch fd.Kind() {
  227. case protoreflect.MessageKind, protoreflect.GroupKind:
  228. val = m.NewField(fd)
  229. err = d.unmarshalMessage(val.Message(), false)
  230. default:
  231. val, err = d.unmarshalScalar(fd)
  232. }
  233. if err != nil {
  234. return err
  235. }
  236. m.Set(fd, val)
  237. return nil
  238. }
  239. // unmarshalScalar unmarshals to a scalar/enum protoreflect.Value specified by
  240. // the given FieldDescriptor.
  241. func (d decoder) unmarshalScalar(fd protoreflect.FieldDescriptor) (protoreflect.Value, error) {
  242. const b32 int = 32
  243. const b64 int = 64
  244. tok, err := d.Read()
  245. if err != nil {
  246. return protoreflect.Value{}, err
  247. }
  248. kind := fd.Kind()
  249. switch kind {
  250. case protoreflect.BoolKind:
  251. if tok.Kind() == json.Bool {
  252. return protoreflect.ValueOfBool(tok.Bool()), nil
  253. }
  254. case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
  255. if v, ok := unmarshalInt(tok, b32); ok {
  256. return v, nil
  257. }
  258. case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
  259. if v, ok := unmarshalInt(tok, b64); ok {
  260. return v, nil
  261. }
  262. case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
  263. if v, ok := unmarshalUint(tok, b32); ok {
  264. return v, nil
  265. }
  266. case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
  267. if v, ok := unmarshalUint(tok, b64); ok {
  268. return v, nil
  269. }
  270. case protoreflect.FloatKind:
  271. if v, ok := unmarshalFloat(tok, b32); ok {
  272. return v, nil
  273. }
  274. case protoreflect.DoubleKind:
  275. if v, ok := unmarshalFloat(tok, b64); ok {
  276. return v, nil
  277. }
  278. case protoreflect.StringKind:
  279. if tok.Kind() == json.String {
  280. return protoreflect.ValueOfString(tok.ParsedString()), nil
  281. }
  282. case protoreflect.BytesKind:
  283. if v, ok := unmarshalBytes(tok); ok {
  284. return v, nil
  285. }
  286. case protoreflect.EnumKind:
  287. if v, ok := unmarshalEnum(tok, fd); ok {
  288. return v, nil
  289. }
  290. default:
  291. panic(fmt.Sprintf("unmarshalScalar: invalid scalar kind %v", kind))
  292. }
  293. return protoreflect.Value{}, d.newError(tok.Pos(), "invalid value for %v type: %v", kind, tok.RawString())
  294. }
  295. func unmarshalInt(tok json.Token, bitSize int) (protoreflect.Value, bool) {
  296. switch tok.Kind() {
  297. case json.Number:
  298. return getInt(tok, bitSize)
  299. case json.String:
  300. // Decode number from string.
  301. s := strings.TrimSpace(tok.ParsedString())
  302. if len(s) != len(tok.ParsedString()) {
  303. return protoreflect.Value{}, false
  304. }
  305. dec := json.NewDecoder([]byte(s))
  306. tok, err := dec.Read()
  307. if err != nil {
  308. return protoreflect.Value{}, false
  309. }
  310. return getInt(tok, bitSize)
  311. }
  312. return protoreflect.Value{}, false
  313. }
  314. func getInt(tok json.Token, bitSize int) (protoreflect.Value, bool) {
  315. n, ok := tok.Int(bitSize)
  316. if !ok {
  317. return protoreflect.Value{}, false
  318. }
  319. if bitSize == 32 {
  320. return protoreflect.ValueOfInt32(int32(n)), true
  321. }
  322. return protoreflect.ValueOfInt64(n), true
  323. }
  324. func unmarshalUint(tok json.Token, bitSize int) (protoreflect.Value, bool) {
  325. switch tok.Kind() {
  326. case json.Number:
  327. return getUint(tok, bitSize)
  328. case json.String:
  329. // Decode number from string.
  330. s := strings.TrimSpace(tok.ParsedString())
  331. if len(s) != len(tok.ParsedString()) {
  332. return protoreflect.Value{}, false
  333. }
  334. dec := json.NewDecoder([]byte(s))
  335. tok, err := dec.Read()
  336. if err != nil {
  337. return protoreflect.Value{}, false
  338. }
  339. return getUint(tok, bitSize)
  340. }
  341. return protoreflect.Value{}, false
  342. }
  343. func getUint(tok json.Token, bitSize int) (protoreflect.Value, bool) {
  344. n, ok := tok.Uint(bitSize)
  345. if !ok {
  346. return protoreflect.Value{}, false
  347. }
  348. if bitSize == 32 {
  349. return protoreflect.ValueOfUint32(uint32(n)), true
  350. }
  351. return protoreflect.ValueOfUint64(n), true
  352. }
  353. func unmarshalFloat(tok json.Token, bitSize int) (protoreflect.Value, bool) {
  354. switch tok.Kind() {
  355. case json.Number:
  356. return getFloat(tok, bitSize)
  357. case json.String:
  358. s := tok.ParsedString()
  359. switch s {
  360. case "NaN":
  361. if bitSize == 32 {
  362. return protoreflect.ValueOfFloat32(float32(math.NaN())), true
  363. }
  364. return protoreflect.ValueOfFloat64(math.NaN()), true
  365. case "Infinity":
  366. if bitSize == 32 {
  367. return protoreflect.ValueOfFloat32(float32(math.Inf(+1))), true
  368. }
  369. return protoreflect.ValueOfFloat64(math.Inf(+1)), true
  370. case "-Infinity":
  371. if bitSize == 32 {
  372. return protoreflect.ValueOfFloat32(float32(math.Inf(-1))), true
  373. }
  374. return protoreflect.ValueOfFloat64(math.Inf(-1)), true
  375. }
  376. // Decode number from string.
  377. if len(s) != len(strings.TrimSpace(s)) {
  378. return protoreflect.Value{}, false
  379. }
  380. dec := json.NewDecoder([]byte(s))
  381. tok, err := dec.Read()
  382. if err != nil {
  383. return protoreflect.Value{}, false
  384. }
  385. return getFloat(tok, bitSize)
  386. }
  387. return protoreflect.Value{}, false
  388. }
  389. func getFloat(tok json.Token, bitSize int) (protoreflect.Value, bool) {
  390. n, ok := tok.Float(bitSize)
  391. if !ok {
  392. return protoreflect.Value{}, false
  393. }
  394. if bitSize == 32 {
  395. return protoreflect.ValueOfFloat32(float32(n)), true
  396. }
  397. return protoreflect.ValueOfFloat64(n), true
  398. }
  399. func unmarshalBytes(tok json.Token) (protoreflect.Value, bool) {
  400. if tok.Kind() != json.String {
  401. return protoreflect.Value{}, false
  402. }
  403. s := tok.ParsedString()
  404. enc := base64.StdEncoding
  405. if strings.ContainsAny(s, "-_") {
  406. enc = base64.URLEncoding
  407. }
  408. if len(s)%4 != 0 {
  409. enc = enc.WithPadding(base64.NoPadding)
  410. }
  411. b, err := enc.DecodeString(s)
  412. if err != nil {
  413. return protoreflect.Value{}, false
  414. }
  415. return protoreflect.ValueOfBytes(b), true
  416. }
  417. func unmarshalEnum(tok json.Token, fd protoreflect.FieldDescriptor) (protoreflect.Value, bool) {
  418. switch tok.Kind() {
  419. case json.String:
  420. // Lookup EnumNumber based on name.
  421. s := tok.ParsedString()
  422. if enumVal := fd.Enum().Values().ByName(protoreflect.Name(s)); enumVal != nil {
  423. return protoreflect.ValueOfEnum(enumVal.Number()), true
  424. }
  425. case json.Number:
  426. if n, ok := tok.Int(32); ok {
  427. return protoreflect.ValueOfEnum(protoreflect.EnumNumber(n)), true
  428. }
  429. case json.Null:
  430. // This is only valid for google.protobuf.NullValue.
  431. if isNullValue(fd) {
  432. return protoreflect.ValueOfEnum(0), true
  433. }
  434. }
  435. return protoreflect.Value{}, false
  436. }
  437. func (d decoder) unmarshalList(list protoreflect.List, fd protoreflect.FieldDescriptor) error {
  438. tok, err := d.Read()
  439. if err != nil {
  440. return err
  441. }
  442. if tok.Kind() != json.ArrayOpen {
  443. return d.unexpectedTokenError(tok)
  444. }
  445. switch fd.Kind() {
  446. case protoreflect.MessageKind, protoreflect.GroupKind:
  447. for {
  448. tok, err := d.Peek()
  449. if err != nil {
  450. return err
  451. }
  452. if tok.Kind() == json.ArrayClose {
  453. d.Read()
  454. return nil
  455. }
  456. val := list.NewElement()
  457. if err := d.unmarshalMessage(val.Message(), false); err != nil {
  458. return err
  459. }
  460. list.Append(val)
  461. }
  462. default:
  463. for {
  464. tok, err := d.Peek()
  465. if err != nil {
  466. return err
  467. }
  468. if tok.Kind() == json.ArrayClose {
  469. d.Read()
  470. return nil
  471. }
  472. val, err := d.unmarshalScalar(fd)
  473. if err != nil {
  474. return err
  475. }
  476. list.Append(val)
  477. }
  478. }
  479. return nil
  480. }
  481. func (d decoder) unmarshalMap(mmap protoreflect.Map, fd protoreflect.FieldDescriptor) error {
  482. tok, err := d.Read()
  483. if err != nil {
  484. return err
  485. }
  486. if tok.Kind() != json.ObjectOpen {
  487. return d.unexpectedTokenError(tok)
  488. }
  489. // Determine ahead whether map entry is a scalar type or a message type in
  490. // order to call the appropriate unmarshalMapValue func inside the for loop
  491. // below.
  492. var unmarshalMapValue func() (protoreflect.Value, error)
  493. switch fd.MapValue().Kind() {
  494. case protoreflect.MessageKind, protoreflect.GroupKind:
  495. unmarshalMapValue = func() (protoreflect.Value, error) {
  496. val := mmap.NewValue()
  497. if err := d.unmarshalMessage(val.Message(), false); err != nil {
  498. return protoreflect.Value{}, err
  499. }
  500. return val, nil
  501. }
  502. default:
  503. unmarshalMapValue = func() (protoreflect.Value, error) {
  504. return d.unmarshalScalar(fd.MapValue())
  505. }
  506. }
  507. Loop:
  508. for {
  509. // Read field name.
  510. tok, err := d.Read()
  511. if err != nil {
  512. return err
  513. }
  514. switch tok.Kind() {
  515. default:
  516. return d.unexpectedTokenError(tok)
  517. case json.ObjectClose:
  518. break Loop
  519. case json.Name:
  520. // Continue.
  521. }
  522. // Unmarshal field name.
  523. pkey, err := d.unmarshalMapKey(tok, fd.MapKey())
  524. if err != nil {
  525. return err
  526. }
  527. // Check for duplicate field name.
  528. if mmap.Has(pkey) {
  529. return d.newError(tok.Pos(), "duplicate map key %v", tok.RawString())
  530. }
  531. // Read and unmarshal field value.
  532. pval, err := unmarshalMapValue()
  533. if err != nil {
  534. return err
  535. }
  536. mmap.Set(pkey, pval)
  537. }
  538. return nil
  539. }
  540. // unmarshalMapKey converts given token of Name kind into a protoreflect.MapKey.
  541. // A map key type is any integral or string type.
  542. func (d decoder) unmarshalMapKey(tok json.Token, fd protoreflect.FieldDescriptor) (protoreflect.MapKey, error) {
  543. const b32 = 32
  544. const b64 = 64
  545. const base10 = 10
  546. name := tok.Name()
  547. kind := fd.Kind()
  548. switch kind {
  549. case protoreflect.StringKind:
  550. return protoreflect.ValueOfString(name).MapKey(), nil
  551. case protoreflect.BoolKind:
  552. switch name {
  553. case "true":
  554. return protoreflect.ValueOfBool(true).MapKey(), nil
  555. case "false":
  556. return protoreflect.ValueOfBool(false).MapKey(), nil
  557. }
  558. case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
  559. if n, err := strconv.ParseInt(name, base10, b32); err == nil {
  560. return protoreflect.ValueOfInt32(int32(n)).MapKey(), nil
  561. }
  562. case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
  563. if n, err := strconv.ParseInt(name, base10, b64); err == nil {
  564. return protoreflect.ValueOfInt64(int64(n)).MapKey(), nil
  565. }
  566. case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
  567. if n, err := strconv.ParseUint(name, base10, b32); err == nil {
  568. return protoreflect.ValueOfUint32(uint32(n)).MapKey(), nil
  569. }
  570. case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
  571. if n, err := strconv.ParseUint(name, base10, b64); err == nil {
  572. return protoreflect.ValueOfUint64(uint64(n)).MapKey(), nil
  573. }
  574. default:
  575. panic(fmt.Sprintf("invalid kind for map key: %v", kind))
  576. }
  577. return protoreflect.MapKey{}, d.newError(tok.Pos(), "invalid value for %v key: %s", kind, tok.RawString())
  578. }