You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
326 lines
9.1 KiB
326 lines
9.1 KiB
package service |
|
|
|
import ( |
|
"context" |
|
"crypto/rand" |
|
"encoding/base64" |
|
"fmt" |
|
mathrand "math/rand" |
|
"net/http" |
|
"strings" |
|
"time" |
|
|
|
"gofaster/internal/auth/model" |
|
"gofaster/internal/auth/repository" |
|
"gofaster/internal/shared/jwt" |
|
|
|
"golang.org/x/crypto/bcrypt" |
|
) |
|
|
|
type AuthService interface { |
|
Login(ctx context.Context, req *model.LoginRequest, clientIP string) (*model.LoginResponse, error) |
|
Logout(ctx context.Context, token string) error |
|
RefreshToken(ctx context.Context, refreshToken string) (*model.LoginResponse, error) |
|
GenerateCaptcha(ctx context.Context) (*model.CaptchaResponse, error) |
|
ValidateCaptcha(ctx context.Context, captchaID, captchaText string) error |
|
GetUserInfo(ctx context.Context, userID uint) (*model.UserInfo, error) |
|
} |
|
|
|
type authService struct { |
|
userRepo repository.UserRepository |
|
captchaRepo repository.CaptchaRepository |
|
jwtManager jwt.JWTManager |
|
} |
|
|
|
func NewAuthService(userRepo repository.UserRepository, captchaRepo repository.CaptchaRepository, jwtManager jwt.JWTManager) AuthService { |
|
return &authService{ |
|
userRepo: userRepo, |
|
captchaRepo: captchaRepo, |
|
jwtManager: jwtManager, |
|
} |
|
} |
|
|
|
// Login 用户登录 |
|
func (s *authService) Login(ctx context.Context, req *model.LoginRequest, clientIP string) (*model.LoginResponse, error) { |
|
// 1. 验证验证码 |
|
if err := s.ValidateCaptcha(ctx, req.CaptchaID, req.Captcha); err != nil { |
|
return nil, fmt.Errorf("验证码错误: %w", err) |
|
} |
|
|
|
// 2. 根据用户名查找用户 |
|
user, err := s.userRepo.GetByUsername(ctx, req.Username) |
|
if err != nil { |
|
return nil, fmt.Errorf("用户不存在") |
|
} |
|
|
|
// 3. 检查用户状态 |
|
if !user.CanLogin() { |
|
if user.IsLocked() { |
|
return nil, fmt.Errorf("账户已被锁定,请30分钟后再试") |
|
} |
|
return nil, fmt.Errorf("账户已被禁用") |
|
} |
|
|
|
// 4. 验证密码 |
|
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil { |
|
// 密码错误,增加错误次数 |
|
if err := s.userRepo.IncrementPasswordError(ctx, user.ID); err != nil { |
|
return nil, fmt.Errorf("系统错误,请稍后重试") |
|
} |
|
|
|
// 检查是否被锁定 |
|
if user.PasswordErrorCount >= 4 { // 已经是第5次错误 |
|
return nil, fmt.Errorf("密码错误次数过多,账户已被锁定30分钟") |
|
} |
|
|
|
remaining := 5 - user.PasswordErrorCount - 1 |
|
return nil, fmt.Errorf("密码错误,还可尝试%d次", remaining) |
|
} |
|
|
|
// 5. 密码正确,重置错误次数并更新登录信息 |
|
if err := s.userRepo.ResetPasswordError(ctx, user.ID); err != nil { |
|
return nil, fmt.Errorf("系统错误,请稍后重试") |
|
} |
|
|
|
if err := s.userRepo.UpdateLastLogin(ctx, user.ID, clientIP); err != nil { |
|
// 登录信息更新失败不影响登录流程 |
|
fmt.Printf("更新登录信息失败: %v\n", err) |
|
} |
|
|
|
// 6. 生成JWT令牌 |
|
claims := map[string]interface{}{ |
|
"user_id": user.ID, |
|
"username": user.Username, |
|
"email": user.Email, |
|
} |
|
|
|
token, err := s.jwtManager.GenerateToken(claims, 24*time.Hour) // 24小时有效期 |
|
if err != nil { |
|
return nil, fmt.Errorf("生成令牌失败: %w", err) |
|
} |
|
|
|
refreshToken, err := s.jwtManager.GenerateToken(claims, 7*24*time.Hour) // 7天有效期 |
|
if err != nil { |
|
return nil, fmt.Errorf("生成刷新令牌失败: %w", err) |
|
} |
|
|
|
// 7. 获取用户角色信息 |
|
userWithRoles, err := s.userRepo.GetUserWithRoles(ctx, user.ID) |
|
if err != nil { |
|
return nil, fmt.Errorf("获取用户信息失败: %w", err) |
|
} |
|
|
|
// 8. 构建响应 |
|
userInfo := s.buildUserInfo(userWithRoles) |
|
|
|
return &model.LoginResponse{ |
|
Token: token, |
|
TokenType: "Bearer", |
|
ExpiresIn: 24 * 60 * 60, // 24小时,单位秒 |
|
RefreshToken: refreshToken, |
|
User: *userInfo, |
|
}, nil |
|
} |
|
|
|
// Logout 用户登出 |
|
func (s *authService) Logout(ctx context.Context, token string) error { |
|
// 这里可以实现令牌黑名单机制 |
|
// 目前简单返回成功 |
|
return nil |
|
} |
|
|
|
// RefreshToken 刷新令牌 |
|
func (s *authService) RefreshToken(ctx context.Context, refreshToken string) (*model.LoginResponse, error) { |
|
// 验证刷新令牌 |
|
claims, err := s.jwtManager.ValidateToken(refreshToken) |
|
if err != nil { |
|
return nil, fmt.Errorf("刷新令牌无效: %w", err) |
|
} |
|
|
|
// 获取用户信息 |
|
userID, ok := claims["user_id"].(float64) |
|
if !ok { |
|
return nil, fmt.Errorf("令牌格式错误") |
|
} |
|
|
|
user, err := s.userRepo.GetUserWithRoles(ctx, uint(userID)) |
|
if err != nil { |
|
return nil, fmt.Errorf("用户不存在: %w", err) |
|
} |
|
|
|
// 检查用户状态 |
|
if !user.CanLogin() { |
|
return nil, fmt.Errorf("用户状态异常") |
|
} |
|
|
|
// 生成新的访问令牌 |
|
newClaims := map[string]interface{}{ |
|
"user_id": user.ID, |
|
"username": user.Username, |
|
"email": user.Email, |
|
} |
|
|
|
newToken, err := s.jwtManager.GenerateToken(newClaims, 24*time.Hour) |
|
if err != nil { |
|
return nil, fmt.Errorf("生成新令牌失败: %w", err) |
|
} |
|
|
|
// 构建响应 |
|
userInfo := s.buildUserInfo(user) |
|
|
|
return &model.LoginResponse{ |
|
Token: newToken, |
|
TokenType: "Bearer", |
|
ExpiresIn: 24 * 60 * 60, |
|
RefreshToken: refreshToken, // 保持原刷新令牌 |
|
User: *userInfo, |
|
}, nil |
|
} |
|
|
|
// GenerateCaptcha 生成验证码 |
|
func (s *authService) GenerateCaptcha(ctx context.Context) (*model.CaptchaResponse, error) { |
|
// 生成随机验证码ID |
|
captchaID, err := s.generateRandomID() |
|
if err != nil { |
|
return nil, fmt.Errorf("生成验证码ID失败: %w", err) |
|
} |
|
|
|
// 生成4位随机验证码 |
|
captchaText := s.generateRandomCaptcha(4) |
|
|
|
// 设置5分钟过期时间 |
|
expiresAt := time.Now().Add(5 * time.Minute) |
|
|
|
// 保存到数据库 |
|
if err := s.captchaRepo.Create(ctx, captchaID, captchaText, expiresAt); err != nil { |
|
return nil, fmt.Errorf("保存验证码失败: %w", err) |
|
} |
|
|
|
// 生成验证码图片(这里简化处理,实际应该生成图片) |
|
captchaImage := s.generateCaptchaImage(captchaText) |
|
|
|
return &model.CaptchaResponse{ |
|
CaptchaID: captchaID, |
|
CaptchaImage: captchaImage, |
|
ExpiresIn: 5 * 60, // 5分钟,单位秒 |
|
}, nil |
|
} |
|
|
|
// ValidateCaptcha 验证验证码 |
|
func (s *authService) ValidateCaptcha(ctx context.Context, captchaID, captchaText string) error { |
|
// 从数据库获取验证码 |
|
storedText, err := s.captchaRepo.Get(ctx, captchaID) |
|
if err != nil { |
|
return fmt.Errorf("验证码已过期或不存在") |
|
} |
|
|
|
// 比较验证码(不区分大小写) |
|
if !strings.EqualFold(storedText, captchaText) { |
|
return fmt.Errorf("验证码错误") |
|
} |
|
|
|
return nil |
|
} |
|
|
|
// GetUserInfo 获取用户信息 |
|
func (s *authService) GetUserInfo(ctx context.Context, userID uint) (*model.UserInfo, error) { |
|
user, err := s.userRepo.GetUserWithRoles(ctx, userID) |
|
if err != nil { |
|
return nil, fmt.Errorf("用户不存在: %w", err) |
|
} |
|
|
|
return s.buildUserInfo(user), nil |
|
} |
|
|
|
// buildUserInfo 构建用户信息 |
|
func (s *authService) buildUserInfo(user *model.User) *model.UserInfo { |
|
roles := make([]model.RoleInfo, 0, len(user.Roles)) |
|
for _, role := range user.Roles { |
|
roles = append(roles, model.RoleInfo{ |
|
ID: role.ID, |
|
Name: role.Name, |
|
Code: role.Code, |
|
}) |
|
} |
|
|
|
return &model.UserInfo{ |
|
ID: user.ID, |
|
Username: user.Username, |
|
Email: user.Email, |
|
Phone: user.Phone, |
|
Status: user.Status, |
|
LastLoginAt: user.LastLoginAt, |
|
LastLoginIP: user.LastLoginIP, |
|
Roles: roles, |
|
} |
|
} |
|
|
|
// generateRandomID 生成随机ID |
|
func (s *authService) generateRandomID() (string, error) { |
|
b := make([]byte, 16) |
|
if _, err := rand.Read(b); err != nil { |
|
return "", err |
|
} |
|
return base64.URLEncoding.EncodeToString(b), nil |
|
} |
|
|
|
// generateRandomCaptcha 生成随机验证码 |
|
func (s *authService) generateRandomCaptcha(length int) string { |
|
const chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" |
|
result := make([]byte, length) |
|
for i := range result { |
|
result[i] = chars[mathrand.Intn(len(chars))] |
|
} |
|
return string(result) |
|
} |
|
|
|
// generateCaptchaImage 生成验证码图片(简化版本,返回Base64编码) |
|
func (s *authService) generateCaptchaImage(text string) string { |
|
// 创建一个简单的文本图片作为验证码 |
|
// 使用HTML5 Canvas风格的文本渲染 |
|
width := 120 |
|
height := 40 |
|
|
|
// 创建一个简单的SVG图片,包含验证码文本 |
|
svg := fmt.Sprintf(`<svg width="%d" height="%d" xmlns="http://www.w3.org/2000/svg"> |
|
<rect width="%d" height="%d" fill="#f0f0f0"/> |
|
<text x="%d" y="%d" font-family="Arial, sans-serif" font-size="24" font-weight="bold" |
|
fill="#333" text-anchor="middle" dominant-baseline="middle">%s</text> |
|
</svg>`, width, height, width, height, width/2, height/2, text) |
|
|
|
// 将SVG转换为Base64 |
|
return fmt.Sprintf("data:image/svg+xml;base64,%s", base64.StdEncoding.EncodeToString([]byte(svg))) |
|
} |
|
|
|
// GetClientIP 获取客户端IP地址 |
|
func GetClientIP(r *http.Request) string { |
|
// 尝试从各种头部获取真实IP |
|
ip := r.Header.Get("X-Real-IP") |
|
if ip != "" { |
|
return ip |
|
} |
|
|
|
ip = r.Header.Get("X-Forwarded-For") |
|
if ip != "" { |
|
// X-Forwarded-For可能包含多个IP,取第一个 |
|
if idx := strings.Index(ip, ","); idx != -1 { |
|
ip = ip[:idx] |
|
} |
|
return strings.TrimSpace(ip) |
|
} |
|
|
|
ip = r.Header.Get("X-Forwarded") |
|
if ip != "" { |
|
return ip |
|
} |
|
|
|
// 从RemoteAddr获取 |
|
if r.RemoteAddr != "" { |
|
if idx := strings.Index(r.RemoteAddr, ":"); idx != -1 { |
|
return r.RemoteAddr[:idx] |
|
} |
|
return r.RemoteAddr |
|
} |
|
|
|
return "127.0.0.1" |
|
}
|
|
|