handler.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. package runtime
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/textproto"
  9. "strings"
  10. "google.golang.org/genproto/googleapis/api/httpbody"
  11. "google.golang.org/grpc/codes"
  12. "google.golang.org/grpc/grpclog"
  13. "google.golang.org/grpc/status"
  14. "google.golang.org/protobuf/proto"
  15. )
  16. // ForwardResponseStream forwards the stream from gRPC server to REST client.
  17. func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  18. f, ok := w.(http.Flusher)
  19. if !ok {
  20. grpclog.Infof("Flush not supported in %T", w)
  21. http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
  22. return
  23. }
  24. md, ok := ServerMetadataFromContext(ctx)
  25. if !ok {
  26. grpclog.Infof("Failed to extract ServerMetadata from context")
  27. http.Error(w, "unexpected error", http.StatusInternalServerError)
  28. return
  29. }
  30. handleForwardResponseServerMetadata(w, mux, md)
  31. w.Header().Set("Transfer-Encoding", "chunked")
  32. if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
  33. HTTPError(ctx, mux, marshaler, w, req, err)
  34. return
  35. }
  36. var delimiter []byte
  37. if d, ok := marshaler.(Delimited); ok {
  38. delimiter = d.Delimiter()
  39. } else {
  40. delimiter = []byte("\n")
  41. }
  42. var wroteHeader bool
  43. for {
  44. resp, err := recv()
  45. if errors.Is(err, io.EOF) {
  46. return
  47. }
  48. if err != nil {
  49. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
  50. return
  51. }
  52. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  53. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
  54. return
  55. }
  56. if !wroteHeader {
  57. w.Header().Set("Content-Type", marshaler.ContentType(resp))
  58. }
  59. var buf []byte
  60. httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
  61. switch {
  62. case resp == nil:
  63. buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
  64. case isHTTPBody:
  65. buf = httpBody.GetData()
  66. default:
  67. result := map[string]interface{}{"result": resp}
  68. if rb, ok := resp.(responseBody); ok {
  69. result["result"] = rb.XXX_ResponseBody()
  70. }
  71. buf, err = marshaler.Marshal(result)
  72. }
  73. if err != nil {
  74. grpclog.Infof("Failed to marshal response chunk: %v", err)
  75. handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
  76. return
  77. }
  78. if _, err := w.Write(buf); err != nil {
  79. grpclog.Infof("Failed to send response chunk: %v", err)
  80. return
  81. }
  82. wroteHeader = true
  83. if _, err := w.Write(delimiter); err != nil {
  84. grpclog.Infof("Failed to send delimiter chunk: %v", err)
  85. return
  86. }
  87. f.Flush()
  88. }
  89. }
  90. func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
  91. for k, vs := range md.HeaderMD {
  92. if h, ok := mux.outgoingHeaderMatcher(k); ok {
  93. for _, v := range vs {
  94. w.Header().Add(h, v)
  95. }
  96. }
  97. }
  98. }
  99. func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
  100. for k := range md.TrailerMD {
  101. tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
  102. w.Header().Add("Trailer", tKey)
  103. }
  104. }
  105. func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
  106. for k, vs := range md.TrailerMD {
  107. tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
  108. for _, v := range vs {
  109. w.Header().Add(tKey, v)
  110. }
  111. }
  112. }
  113. // responseBody interface contains method for getting field for marshaling to the response body
  114. // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
  115. type responseBody interface {
  116. XXX_ResponseBody() interface{}
  117. }
  118. // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
  119. func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
  120. md, ok := ServerMetadataFromContext(ctx)
  121. if !ok {
  122. grpclog.Infof("Failed to extract ServerMetadata from context")
  123. }
  124. handleForwardResponseServerMetadata(w, mux, md)
  125. // RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
  126. // Unless the request includes a TE header field indicating "trailers"
  127. // is acceptable, as described in Section 4.3, a server SHOULD NOT
  128. // generate trailer fields that it believes are necessary for the user
  129. // agent to receive.
  130. doForwardTrailers := requestAcceptsTrailers(req)
  131. if doForwardTrailers {
  132. handleForwardResponseTrailerHeader(w, md)
  133. w.Header().Set("Transfer-Encoding", "chunked")
  134. }
  135. handleForwardResponseTrailerHeader(w, md)
  136. contentType := marshaler.ContentType(resp)
  137. w.Header().Set("Content-Type", contentType)
  138. if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
  139. HTTPError(ctx, mux, marshaler, w, req, err)
  140. return
  141. }
  142. var buf []byte
  143. var err error
  144. if rb, ok := resp.(responseBody); ok {
  145. buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
  146. } else {
  147. buf, err = marshaler.Marshal(resp)
  148. }
  149. if err != nil {
  150. grpclog.Infof("Marshal error: %v", err)
  151. HTTPError(ctx, mux, marshaler, w, req, err)
  152. return
  153. }
  154. if _, err = w.Write(buf); err != nil {
  155. grpclog.Infof("Failed to write response: %v", err)
  156. }
  157. if doForwardTrailers {
  158. handleForwardResponseTrailer(w, md)
  159. }
  160. }
  161. func requestAcceptsTrailers(req *http.Request) bool {
  162. te := req.Header.Get("TE")
  163. return strings.Contains(strings.ToLower(te), "trailers")
  164. }
  165. func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
  166. if len(opts) == 0 {
  167. return nil
  168. }
  169. for _, opt := range opts {
  170. if err := opt(ctx, w, resp); err != nil {
  171. grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
  172. return err
  173. }
  174. }
  175. return nil
  176. }
  177. func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error, delimiter []byte) {
  178. st := mux.streamErrorHandler(ctx, err)
  179. msg := errorChunk(st)
  180. if !wroteHeader {
  181. w.Header().Set("Content-Type", marshaler.ContentType(msg))
  182. w.WriteHeader(HTTPStatusFromCode(st.Code()))
  183. }
  184. buf, err := marshaler.Marshal(msg)
  185. if err != nil {
  186. grpclog.Infof("Failed to marshal an error: %v", err)
  187. return
  188. }
  189. if _, err := w.Write(buf); err != nil {
  190. grpclog.Infof("Failed to notify error to client: %v", err)
  191. return
  192. }
  193. if _, err := w.Write(delimiter); err != nil {
  194. grpclog.Infof("Failed to send delimiter chunk: %v", err)
  195. return
  196. }
  197. }
  198. func errorChunk(st *status.Status) map[string]proto.Message {
  199. return map[string]proto.Message{"error": st.Proto()}
  200. }