package middleware import ( "log" "net" "net/http" "strings" ) // WhiteListCheck 用于 IP 白名单验证的中间件函数 func WhiteListCheck(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc { ipNets, err := parseIPNets(allowedIPs) if err != nil { log.Fatal(err) } return func(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ip := getClientIP(r) if !isAllowedIP(ip, ipNets) { http.Error(w, "Forbidden", http.StatusForbidden) return } next.ServeHTTP(w, r) } } } //func WhiteListMiddleware(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc { // return func(next http.HandlerFunc) http.HandlerFunc { // return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // ip := getClientIP(r) // if !isAllowedIP(ip, allowedIPs) { // http.Error(w, "Forbidden", http.StatusForbidden) // return // } // next.ServeHTTP(w, r) // }) // } //} // 获取客户端IP地址 func getClientIP(r *http.Request) string { forwardedFor := r.Header.Get("X-Forwarded-For") if forwardedFor != "" { ips := strings.Split(forwardedFor, ",") if len(ips) > 0 { return ips[0] } } return r.RemoteAddr } // 判断IP是否在白名单中 func isAllowedIP(ip string, ipNets []*net.IPNet) bool { clientIP := net.ParseIP(ip) for _, ipNet := range ipNets { if ipNet.Contains(clientIP) { return true } } return false } //func isAllowedIP(ip string, whitelist []string) bool { // for _, allowedIP := range whitelist { // if ip == allowedIP { // return true // } // } // return false //} func parseIPNets(allowedIPs []string) ([]*net.IPNet, error) { ipNets := make([]*net.IPNet, 0, len(allowedIPs)) for _, ipStr := range allowedIPs { _, ipNet, err := net.ParseCIDR(ipStr) if err != nil { return nil, err } ipNets = append(ipNets, ipNet) } return ipNets, nil }