You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

380 lines
11 KiB

package service
import (
"context"
"encoding/base64"
"fmt"
mathrand "math/rand"
"net/http"
"strings"
"time"
"gofaster/internal/auth/model"
"gofaster/internal/auth/repository"
"gofaster/internal/shared/jwt"
"golang.org/x/crypto/bcrypt"
)
type AuthService interface {
Login(ctx context.Context, req *model.LoginRequest, clientIP string) (*model.LoginResponse, error)
Logout(ctx context.Context, token string) error
RefreshToken(ctx context.Context, refreshToken string) (*model.LoginResponse, error)
GenerateCaptcha(ctx context.Context) (*model.CaptchaResponse, error)
ValidateCaptcha(ctx context.Context, captchaID, captchaText string) error
GetUserInfo(ctx context.Context, userID uint) (*model.UserInfo, error)
}
type authService struct {
userRepo repository.UserRepository
captchaRepo repository.CaptchaRepository
jwtManager jwt.JWTManager
}
func NewAuthService(userRepo repository.UserRepository, captchaRepo repository.CaptchaRepository, jwtManager jwt.JWTManager) AuthService {
return &authService{
userRepo: userRepo,
captchaRepo: captchaRepo,
jwtManager: jwtManager,
}
}
// Login 用户登录
func (s *authService) Login(ctx context.Context, req *model.LoginRequest, clientIP string) (*model.LoginResponse, error) {
// 1. 验证验证码
if err := s.ValidateCaptcha(ctx, req.CaptchaID, req.Captcha); err != nil {
return nil, fmt.Errorf("验证码错误: %w", err)
}
// 2. 根据用户名查找用户
user, err := s.userRepo.GetByUsername(ctx, req.Username)
if err != nil {
return nil, fmt.Errorf("用户不存在")
}
// 3. 检查用户状态
if !user.CanLogin() {
if user.IsLocked() {
return nil, fmt.Errorf("账户已被锁定,请30分钟后再试")
}
return nil, fmt.Errorf("账户已被禁用")
}
// 4. 验证密码
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
// 密码错误,增加错误次数
if err := s.userRepo.IncrementPasswordError(ctx, user.ID); err != nil {
return nil, fmt.Errorf("系统错误,请稍后重试")
}
// 检查是否被锁定
if user.PasswordErrorCount >= 4 { // 已经是第5次错误
return nil, fmt.Errorf("密码错误次数过多,账户已被锁定30分钟")
}
remaining := 5 - user.PasswordErrorCount - 1
return nil, fmt.Errorf("密码错误,还可尝试%d次", remaining)
}
// 5. 密码正确,重置错误次数并更新登录信息
if err := s.userRepo.ResetPasswordError(ctx, user.ID); err != nil {
return nil, fmt.Errorf("系统错误,请稍后重试")
}
// 更新最后登录时间和IP
if err := s.userRepo.UpdateLastLogin(ctx, user.ID, req.ClientIP); err != nil {
// 登录信息更新失败不影响登录流程
fmt.Printf("更新登录信息失败: %v\n", err)
}
// 6. 生成JWT令牌
claims := map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
}
token, err := s.jwtManager.GenerateToken(claims, 24*time.Hour) // 24小时有效期
if err != nil {
return nil, fmt.Errorf("生成令牌失败: %w", err)
}
refreshToken, err := s.jwtManager.GenerateToken(claims, 7*24*time.Hour) // 7天有效期
if err != nil {
return nil, fmt.Errorf("生成刷新令牌失败: %w", err)
}
// 7. 获取用户角色信息
userWithRoles, err := s.userRepo.GetUserWithRoles(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("获取用户信息失败: %w", err)
}
// 8. 构建响应
userInfo := s.buildUserInfo(userWithRoles)
return &model.LoginResponse{
Token: token,
TokenType: "Bearer",
ExpiresIn: 24 * 60 * 60, // 24小时,单位秒
RefreshToken: refreshToken,
User: *userInfo,
}, nil
}
// Logout 用户登出
func (s *authService) Logout(ctx context.Context, token string) error {
// 这里可以实现令牌黑名单机制
// 目前简单返回成功
return nil
}
// RefreshToken 刷新令牌
func (s *authService) RefreshToken(ctx context.Context, refreshToken string) (*model.LoginResponse, error) {
// 验证刷新令牌
claims, err := s.jwtManager.ValidateToken(refreshToken)
if err != nil {
return nil, fmt.Errorf("刷新令牌无效: %w", err)
}
// 获取用户信息
userID, ok := claims["user_id"].(float64)
if !ok {
return nil, fmt.Errorf("令牌格式错误")
}
user, err := s.userRepo.GetUserWithRoles(ctx, uint(userID))
if err != nil {
return nil, fmt.Errorf("用户不存在: %w", err)
}
// 检查用户状态
if !user.CanLogin() {
return nil, fmt.Errorf("用户状态异常")
}
// 生成新的访问令牌
newClaims := map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
}
newToken, err := s.jwtManager.GenerateToken(newClaims, 24*time.Hour)
if err != nil {
return nil, fmt.Errorf("生成新令牌失败: %w", err)
}
// 构建响应
userInfo := s.buildUserInfo(user)
return &model.LoginResponse{
Token: newToken,
TokenType: "Bearer",
ExpiresIn: 24 * 60 * 60,
RefreshToken: refreshToken, // 保持原刷新令牌
User: *userInfo,
}, nil
}
// GenerateCaptcha 生成验证码
func (s *authService) GenerateCaptcha(ctx context.Context) (*model.CaptchaResponse, error) {
// 生成4位随机验证码
captchaText := s.generateRandomCaptcha(4)
// 使用验证码文本的base64编码作为ID
captchaID := base64.StdEncoding.EncodeToString([]byte(captchaText))
// 设置5分钟过期时间
expiresAt := time.Now().Add(5 * time.Minute)
// 保存到数据库
if err := s.captchaRepo.Create(ctx, captchaID, captchaText, expiresAt); err != nil {
return nil, fmt.Errorf("保存验证码失败: %w", err)
}
// 生成验证码图片
captchaImage := s.generateCaptchaImage(captchaText)
return &model.CaptchaResponse{
CaptchaID: captchaID,
CaptchaImage: captchaImage,
ExpiresIn: 5 * 60, // 5分钟,单位秒
}, nil
}
// ValidateCaptcha 验证验证码
func (s *authService) ValidateCaptcha(ctx context.Context, captchaID, captchaText string) error {
// 从数据库获取验证码
storedText, err := s.captchaRepo.Get(ctx, captchaID)
if err != nil {
return fmt.Errorf("验证码已过期或不存在")
}
// 比较验证码(不区分大小写)
if !strings.EqualFold(storedText, captchaText) {
return fmt.Errorf("验证码错误")
}
return nil
}
// GetUserInfo 获取用户信息
func (s *authService) GetUserInfo(ctx context.Context, userID uint) (*model.UserInfo, error) {
user, err := s.userRepo.GetUserWithRoles(ctx, userID)
if err != nil {
return nil, fmt.Errorf("用户不存在: %w", err)
}
return s.buildUserInfo(user), nil
}
// buildUserInfo 构建用户信息
func (s *authService) buildUserInfo(user *model.User) *model.UserInfo {
roles := make([]model.RoleInfo, 0, len(user.Roles))
for _, role := range user.Roles {
// 构建权限信息
permissions := make([]model.PermissionInfo, 0, len(role.Permissions))
for _, perm := range role.Permissions {
permissions = append(permissions, model.PermissionInfo{
ID: perm.ID,
Name: perm.Name,
Description: perm.Description,
Resource: perm.Resource,
Action: perm.Action,
})
}
roles = append(roles, model.RoleInfo{
ID: role.ID,
Name: role.Name,
Code: role.Code,
Description: role.Description,
Permissions: permissions,
})
}
return &model.UserInfo{
ID: user.ID,
Username: user.Username,
Email: user.Email,
Phone: user.Phone,
Status: user.Status,
CreatedAt: user.CreatedAt,
LastLoginAt: user.LastLoginAt,
LastLoginIP: user.LastLoginIP,
Roles: roles,
}
}
// generateRandomCaptcha 生成随机验证码
func (s *authService) generateRandomCaptcha(length int) string {
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
result := make([]byte, length)
for i := range result {
result[i] = chars[mathrand.Intn(len(chars))]
}
return string(result)
}
// generateCaptchaImage 生成验证码图片(增强版,包含干扰元素)
func (s *authService) generateCaptchaImage(text string) string {
width := 120
height := 40
// 生成随机干扰线
interferenceLines := ""
for i := 0; i < 3; i++ {
x1 := mathrand.Intn(width)
y1 := mathrand.Intn(height)
x2 := mathrand.Intn(width)
y2 := mathrand.Intn(height)
color := fmt.Sprintf("#%06x", mathrand.Intn(0xFFFFFF))
interferenceLines += fmt.Sprintf(`<line x1="%d" y1="%d" x2="%d" y2="%d" stroke="%s" stroke-width="1" opacity="0.3"/>`, x1, y1, x2, y2, color)
}
// 生成随机干扰点
interferenceDots := ""
for i := 0; i < 20; i++ {
x := mathrand.Intn(width)
y := mathrand.Intn(height)
color := fmt.Sprintf("#%06x", mathrand.Intn(0xFFFFFF))
interferenceDots += fmt.Sprintf(`<circle cx="%d" cy="%d" r="1" fill="%s" opacity="0.4"/>`, x, y, color)
}
// 生成随机干扰圆
interferenceCircles := ""
for i := 0; i < 5; i++ {
cx := mathrand.Intn(width)
cy := mathrand.Intn(height)
r := 2 + mathrand.Intn(3)
color := fmt.Sprintf("#%06x", mathrand.Intn(0xFFFFFF))
interferenceCircles += fmt.Sprintf(`<circle cx="%d" cy="%d" r="%d" fill="%s" opacity="0.2"/>`, cx, cy, r, color)
}
// 为每个字符添加随机旋转和颜色
textElements := ""
charWidth := width / (len(text) + 1)
for i, char := range text {
x := charWidth * (i + 1)
y := height/2 + mathrand.Intn(6) - 3
rotation := mathrand.Intn(20) - 10
color := fmt.Sprintf("#%06x", mathrand.Intn(0x666666)+0x333333)
fontSize := 18 + mathrand.Intn(8)
textElements += fmt.Sprintf(`<text x="%d" y="%d" font-family="Arial, sans-serif" font-size="%d" font-weight="bold"
fill="%s" text-anchor="middle" dominant-baseline="middle" transform="rotate(%d %d %d)">%c</text>`,
x, y, fontSize, color, rotation, x, y, char)
}
// 创建增强版SVG图片
svg := fmt.Sprintf(`<svg width="%d" height="%d" xmlns="http://www.w3.org/2000/svg">
<defs>
<filter id="noise" x="0%%" y="0%%" width="100%%" height="100%%">
<feTurbulence type="fractalNoise" baseFrequency="0.8" numOctaves="4" stitchTiles="stitch"/>
<feColorMatrix type="matrix" values="0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0.1 0"/>
</filter>
</defs>
<rect width="%d" height="%d" fill="#f8f9fa"/>
<rect width="%d" height="%d" fill="url(#noise)" opacity="0.1"/>
%s
%s
%s
%s
</svg>`, width, height, width, height, width, height, interferenceLines, interferenceDots, interferenceCircles, textElements)
// 将SVG转换为Base64
return fmt.Sprintf("data:image/svg+xml;base64,%s", base64.StdEncoding.EncodeToString([]byte(svg)))
}
// GetClientIP 获取客户端IP地址
func GetClientIP(r *http.Request) string {
// 尝试从各种头部获取真实IP
ip := r.Header.Get("X-Real-IP")
if ip != "" {
return ip
}
ip = r.Header.Get("X-Forwarded-For")
if ip != "" {
// X-Forwarded-For可能包含多个IP,取第一个
if idx := strings.Index(ip, ","); idx != -1 {
ip = ip[:idx]
}
return strings.TrimSpace(ip)
}
ip = r.Header.Get("X-Forwarded")
if ip != "" {
return ip
}
// 从RemoteAddr获取
if r.RemoteAddr != "" {
if idx := strings.Index(r.RemoteAddr, ":"); idx != -1 {
return r.RemoteAddr[:idx]
}
return r.RemoteAddr
}
return "127.0.0.1"
}