|
@@ -0,0 +1,87 @@
|
|
|
|
+package middleware
|
|
|
|
+
|
|
|
|
+import (
|
|
|
|
+ "fmt"
|
|
|
|
+ "log"
|
|
|
|
+ "net"
|
|
|
|
+ "net/http"
|
|
|
|
+ "strings"
|
|
|
|
+)
|
|
|
|
+
|
|
|
|
+// WhiteListCheck 用于 IP 白名单验证的中间件函数
|
|
|
|
+func WhiteListCheck(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc {
|
|
|
|
+ ipNets, err := parseIPNets(allowedIPs)
|
|
|
|
+ if err != nil {
|
|
|
|
+ log.Fatal(err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return func(next http.HandlerFunc) http.HandlerFunc {
|
|
|
|
+ return func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
+ ip := getClientIP(r)
|
|
|
|
+ if !isAllowedIP(ip, ipNets) {
|
|
|
|
+ http.Error(w, "Forbidden", http.StatusForbidden)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+ next.ServeHTTP(w, r)
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+//func WhiteListMiddleware(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc {
|
|
|
|
+// return func(next http.HandlerFunc) http.HandlerFunc {
|
|
|
|
+// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
|
|
+// ip := getClientIP(r)
|
|
|
|
+// if !isAllowedIP(ip, allowedIPs) {
|
|
|
|
+// http.Error(w, "Forbidden", http.StatusForbidden)
|
|
|
|
+// return
|
|
|
|
+// }
|
|
|
|
+// next.ServeHTTP(w, r)
|
|
|
|
+// })
|
|
|
|
+// }
|
|
|
|
+//}
|
|
|
|
+
|
|
|
|
+// 获取客户端IP地址
|
|
|
|
+func getClientIP(r *http.Request) string {
|
|
|
|
+ forwardedFor := r.Header.Get("X-Forwarded-For")
|
|
|
|
+ if forwardedFor != "" {
|
|
|
|
+ ips := strings.Split(forwardedFor, ",")
|
|
|
|
+ if len(ips) > 0 {
|
|
|
|
+ return ips[0]
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return r.RemoteAddr
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+// 判断IP是否在白名单中
|
|
|
|
+func isAllowedIP(ip string, ipNets []*net.IPNet) bool {
|
|
|
|
+ clientIP := net.ParseIP(ip)
|
|
|
|
+ fmt.Println("clientIP:", clientIP)
|
|
|
|
+ for _, ipNet := range ipNets {
|
|
|
|
+ if ipNet.Contains(clientIP) {
|
|
|
|
+ return true
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ return false
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+//func isAllowedIP(ip string, whitelist []string) bool {
|
|
|
|
+// for _, allowedIP := range whitelist {
|
|
|
|
+// if ip == allowedIP {
|
|
|
|
+// return true
|
|
|
|
+// }
|
|
|
|
+// }
|
|
|
|
+// return false
|
|
|
|
+//}
|
|
|
|
+
|
|
|
|
+func parseIPNets(allowedIPs []string) ([]*net.IPNet, error) {
|
|
|
|
+ ipNets := make([]*net.IPNet, 0, len(allowedIPs))
|
|
|
|
+ for _, ipStr := range allowedIPs {
|
|
|
|
+ _, ipNet, err := net.ParseCIDR(ipStr)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return nil, err
|
|
|
|
+ }
|
|
|
|
+ ipNets = append(ipNets, ipNet)
|
|
|
|
+ }
|
|
|
|
+ fmt.Println("ipNets:", ipNets)
|
|
|
|
+ return ipNets, nil
|
|
|
|
+}
|