fieldmask.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. package runtime
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sort"
  8. "google.golang.org/protobuf/proto"
  9. "google.golang.org/protobuf/reflect/protoreflect"
  10. field_mask "google.golang.org/protobuf/types/known/fieldmaskpb"
  11. )
  12. func getFieldByName(fields protoreflect.FieldDescriptors, name string) protoreflect.FieldDescriptor {
  13. fd := fields.ByName(protoreflect.Name(name))
  14. if fd != nil {
  15. return fd
  16. }
  17. return fields.ByJSONName(name)
  18. }
  19. // FieldMaskFromRequestBody creates a FieldMask printing all complete paths from the JSON body.
  20. func FieldMaskFromRequestBody(r io.Reader, msg proto.Message) (*field_mask.FieldMask, error) {
  21. fm := &field_mask.FieldMask{}
  22. var root interface{}
  23. if err := json.NewDecoder(r).Decode(&root); err != nil {
  24. if errors.Is(err, io.EOF) {
  25. return fm, nil
  26. }
  27. return nil, err
  28. }
  29. queue := []fieldMaskPathItem{{node: root, msg: msg.ProtoReflect()}}
  30. for len(queue) > 0 {
  31. // dequeue an item
  32. item := queue[0]
  33. queue = queue[1:]
  34. m, ok := item.node.(map[string]interface{})
  35. switch {
  36. case ok:
  37. // if the item is an object, then enqueue all of its children
  38. for k, v := range m {
  39. if item.msg == nil {
  40. return nil, errors.New("JSON structure did not match request type")
  41. }
  42. fd := getFieldByName(item.msg.Descriptor().Fields(), k)
  43. if fd == nil {
  44. return nil, fmt.Errorf("could not find field %q in %q", k, item.msg.Descriptor().FullName())
  45. }
  46. if isDynamicProtoMessage(fd.Message()) {
  47. for _, p := range buildPathsBlindly(string(fd.FullName().Name()), v) {
  48. newPath := p
  49. if item.path != "" {
  50. newPath = item.path + "." + newPath
  51. }
  52. queue = append(queue, fieldMaskPathItem{path: newPath})
  53. }
  54. continue
  55. }
  56. if isProtobufAnyMessage(fd.Message()) && !fd.IsList() {
  57. _, hasTypeField := v.(map[string]interface{})["@type"]
  58. if hasTypeField {
  59. queue = append(queue, fieldMaskPathItem{path: k})
  60. continue
  61. } else {
  62. return nil, fmt.Errorf("could not find field @type in %q in message %q", k, item.msg.Descriptor().FullName())
  63. }
  64. }
  65. child := fieldMaskPathItem{
  66. node: v,
  67. }
  68. if item.path == "" {
  69. child.path = string(fd.FullName().Name())
  70. } else {
  71. child.path = item.path + "." + string(fd.FullName().Name())
  72. }
  73. switch {
  74. case fd.IsList(), fd.IsMap():
  75. // As per: https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/field_mask.proto#L85-L86
  76. // Do not recurse into repeated fields. The repeated field goes on the end of the path and we stop.
  77. fm.Paths = append(fm.Paths, child.path)
  78. case fd.Message() != nil:
  79. child.msg = item.msg.Get(fd).Message()
  80. fallthrough
  81. default:
  82. queue = append(queue, child)
  83. }
  84. }
  85. case len(item.path) > 0:
  86. // otherwise, it's a leaf node so print its path
  87. fm.Paths = append(fm.Paths, item.path)
  88. }
  89. }
  90. // Sort for deterministic output in the presence
  91. // of repeated fields.
  92. sort.Strings(fm.Paths)
  93. return fm, nil
  94. }
  95. func isProtobufAnyMessage(md protoreflect.MessageDescriptor) bool {
  96. return md != nil && (md.FullName() == "google.protobuf.Any")
  97. }
  98. func isDynamicProtoMessage(md protoreflect.MessageDescriptor) bool {
  99. return md != nil && (md.FullName() == "google.protobuf.Struct" || md.FullName() == "google.protobuf.Value")
  100. }
  101. // buildPathsBlindly does not attempt to match proto field names to the
  102. // json value keys. Instead it relies completely on the structure of
  103. // the unmarshalled json contained within in.
  104. // Returns a slice containing all subpaths with the root at the
  105. // passed in name and json value.
  106. func buildPathsBlindly(name string, in interface{}) []string {
  107. m, ok := in.(map[string]interface{})
  108. if !ok {
  109. return []string{name}
  110. }
  111. var paths []string
  112. queue := []fieldMaskPathItem{{path: name, node: m}}
  113. for len(queue) > 0 {
  114. cur := queue[0]
  115. queue = queue[1:]
  116. m, ok := cur.node.(map[string]interface{})
  117. if !ok {
  118. // This should never happen since we should always check that we only add
  119. // nodes of type map[string]interface{} to the queue.
  120. continue
  121. }
  122. for k, v := range m {
  123. if mi, ok := v.(map[string]interface{}); ok {
  124. queue = append(queue, fieldMaskPathItem{path: cur.path + "." + k, node: mi})
  125. } else {
  126. // This is not a struct, so there are no more levels to descend.
  127. curPath := cur.path + "." + k
  128. paths = append(paths, curPath)
  129. }
  130. }
  131. }
  132. return paths
  133. }
  134. // fieldMaskPathItem stores a in-progress deconstruction of a path for a fieldmask
  135. type fieldMaskPathItem struct {
  136. // the list of prior fields leading up to node connected by dots
  137. path string
  138. // a generic decoded json object the current item to inspect for further path extraction
  139. node interface{}
  140. // parent message
  141. msg protoreflect.Message
  142. }