query.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. package runtime
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/url"
  6. "regexp"
  7. "strconv"
  8. "strings"
  9. "time"
  10. "github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
  11. "google.golang.org/grpc/grpclog"
  12. "google.golang.org/protobuf/encoding/protojson"
  13. "google.golang.org/protobuf/proto"
  14. "google.golang.org/protobuf/reflect/protoreflect"
  15. "google.golang.org/protobuf/reflect/protoregistry"
  16. "google.golang.org/protobuf/types/known/durationpb"
  17. field_mask "google.golang.org/protobuf/types/known/fieldmaskpb"
  18. "google.golang.org/protobuf/types/known/structpb"
  19. "google.golang.org/protobuf/types/known/timestamppb"
  20. "google.golang.org/protobuf/types/known/wrapperspb"
  21. )
  22. var valuesKeyRegexp = regexp.MustCompile(`^(.*)\[(.*)\]$`)
  23. var currentQueryParser QueryParameterParser = &DefaultQueryParser{}
  24. // QueryParameterParser defines interface for all query parameter parsers
  25. type QueryParameterParser interface {
  26. Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error
  27. }
  28. // PopulateQueryParameters parses query parameters
  29. // into "msg" using current query parser
  30. func PopulateQueryParameters(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
  31. return currentQueryParser.Parse(msg, values, filter)
  32. }
  33. // DefaultQueryParser is a QueryParameterParser which implements the default
  34. // query parameters parsing behavior.
  35. //
  36. // See https://github.com/grpc-ecosystem/grpc-gateway/issues/2632 for more context.
  37. type DefaultQueryParser struct{}
  38. // Parse populates "values" into "msg".
  39. // A value is ignored if its key starts with one of the elements in "filter".
  40. func (*DefaultQueryParser) Parse(msg proto.Message, values url.Values, filter *utilities.DoubleArray) error {
  41. for key, values := range values {
  42. if match := valuesKeyRegexp.FindStringSubmatch(key); len(match) == 3 {
  43. key = match[1]
  44. values = append([]string{match[2]}, values...)
  45. }
  46. fieldPath := strings.Split(key, ".")
  47. if filter.HasCommonPrefix(fieldPath) {
  48. continue
  49. }
  50. if err := populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, values); err != nil {
  51. return err
  52. }
  53. }
  54. return nil
  55. }
  56. // PopulateFieldFromPath sets a value in a nested Protobuf structure.
  57. func PopulateFieldFromPath(msg proto.Message, fieldPathString string, value string) error {
  58. fieldPath := strings.Split(fieldPathString, ".")
  59. return populateFieldValueFromPath(msg.ProtoReflect(), fieldPath, []string{value})
  60. }
  61. func populateFieldValueFromPath(msgValue protoreflect.Message, fieldPath []string, values []string) error {
  62. if len(fieldPath) < 1 {
  63. return errors.New("no field path")
  64. }
  65. if len(values) < 1 {
  66. return errors.New("no value provided")
  67. }
  68. var fieldDescriptor protoreflect.FieldDescriptor
  69. for i, fieldName := range fieldPath {
  70. fields := msgValue.Descriptor().Fields()
  71. // Get field by name
  72. fieldDescriptor = fields.ByName(protoreflect.Name(fieldName))
  73. if fieldDescriptor == nil {
  74. fieldDescriptor = fields.ByJSONName(fieldName)
  75. if fieldDescriptor == nil {
  76. // We're not returning an error here because this could just be
  77. // an extra query parameter that isn't part of the request.
  78. grpclog.Infof("field not found in %q: %q", msgValue.Descriptor().FullName(), strings.Join(fieldPath, "."))
  79. return nil
  80. }
  81. }
  82. // If this is the last element, we're done
  83. if i == len(fieldPath)-1 {
  84. break
  85. }
  86. // Only singular message fields are allowed
  87. if fieldDescriptor.Message() == nil || fieldDescriptor.Cardinality() == protoreflect.Repeated {
  88. return fmt.Errorf("invalid path: %q is not a message", fieldName)
  89. }
  90. // Get the nested message
  91. msgValue = msgValue.Mutable(fieldDescriptor).Message()
  92. }
  93. // Check if oneof already set
  94. if of := fieldDescriptor.ContainingOneof(); of != nil {
  95. if f := msgValue.WhichOneof(of); f != nil {
  96. return fmt.Errorf("field already set for oneof %q", of.FullName().Name())
  97. }
  98. }
  99. switch {
  100. case fieldDescriptor.IsList():
  101. return populateRepeatedField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).List(), values)
  102. case fieldDescriptor.IsMap():
  103. return populateMapField(fieldDescriptor, msgValue.Mutable(fieldDescriptor).Map(), values)
  104. }
  105. if len(values) > 1 {
  106. return fmt.Errorf("too many values for field %q: %s", fieldDescriptor.FullName().Name(), strings.Join(values, ", "))
  107. }
  108. return populateField(fieldDescriptor, msgValue, values[0])
  109. }
  110. func populateField(fieldDescriptor protoreflect.FieldDescriptor, msgValue protoreflect.Message, value string) error {
  111. v, err := parseField(fieldDescriptor, value)
  112. if err != nil {
  113. return fmt.Errorf("parsing field %q: %w", fieldDescriptor.FullName().Name(), err)
  114. }
  115. msgValue.Set(fieldDescriptor, v)
  116. return nil
  117. }
  118. func populateRepeatedField(fieldDescriptor protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
  119. for _, value := range values {
  120. v, err := parseField(fieldDescriptor, value)
  121. if err != nil {
  122. return fmt.Errorf("parsing list %q: %w", fieldDescriptor.FullName().Name(), err)
  123. }
  124. list.Append(v)
  125. }
  126. return nil
  127. }
  128. func populateMapField(fieldDescriptor protoreflect.FieldDescriptor, mp protoreflect.Map, values []string) error {
  129. if len(values) != 2 {
  130. return fmt.Errorf("more than one value provided for key %q in map %q", values[0], fieldDescriptor.FullName())
  131. }
  132. key, err := parseField(fieldDescriptor.MapKey(), values[0])
  133. if err != nil {
  134. return fmt.Errorf("parsing map key %q: %w", fieldDescriptor.FullName().Name(), err)
  135. }
  136. value, err := parseField(fieldDescriptor.MapValue(), values[1])
  137. if err != nil {
  138. return fmt.Errorf("parsing map value %q: %w", fieldDescriptor.FullName().Name(), err)
  139. }
  140. mp.Set(key.MapKey(), value)
  141. return nil
  142. }
  143. func parseField(fieldDescriptor protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
  144. switch fieldDescriptor.Kind() {
  145. case protoreflect.BoolKind:
  146. v, err := strconv.ParseBool(value)
  147. if err != nil {
  148. return protoreflect.Value{}, err
  149. }
  150. return protoreflect.ValueOfBool(v), nil
  151. case protoreflect.EnumKind:
  152. enum, err := protoregistry.GlobalTypes.FindEnumByName(fieldDescriptor.Enum().FullName())
  153. if err != nil {
  154. if errors.Is(err, protoregistry.NotFound) {
  155. return protoreflect.Value{}, fmt.Errorf("enum %q is not registered", fieldDescriptor.Enum().FullName())
  156. }
  157. return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
  158. }
  159. // Look for enum by name
  160. v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
  161. if v == nil {
  162. i, err := strconv.Atoi(value)
  163. if err != nil {
  164. return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
  165. }
  166. // Look for enum by number
  167. if v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)); v == nil {
  168. return protoreflect.Value{}, fmt.Errorf("%q is not a valid value", value)
  169. }
  170. }
  171. return protoreflect.ValueOfEnum(v.Number()), nil
  172. case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
  173. v, err := strconv.ParseInt(value, 10, 32)
  174. if err != nil {
  175. return protoreflect.Value{}, err
  176. }
  177. return protoreflect.ValueOfInt32(int32(v)), nil
  178. case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
  179. v, err := strconv.ParseInt(value, 10, 64)
  180. if err != nil {
  181. return protoreflect.Value{}, err
  182. }
  183. return protoreflect.ValueOfInt64(v), nil
  184. case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
  185. v, err := strconv.ParseUint(value, 10, 32)
  186. if err != nil {
  187. return protoreflect.Value{}, err
  188. }
  189. return protoreflect.ValueOfUint32(uint32(v)), nil
  190. case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
  191. v, err := strconv.ParseUint(value, 10, 64)
  192. if err != nil {
  193. return protoreflect.Value{}, err
  194. }
  195. return protoreflect.ValueOfUint64(v), nil
  196. case protoreflect.FloatKind:
  197. v, err := strconv.ParseFloat(value, 32)
  198. if err != nil {
  199. return protoreflect.Value{}, err
  200. }
  201. return protoreflect.ValueOfFloat32(float32(v)), nil
  202. case protoreflect.DoubleKind:
  203. v, err := strconv.ParseFloat(value, 64)
  204. if err != nil {
  205. return protoreflect.Value{}, err
  206. }
  207. return protoreflect.ValueOfFloat64(v), nil
  208. case protoreflect.StringKind:
  209. return protoreflect.ValueOfString(value), nil
  210. case protoreflect.BytesKind:
  211. v, err := Bytes(value)
  212. if err != nil {
  213. return protoreflect.Value{}, err
  214. }
  215. return protoreflect.ValueOfBytes(v), nil
  216. case protoreflect.MessageKind, protoreflect.GroupKind:
  217. return parseMessage(fieldDescriptor.Message(), value)
  218. default:
  219. panic(fmt.Sprintf("unknown field kind: %v", fieldDescriptor.Kind()))
  220. }
  221. }
  222. func parseMessage(msgDescriptor protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
  223. var msg proto.Message
  224. switch msgDescriptor.FullName() {
  225. case "google.protobuf.Timestamp":
  226. t, err := time.Parse(time.RFC3339Nano, value)
  227. if err != nil {
  228. return protoreflect.Value{}, err
  229. }
  230. msg = timestamppb.New(t)
  231. case "google.protobuf.Duration":
  232. d, err := time.ParseDuration(value)
  233. if err != nil {
  234. return protoreflect.Value{}, err
  235. }
  236. msg = durationpb.New(d)
  237. case "google.protobuf.DoubleValue":
  238. v, err := strconv.ParseFloat(value, 64)
  239. if err != nil {
  240. return protoreflect.Value{}, err
  241. }
  242. msg = wrapperspb.Double(v)
  243. case "google.protobuf.FloatValue":
  244. v, err := strconv.ParseFloat(value, 32)
  245. if err != nil {
  246. return protoreflect.Value{}, err
  247. }
  248. msg = wrapperspb.Float(float32(v))
  249. case "google.protobuf.Int64Value":
  250. v, err := strconv.ParseInt(value, 10, 64)
  251. if err != nil {
  252. return protoreflect.Value{}, err
  253. }
  254. msg = wrapperspb.Int64(v)
  255. case "google.protobuf.Int32Value":
  256. v, err := strconv.ParseInt(value, 10, 32)
  257. if err != nil {
  258. return protoreflect.Value{}, err
  259. }
  260. msg = wrapperspb.Int32(int32(v))
  261. case "google.protobuf.UInt64Value":
  262. v, err := strconv.ParseUint(value, 10, 64)
  263. if err != nil {
  264. return protoreflect.Value{}, err
  265. }
  266. msg = wrapperspb.UInt64(v)
  267. case "google.protobuf.UInt32Value":
  268. v, err := strconv.ParseUint(value, 10, 32)
  269. if err != nil {
  270. return protoreflect.Value{}, err
  271. }
  272. msg = wrapperspb.UInt32(uint32(v))
  273. case "google.protobuf.BoolValue":
  274. v, err := strconv.ParseBool(value)
  275. if err != nil {
  276. return protoreflect.Value{}, err
  277. }
  278. msg = wrapperspb.Bool(v)
  279. case "google.protobuf.StringValue":
  280. msg = wrapperspb.String(value)
  281. case "google.protobuf.BytesValue":
  282. v, err := Bytes(value)
  283. if err != nil {
  284. return protoreflect.Value{}, err
  285. }
  286. msg = wrapperspb.Bytes(v)
  287. case "google.protobuf.FieldMask":
  288. fm := &field_mask.FieldMask{}
  289. fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
  290. msg = fm
  291. case "google.protobuf.Value":
  292. var v structpb.Value
  293. if err := protojson.Unmarshal([]byte(value), &v); err != nil {
  294. return protoreflect.Value{}, err
  295. }
  296. msg = &v
  297. case "google.protobuf.Struct":
  298. var v structpb.Struct
  299. if err := protojson.Unmarshal([]byte(value), &v); err != nil {
  300. return protoreflect.Value{}, err
  301. }
  302. msg = &v
  303. default:
  304. return protoreflect.Value{}, fmt.Errorf("unsupported message type: %q", string(msgDescriptor.FullName()))
  305. }
  306. return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
  307. }