container.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. package restful
  2. // Copyright 2013 Ernest Micklei. All rights reserved.
  3. // Use of this source code is governed by a license
  4. // that can be found in the LICENSE file.
  5. import (
  6. "bytes"
  7. "errors"
  8. "fmt"
  9. "net/http"
  10. "os"
  11. "runtime"
  12. "strings"
  13. "sync"
  14. "github.com/emicklei/go-restful/v3/log"
  15. )
  16. // Container holds a collection of WebServices and a http.ServeMux to dispatch http requests.
  17. // The requests are further dispatched to routes of WebServices using a RouteSelector
  18. type Container struct {
  19. webServicesLock sync.RWMutex
  20. webServices []*WebService
  21. ServeMux *http.ServeMux
  22. isRegisteredOnRoot bool
  23. containerFilters []FilterFunction
  24. doNotRecover bool // default is true
  25. recoverHandleFunc RecoverHandleFunction
  26. serviceErrorHandleFunc ServiceErrorHandleFunction
  27. router RouteSelector // default is a CurlyRouter (RouterJSR311 is a slower alternative)
  28. contentEncodingEnabled bool // default is false
  29. }
  30. // NewContainer creates a new Container using a new ServeMux and default router (CurlyRouter)
  31. func NewContainer() *Container {
  32. return &Container{
  33. webServices: []*WebService{},
  34. ServeMux: http.NewServeMux(),
  35. isRegisteredOnRoot: false,
  36. containerFilters: []FilterFunction{},
  37. doNotRecover: true,
  38. recoverHandleFunc: logStackOnRecover,
  39. serviceErrorHandleFunc: writeServiceError,
  40. router: CurlyRouter{},
  41. contentEncodingEnabled: false}
  42. }
  43. // RecoverHandleFunction declares functions that can be used to handle a panic situation.
  44. // The first argument is what recover() returns. The second must be used to communicate an error response.
  45. type RecoverHandleFunction func(interface{}, http.ResponseWriter)
  46. // RecoverHandler changes the default function (logStackOnRecover) to be called
  47. // when a panic is detected. DoNotRecover must be have its default value (=false).
  48. func (c *Container) RecoverHandler(handler RecoverHandleFunction) {
  49. c.recoverHandleFunc = handler
  50. }
  51. // ServiceErrorHandleFunction declares functions that can be used to handle a service error situation.
  52. // The first argument is the service error, the second is the request that resulted in the error and
  53. // the third must be used to communicate an error response.
  54. type ServiceErrorHandleFunction func(ServiceError, *Request, *Response)
  55. // ServiceErrorHandler changes the default function (writeServiceError) to be called
  56. // when a ServiceError is detected.
  57. func (c *Container) ServiceErrorHandler(handler ServiceErrorHandleFunction) {
  58. c.serviceErrorHandleFunc = handler
  59. }
  60. // DoNotRecover controls whether panics will be caught to return HTTP 500.
  61. // If set to true, Route functions are responsible for handling any error situation.
  62. // Default value is true.
  63. func (c *Container) DoNotRecover(doNot bool) {
  64. c.doNotRecover = doNot
  65. }
  66. // Router changes the default Router (currently CurlyRouter)
  67. func (c *Container) Router(aRouter RouteSelector) {
  68. c.router = aRouter
  69. }
  70. // EnableContentEncoding (default=false) allows for GZIP or DEFLATE encoding of responses.
  71. func (c *Container) EnableContentEncoding(enabled bool) {
  72. c.contentEncodingEnabled = enabled
  73. }
  74. // Add a WebService to the Container. It will detect duplicate root paths and exit in that case.
  75. func (c *Container) Add(service *WebService) *Container {
  76. c.webServicesLock.Lock()
  77. defer c.webServicesLock.Unlock()
  78. // if rootPath was not set then lazy initialize it
  79. if len(service.rootPath) == 0 {
  80. service.Path("/")
  81. }
  82. // cannot have duplicate root paths
  83. for _, each := range c.webServices {
  84. if each.RootPath() == service.RootPath() {
  85. log.Printf("WebService with duplicate root path detected:['%v']", each)
  86. os.Exit(1)
  87. }
  88. }
  89. // If not registered on root then add specific mapping
  90. if !c.isRegisteredOnRoot {
  91. c.isRegisteredOnRoot = c.addHandler(service, c.ServeMux)
  92. }
  93. c.webServices = append(c.webServices, service)
  94. return c
  95. }
  96. // addHandler may set a new HandleFunc for the serveMux
  97. // this function must run inside the critical region protected by the webServicesLock.
  98. // returns true if the function was registered on root ("/")
  99. func (c *Container) addHandler(service *WebService, serveMux *http.ServeMux) bool {
  100. pattern := fixedPrefixPath(service.RootPath())
  101. // check if root path registration is needed
  102. if "/" == pattern || "" == pattern {
  103. serveMux.HandleFunc("/", c.dispatch)
  104. return true
  105. }
  106. // detect if registration already exists
  107. alreadyMapped := false
  108. for _, each := range c.webServices {
  109. if each.RootPath() == service.RootPath() {
  110. alreadyMapped = true
  111. break
  112. }
  113. }
  114. if !alreadyMapped {
  115. serveMux.HandleFunc(pattern, c.dispatch)
  116. if !strings.HasSuffix(pattern, "/") {
  117. serveMux.HandleFunc(pattern+"/", c.dispatch)
  118. }
  119. }
  120. return false
  121. }
  122. func (c *Container) Remove(ws *WebService) error {
  123. if c.ServeMux == http.DefaultServeMux {
  124. errMsg := fmt.Sprintf("cannot remove a WebService from a Container using the DefaultServeMux: ['%v']", ws)
  125. log.Print(errMsg)
  126. return errors.New(errMsg)
  127. }
  128. c.webServicesLock.Lock()
  129. defer c.webServicesLock.Unlock()
  130. // build a new ServeMux and re-register all WebServices
  131. newServeMux := http.NewServeMux()
  132. newServices := []*WebService{}
  133. newIsRegisteredOnRoot := false
  134. for _, each := range c.webServices {
  135. if each.rootPath != ws.rootPath {
  136. // If not registered on root then add specific mapping
  137. if !newIsRegisteredOnRoot {
  138. newIsRegisteredOnRoot = c.addHandler(each, newServeMux)
  139. }
  140. newServices = append(newServices, each)
  141. }
  142. }
  143. c.webServices, c.ServeMux, c.isRegisteredOnRoot = newServices, newServeMux, newIsRegisteredOnRoot
  144. return nil
  145. }
  146. // logStackOnRecover is the default RecoverHandleFunction and is called
  147. // when DoNotRecover is false and the recoverHandleFunc is not set for the container.
  148. // Default implementation logs the stacktrace and writes the stacktrace on the response.
  149. // This may be a security issue as it exposes sourcecode information.
  150. func logStackOnRecover(panicReason interface{}, httpWriter http.ResponseWriter) {
  151. var buffer bytes.Buffer
  152. buffer.WriteString(fmt.Sprintf("recover from panic situation: - %v\r\n", panicReason))
  153. for i := 2; ; i += 1 {
  154. _, file, line, ok := runtime.Caller(i)
  155. if !ok {
  156. break
  157. }
  158. buffer.WriteString(fmt.Sprintf(" %s:%d\r\n", file, line))
  159. }
  160. log.Print(buffer.String())
  161. httpWriter.WriteHeader(http.StatusInternalServerError)
  162. httpWriter.Write(buffer.Bytes())
  163. }
  164. // writeServiceError is the default ServiceErrorHandleFunction and is called
  165. // when a ServiceError is returned during route selection. Default implementation
  166. // calls resp.WriteErrorString(err.Code, err.Message)
  167. func writeServiceError(err ServiceError, req *Request, resp *Response) {
  168. for header, values := range err.Header {
  169. for _, value := range values {
  170. resp.Header().Add(header, value)
  171. }
  172. }
  173. resp.WriteErrorString(err.Code, err.Message)
  174. }
  175. // Dispatch the incoming Http Request to a matching WebService.
  176. func (c *Container) Dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
  177. if httpWriter == nil {
  178. panic("httpWriter cannot be nil")
  179. }
  180. if httpRequest == nil {
  181. panic("httpRequest cannot be nil")
  182. }
  183. c.dispatch(httpWriter, httpRequest)
  184. }
  185. // Dispatch the incoming Http Request to a matching WebService.
  186. func (c *Container) dispatch(httpWriter http.ResponseWriter, httpRequest *http.Request) {
  187. // so we can assign a compressing one later
  188. writer := httpWriter
  189. // CompressingResponseWriter should be closed after all operations are done
  190. defer func() {
  191. if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
  192. compressWriter.Close()
  193. }
  194. }()
  195. // Instal panic recovery unless told otherwise
  196. if !c.doNotRecover { // catch all for 500 response
  197. defer func() {
  198. if r := recover(); r != nil {
  199. c.recoverHandleFunc(r, writer)
  200. return
  201. }
  202. }()
  203. }
  204. // Find best match Route ; err is non nil if no match was found
  205. var webService *WebService
  206. var route *Route
  207. var err error
  208. func() {
  209. c.webServicesLock.RLock()
  210. defer c.webServicesLock.RUnlock()
  211. webService, route, err = c.router.SelectRoute(
  212. c.webServices,
  213. httpRequest)
  214. }()
  215. if err != nil {
  216. // a non-200 response (may be compressed) has already been written
  217. // run container filters anyway ; they should not touch the response...
  218. chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
  219. switch err.(type) {
  220. case ServiceError:
  221. ser := err.(ServiceError)
  222. c.serviceErrorHandleFunc(ser, req, resp)
  223. }
  224. // TODO
  225. }}
  226. chain.ProcessFilter(NewRequest(httpRequest), NewResponse(writer))
  227. return
  228. }
  229. // Unless httpWriter is already an CompressingResponseWriter see if we need to install one
  230. if _, isCompressing := httpWriter.(*CompressingResponseWriter); !isCompressing {
  231. // Detect if compression is needed
  232. // assume without compression, test for override
  233. contentEncodingEnabled := c.contentEncodingEnabled
  234. if route != nil && route.contentEncodingEnabled != nil {
  235. contentEncodingEnabled = *route.contentEncodingEnabled
  236. }
  237. if contentEncodingEnabled {
  238. doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
  239. if doCompress {
  240. var err error
  241. writer, err = NewCompressingResponseWriter(httpWriter, encoding)
  242. if err != nil {
  243. log.Print("unable to install compressor: ", err)
  244. httpWriter.WriteHeader(http.StatusInternalServerError)
  245. return
  246. }
  247. }
  248. }
  249. }
  250. pathProcessor, routerProcessesPath := c.router.(PathProcessor)
  251. if !routerProcessesPath {
  252. pathProcessor = defaultPathProcessor{}
  253. }
  254. pathParams := pathProcessor.ExtractParameters(route, webService, httpRequest.URL.Path)
  255. wrappedRequest, wrappedResponse := route.wrapRequestResponse(writer, httpRequest, pathParams)
  256. // pass through filters (if any)
  257. if size := len(c.containerFilters) + len(webService.filters) + len(route.Filters); size > 0 {
  258. // compose filter chain
  259. allFilters := make([]FilterFunction, 0, size)
  260. allFilters = append(allFilters, c.containerFilters...)
  261. allFilters = append(allFilters, webService.filters...)
  262. allFilters = append(allFilters, route.Filters...)
  263. chain := FilterChain{
  264. Filters: allFilters,
  265. Target: route.Function,
  266. ParameterDocs: route.ParameterDocs,
  267. Operation: route.Operation,
  268. }
  269. chain.ProcessFilter(wrappedRequest, wrappedResponse)
  270. } else {
  271. // no filters, handle request by route
  272. route.Function(wrappedRequest, wrappedResponse)
  273. }
  274. }
  275. // fixedPrefixPath returns the fixed part of the partspec ; it may include template vars {}
  276. func fixedPrefixPath(pathspec string) string {
  277. varBegin := strings.Index(pathspec, "{")
  278. if -1 == varBegin {
  279. return pathspec
  280. }
  281. return pathspec[:varBegin]
  282. }
  283. // ServeHTTP implements net/http.Handler therefore a Container can be a Handler in a http.Server
  284. func (c *Container) ServeHTTP(httpWriter http.ResponseWriter, httpRequest *http.Request) {
  285. // Skip, if content encoding is disabled
  286. if !c.contentEncodingEnabled {
  287. c.ServeMux.ServeHTTP(httpWriter, httpRequest)
  288. return
  289. }
  290. // content encoding is enabled
  291. // Skip, if httpWriter is already an CompressingResponseWriter
  292. if _, ok := httpWriter.(*CompressingResponseWriter); ok {
  293. c.ServeMux.ServeHTTP(httpWriter, httpRequest)
  294. return
  295. }
  296. writer := httpWriter
  297. // CompressingResponseWriter should be closed after all operations are done
  298. defer func() {
  299. if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
  300. compressWriter.Close()
  301. }
  302. }()
  303. doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
  304. if doCompress {
  305. var err error
  306. writer, err = NewCompressingResponseWriter(httpWriter, encoding)
  307. if err != nil {
  308. log.Print("unable to install compressor: ", err)
  309. httpWriter.WriteHeader(http.StatusInternalServerError)
  310. return
  311. }
  312. }
  313. c.ServeMux.ServeHTTP(writer, httpRequest)
  314. }
  315. // Handle registers the handler for the given pattern. If a handler already exists for pattern, Handle panics.
  316. func (c *Container) Handle(pattern string, handler http.Handler) {
  317. c.ServeMux.Handle(pattern, http.HandlerFunc(func(httpWriter http.ResponseWriter, httpRequest *http.Request) {
  318. // Skip, if httpWriter is already an CompressingResponseWriter
  319. if _, ok := httpWriter.(*CompressingResponseWriter); ok {
  320. handler.ServeHTTP(httpWriter, httpRequest)
  321. return
  322. }
  323. writer := httpWriter
  324. // CompressingResponseWriter should be closed after all operations are done
  325. defer func() {
  326. if compressWriter, ok := writer.(*CompressingResponseWriter); ok {
  327. compressWriter.Close()
  328. }
  329. }()
  330. if c.contentEncodingEnabled {
  331. doCompress, encoding := wantsCompressedResponse(httpRequest, httpWriter)
  332. if doCompress {
  333. var err error
  334. writer, err = NewCompressingResponseWriter(httpWriter, encoding)
  335. if err != nil {
  336. log.Print("unable to install compressor: ", err)
  337. httpWriter.WriteHeader(http.StatusInternalServerError)
  338. return
  339. }
  340. }
  341. }
  342. handler.ServeHTTP(writer, httpRequest)
  343. }))
  344. }
  345. // HandleWithFilter registers the handler for the given pattern.
  346. // Container's filter chain is applied for handler.
  347. // If a handler already exists for pattern, HandleWithFilter panics.
  348. func (c *Container) HandleWithFilter(pattern string, handler http.Handler) {
  349. f := func(httpResponse http.ResponseWriter, httpRequest *http.Request) {
  350. if len(c.containerFilters) == 0 {
  351. handler.ServeHTTP(httpResponse, httpRequest)
  352. return
  353. }
  354. chain := FilterChain{Filters: c.containerFilters, Target: func(req *Request, resp *Response) {
  355. handler.ServeHTTP(resp, req.Request)
  356. }}
  357. chain.ProcessFilter(NewRequest(httpRequest), NewResponse(httpResponse))
  358. }
  359. c.Handle(pattern, http.HandlerFunc(f))
  360. }
  361. // Filter appends a container FilterFunction. These are called before dispatching
  362. // a http.Request to a WebService from the container
  363. func (c *Container) Filter(filter FilterFunction) {
  364. c.containerFilters = append(c.containerFilters, filter)
  365. }
  366. // RegisteredWebServices returns the collections of added WebServices
  367. func (c *Container) RegisteredWebServices() []*WebService {
  368. c.webServicesLock.RLock()
  369. defer c.webServicesLock.RUnlock()
  370. result := make([]*WebService, len(c.webServices))
  371. for ix := range c.webServices {
  372. result[ix] = c.webServices[ix]
  373. }
  374. return result
  375. }
  376. // computeAllowedMethods returns a list of HTTP methods that are valid for a Request
  377. func (c *Container) computeAllowedMethods(req *Request) []string {
  378. // Go through all RegisteredWebServices() and all its Routes to collect the options
  379. methods := []string{}
  380. requestPath := req.Request.URL.Path
  381. for _, ws := range c.RegisteredWebServices() {
  382. matches := ws.pathExpr.Matcher.FindStringSubmatch(requestPath)
  383. if matches != nil {
  384. finalMatch := matches[len(matches)-1]
  385. for _, rt := range ws.Routes() {
  386. matches := rt.pathExpr.Matcher.FindStringSubmatch(finalMatch)
  387. if matches != nil {
  388. lastMatch := matches[len(matches)-1]
  389. if lastMatch == "" || lastMatch == "/" { // do not include if value is neither empty nor ‘/’.
  390. methods = append(methods, rt.Method)
  391. }
  392. }
  393. }
  394. }
  395. }
  396. // methods = append(methods, "OPTIONS") not sure about this
  397. return methods
  398. }
  399. // newBasicRequestResponse creates a pair of Request,Response from its http versions.
  400. // It is basic because no parameter or (produces) content-type information is given.
  401. func newBasicRequestResponse(httpWriter http.ResponseWriter, httpRequest *http.Request) (*Request, *Response) {
  402. resp := NewResponse(httpWriter)
  403. resp.requestAccept = httpRequest.Header.Get(HEADER_Accept)
  404. return NewRequest(httpRequest), resp
  405. }