|
|
package middleware |
|
|
|
|
|
import ( |
|
|
"fmt" |
|
|
"net/http" |
|
|
"strconv" |
|
|
"strings" |
|
|
|
|
|
"github.com/gin-gonic/gin" |
|
|
"github.com/golang-jwt/jwt/v5" |
|
|
) |
|
|
|
|
|
// JWTAuth JWT认证中间件 |
|
|
func JWTAuth() gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
// 从请求头获取token |
|
|
authHeader := c.GetHeader("Authorization") |
|
|
if authHeader == "" { |
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供认证令牌"}) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
|
|
|
// 检查Bearer前缀 |
|
|
if !strings.HasPrefix(authHeader, "Bearer ") { |
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌格式错误"}) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
|
|
|
// 提取token |
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ") |
|
|
|
|
|
// 解析和验证token |
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { |
|
|
// 验证签名方法 |
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { |
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) |
|
|
} |
|
|
// 返回密钥 - 使用配置中的密钥 |
|
|
return []byte("your-secret-key"), nil |
|
|
}) |
|
|
|
|
|
if err != nil { |
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的认证令牌"}) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
|
|
|
// 验证token是否有效 |
|
|
if !token.Valid { |
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌已失效"}) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
|
|
|
// 获取claims |
|
|
claims, ok := token.Claims.(jwt.MapClaims) |
|
|
if !ok { |
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "无法解析认证令牌"}) |
|
|
c.Abort() |
|
|
return |
|
|
} |
|
|
|
|
|
// 将用户信息存储到上下文中 |
|
|
c.Set("user_id", claims["user_id"]) |
|
|
c.Set("username", claims["username"]) |
|
|
|
|
|
c.Next() |
|
|
} |
|
|
} |
|
|
|
|
|
// OptionalJWTAuth 可选的JWT认证中间件(不强制要求认证) |
|
|
func OptionalJWTAuth() gin.HandlerFunc { |
|
|
return func(c *gin.Context) { |
|
|
// 获取Authorization头部 |
|
|
authHeader := c.GetHeader("Authorization") |
|
|
if authHeader == "" { |
|
|
// 没有令牌,继续处理请求 |
|
|
c.Next() |
|
|
return |
|
|
} |
|
|
|
|
|
// 检查Bearer前缀 |
|
|
if !strings.HasPrefix(authHeader, "Bearer ") { |
|
|
// 令牌格式错误,继续处理请求(不强制要求) |
|
|
c.Next() |
|
|
return |
|
|
} |
|
|
|
|
|
// 提取令牌 |
|
|
_ = strings.TrimPrefix(authHeader, "Bearer ") |
|
|
|
|
|
// 创建JWT管理器 |
|
|
// jwtManager := jwt.NewJWTManager(config.SecretKey, config.Issuer) // This line was removed as per the new_code |
|
|
|
|
|
// 验证令牌 |
|
|
// claims, err := jwtManager.ValidateToken(tokenString) // This line was removed as per the new_code |
|
|
// if err != nil { // This line was removed as per the new_code |
|
|
// // 令牌无效,继续处理请求(不强制要求) // This line was removed as per the new_code |
|
|
// c.Next() // This line was removed as per the new_code |
|
|
// return // This line was removed as per the new_code |
|
|
// } // This line was removed as per the new_code |
|
|
|
|
|
// 将用户信息存储到上下文中 |
|
|
// if userID, exists := claims["user_id"]; exists { // This block was removed as per the new_code |
|
|
// c.Set("user_id", userID) // This line was removed as per the new_code |
|
|
// } // This block was removed as per the new_code |
|
|
// if username, exists := claims["username"]; exists { // This block was removed as per the new_code |
|
|
// c.Set("username", username) // This line was removed as per the new_code |
|
|
// } // This block was removed as per the new_code |
|
|
// if email, exists := claims["email"]; exists { // This block was removed as per the new_code |
|
|
// c.Set("email", email) // This line was removed as per the new_code |
|
|
// } // This block was removed as per the new_code |
|
|
|
|
|
// 继续处理请求 |
|
|
c.Next() |
|
|
} |
|
|
} |
|
|
|
|
|
// GetUserID 安全地从上下文中获取用户ID |
|
|
func GetUserID(c *gin.Context) uint { |
|
|
userID, exists := c.Get("user_id") |
|
|
if !exists { |
|
|
fmt.Printf("🔍 GetUserID函数开始执行\n") |
|
|
fmt.Printf("❌ GetUserID - user_id不存在于上下文中\n") |
|
|
return 0 |
|
|
} |
|
|
|
|
|
// 根据不同的类型进行安全转换 |
|
|
switch v := userID.(type) { |
|
|
case uint: |
|
|
return v |
|
|
case int: |
|
|
if v < 0 { |
|
|
return 0 |
|
|
} |
|
|
return uint(v) |
|
|
case int64: |
|
|
if v < 0 { |
|
|
return 0 |
|
|
} |
|
|
return uint(v) |
|
|
case float64: |
|
|
if v < 0 || v > float64(^uint(0)) { |
|
|
return 0 |
|
|
} |
|
|
return uint(v) |
|
|
case string: |
|
|
if parsed, err := strconv.ParseUint(v, 10, 64); err == nil { |
|
|
return uint(parsed) |
|
|
} |
|
|
return 0 |
|
|
default: |
|
|
fmt.Printf("⚠️ GetUserID - 未知的用户ID类型: %T, 值: %v\n", v, v) |
|
|
return 0 |
|
|
} |
|
|
} |
|
|
|
|
|
// GetUsername 从上下文中获取用户名 |
|
|
func GetUsername(c *gin.Context) (string, bool) { |
|
|
username, exists := c.Get("username") |
|
|
if !exists { |
|
|
return "", false |
|
|
} |
|
|
if str, ok := username.(string); ok { |
|
|
return str, true |
|
|
} |
|
|
return "", false |
|
|
} |
|
|
|
|
|
// GetEmail 从上下文中获取邮箱 |
|
|
func GetEmail(c *gin.Context) (string, bool) { |
|
|
email, exists := c.Get("email") |
|
|
if !exists { |
|
|
return "", false |
|
|
} |
|
|
if str, ok := email.(string); ok { |
|
|
return str, true |
|
|
} |
|
|
return "", false |
|
|
}
|
|
|
|