|
|
|
|
package middleware
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"net/http"
|
|
|
|
|
"strconv"
|
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// JWTAuth JWT认证中间件
|
|
|
|
|
func JWTAuth() gin.HandlerFunc {
|
|
|
|
|
return func(c *gin.Context) {
|
|
|
|
|
// 从请求头获取token
|
|
|
|
|
authHeader := c.GetHeader("Authorization")
|
|
|
|
|
if authHeader == "" {
|
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "未提供认证令牌"})
|
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 检查Bearer前缀
|
|
|
|
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌格式错误"})
|
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 提取token
|
|
|
|
|
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
|
|
|
|
|
|
|
|
|
// 解析和验证token
|
|
|
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
|
|
|
// 验证签名方法
|
|
|
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
|
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
|
|
|
}
|
|
|
|
|
// 返回密钥 - 使用配置中的密钥
|
|
|
|
|
return []byte("your-secret-key"), nil
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "无效的认证令牌"})
|
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 验证token是否有效
|
|
|
|
|
if !token.Valid {
|
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "认证令牌已失效"})
|
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取claims
|
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
|
|
|
if !ok {
|
|
|
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "无法解析认证令牌"})
|
|
|
|
|
c.Abort()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 将用户信息存储到上下文中
|
|
|
|
|
c.Set("user_id", claims["user_id"])
|
|
|
|
|
c.Set("username", claims["username"])
|
|
|
|
|
|
|
|
|
|
c.Next()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// OptionalJWTAuth 可选的JWT认证中间件(不强制要求认证)
|
|
|
|
|
func OptionalJWTAuth() gin.HandlerFunc {
|
|
|
|
|
return func(c *gin.Context) {
|
|
|
|
|
// 获取Authorization头部
|
|
|
|
|
authHeader := c.GetHeader("Authorization")
|
|
|
|
|
if authHeader == "" {
|
|
|
|
|
// 没有令牌,继续处理请求
|
|
|
|
|
c.Next()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 检查Bearer前缀
|
|
|
|
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
|
|
|
|
// 令牌格式错误,继续处理请求(不强制要求)
|
|
|
|
|
c.Next()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 提取令牌
|
|
|
|
|
_ = strings.TrimPrefix(authHeader, "Bearer ")
|
|
|
|
|
|
|
|
|
|
// 创建JWT管理器
|
|
|
|
|
// jwtManager := jwt.NewJWTManager(config.SecretKey, config.Issuer) // This line was removed as per the new_code
|
|
|
|
|
|
|
|
|
|
// 验证令牌
|
|
|
|
|
// claims, err := jwtManager.ValidateToken(tokenString) // This line was removed as per the new_code
|
|
|
|
|
// if err != nil { // This line was removed as per the new_code
|
|
|
|
|
// // 令牌无效,继续处理请求(不强制要求) // This line was removed as per the new_code
|
|
|
|
|
// c.Next() // This line was removed as per the new_code
|
|
|
|
|
// return // This line was removed as per the new_code
|
|
|
|
|
// } // This line was removed as per the new_code
|
|
|
|
|
|
|
|
|
|
// 将用户信息存储到上下文中
|
|
|
|
|
// if userID, exists := claims["user_id"]; exists { // This block was removed as per the new_code
|
|
|
|
|
// c.Set("user_id", userID) // This line was removed as per the new_code
|
|
|
|
|
// } // This block was removed as per the new_code
|
|
|
|
|
// if username, exists := claims["username"]; exists { // This block was removed as per the new_code
|
|
|
|
|
// c.Set("username", username) // This line was removed as per the new_code
|
|
|
|
|
// } // This block was removed as per the new_code
|
|
|
|
|
// if email, exists := claims["email"]; exists { // This block was removed as per the new_code
|
|
|
|
|
// c.Set("email", email) // This line was removed as per the new_code
|
|
|
|
|
// } // This block was removed as per the new_code
|
|
|
|
|
|
|
|
|
|
// 继续处理请求
|
|
|
|
|
c.Next()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetUserID 安全地从上下文中获取用户ID
|
|
|
|
|
func GetUserID(c *gin.Context) uint {
|
|
|
|
|
userID, exists := c.Get("user_id")
|
|
|
|
|
if !exists {
|
|
|
|
|
fmt.Printf("🔍 GetUserID函数开始执行\n")
|
|
|
|
|
fmt.Printf("❌ GetUserID - user_id不存在于上下文中\n")
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 根据不同的类型进行安全转换
|
|
|
|
|
switch v := userID.(type) {
|
|
|
|
|
case uint:
|
|
|
|
|
return v
|
|
|
|
|
case int:
|
|
|
|
|
if v < 0 {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
return uint(v)
|
|
|
|
|
case int64:
|
|
|
|
|
if v < 0 {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
return uint(v)
|
|
|
|
|
case float64:
|
|
|
|
|
if v < 0 || v > float64(^uint(0)) {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
return uint(v)
|
|
|
|
|
case string:
|
|
|
|
|
if parsed, err := strconv.ParseUint(v, 10, 64); err == nil {
|
|
|
|
|
return uint(parsed)
|
|
|
|
|
}
|
|
|
|
|
return 0
|
|
|
|
|
default:
|
|
|
|
|
fmt.Printf("⚠️ GetUserID - 未知的用户ID类型: %T, 值: %v\n", v, v)
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetUsername 从上下文中获取用户名
|
|
|
|
|
func GetUsername(c *gin.Context) (string, bool) {
|
|
|
|
|
username, exists := c.Get("username")
|
|
|
|
|
if !exists {
|
|
|
|
|
return "", false
|
|
|
|
|
}
|
|
|
|
|
if str, ok := username.(string); ok {
|
|
|
|
|
return str, true
|
|
|
|
|
}
|
|
|
|
|
return "", false
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// GetEmail 从上下文中获取邮箱
|
|
|
|
|
func GetEmail(c *gin.Context) (string, bool) {
|
|
|
|
|
email, exists := c.Get("email")
|
|
|
|
|
if !exists {
|
|
|
|
|
return "", false
|
|
|
|
|
}
|
|
|
|
|
if str, ok := email.(string); ok {
|
|
|
|
|
return str, true
|
|
|
|
|
}
|
|
|
|
|
return "", false
|
|
|
|
|
}
|