context.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. package runtime
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "fmt"
  6. "net"
  7. "net/http"
  8. "net/textproto"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. "google.golang.org/grpc/codes"
  14. "google.golang.org/grpc/grpclog"
  15. "google.golang.org/grpc/metadata"
  16. "google.golang.org/grpc/status"
  17. )
  18. // MetadataHeaderPrefix is the http prefix that represents custom metadata
  19. // parameters to or from a gRPC call.
  20. const MetadataHeaderPrefix = "Grpc-Metadata-"
  21. // MetadataPrefix is prepended to permanent HTTP header keys (as specified
  22. // by the IANA) when added to the gRPC context.
  23. const MetadataPrefix = "grpcgateway-"
  24. // MetadataTrailerPrefix is prepended to gRPC metadata as it is converted to
  25. // HTTP headers in a response handled by grpc-gateway
  26. const MetadataTrailerPrefix = "Grpc-Trailer-"
  27. const metadataGrpcTimeout = "Grpc-Timeout"
  28. const metadataHeaderBinarySuffix = "-Bin"
  29. const xForwardedFor = "X-Forwarded-For"
  30. const xForwardedHost = "X-Forwarded-Host"
  31. // DefaultContextTimeout is used for gRPC call context.WithTimeout whenever a Grpc-Timeout inbound
  32. // header isn't present. If the value is 0 the sent `context` will not have a timeout.
  33. var DefaultContextTimeout = 0 * time.Second
  34. // malformedHTTPHeaders lists the headers that the gRPC server may reject outright as malformed.
  35. // See https://github.com/grpc/grpc-go/pull/4803#issuecomment-986093310 for more context.
  36. var malformedHTTPHeaders = map[string]struct{}{
  37. "connection": {},
  38. }
  39. type (
  40. rpcMethodKey struct{}
  41. httpPathPatternKey struct{}
  42. AnnotateContextOption func(ctx context.Context) context.Context
  43. )
  44. func WithHTTPPathPattern(pattern string) AnnotateContextOption {
  45. return func(ctx context.Context) context.Context {
  46. return withHTTPPathPattern(ctx, pattern)
  47. }
  48. }
  49. func decodeBinHeader(v string) ([]byte, error) {
  50. if len(v)%4 == 0 {
  51. // Input was padded, or padding was not necessary.
  52. return base64.StdEncoding.DecodeString(v)
  53. }
  54. return base64.RawStdEncoding.DecodeString(v)
  55. }
  56. /*
  57. AnnotateContext adds context information such as metadata from the request.
  58. At a minimum, the RemoteAddr is included in the fashion of "X-Forwarded-For",
  59. except that the forwarded destination is not another HTTP service but rather
  60. a gRPC service.
  61. */
  62. func AnnotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
  63. ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
  64. if err != nil {
  65. return nil, err
  66. }
  67. if md == nil {
  68. return ctx, nil
  69. }
  70. return metadata.NewOutgoingContext(ctx, md), nil
  71. }
  72. // AnnotateIncomingContext adds context information such as metadata from the request.
  73. // Attach metadata as incoming context.
  74. func AnnotateIncomingContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, error) {
  75. ctx, md, err := annotateContext(ctx, mux, req, rpcMethodName, options...)
  76. if err != nil {
  77. return nil, err
  78. }
  79. if md == nil {
  80. return ctx, nil
  81. }
  82. return metadata.NewIncomingContext(ctx, md), nil
  83. }
  84. func isValidGRPCMetadataKey(key string) bool {
  85. // Must be a valid gRPC "Header-Name" as defined here:
  86. // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md
  87. // This means 0-9 a-z _ - .
  88. // Only lowercase letters are valid in the wire protocol, but the client library will normalize
  89. // uppercase ASCII to lowercase, so uppercase ASCII is also acceptable.
  90. bytes := []byte(key) // gRPC validates strings on the byte level, not Unicode.
  91. for _, ch := range bytes {
  92. validLowercaseLetter := ch >= 'a' && ch <= 'z'
  93. validUppercaseLetter := ch >= 'A' && ch <= 'Z'
  94. validDigit := ch >= '0' && ch <= '9'
  95. validOther := ch == '.' || ch == '-' || ch == '_'
  96. if !validLowercaseLetter && !validUppercaseLetter && !validDigit && !validOther {
  97. return false
  98. }
  99. }
  100. return true
  101. }
  102. func isValidGRPCMetadataTextValue(textValue string) bool {
  103. // Must be a valid gRPC "ASCII-Value" as defined here:
  104. // https://github.com/grpc/grpc/blob/4b05dc88b724214d0c725c8e7442cbc7a61b1374/doc/PROTOCOL-HTTP2.md
  105. // This means printable ASCII (including/plus spaces); 0x20 to 0x7E inclusive.
  106. bytes := []byte(textValue) // gRPC validates strings on the byte level, not Unicode.
  107. for _, ch := range bytes {
  108. if ch < 0x20 || ch > 0x7E {
  109. return false
  110. }
  111. }
  112. return true
  113. }
  114. func annotateContext(ctx context.Context, mux *ServeMux, req *http.Request, rpcMethodName string, options ...AnnotateContextOption) (context.Context, metadata.MD, error) {
  115. ctx = withRPCMethod(ctx, rpcMethodName)
  116. for _, o := range options {
  117. ctx = o(ctx)
  118. }
  119. timeout := DefaultContextTimeout
  120. if tm := req.Header.Get(metadataGrpcTimeout); tm != "" {
  121. var err error
  122. timeout, err = timeoutDecode(tm)
  123. if err != nil {
  124. return nil, nil, status.Errorf(codes.InvalidArgument, "invalid grpc-timeout: %s", tm)
  125. }
  126. }
  127. var pairs []string
  128. for key, vals := range req.Header {
  129. key = textproto.CanonicalMIMEHeaderKey(key)
  130. for _, val := range vals {
  131. // For backwards-compatibility, pass through 'authorization' header with no prefix.
  132. if key == "Authorization" {
  133. pairs = append(pairs, "authorization", val)
  134. }
  135. if h, ok := mux.incomingHeaderMatcher(key); ok {
  136. if !isValidGRPCMetadataKey(h) {
  137. grpclog.Errorf("HTTP header name %q is not valid as gRPC metadata key; skipping", h)
  138. continue
  139. }
  140. // Handles "-bin" metadata in grpc, since grpc will do another base64
  141. // encode before sending to server, we need to decode it first.
  142. if strings.HasSuffix(key, metadataHeaderBinarySuffix) {
  143. b, err := decodeBinHeader(val)
  144. if err != nil {
  145. return nil, nil, status.Errorf(codes.InvalidArgument, "invalid binary header %s: %s", key, err)
  146. }
  147. val = string(b)
  148. } else if !isValidGRPCMetadataTextValue(val) {
  149. grpclog.Errorf("Value of HTTP header %q contains non-ASCII value (not valid as gRPC metadata): skipping", h)
  150. continue
  151. }
  152. pairs = append(pairs, h, val)
  153. }
  154. }
  155. }
  156. if host := req.Header.Get(xForwardedHost); host != "" {
  157. pairs = append(pairs, strings.ToLower(xForwardedHost), host)
  158. } else if req.Host != "" {
  159. pairs = append(pairs, strings.ToLower(xForwardedHost), req.Host)
  160. }
  161. if addr := req.RemoteAddr; addr != "" {
  162. if remoteIP, _, err := net.SplitHostPort(addr); err == nil {
  163. if fwd := req.Header.Get(xForwardedFor); fwd == "" {
  164. pairs = append(pairs, strings.ToLower(xForwardedFor), remoteIP)
  165. } else {
  166. pairs = append(pairs, strings.ToLower(xForwardedFor), fmt.Sprintf("%s, %s", fwd, remoteIP))
  167. }
  168. }
  169. }
  170. if timeout != 0 {
  171. //nolint:govet // The context outlives this function
  172. ctx, _ = context.WithTimeout(ctx, timeout)
  173. }
  174. if len(pairs) == 0 {
  175. return ctx, nil, nil
  176. }
  177. md := metadata.Pairs(pairs...)
  178. for _, mda := range mux.metadataAnnotators {
  179. md = metadata.Join(md, mda(ctx, req))
  180. }
  181. return ctx, md, nil
  182. }
  183. // ServerMetadata consists of metadata sent from gRPC server.
  184. type ServerMetadata struct {
  185. HeaderMD metadata.MD
  186. TrailerMD metadata.MD
  187. }
  188. type serverMetadataKey struct{}
  189. // NewServerMetadataContext creates a new context with ServerMetadata
  190. func NewServerMetadataContext(ctx context.Context, md ServerMetadata) context.Context {
  191. if ctx == nil {
  192. ctx = context.Background()
  193. }
  194. return context.WithValue(ctx, serverMetadataKey{}, md)
  195. }
  196. // ServerMetadataFromContext returns the ServerMetadata in ctx
  197. func ServerMetadataFromContext(ctx context.Context) (md ServerMetadata, ok bool) {
  198. if ctx == nil {
  199. return md, false
  200. }
  201. md, ok = ctx.Value(serverMetadataKey{}).(ServerMetadata)
  202. return
  203. }
  204. // ServerTransportStream implements grpc.ServerTransportStream.
  205. // It should only be used by the generated files to support grpc.SendHeader
  206. // outside of gRPC server use.
  207. type ServerTransportStream struct {
  208. mu sync.Mutex
  209. header metadata.MD
  210. trailer metadata.MD
  211. }
  212. // Method returns the method for the stream.
  213. func (s *ServerTransportStream) Method() string {
  214. return ""
  215. }
  216. // Header returns the header metadata of the stream.
  217. func (s *ServerTransportStream) Header() metadata.MD {
  218. s.mu.Lock()
  219. defer s.mu.Unlock()
  220. return s.header.Copy()
  221. }
  222. // SetHeader sets the header metadata.
  223. func (s *ServerTransportStream) SetHeader(md metadata.MD) error {
  224. if md.Len() == 0 {
  225. return nil
  226. }
  227. s.mu.Lock()
  228. s.header = metadata.Join(s.header, md)
  229. s.mu.Unlock()
  230. return nil
  231. }
  232. // SendHeader sets the header metadata.
  233. func (s *ServerTransportStream) SendHeader(md metadata.MD) error {
  234. return s.SetHeader(md)
  235. }
  236. // Trailer returns the cached trailer metadata.
  237. func (s *ServerTransportStream) Trailer() metadata.MD {
  238. s.mu.Lock()
  239. defer s.mu.Unlock()
  240. return s.trailer.Copy()
  241. }
  242. // SetTrailer sets the trailer metadata.
  243. func (s *ServerTransportStream) SetTrailer(md metadata.MD) error {
  244. if md.Len() == 0 {
  245. return nil
  246. }
  247. s.mu.Lock()
  248. s.trailer = metadata.Join(s.trailer, md)
  249. s.mu.Unlock()
  250. return nil
  251. }
  252. func timeoutDecode(s string) (time.Duration, error) {
  253. size := len(s)
  254. if size < 2 {
  255. return 0, fmt.Errorf("timeout string is too short: %q", s)
  256. }
  257. d, ok := timeoutUnitToDuration(s[size-1])
  258. if !ok {
  259. return 0, fmt.Errorf("timeout unit is not recognized: %q", s)
  260. }
  261. t, err := strconv.ParseInt(s[:size-1], 10, 64)
  262. if err != nil {
  263. return 0, err
  264. }
  265. return d * time.Duration(t), nil
  266. }
  267. func timeoutUnitToDuration(u uint8) (d time.Duration, ok bool) {
  268. switch u {
  269. case 'H':
  270. return time.Hour, true
  271. case 'M':
  272. return time.Minute, true
  273. case 'S':
  274. return time.Second, true
  275. case 'm':
  276. return time.Millisecond, true
  277. case 'u':
  278. return time.Microsecond, true
  279. case 'n':
  280. return time.Nanosecond, true
  281. default:
  282. return
  283. }
  284. }
  285. // isPermanentHTTPHeader checks whether hdr belongs to the list of
  286. // permanent request headers maintained by IANA.
  287. // http://www.iana.org/assignments/message-headers/message-headers.xml
  288. func isPermanentHTTPHeader(hdr string) bool {
  289. switch hdr {
  290. case
  291. "Accept",
  292. "Accept-Charset",
  293. "Accept-Language",
  294. "Accept-Ranges",
  295. "Authorization",
  296. "Cache-Control",
  297. "Content-Type",
  298. "Cookie",
  299. "Date",
  300. "Expect",
  301. "From",
  302. "Host",
  303. "If-Match",
  304. "If-Modified-Since",
  305. "If-None-Match",
  306. "If-Schedule-Tag-Match",
  307. "If-Unmodified-Since",
  308. "Max-Forwards",
  309. "Origin",
  310. "Pragma",
  311. "Referer",
  312. "User-Agent",
  313. "Via",
  314. "Warning":
  315. return true
  316. }
  317. return false
  318. }
  319. // isMalformedHTTPHeader checks whether header belongs to the list of
  320. // "malformed headers" and would be rejected by the gRPC server.
  321. func isMalformedHTTPHeader(header string) bool {
  322. _, isMalformed := malformedHTTPHeaders[strings.ToLower(header)]
  323. return isMalformed
  324. }
  325. // RPCMethod returns the method string for the server context. The returned
  326. // string is in the format of "/package.service/method".
  327. func RPCMethod(ctx context.Context) (string, bool) {
  328. m := ctx.Value(rpcMethodKey{})
  329. if m == nil {
  330. return "", false
  331. }
  332. ms, ok := m.(string)
  333. if !ok {
  334. return "", false
  335. }
  336. return ms, true
  337. }
  338. func withRPCMethod(ctx context.Context, rpcMethodName string) context.Context {
  339. return context.WithValue(ctx, rpcMethodKey{}, rpcMethodName)
  340. }
  341. // HTTPPathPattern returns the HTTP path pattern string relating to the HTTP handler, if one exists.
  342. // The format of the returned string is defined by the google.api.http path template type.
  343. func HTTPPathPattern(ctx context.Context) (string, bool) {
  344. m := ctx.Value(httpPathPatternKey{})
  345. if m == nil {
  346. return "", false
  347. }
  348. ms, ok := m.(string)
  349. if !ok {
  350. return "", false
  351. }
  352. return ms, true
  353. }
  354. func withHTTPPathPattern(ctx context.Context, httpPathPattern string) context.Context {
  355. return context.WithValue(ctx, httpPathPatternKey{}, httpPathPattern)
  356. }