123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- package runtime
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "net/http"
- "net/textproto"
- "strings"
- "google.golang.org/genproto/googleapis/api/httpbody"
- "google.golang.org/grpc/codes"
- "google.golang.org/grpc/grpclog"
- "google.golang.org/grpc/status"
- "google.golang.org/protobuf/proto"
- )
- // ForwardResponseStream forwards the stream from gRPC server to REST client.
- func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
- f, ok := w.(http.Flusher)
- if !ok {
- grpclog.Infof("Flush not supported in %T", w)
- http.Error(w, "unexpected type of web server", http.StatusInternalServerError)
- return
- }
- md, ok := ServerMetadataFromContext(ctx)
- if !ok {
- grpclog.Infof("Failed to extract ServerMetadata from context")
- http.Error(w, "unexpected error", http.StatusInternalServerError)
- return
- }
- handleForwardResponseServerMetadata(w, mux, md)
- w.Header().Set("Transfer-Encoding", "chunked")
- if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil {
- HTTPError(ctx, mux, marshaler, w, req, err)
- return
- }
- var delimiter []byte
- if d, ok := marshaler.(Delimited); ok {
- delimiter = d.Delimiter()
- } else {
- delimiter = []byte("\n")
- }
- var wroteHeader bool
- for {
- resp, err := recv()
- if errors.Is(err, io.EOF) {
- return
- }
- if err != nil {
- handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
- return
- }
- if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
- handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
- return
- }
- if !wroteHeader {
- w.Header().Set("Content-Type", marshaler.ContentType(resp))
- }
- var buf []byte
- httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
- switch {
- case resp == nil:
- buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
- case isHTTPBody:
- buf = httpBody.GetData()
- default:
- result := map[string]interface{}{"result": resp}
- if rb, ok := resp.(responseBody); ok {
- result["result"] = rb.XXX_ResponseBody()
- }
- buf, err = marshaler.Marshal(result)
- }
- if err != nil {
- grpclog.Infof("Failed to marshal response chunk: %v", err)
- handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
- return
- }
- if _, err := w.Write(buf); err != nil {
- grpclog.Infof("Failed to send response chunk: %v", err)
- return
- }
- wroteHeader = true
- if _, err := w.Write(delimiter); err != nil {
- grpclog.Infof("Failed to send delimiter chunk: %v", err)
- return
- }
- f.Flush()
- }
- }
- func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) {
- for k, vs := range md.HeaderMD {
- if h, ok := mux.outgoingHeaderMatcher(k); ok {
- for _, v := range vs {
- w.Header().Add(h, v)
- }
- }
- }
- }
- func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) {
- for k := range md.TrailerMD {
- tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k))
- w.Header().Add("Trailer", tKey)
- }
- }
- func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) {
- for k, vs := range md.TrailerMD {
- tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)
- for _, v := range vs {
- w.Header().Add(tKey, v)
- }
- }
- }
- // responseBody interface contains method for getting field for marshaling to the response body
- // this method is generated for response struct from the value of `response_body` in the `google.api.HttpRule`
- type responseBody interface {
- XXX_ResponseBody() interface{}
- }
- // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client.
- func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) {
- md, ok := ServerMetadataFromContext(ctx)
- if !ok {
- grpclog.Infof("Failed to extract ServerMetadata from context")
- }
- handleForwardResponseServerMetadata(w, mux, md)
- // RFC 7230 https://tools.ietf.org/html/rfc7230#section-4.1.2
- // Unless the request includes a TE header field indicating "trailers"
- // is acceptable, as described in Section 4.3, a server SHOULD NOT
- // generate trailer fields that it believes are necessary for the user
- // agent to receive.
- doForwardTrailers := requestAcceptsTrailers(req)
- if doForwardTrailers {
- handleForwardResponseTrailerHeader(w, md)
- w.Header().Set("Transfer-Encoding", "chunked")
- }
- handleForwardResponseTrailerHeader(w, md)
- contentType := marshaler.ContentType(resp)
- w.Header().Set("Content-Type", contentType)
- if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil {
- HTTPError(ctx, mux, marshaler, w, req, err)
- return
- }
- var buf []byte
- var err error
- if rb, ok := resp.(responseBody); ok {
- buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
- } else {
- buf, err = marshaler.Marshal(resp)
- }
- if err != nil {
- grpclog.Infof("Marshal error: %v", err)
- HTTPError(ctx, mux, marshaler, w, req, err)
- return
- }
- if _, err = w.Write(buf); err != nil {
- grpclog.Infof("Failed to write response: %v", err)
- }
- if doForwardTrailers {
- handleForwardResponseTrailer(w, md)
- }
- }
- func requestAcceptsTrailers(req *http.Request) bool {
- te := req.Header.Get("TE")
- return strings.Contains(strings.ToLower(te), "trailers")
- }
- func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error {
- if len(opts) == 0 {
- return nil
- }
- for _, opt := range opts {
- if err := opt(ctx, w, resp); err != nil {
- grpclog.Infof("Error handling ForwardResponseOptions: %v", err)
- return err
- }
- }
- return nil
- }
- func handleForwardResponseStreamError(ctx context.Context, wroteHeader bool, marshaler Marshaler, w http.ResponseWriter, req *http.Request, mux *ServeMux, err error, delimiter []byte) {
- st := mux.streamErrorHandler(ctx, err)
- msg := errorChunk(st)
- if !wroteHeader {
- w.Header().Set("Content-Type", marshaler.ContentType(msg))
- w.WriteHeader(HTTPStatusFromCode(st.Code()))
- }
- buf, err := marshaler.Marshal(msg)
- if err != nil {
- grpclog.Infof("Failed to marshal an error: %v", err)
- return
- }
- if _, err := w.Write(buf); err != nil {
- grpclog.Infof("Failed to notify error to client: %v", err)
- return
- }
- if _, err := w.Write(delimiter); err != nil {
- grpclog.Infof("Failed to send delimiter chunk: %v", err)
- return
- }
- }
- func errorChunk(st *status.Status) map[string]proto.Message {
- return map[string]proto.Message{"error": st.Proto()}
- }
|