You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

326 lines
9.1 KiB

package service
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
mathrand "math/rand"
"net/http"
"strings"
"time"
"gofaster/internal/auth/model"
"gofaster/internal/auth/repository"
"gofaster/internal/shared/jwt"
"golang.org/x/crypto/bcrypt"
)
type AuthService interface {
Login(ctx context.Context, req *model.LoginRequest, clientIP string) (*model.LoginResponse, error)
Logout(ctx context.Context, token string) error
RefreshToken(ctx context.Context, refreshToken string) (*model.LoginResponse, error)
GenerateCaptcha(ctx context.Context) (*model.CaptchaResponse, error)
ValidateCaptcha(ctx context.Context, captchaID, captchaText string) error
GetUserInfo(ctx context.Context, userID uint) (*model.UserInfo, error)
}
type authService struct {
userRepo repository.UserRepository
captchaRepo repository.CaptchaRepository
jwtManager jwt.JWTManager
}
func NewAuthService(userRepo repository.UserRepository, captchaRepo repository.CaptchaRepository, jwtManager jwt.JWTManager) AuthService {
return &authService{
userRepo: userRepo,
captchaRepo: captchaRepo,
jwtManager: jwtManager,
}
}
// Login 用户登录
func (s *authService) Login(ctx context.Context, req *model.LoginRequest, clientIP string) (*model.LoginResponse, error) {
// 1. 验证验证码
if err := s.ValidateCaptcha(ctx, req.CaptchaID, req.Captcha); err != nil {
return nil, fmt.Errorf("验证码错误: %w", err)
}
// 2. 根据用户名查找用户
user, err := s.userRepo.GetByUsername(ctx, req.Username)
if err != nil {
return nil, fmt.Errorf("用户不存在")
}
// 3. 检查用户状态
if !user.CanLogin() {
if user.IsLocked() {
return nil, fmt.Errorf("账户已被锁定,请30分钟后再试")
}
return nil, fmt.Errorf("账户已被禁用")
}
// 4. 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
// 密码错误,增加错误次数
if err := s.userRepo.IncrementPasswordError(ctx, user.ID); err != nil {
return nil, fmt.Errorf("系统错误,请稍后重试")
}
// 检查是否被锁定
if user.PasswordErrorCount >= 4 { // 已经是第5次错误
return nil, fmt.Errorf("密码错误次数过多,账户已被锁定30分钟")
}
remaining := 5 - user.PasswordErrorCount - 1
return nil, fmt.Errorf("密码错误,还可尝试%d次", remaining)
}
// 5. 密码正确,重置错误次数并更新登录信息
if err := s.userRepo.ResetPasswordError(ctx, user.ID); err != nil {
return nil, fmt.Errorf("系统错误,请稍后重试")
}
if err := s.userRepo.UpdateLastLogin(ctx, user.ID, clientIP); err != nil {
// 登录信息更新失败不影响登录流程
fmt.Printf("更新登录信息失败: %v\n", err)
}
// 6. 生成JWT令牌
claims := map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
}
token, err := s.jwtManager.GenerateToken(claims, 24*time.Hour) // 24小时有效期
if err != nil {
return nil, fmt.Errorf("生成令牌失败: %w", err)
}
refreshToken, err := s.jwtManager.GenerateToken(claims, 7*24*time.Hour) // 7天有效期
if err != nil {
return nil, fmt.Errorf("生成刷新令牌失败: %w", err)
}
// 7. 获取用户角色信息
userWithRoles, err := s.userRepo.GetUserWithRoles(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("获取用户信息失败: %w", err)
}
// 8. 构建响应
userInfo := s.buildUserInfo(userWithRoles)
return &model.LoginResponse{
Token: token,
TokenType: "Bearer",
ExpiresIn: 24 * 60 * 60, // 24小时,单位秒
RefreshToken: refreshToken,
User: *userInfo,
}, nil
}
// Logout 用户登出
func (s *authService) Logout(ctx context.Context, token string) error {
// 这里可以实现令牌黑名单机制
// 目前简单返回成功
return nil
}
// RefreshToken 刷新令牌
func (s *authService) RefreshToken(ctx context.Context, refreshToken string) (*model.LoginResponse, error) {
// 验证刷新令牌
claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil {
return nil, fmt.Errorf("刷新令牌无效: %w", err)
}
// 获取用户信息
userID, ok := claims["user_id"].(float64)
if !ok {
return nil, fmt.Errorf("令牌格式错误")
}
user, err := s.userRepo.GetUserWithRoles(ctx, uint(userID))
if err != nil {
return nil, fmt.Errorf("用户不存在: %w", err)
}
// 检查用户状态
if !user.CanLogin() {
return nil, fmt.Errorf("用户状态异常")
}
// 生成新的访问令牌
newClaims := map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
}
newToken, err := s.jwtManager.GenerateToken(newClaims, 24*time.Hour)
if err != nil {
return nil, fmt.Errorf("生成新令牌失败: %w", err)
}
// 构建响应
userInfo := s.buildUserInfo(user)
return &model.LoginResponse{
Token: newToken,
TokenType: "Bearer",
ExpiresIn: 24 * 60 * 60,
RefreshToken: refreshToken, // 保持原刷新令牌
User: *userInfo,
}, nil
}
// GenerateCaptcha 生成验证码
func (s *authService) GenerateCaptcha(ctx context.Context) (*model.CaptchaResponse, error) {
// 生成随机验证码ID
captchaID, err := s.generateRandomID()
if err != nil {
return nil, fmt.Errorf("生成验证码ID失败: %w", err)
}
// 生成4位随机验证码
captchaText := s.generateRandomCaptcha(4)
// 设置5分钟过期时间
expiresAt := time.Now().Add(5 * time.Minute)
// 保存到数据库
if err := s.captchaRepo.Create(ctx, captchaID, captchaText, expiresAt); err != nil {
return nil, fmt.Errorf("保存验证码失败: %w", err)
}
// 生成验证码图片(这里简化处理,实际应该生成图片)
captchaImage := s.generateCaptchaImage(captchaText)
return &model.CaptchaResponse{
CaptchaID: captchaID,
CaptchaImage: captchaImage,
ExpiresIn: 5 * 60, // 5分钟,单位秒
}, nil
}
// ValidateCaptcha 验证验证码
func (s *authService) ValidateCaptcha(ctx context.Context, captchaID, captchaText string) error {
// 从数据库获取验证码
storedText, err := s.captchaRepo.Get(ctx, captchaID)
if err != nil {
return fmt.Errorf("验证码已过期或不存在")
}
// 比较验证码(不区分大小写)
if !strings.EqualFold(storedText, captchaText) {
return fmt.Errorf("验证码错误")
}
return nil
}
// GetUserInfo 获取用户信息
func (s *authService) GetUserInfo(ctx context.Context, userID uint) (*model.UserInfo, error) {
user, err := s.userRepo.GetUserWithRoles(ctx, userID)
if err != nil {
return nil, fmt.Errorf("用户不存在: %w", err)
}
return s.buildUserInfo(user), nil
}
// buildUserInfo 构建用户信息
func (s *authService) buildUserInfo(user *model.User) *model.UserInfo {
roles := make([]model.RoleInfo, 0, len(user.Roles))
for _, role := range user.Roles {
roles = append(roles, model.RoleInfo{
ID: role.ID,
Name: role.Name,
Code: role.Code,
})
}
return &model.UserInfo{
ID: user.ID,
Username: user.Username,
Email: user.Email,
Phone: user.Phone,
Status: user.Status,
LastLoginAt: user.LastLoginAt,
LastLoginIP: user.LastLoginIP,
Roles: roles,
}
}
// generateRandomID 生成随机ID
func (s *authService) generateRandomID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// generateRandomCaptcha 生成随机验证码
func (s *authService) generateRandomCaptcha(length int) string {
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = chars[mathrand.Intn(len(chars))]
}
return string(result)
}
// generateCaptchaImage 生成验证码图片(简化版本,返回Base64编码)
func (s *authService) generateCaptchaImage(text string) string {
// 创建一个简单的文本图片作为验证码
// 使用HTML5 Canvas风格的文本渲染
width := 120
height := 40
// 创建一个简单的SVG图片,包含验证码文本
svg := fmt.Sprintf(`<svg width="%d" height="%d" xmlns="http://www.w3.org/2000/svg">
<rect width="%d" height="%d" fill="#f0f0f0"/>
<text x="%d" y="%d" font-family="Arial, sans-serif" font-size="24" font-weight="bold"
fill="#333" text-anchor="middle" dominant-baseline="middle">%s</text>
</svg>`, width, height, width, height, width/2, height/2, text)
// 将SVG转换为Base64
return fmt.Sprintf("data:image/svg+xml;base64,%s", base64.StdEncoding.EncodeToString([]byte(svg)))
}
// GetClientIP 获取客户端IP地址
func GetClientIP(r *http.Request) string {
// 尝试从各种头部获取真实IP
ip := r.Header.Get("X-Real-IP")
if ip != "" {
return ip
}
ip = r.Header.Get("X-Forwarded-For")
if ip != "" {
// X-Forwarded-For可能包含多个IP,取第一个
if idx := strings.Index(ip, ","); idx != -1 {
ip = ip[:idx]
}
return strings.TrimSpace(ip)
}
ip = r.Header.Get("X-Forwarded")
if ip != "" {
return ip
}
// 从RemoteAddr获取
if r.RemoteAddr != "" {
if idx := strings.Index(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
return "127.0.0.1"
}