package middleware import ( "strconv" "strings" "gofaster/internal/shared/jwt" "gofaster/internal/shared/response" "github.com/gin-gonic/gin" ) // JWTConfig JWT中间件配置 type JWTConfig struct { SecretKey string Issuer string } // JWTAuth JWT认证中间件 func JWTAuth(config JWTConfig) gin.HandlerFunc { return func(c *gin.Context) { // 获取Authorization头部 authHeader := c.GetHeader("Authorization") if authHeader == "" { response.Unauthorized(c, "未提供认证令牌", "Authorization头部缺失") c.Abort() return } // 检查Bearer前缀 if !strings.HasPrefix(authHeader, "Bearer ") { response.Unauthorized(c, "认证令牌格式错误", "令牌必须以Bearer开头") c.Abort() return } // 提取令牌 tokenString := strings.TrimPrefix(authHeader, "Bearer ") // 创建JWT管理器 jwtManager := jwt.NewJWTManager(config.SecretKey, config.Issuer) // 验证令牌 claims, err := jwtManager.ValidateToken(tokenString) if err != nil { response.Unauthorized(c, "认证令牌无效", err.Error()) c.Abort() return } // 将用户信息存储到上下文中 if userID, exists := claims["user_id"]; exists { c.Set("user_id", userID) } if username, exists := claims["username"]; exists { c.Set("username", username) } if email, exists := claims["email"]; exists { c.Set("email", email) } // 继续处理请求 c.Next() } } // OptionalJWTAuth 可选的JWT认证中间件(不强制要求认证) func OptionalJWTAuth(config JWTConfig) gin.HandlerFunc { return func(c *gin.Context) { // 获取Authorization头部 authHeader := c.GetHeader("Authorization") if authHeader == "" { // 没有令牌,继续处理请求 c.Next() return } // 检查Bearer前缀 if !strings.HasPrefix(authHeader, "Bearer ") { // 令牌格式错误,继续处理请求(不强制要求) c.Next() return } // 提取令牌 tokenString := strings.TrimPrefix(authHeader, "Bearer ") // 创建JWT管理器 jwtManager := jwt.NewJWTManager(config.SecretKey, config.Issuer) // 验证令牌 claims, err := jwtManager.ValidateToken(tokenString) if err != nil { // 令牌无效,继续处理请求(不强制要求) c.Next() return } // 将用户信息存储到上下文中 if userID, exists := claims["user_id"]; exists { c.Set("user_id", userID) } if username, exists := claims["username"]; exists { c.Set("username", username) } if email, exists := claims["email"]; exists { c.Set("email", email) } // 继续处理请求 c.Next() } } // GetUserID 从上下文中获取用户ID func GetUserID(c *gin.Context) (uint, bool) { userID, exists := c.Get("user_id") if !exists { return 0, false } switch v := userID.(type) { case float64: return uint(v), true case int: return uint(v), true case string: if parsed, err := strconv.ParseUint(v, 10, 32); err == nil { return uint(parsed), true } } return 0, false } // GetUsername 从上下文中获取用户名 func GetUsername(c *gin.Context) (string, bool) { username, exists := c.Get("username") if !exists { return "", false } if str, ok := username.(string); ok { return str, true } return "", false } // GetEmail 从上下文中获取邮箱 func GetEmail(c *gin.Context) (string, bool) { email, exists := c.Get("email") if !exists { return "", false } if str, ok := email.(string); ok { return str, true } return "", false }