whiteList.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package middleware
  2. import (
  3. "log"
  4. "net"
  5. "net/http"
  6. "strings"
  7. )
  8. // WhiteListCheck 用于 IP 白名单验证的中间件函数
  9. func WhiteListCheck(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc {
  10. ipNets, err := parseIPNets(allowedIPs)
  11. if err != nil {
  12. log.Fatal(err)
  13. }
  14. return func(next http.HandlerFunc) http.HandlerFunc {
  15. return func(w http.ResponseWriter, r *http.Request) {
  16. ip := getClientIP(r)
  17. if !isAllowedIP(ip, ipNets) {
  18. http.Error(w, "Forbidden", http.StatusForbidden)
  19. return
  20. }
  21. next.ServeHTTP(w, r)
  22. }
  23. }
  24. }
  25. //func WhiteListMiddleware(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc {
  26. // return func(next http.HandlerFunc) http.HandlerFunc {
  27. // return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  28. // ip := getClientIP(r)
  29. // if !isAllowedIP(ip, allowedIPs) {
  30. // http.Error(w, "Forbidden", http.StatusForbidden)
  31. // return
  32. // }
  33. // next.ServeHTTP(w, r)
  34. // })
  35. // }
  36. //}
  37. // 获取客户端IP地址
  38. func getClientIP(r *http.Request) string {
  39. forwardedFor := r.Header.Get("X-Forwarded-For")
  40. if forwardedFor != "" {
  41. ips := strings.Split(forwardedFor, ",")
  42. if len(ips) > 0 {
  43. return ips[0]
  44. }
  45. }
  46. return r.RemoteAddr
  47. }
  48. // 判断IP是否在白名单中
  49. func isAllowedIP(ip string, ipNets []*net.IPNet) bool {
  50. clientIP := net.ParseIP(ip)
  51. for _, ipNet := range ipNets {
  52. if ipNet.Contains(clientIP) {
  53. return true
  54. }
  55. }
  56. return false
  57. }
  58. //func isAllowedIP(ip string, whitelist []string) bool {
  59. // for _, allowedIP := range whitelist {
  60. // if ip == allowedIP {
  61. // return true
  62. // }
  63. // }
  64. // return false
  65. //}
  66. func parseIPNets(allowedIPs []string) ([]*net.IPNet, error) {
  67. ipNets := make([]*net.IPNet, 0, len(allowedIPs))
  68. for _, ipStr := range allowedIPs {
  69. _, ipNet, err := net.ParseCIDR(ipStr)
  70. if err != nil {
  71. return nil, err
  72. }
  73. ipNets = append(ipNets, ipNet)
  74. }
  75. return ipNets, nil
  76. }