extractor.go 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. package request
  2. import (
  3. "errors"
  4. "net/http"
  5. "strings"
  6. )
  7. // Errors
  8. var (
  9. ErrNoTokenInRequest = errors.New("no token present in request")
  10. )
  11. // Extractor is an interface for extracting a token from an HTTP request.
  12. // The ExtractToken method should return a token string or an error.
  13. // If no token is present, you must return ErrNoTokenInRequest.
  14. type Extractor interface {
  15. ExtractToken(*http.Request) (string, error)
  16. }
  17. // HeaderExtractor is an extractor for finding a token in a header.
  18. // Looks at each specified header in order until there's a match
  19. type HeaderExtractor []string
  20. func (e HeaderExtractor) ExtractToken(req *http.Request) (string, error) {
  21. // loop over header names and return the first one that contains data
  22. for _, header := range e {
  23. if ah := req.Header.Get(header); ah != "" {
  24. return ah, nil
  25. }
  26. }
  27. return "", ErrNoTokenInRequest
  28. }
  29. // ArgumentExtractor extracts a token from request arguments. This includes a POSTed form or
  30. // GET URL arguments. Argument names are tried in order until there's a match.
  31. // This extractor calls `ParseMultipartForm` on the request
  32. type ArgumentExtractor []string
  33. func (e ArgumentExtractor) ExtractToken(req *http.Request) (string, error) {
  34. // Make sure form is parsed
  35. req.ParseMultipartForm(10e6)
  36. // loop over arg names and return the first one that contains data
  37. for _, arg := range e {
  38. if ah := req.Form.Get(arg); ah != "" {
  39. return ah, nil
  40. }
  41. }
  42. return "", ErrNoTokenInRequest
  43. }
  44. // MultiExtractor tries Extractors in order until one returns a token string or an error occurs
  45. type MultiExtractor []Extractor
  46. func (e MultiExtractor) ExtractToken(req *http.Request) (string, error) {
  47. // loop over header names and return the first one that contains data
  48. for _, extractor := range e {
  49. if tok, err := extractor.ExtractToken(req); tok != "" {
  50. return tok, nil
  51. } else if !errors.Is(err, ErrNoTokenInRequest) {
  52. return "", err
  53. }
  54. }
  55. return "", ErrNoTokenInRequest
  56. }
  57. // PostExtractionFilter wraps an Extractor in this to post-process the value before it's handed off.
  58. // See AuthorizationHeaderExtractor for an example
  59. type PostExtractionFilter struct {
  60. Extractor
  61. Filter func(string) (string, error)
  62. }
  63. func (e *PostExtractionFilter) ExtractToken(req *http.Request) (string, error) {
  64. if tok, err := e.Extractor.ExtractToken(req); tok != "" {
  65. return e.Filter(tok)
  66. } else {
  67. return "", err
  68. }
  69. }
  70. // BearerExtractor extracts a token from the Authorization header.
  71. // The header is expected to match the format "Bearer XX", where "XX" is the
  72. // JWT token.
  73. type BearerExtractor struct{}
  74. func (e BearerExtractor) ExtractToken(req *http.Request) (string, error) {
  75. tokenHeader := req.Header.Get("Authorization")
  76. // The usual convention is for "Bearer" to be title-cased. However, there's no
  77. // strict rule around this, and it's best to follow the robustness principle here.
  78. if tokenHeader == "" || !strings.HasPrefix(strings.ToLower(tokenHeader), "bearer ") {
  79. return "", ErrNoTokenInRequest
  80. }
  81. return tokenHeader[7:], nil
  82. }