mux.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. package runtime
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net/http"
  7. "net/textproto"
  8. "regexp"
  9. "strings"
  10. "github.com/grpc-ecosystem/grpc-gateway/v2/internal/httprule"
  11. "google.golang.org/grpc/codes"
  12. "google.golang.org/grpc/grpclog"
  13. "google.golang.org/grpc/health/grpc_health_v1"
  14. "google.golang.org/grpc/metadata"
  15. "google.golang.org/grpc/status"
  16. "google.golang.org/protobuf/proto"
  17. )
  18. // UnescapingMode defines the behavior of ServeMux when unescaping path parameters.
  19. type UnescapingMode int
  20. const (
  21. // UnescapingModeLegacy is the default V2 behavior, which escapes the entire
  22. // path string before doing any routing.
  23. UnescapingModeLegacy UnescapingMode = iota
  24. // UnescapingModeAllExceptReserved unescapes all path parameters except RFC 6570
  25. // reserved characters.
  26. UnescapingModeAllExceptReserved
  27. // UnescapingModeAllExceptSlash unescapes URL path parameters except path
  28. // separators, which will be left as "%2F".
  29. UnescapingModeAllExceptSlash
  30. // UnescapingModeAllCharacters unescapes all URL path parameters.
  31. UnescapingModeAllCharacters
  32. // UnescapingModeDefault is the default escaping type.
  33. // TODO(v3): default this to UnescapingModeAllExceptReserved per grpc-httpjson-transcoding's
  34. // reference implementation
  35. UnescapingModeDefault = UnescapingModeLegacy
  36. )
  37. var encodedPathSplitter = regexp.MustCompile("(/|%2F)")
  38. // A HandlerFunc handles a specific pair of path pattern and HTTP method.
  39. type HandlerFunc func(w http.ResponseWriter, r *http.Request, pathParams map[string]string)
  40. // ServeMux is a request multiplexer for grpc-gateway.
  41. // It matches http requests to patterns and invokes the corresponding handler.
  42. type ServeMux struct {
  43. // handlers maps HTTP method to a list of handlers.
  44. handlers map[string][]handler
  45. forwardResponseOptions []func(context.Context, http.ResponseWriter, proto.Message) error
  46. marshalers marshalerRegistry
  47. incomingHeaderMatcher HeaderMatcherFunc
  48. outgoingHeaderMatcher HeaderMatcherFunc
  49. metadataAnnotators []func(context.Context, *http.Request) metadata.MD
  50. errorHandler ErrorHandlerFunc
  51. streamErrorHandler StreamErrorHandlerFunc
  52. routingErrorHandler RoutingErrorHandlerFunc
  53. disablePathLengthFallback bool
  54. unescapingMode UnescapingMode
  55. }
  56. // ServeMuxOption is an option that can be given to a ServeMux on construction.
  57. type ServeMuxOption func(*ServeMux)
  58. // WithForwardResponseOption returns a ServeMuxOption representing the forwardResponseOption.
  59. //
  60. // forwardResponseOption is an option that will be called on the relevant context.Context,
  61. // http.ResponseWriter, and proto.Message before every forwarded response.
  62. //
  63. // The message may be nil in the case where just a header is being sent.
  64. func WithForwardResponseOption(forwardResponseOption func(context.Context, http.ResponseWriter, proto.Message) error) ServeMuxOption {
  65. return func(serveMux *ServeMux) {
  66. serveMux.forwardResponseOptions = append(serveMux.forwardResponseOptions, forwardResponseOption)
  67. }
  68. }
  69. // WithUnescapingMode sets the escaping type. See the definitions of UnescapingMode
  70. // for more information.
  71. func WithUnescapingMode(mode UnescapingMode) ServeMuxOption {
  72. return func(serveMux *ServeMux) {
  73. serveMux.unescapingMode = mode
  74. }
  75. }
  76. // SetQueryParameterParser sets the query parameter parser, used to populate message from query parameters.
  77. // Configuring this will mean the generated OpenAPI output is no longer correct, and it should be
  78. // done with careful consideration.
  79. func SetQueryParameterParser(queryParameterParser QueryParameterParser) ServeMuxOption {
  80. return func(serveMux *ServeMux) {
  81. currentQueryParser = queryParameterParser
  82. }
  83. }
  84. // HeaderMatcherFunc checks whether a header key should be forwarded to/from gRPC context.
  85. type HeaderMatcherFunc func(string) (string, bool)
  86. // DefaultHeaderMatcher is used to pass http request headers to/from gRPC context. This adds permanent HTTP header
  87. // keys (as specified by the IANA, e.g: Accept, Cookie, Host) to the gRPC metadata with the grpcgateway- prefix. If you want to know which headers are considered permanent, you can view the isPermanentHTTPHeader function.
  88. // HTTP headers that start with 'Grpc-Metadata-' are mapped to gRPC metadata after removing the prefix 'Grpc-Metadata-'.
  89. // Other headers are not added to the gRPC metadata.
  90. func DefaultHeaderMatcher(key string) (string, bool) {
  91. switch key = textproto.CanonicalMIMEHeaderKey(key); {
  92. case isPermanentHTTPHeader(key):
  93. return MetadataPrefix + key, true
  94. case strings.HasPrefix(key, MetadataHeaderPrefix):
  95. return key[len(MetadataHeaderPrefix):], true
  96. }
  97. return "", false
  98. }
  99. // WithIncomingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for incoming request to gateway.
  100. //
  101. // This matcher will be called with each header in http.Request. If matcher returns true, that header will be
  102. // passed to gRPC context. To transform the header before passing to gRPC context, matcher should return modified header.
  103. func WithIncomingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
  104. for _, header := range fn.matchedMalformedHeaders() {
  105. grpclog.Warningf("The configured forwarding filter would allow %q to be sent to the gRPC server, which will likely cause errors. See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more information.", header)
  106. }
  107. return func(mux *ServeMux) {
  108. mux.incomingHeaderMatcher = fn
  109. }
  110. }
  111. // matchedMalformedHeaders returns the malformed headers that would be forwarded to gRPC server.
  112. func (fn HeaderMatcherFunc) matchedMalformedHeaders() []string {
  113. if fn == nil {
  114. return nil
  115. }
  116. headers := make([]string, 0)
  117. for header := range malformedHTTPHeaders {
  118. out, accept := fn(header)
  119. if accept && isMalformedHTTPHeader(out) {
  120. headers = append(headers, out)
  121. }
  122. }
  123. return headers
  124. }
  125. // WithOutgoingHeaderMatcher returns a ServeMuxOption representing a headerMatcher for outgoing response from gateway.
  126. //
  127. // This matcher will be called with each header in response header metadata. If matcher returns true, that header will be
  128. // passed to http response returned from gateway. To transform the header before passing to response,
  129. // matcher should return modified header.
  130. func WithOutgoingHeaderMatcher(fn HeaderMatcherFunc) ServeMuxOption {
  131. return func(mux *ServeMux) {
  132. mux.outgoingHeaderMatcher = fn
  133. }
  134. }
  135. // WithMetadata returns a ServeMuxOption for passing metadata to a gRPC context.
  136. //
  137. // This can be used by services that need to read from http.Request and modify gRPC context. A common use case
  138. // is reading token from cookie and adding it in gRPC context.
  139. func WithMetadata(annotator func(context.Context, *http.Request) metadata.MD) ServeMuxOption {
  140. return func(serveMux *ServeMux) {
  141. serveMux.metadataAnnotators = append(serveMux.metadataAnnotators, annotator)
  142. }
  143. }
  144. // WithErrorHandler returns a ServeMuxOption for configuring a custom error handler.
  145. //
  146. // This can be used to configure a custom error response.
  147. func WithErrorHandler(fn ErrorHandlerFunc) ServeMuxOption {
  148. return func(serveMux *ServeMux) {
  149. serveMux.errorHandler = fn
  150. }
  151. }
  152. // WithStreamErrorHandler returns a ServeMuxOption that will use the given custom stream
  153. // error handler, which allows for customizing the error trailer for server-streaming
  154. // calls.
  155. //
  156. // For stream errors that occur before any response has been written, the mux's
  157. // ErrorHandler will be invoked. However, once data has been written, the errors must
  158. // be handled differently: they must be included in the response body. The response body's
  159. // final message will include the error details returned by the stream error handler.
  160. func WithStreamErrorHandler(fn StreamErrorHandlerFunc) ServeMuxOption {
  161. return func(serveMux *ServeMux) {
  162. serveMux.streamErrorHandler = fn
  163. }
  164. }
  165. // WithRoutingErrorHandler returns a ServeMuxOption for configuring a custom error handler to handle http routing errors.
  166. //
  167. // Method called for errors which can happen before gRPC route selected or executed.
  168. // The following error codes: StatusMethodNotAllowed StatusNotFound StatusBadRequest
  169. func WithRoutingErrorHandler(fn RoutingErrorHandlerFunc) ServeMuxOption {
  170. return func(serveMux *ServeMux) {
  171. serveMux.routingErrorHandler = fn
  172. }
  173. }
  174. // WithDisablePathLengthFallback returns a ServeMuxOption for disable path length fallback.
  175. func WithDisablePathLengthFallback() ServeMuxOption {
  176. return func(serveMux *ServeMux) {
  177. serveMux.disablePathLengthFallback = true
  178. }
  179. }
  180. // WithHealthEndpointAt returns a ServeMuxOption that will add an endpoint to the created ServeMux at the path specified by endpointPath.
  181. // When called the handler will forward the request to the upstream grpc service health check (defined in the
  182. // gRPC Health Checking Protocol).
  183. //
  184. // See here https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/health_check/ for more information on how
  185. // to setup the protocol in the grpc server.
  186. //
  187. // If you define a service as query parameter, this will also be forwarded as service in the HealthCheckRequest.
  188. func WithHealthEndpointAt(healthCheckClient grpc_health_v1.HealthClient, endpointPath string) ServeMuxOption {
  189. return func(s *ServeMux) {
  190. // error can be ignored since pattern is definitely valid
  191. _ = s.HandlePath(
  192. http.MethodGet, endpointPath, func(w http.ResponseWriter, r *http.Request, _ map[string]string,
  193. ) {
  194. _, outboundMarshaler := MarshalerForRequest(s, r)
  195. resp, err := healthCheckClient.Check(r.Context(), &grpc_health_v1.HealthCheckRequest{
  196. Service: r.URL.Query().Get("service"),
  197. })
  198. if err != nil {
  199. s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
  200. return
  201. }
  202. w.Header().Set("Content-Type", "application/json")
  203. if resp.GetStatus() != grpc_health_v1.HealthCheckResponse_SERVING {
  204. switch resp.GetStatus() {
  205. case grpc_health_v1.HealthCheckResponse_NOT_SERVING, grpc_health_v1.HealthCheckResponse_UNKNOWN:
  206. err = status.Error(codes.Unavailable, resp.String())
  207. case grpc_health_v1.HealthCheckResponse_SERVICE_UNKNOWN:
  208. err = status.Error(codes.NotFound, resp.String())
  209. }
  210. s.errorHandler(r.Context(), s, outboundMarshaler, w, r, err)
  211. return
  212. }
  213. _ = outboundMarshaler.NewEncoder(w).Encode(resp)
  214. })
  215. }
  216. }
  217. // WithHealthzEndpoint returns a ServeMuxOption that will add a /healthz endpoint to the created ServeMux.
  218. //
  219. // See WithHealthEndpointAt for the general implementation.
  220. func WithHealthzEndpoint(healthCheckClient grpc_health_v1.HealthClient) ServeMuxOption {
  221. return WithHealthEndpointAt(healthCheckClient, "/healthz")
  222. }
  223. // NewServeMux returns a new ServeMux whose internal mapping is empty.
  224. func NewServeMux(opts ...ServeMuxOption) *ServeMux {
  225. serveMux := &ServeMux{
  226. handlers: make(map[string][]handler),
  227. forwardResponseOptions: make([]func(context.Context, http.ResponseWriter, proto.Message) error, 0),
  228. marshalers: makeMarshalerMIMERegistry(),
  229. errorHandler: DefaultHTTPErrorHandler,
  230. streamErrorHandler: DefaultStreamErrorHandler,
  231. routingErrorHandler: DefaultRoutingErrorHandler,
  232. unescapingMode: UnescapingModeDefault,
  233. }
  234. for _, opt := range opts {
  235. opt(serveMux)
  236. }
  237. if serveMux.incomingHeaderMatcher == nil {
  238. serveMux.incomingHeaderMatcher = DefaultHeaderMatcher
  239. }
  240. if serveMux.outgoingHeaderMatcher == nil {
  241. serveMux.outgoingHeaderMatcher = func(key string) (string, bool) {
  242. return fmt.Sprintf("%s%s", MetadataHeaderPrefix, key), true
  243. }
  244. }
  245. return serveMux
  246. }
  247. // Handle associates "h" to the pair of HTTP method and path pattern.
  248. func (s *ServeMux) Handle(meth string, pat Pattern, h HandlerFunc) {
  249. s.handlers[meth] = append([]handler{{pat: pat, h: h}}, s.handlers[meth]...)
  250. }
  251. // HandlePath allows users to configure custom path handlers.
  252. // refer: https://grpc-ecosystem.github.io/grpc-gateway/docs/operations/inject_router/
  253. func (s *ServeMux) HandlePath(meth string, pathPattern string, h HandlerFunc) error {
  254. compiler, err := httprule.Parse(pathPattern)
  255. if err != nil {
  256. return fmt.Errorf("parsing path pattern: %w", err)
  257. }
  258. tp := compiler.Compile()
  259. pattern, err := NewPattern(tp.Version, tp.OpCodes, tp.Pool, tp.Verb)
  260. if err != nil {
  261. return fmt.Errorf("creating new pattern: %w", err)
  262. }
  263. s.Handle(meth, pattern, h)
  264. return nil
  265. }
  266. // ServeHTTP dispatches the request to the first handler whose pattern matches to r.Method and r.URL.Path.
  267. func (s *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  268. ctx := r.Context()
  269. path := r.URL.Path
  270. if !strings.HasPrefix(path, "/") {
  271. _, outboundMarshaler := MarshalerForRequest(s, r)
  272. s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusBadRequest)
  273. return
  274. }
  275. // TODO(v3): remove UnescapingModeLegacy
  276. if s.unescapingMode != UnescapingModeLegacy && r.URL.RawPath != "" {
  277. path = r.URL.RawPath
  278. }
  279. if override := r.Header.Get("X-HTTP-Method-Override"); override != "" && s.isPathLengthFallback(r) {
  280. r.Method = strings.ToUpper(override)
  281. if err := r.ParseForm(); err != nil {
  282. _, outboundMarshaler := MarshalerForRequest(s, r)
  283. sterr := status.Error(codes.InvalidArgument, err.Error())
  284. s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  285. return
  286. }
  287. }
  288. var pathComponents []string
  289. // since in UnescapeModeLegacy, the URL will already have been fully unescaped, if we also split on "%2F"
  290. // in this escaping mode we would be double unescaping but in UnescapingModeAllCharacters, we still do as the
  291. // path is the RawPath (i.e. unescaped). That does mean that the behavior of this function will change its default
  292. // behavior when the UnescapingModeDefault gets changed from UnescapingModeLegacy to UnescapingModeAllExceptReserved
  293. if s.unescapingMode == UnescapingModeAllCharacters {
  294. pathComponents = encodedPathSplitter.Split(path[1:], -1)
  295. } else {
  296. pathComponents = strings.Split(path[1:], "/")
  297. }
  298. lastPathComponent := pathComponents[len(pathComponents)-1]
  299. for _, h := range s.handlers[r.Method] {
  300. // If the pattern has a verb, explicitly look for a suffix in the last
  301. // component that matches a colon plus the verb. This allows us to
  302. // handle some cases that otherwise can't be correctly handled by the
  303. // former LastIndex case, such as when the verb literal itself contains
  304. // a colon. This should work for all cases that have run through the
  305. // parser because we know what verb we're looking for, however, there
  306. // are still some cases that the parser itself cannot disambiguate. See
  307. // the comment there if interested.
  308. var verb string
  309. patVerb := h.pat.Verb()
  310. idx := -1
  311. if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
  312. idx = len(lastPathComponent) - len(patVerb) - 1
  313. }
  314. if idx == 0 {
  315. _, outboundMarshaler := MarshalerForRequest(s, r)
  316. s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
  317. return
  318. }
  319. comps := make([]string, len(pathComponents))
  320. copy(comps, pathComponents)
  321. if idx > 0 {
  322. comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
  323. }
  324. pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
  325. if err != nil {
  326. var mse MalformedSequenceError
  327. if ok := errors.As(err, &mse); ok {
  328. _, outboundMarshaler := MarshalerForRequest(s, r)
  329. s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
  330. HTTPStatus: http.StatusBadRequest,
  331. Err: mse,
  332. })
  333. }
  334. continue
  335. }
  336. h.h(w, r, pathParams)
  337. return
  338. }
  339. // if no handler has found for the request, lookup for other methods
  340. // to handle POST -> GET fallback if the request is subject to path
  341. // length fallback.
  342. // Note we are not eagerly checking the request here as we want to return the
  343. // right HTTP status code, and we need to process the fallback candidates in
  344. // order to do that.
  345. for m, handlers := range s.handlers {
  346. if m == r.Method {
  347. continue
  348. }
  349. for _, h := range handlers {
  350. var verb string
  351. patVerb := h.pat.Verb()
  352. idx := -1
  353. if patVerb != "" && strings.HasSuffix(lastPathComponent, ":"+patVerb) {
  354. idx = len(lastPathComponent) - len(patVerb) - 1
  355. }
  356. comps := make([]string, len(pathComponents))
  357. copy(comps, pathComponents)
  358. if idx > 0 {
  359. comps[len(comps)-1], verb = lastPathComponent[:idx], lastPathComponent[idx+1:]
  360. }
  361. pathParams, err := h.pat.MatchAndEscape(comps, verb, s.unescapingMode)
  362. if err != nil {
  363. var mse MalformedSequenceError
  364. if ok := errors.As(err, &mse); ok {
  365. _, outboundMarshaler := MarshalerForRequest(s, r)
  366. s.errorHandler(ctx, s, outboundMarshaler, w, r, &HTTPStatusError{
  367. HTTPStatus: http.StatusBadRequest,
  368. Err: mse,
  369. })
  370. }
  371. continue
  372. }
  373. // X-HTTP-Method-Override is optional. Always allow fallback to POST.
  374. // Also, only consider POST -> GET fallbacks, and avoid falling back to
  375. // potentially dangerous operations like DELETE.
  376. if s.isPathLengthFallback(r) && m == http.MethodGet {
  377. if err := r.ParseForm(); err != nil {
  378. _, outboundMarshaler := MarshalerForRequest(s, r)
  379. sterr := status.Error(codes.InvalidArgument, err.Error())
  380. s.errorHandler(ctx, s, outboundMarshaler, w, r, sterr)
  381. return
  382. }
  383. h.h(w, r, pathParams)
  384. return
  385. }
  386. _, outboundMarshaler := MarshalerForRequest(s, r)
  387. s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusMethodNotAllowed)
  388. return
  389. }
  390. }
  391. _, outboundMarshaler := MarshalerForRequest(s, r)
  392. s.routingErrorHandler(ctx, s, outboundMarshaler, w, r, http.StatusNotFound)
  393. }
  394. // GetForwardResponseOptions returns the ForwardResponseOptions associated with this ServeMux.
  395. func (s *ServeMux) GetForwardResponseOptions() []func(context.Context, http.ResponseWriter, proto.Message) error {
  396. return s.forwardResponseOptions
  397. }
  398. func (s *ServeMux) isPathLengthFallback(r *http.Request) bool {
  399. return !s.disablePathLengthFallback && r.Method == "POST" && r.Header.Get("Content-Type") == "application/x-www-form-urlencoded"
  400. }
  401. type handler struct {
  402. pat Pattern
  403. h HandlerFunc
  404. }