|
|
|
package controller
|
|
|
|
|
|
|
|
import (
|
|
|
|
"crypto/rand"
|
|
|
|
"encoding/base64"
|
|
|
|
"fmt"
|
|
|
|
"math/big"
|
|
|
|
"net/http"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"gofaster/internal/auth/model"
|
|
|
|
"gofaster/internal/auth/service"
|
|
|
|
"gofaster/internal/shared/response"
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
)
|
|
|
|
|
|
|
|
type AuthController struct {
|
|
|
|
authService *service.AuthService
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewAuthController(authService *service.AuthService) *AuthController {
|
|
|
|
fmt.Println("🔍 [验证码] AuthController 初始化 - 新版本代码已加载!")
|
|
|
|
return &AuthController{
|
|
|
|
authService: authService,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// Login 用户登录
|
|
|
|
// @Summary 用户登录
|
|
|
|
// @Description 用户登录接口,支持验证码验证和密码错误次数限制
|
|
|
|
// @Tags 认证
|
|
|
|
// @Accept json
|
|
|
|
// @Produce json
|
|
|
|
// @Param request body model.LoginRequest true "登录请求参数"
|
|
|
|
// @Success 200 {object} response.Response{data=model.LoginResponse} "登录成功"
|
|
|
|
// @Failure 400 {object} response.Response "请求参数错误"
|
|
|
|
// @Failure 401 {object} response.Response "认证失败"
|
|
|
|
// @Failure 423 {object} response.Response "账户被锁定"
|
|
|
|
// @Router /auth/login [post]
|
|
|
|
func (c *AuthController) Login(ctx *gin.Context) {
|
|
|
|
var req model.LoginRequest
|
|
|
|
if err := ctx.ShouldBindJSON(&req); err != nil {
|
|
|
|
response.Error(ctx, http.StatusBadRequest, "请求参数错误", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// 调用服务层处理登录
|
|
|
|
resp, err := c.authService.Login(ctx, req.Username, req.Password, req.Captcha, req.CaptchaID)
|
|
|
|
if err != nil {
|
|
|
|
// 根据错误类型返回不同的状态码
|
|
|
|
if isLockedError(err) {
|
|
|
|
response.Error(ctx, http.StatusLocked, "账户被锁定", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if isAuthError(err) {
|
|
|
|
response.Error(ctx, http.StatusUnauthorized, "认证失败", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
response.Error(ctx, http.StatusInternalServerError, "系统错误", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
response.Success(ctx, "登录成功", resp)
|
|
|
|
}
|
|
|
|
|
|
|
|
// Logout 用户登出
|
|
|
|
// @Summary 用户登出
|
|
|
|
// @Description 用户登出接口
|
|
|
|
// @Tags 认证
|
|
|
|
// @Accept json
|
|
|
|
// @Produce json
|
|
|
|
// @Param request body model.LogoutRequest true "登出请求参数"
|
|
|
|
// @Success 200 {object} response.Response "登出成功"
|
|
|
|
// @Router /auth/logout [post]
|
|
|
|
func (c *AuthController) Logout(ctx *gin.Context) {
|
|
|
|
var req model.LogoutRequest
|
|
|
|
if err := ctx.ShouldBindJSON(&req); err != nil {
|
|
|
|
response.Error(ctx, http.StatusBadRequest, "请求参数错误", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
err := c.authService.Logout(ctx, req.Token)
|
|
|
|
if err != nil {
|
|
|
|
response.Error(ctx, http.StatusInternalServerError, "登出失败", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
response.Success(ctx, "登出成功", nil)
|
|
|
|
}
|
|
|
|
|
|
|
|
// RefreshToken 刷新访问令牌
|
|
|
|
// @Summary 刷新访问令牌
|
|
|
|
// @Description 使用刷新令牌获取新的访问令牌
|
|
|
|
// @Tags 认证
|
|
|
|
// @Accept json
|
|
|
|
// @Produce json
|
|
|
|
// @Param request body model.RefreshTokenRequest true "刷新令牌请求参数"
|
|
|
|
// @Success 200 {object} response.Response{data=model.LoginResponse} "刷新成功"
|
|
|
|
// @Failure 400 {object} response.Response "请求参数错误"
|
|
|
|
// @Failure 401 {object} response.Response "刷新令牌无效"
|
|
|
|
// @Router /auth/refresh [post]
|
|
|
|
func (c *AuthController) RefreshToken(ctx *gin.Context) {
|
|
|
|
var req model.RefreshTokenRequest
|
|
|
|
if err := ctx.ShouldBindJSON(&req); err != nil {
|
|
|
|
response.Error(ctx, http.StatusBadRequest, "请求参数错误", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := c.authService.RefreshToken(ctx, req.RefreshToken)
|
|
|
|
if err != nil {
|
|
|
|
response.Error(ctx, http.StatusUnauthorized, "刷新令牌无效", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
response.Success(ctx, "令牌刷新成功", resp)
|
|
|
|
}
|
|
|
|
|
|
|
|
// GenerateCaptcha 生成验证码
|
|
|
|
// @Summary 生成验证码
|
|
|
|
// @Description 生成图形验证码,用于登录验证
|
|
|
|
// @Tags 认证
|
|
|
|
// @Produce json
|
|
|
|
// @Success 200 {object} response.Response{data=model.CaptchaResponse} "验证码生成成功"
|
|
|
|
// @Router /auth/captcha [get]
|
|
|
|
func (c *AuthController) GenerateCaptcha(ctx *gin.Context) {
|
|
|
|
fmt.Println("🔍 [验证码] 开始生成验证码...")
|
|
|
|
|
|
|
|
// 生成4位数字验证码
|
|
|
|
captchaText := generateRandomCaptcha(4)
|
|
|
|
fmt.Printf("🔍 [验证码] 生成的验证码文本: %s\n", captchaText)
|
|
|
|
|
|
|
|
captchaID := generateCaptchaID()
|
|
|
|
fmt.Printf("🔍 [验证码] 生成的验证码ID: %s\n", captchaID)
|
|
|
|
|
|
|
|
// 存储验证码到数据库(5分钟过期)
|
|
|
|
expiresAt := time.Now().Add(5 * time.Minute)
|
|
|
|
if err := c.authService.GetCaptchaRepo().Create(ctx, captchaID, captchaText, expiresAt); err != nil {
|
|
|
|
fmt.Printf("🔍 [验证码] 存储验证码失败: %v\n", err)
|
|
|
|
response.Error(ctx, http.StatusInternalServerError, "验证码生成失败", "存储验证码失败")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// 生成验证码图片(SVG格式)
|
|
|
|
captchaImage := generateCaptchaSVG(captchaText)
|
|
|
|
fmt.Printf("🔍 [验证码] 生成的图片长度: %d\n", len(captchaImage))
|
|
|
|
fmt.Printf("🔍 [验证码] 图片前缀: %s\n", captchaImage[:50])
|
|
|
|
|
|
|
|
// 设置过期时间(5分钟)
|
|
|
|
expiresIn := int64(300)
|
|
|
|
|
|
|
|
// 返回验证码响应
|
|
|
|
resp := &model.CaptchaResponse{
|
|
|
|
CaptchaID: captchaID,
|
|
|
|
CaptchaImage: captchaImage,
|
|
|
|
ExpiresIn: expiresIn,
|
|
|
|
}
|
|
|
|
|
|
|
|
fmt.Println("🔍 [验证码] 验证码生成完成,准备返回响应")
|
|
|
|
response.Success(ctx, "验证码生成成功", resp)
|
|
|
|
}
|
|
|
|
|
|
|
|
// GetUserInfo 获取用户信息
|
|
|
|
// @Summary 获取用户信息
|
|
|
|
// @Description 获取当前登录用户的详细信息
|
|
|
|
// @Tags 认证
|
|
|
|
// @Accept json
|
|
|
|
// @Produce json
|
|
|
|
// @Security BearerAuth
|
|
|
|
// @Success 200 {object} response.Response{data=model.UserInfo} "获取成功"
|
|
|
|
// @Failure 401 {object} response.Response "未授权"
|
|
|
|
// @Router /auth/userinfo [get]
|
|
|
|
func (c *AuthController) GetUserInfo(ctx *gin.Context) {
|
|
|
|
// 从JWT令牌中获取用户ID
|
|
|
|
userID, exists := ctx.Get("user_id")
|
|
|
|
if !exists {
|
|
|
|
response.Error(ctx, http.StatusUnauthorized, "未授权", "用户ID不存在")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// 转换用户ID类型
|
|
|
|
var uid uint
|
|
|
|
switch v := userID.(type) {
|
|
|
|
case float64:
|
|
|
|
uid = uint(v)
|
|
|
|
case int:
|
|
|
|
uid = uint(v)
|
|
|
|
case string:
|
|
|
|
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
|
|
|
|
uid = uint(parsed)
|
|
|
|
} else {
|
|
|
|
response.Error(ctx, http.StatusBadRequest, "用户ID格式错误", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
default:
|
|
|
|
response.Error(ctx, http.StatusBadRequest, "用户ID类型错误", "无法识别的用户ID类型")
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
// 获取用户信息
|
|
|
|
userInfo, err := c.authService.GetUserInfo(ctx, uid)
|
|
|
|
if err != nil {
|
|
|
|
response.Error(ctx, http.StatusInternalServerError, "获取用户信息失败", err.Error())
|
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
response.Success(ctx, "获取用户信息成功", userInfo)
|
|
|
|
}
|
|
|
|
|
|
|
|
// getClientIP 获取客户端IP地址
|
|
|
|
func getClientIP(r *http.Request) string {
|
|
|
|
// 尝试从各种头部获取真实IP
|
|
|
|
ip := r.Header.Get("X-Real-IP")
|
|
|
|
if ip != "" {
|
|
|
|
return ip
|
|
|
|
}
|
|
|
|
|
|
|
|
ip = r.Header.Get("X-Forwarded-For")
|
|
|
|
if ip != "" {
|
|
|
|
// X-Forwarded-For可能包含多个IP,取第一个
|
|
|
|
if idx := strings.Index(ip, ","); idx != -1 {
|
|
|
|
ip = ip[:idx]
|
|
|
|
}
|
|
|
|
return strings.TrimSpace(ip)
|
|
|
|
}
|
|
|
|
|
|
|
|
ip = r.Header.Get("X-Forwarded")
|
|
|
|
if ip != "" {
|
|
|
|
return ip
|
|
|
|
}
|
|
|
|
|
|
|
|
// 从RemoteAddr获取
|
|
|
|
if r.RemoteAddr != "" {
|
|
|
|
if idx := strings.Index(r.RemoteAddr, ":"); idx != -1 {
|
|
|
|
return r.RemoteAddr[:idx]
|
|
|
|
}
|
|
|
|
return r.RemoteAddr
|
|
|
|
}
|
|
|
|
|
|
|
|
return "127.0.0.1"
|
|
|
|
}
|
|
|
|
|
|
|
|
// isLockedError 检查是否为锁定错误
|
|
|
|
func isLockedError(err error) bool {
|
|
|
|
return strings.Contains(err.Error(), "锁定")
|
|
|
|
}
|
|
|
|
|
|
|
|
// isAuthError 检查是否为认证错误
|
|
|
|
func isAuthError(err error) bool {
|
|
|
|
return strings.Contains(err.Error(), "密码错误") ||
|
|
|
|
strings.Contains(err.Error(), "用户不存在") ||
|
|
|
|
strings.Contains(err.Error(), "验证码错误")
|
|
|
|
}
|
|
|
|
|
|
|
|
// generateRandomCaptcha 生成随机验证码
|
|
|
|
func generateRandomCaptcha(length int) string {
|
|
|
|
fmt.Printf("🔍 [验证码] 开始生成%d位随机验证码\n", length)
|
|
|
|
const charset = "0123456789"
|
|
|
|
result := make([]byte, length)
|
|
|
|
for i := range result {
|
|
|
|
randomIndex, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
|
|
|
result[i] = charset[randomIndex.Int64()]
|
|
|
|
}
|
|
|
|
generated := string(result)
|
|
|
|
fmt.Printf("🔍 [验证码] 生成的验证码: %s\n", generated)
|
|
|
|
return generated
|
|
|
|
}
|
|
|
|
|
|
|
|
// generateCaptchaID 生成验证码ID
|
|
|
|
func generateCaptchaID() string {
|
|
|
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
|
|
|
result := make([]byte, 16)
|
|
|
|
for i := range result {
|
|
|
|
randomIndex, _ := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
|
|
|
|
result[i] = charset[randomIndex.Int64()]
|
|
|
|
}
|
|
|
|
return string(result)
|
|
|
|
}
|
|
|
|
|
|
|
|
// generateCaptchaSVG 生成验证码SVG图片
|
|
|
|
func generateCaptchaSVG(text string) string {
|
|
|
|
fmt.Printf("🔍 [验证码] 开始生成SVG图片,文本: %s\n", text)
|
|
|
|
|
|
|
|
// 生成随机干扰元素
|
|
|
|
line1Y1, _ := rand.Int(rand.Reader, big.NewInt(20))
|
|
|
|
line1Y2, _ := rand.Int(rand.Reader, big.NewInt(20))
|
|
|
|
line2Y1, _ := rand.Int(rand.Reader, big.NewInt(20))
|
|
|
|
line2Y2, _ := rand.Int(rand.Reader, big.NewInt(20))
|
|
|
|
circle1X, _ := rand.Int(rand.Reader, big.NewInt(120))
|
|
|
|
circle1Y, _ := rand.Int(rand.Reader, big.NewInt(40))
|
|
|
|
circle2X, _ := rand.Int(rand.Reader, big.NewInt(120))
|
|
|
|
circle2Y, _ := rand.Int(rand.Reader, big.NewInt(40))
|
|
|
|
circle3X, _ := rand.Int(rand.Reader, big.NewInt(120))
|
|
|
|
circle3Y, _ := rand.Int(rand.Reader, big.NewInt(40))
|
|
|
|
|
|
|
|
// 生成SVG验证码图片
|
|
|
|
svg := fmt.Sprintf(`<svg width="120" height="40" xmlns="http://www.w3.org/2000/svg">
|
|
|
|
<rect width="120" height="40" fill="#f0f0f0"/>
|
|
|
|
<text x="60" y="25" font-family="Arial, sans-serif" font-size="18" font-weight="bold" text-anchor="middle" fill="#333">%s</text>
|
|
|
|
<line x1="0" y1="%d" x2="120" y2="%d" stroke="#ccc" stroke-width="1"/>
|
|
|
|
<line x1="0" y1="%d" x2="120" y2="%d" stroke="#ccc" stroke-width="1"/>
|
|
|
|
<circle cx="%d" cy="%d" r="1" fill="#999"/>
|
|
|
|
<circle cx="%d" cy="%d" r="1" fill="#999"/>
|
|
|
|
<circle cx="%d" cy="%d" r="1" fill="#999"/>
|
|
|
|
</svg>`,
|
|
|
|
text,
|
|
|
|
line1Y1.Int64()+10, line1Y2.Int64()+10,
|
|
|
|
line2Y1.Int64()+20, line2Y2.Int64()+20,
|
|
|
|
circle1X.Int64(), circle1Y.Int64(),
|
|
|
|
circle2X.Int64(), circle2Y.Int64(),
|
|
|
|
circle3X.Int64(), circle3Y.Int64())
|
|
|
|
|
|
|
|
encoded := "data:image/svg+xml;base64," + base64.StdEncoding.EncodeToString([]byte(svg))
|
|
|
|
fmt.Printf("🔍 [验证码] SVG图片生成完成,长度: %d\n", len(encoded))
|
|
|
|
return encoded
|
|
|
|
}
|