Browse Source

ip白名单

songxiaohang 1 year ago
parent
commit
9e41755f96

+ 4 - 0
app/cmd/dtgateway/dtgateway.go

@@ -1,6 +1,7 @@
 package main
 package main
 
 
 import (
 import (
+	"GtDataStore/common/middleware"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 
 
@@ -23,6 +24,9 @@ func main() {
 	server := rest.MustNewServer(c.RestConf)
 	server := rest.MustNewServer(c.RestConf)
 	defer server.Stop()
 	defer server.Stop()
 
 
+	ipWhiteList := []string{"127.0.0.1/32", "47.96.12.136/32", "120.55.44.4/32", "58.214.245.78/32"}
+	server.Use(middleware.WhiteListCheck(ipWhiteList))
+
 	ctx := svc.NewServiceContext(c)
 	ctx := svc.NewServiceContext(c)
 	handler.RegisterHandlers(server, ctx)
 	handler.RegisterHandlers(server, ctx)
 
 

+ 8 - 8
app/cmd/dtgateway/etc/dtgateway.yaml

@@ -18,13 +18,13 @@ Log:
   Level: error
   Level: error
 
 
 #rpc service
 #rpc service
-#OrganizationRpcConf:
-#  Endpoints:
-#    - 127.0.0.1:1117
-#  NonBlock: true
-#  Timeout: 0
-
 OrganizationRpcConf:
 OrganizationRpcConf:
-  Timeout: 50000
-  Target: k8s://gt-datacenter/organization-rpc-svc:1117 #goctl kube 默认生成的k8s yaml的serviceName: {rpc中定义的name}-svc
+  Endpoints:
+    - 127.0.0.1:1117
+  NonBlock: true
+  Timeout: 0
+
+#OrganizationRpcConf:
+#  Timeout: 50000
+#  Target: k8s://gt-datacenter/organization-rpc-svc:1117 #goctl kube 默认生成的k8s yaml的serviceName: {rpc中定义的name}-svc
 
 

BIN
common/.DS_Store


+ 87 - 0
common/middleware/whiteList.go

@@ -0,0 +1,87 @@
+package middleware
+
+import (
+	"fmt"
+	"log"
+	"net"
+	"net/http"
+	"strings"
+)
+
+// WhiteListCheck 用于 IP 白名单验证的中间件函数
+func WhiteListCheck(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc {
+	ipNets, err := parseIPNets(allowedIPs)
+	if err != nil {
+		log.Fatal(err)
+	}
+
+	return func(next http.HandlerFunc) http.HandlerFunc {
+		return func(w http.ResponseWriter, r *http.Request) {
+			ip := getClientIP(r)
+			if !isAllowedIP(ip, ipNets) {
+				http.Error(w, "Forbidden", http.StatusForbidden)
+				return
+			}
+			next.ServeHTTP(w, r)
+		}
+	}
+}
+
+//func WhiteListMiddleware(allowedIPs []string) func(http.HandlerFunc) http.HandlerFunc {
+//	return func(next http.HandlerFunc) http.HandlerFunc {
+//		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+//			ip := getClientIP(r)
+//			if !isAllowedIP(ip, allowedIPs) {
+//				http.Error(w, "Forbidden", http.StatusForbidden)
+//				return
+//			}
+//			next.ServeHTTP(w, r)
+//		})
+//	}
+//}
+
+// 获取客户端IP地址
+func getClientIP(r *http.Request) string {
+	forwardedFor := r.Header.Get("X-Forwarded-For")
+	if forwardedFor != "" {
+		ips := strings.Split(forwardedFor, ",")
+		if len(ips) > 0 {
+			return ips[0]
+		}
+	}
+	return r.RemoteAddr
+}
+
+// 判断IP是否在白名单中
+func isAllowedIP(ip string, ipNets []*net.IPNet) bool {
+	clientIP := net.ParseIP(ip)
+	fmt.Println("clientIP:", clientIP)
+	for _, ipNet := range ipNets {
+		if ipNet.Contains(clientIP) {
+			return true
+		}
+	}
+	return false
+}
+
+//func isAllowedIP(ip string, whitelist []string) bool {
+//	for _, allowedIP := range whitelist {
+//		if ip == allowedIP {
+//			return true
+//		}
+//	}
+//	return false
+//}
+
+func parseIPNets(allowedIPs []string) ([]*net.IPNet, error) {
+	ipNets := make([]*net.IPNet, 0, len(allowedIPs))
+	for _, ipStr := range allowedIPs {
+		_, ipNet, err := net.ParseCIDR(ipStr)
+		if err != nil {
+			return nil, err
+		}
+		ipNets = append(ipNets, ipNet)
+	}
+	fmt.Println("ipNets:", ipNets)
+	return ipNets, nil
+}