You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

422 lines
9.7 KiB

package service
import (
"context"
"errors"
"fmt"
"math/rand"
"time"
"gofaster/internal/auth/model"
"gofaster/internal/auth/repository"
"golang.org/x/crypto/bcrypt"
)
type PasswordService struct {
userRepo repository.UserRepository
passwordPolicyRepo repository.PasswordPolicyRepository
passwordHistoryRepo repository.PasswordHistoryRepository
passwordResetRepo repository.PasswordResetRepository
}
func NewPasswordService(
userRepo repository.UserRepository,
passwordPolicyRepo repository.PasswordPolicyRepository,
passwordHistoryRepo repository.PasswordHistoryRepository,
passwordResetRepo repository.PasswordResetRepository,
) *PasswordService {
return &PasswordService{
userRepo: userRepo,
passwordPolicyRepo: passwordPolicyRepo,
passwordHistoryRepo: passwordHistoryRepo,
passwordResetRepo: passwordResetRepo,
}
}
// GetPasswordPolicy 获取密码策略
func (ps *PasswordService) GetPasswordPolicy() (*model.PasswordPolicy, error) {
return ps.passwordPolicyRepo.GetActivePolicy()
}
// ValidatePassword 验证密码是否符合策略
func (ps *PasswordService) ValidatePassword(ctx context.Context, userID uint, newPassword string) (*model.PasswordValidationResult, error) {
// 获取密码策略
policy, err := ps.GetPasswordPolicy()
if err != nil {
return nil, err
}
// 计算密码强度
strength := ps.calculatePasswordStrength(newPassword)
level := ps.calculatePasswordLevel(newPassword)
// 检查密码长度
if len(newPassword) < policy.MinLength {
return &model.PasswordValidationResult{
IsValid: false,
Strength: strength,
Level: level,
Errors: []string{fmt.Sprintf("密码长度不能少于%d位", policy.MinLength)},
}, nil
}
// 检查字符类型要求
charTypes := 0
hasUppercase := false
hasLowercase := false
hasNumbers := false
hasSpecial := false
for _, char := range newPassword {
if char >= 'A' && char <= 'Z' {
hasUppercase = true
} else if char >= 'a' && char <= 'z' {
hasLowercase = true
} else if char >= '0' && char <= '9' {
hasNumbers = true
} else {
hasSpecial = true
}
}
if hasUppercase {
charTypes++
}
if hasLowercase {
charTypes++
}
if hasNumbers {
charTypes++
}
if hasSpecial {
charTypes++
}
// 检查字符类型数量要求
if charTypes < policy.MinCharTypes {
return &model.PasswordValidationResult{
IsValid: false,
Strength: strength,
Level: level,
Errors: []string{fmt.Sprintf("密码必须包含至少%d种字符类型", policy.MinCharTypes)},
}, nil
}
// 检查密码等级要求
if level < policy.MinRequiredLevel {
return &model.PasswordValidationResult{
IsValid: false,
Strength: strength,
Level: level,
Errors: []string{fmt.Sprintf("密码强度等级不能低于%d级", policy.MinRequiredLevel)},
}, nil
}
// 检查是否与历史密码重复
if err := ps.CheckPasswordReuse(userID, newPassword); err != nil {
return &model.PasswordValidationResult{
IsValid: false,
Strength: strength,
Level: level,
Errors: []string{err.Error()},
}, nil
}
return &model.PasswordValidationResult{
IsValid: true,
Strength: strength,
Level: level,
Errors: []string{},
}, nil
}
// calculatePasswordStrength 计算密码强度(0-100)
func (ps *PasswordService) calculatePasswordStrength(password string) int {
strength := 0
// 基础长度分数
if len(password) >= 8 {
strength += 20
} else if len(password) >= 6 {
strength += 15
} else {
strength += 10
}
// 字符类型分数
charTypes := 0
hasUppercase := false
hasLowercase := false
hasNumbers := false
hasSpecial := false
for _, char := range password {
if char >= 'A' && char <= 'Z' {
hasUppercase = true
} else if char >= 'a' && char <= 'z' {
hasLowercase = true
} else if char >= '0' && char <= '9' {
hasNumbers = true
} else {
hasSpecial = true
}
}
if hasUppercase {
charTypes++
strength += 15
}
if hasLowercase {
charTypes++
strength += 15
}
if hasNumbers {
charTypes++
strength += 15
}
if hasSpecial {
charTypes++
strength += 20
}
// 复杂度奖励
if charTypes >= 4 {
strength += 20
} else if charTypes >= 3 {
strength += 15
} else if charTypes >= 2 {
strength += 10
}
// 确保分数在0-100范围内
if strength > 100 {
strength = 100
}
return strength
}
// calculatePasswordLevel 计算密码等级(0-5)
func (ps *PasswordService) calculatePasswordLevel(password string) int {
if len(password) < 6 {
return 0
}
charTypes := 0
hasUppercase := false
hasLowercase := false
hasNumbers := false
hasSpecial := false
for _, char := range password {
if char >= 'A' && char <= 'Z' {
hasUppercase = true
} else if char >= 'a' && char <= 'z' {
hasLowercase = true
} else if char >= '0' && char <= '9' {
hasNumbers = true
} else {
hasSpecial = true
}
}
if hasUppercase {
charTypes++
}
if hasLowercase {
charTypes++
}
if hasNumbers {
charTypes++
}
if hasSpecial {
charTypes++
}
// 根据长度和字符类型确定等级
if len(password) >= 8 && charTypes >= 4 {
return 5
} else if len(password) >= 8 && charTypes >= 3 {
return 4
} else if len(password) >= 6 && charTypes >= 3 {
return 3
} else if len(password) >= 6 && charTypes >= 2 {
return 2
} else if len(password) >= 6 && charTypes >= 1 {
return 1
}
return 0
}
// CheckPasswordReuse 检查密码是否重复使用
func (ps *PasswordService) CheckPasswordReuse(userID uint, newPassword string) error {
policy, err := ps.GetPasswordPolicy()
if err != nil {
return err
}
if policy.PreventReuse <= 0 {
return nil // 不检查重复使用
}
// 获取用户密码历史
history, err := ps.passwordHistoryRepo.GetRecentPasswords(userID, policy.PreventReuse)
if err != nil {
return err
}
// 检查新密码是否与历史密码重复
for _, record := range history {
if ps.verifyPassword(newPassword, record.Password) {
return errors.New(fmt.Sprintf("新密码不能与前%d次使用的密码重复", policy.PreventReuse))
}
}
return nil
}
// ChangePassword 修改密码
func (ps *PasswordService) ChangePassword(ctx context.Context, userID uint, currentPassword, newPassword string) error {
// 获取用户信息
user, err := ps.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
// 验证当前密码
if !ps.verifyPassword(currentPassword, user.Password) {
return errors.New("当前密码不正确")
}
// 验证新密码
validationResult, err := ps.ValidatePassword(ctx, userID, newPassword)
if err != nil {
return err
}
if !validationResult.IsValid {
return errors.New("新密码不符合要求")
}
// 检查密码重复使用
if err := ps.CheckPasswordReuse(userID, newPassword); err != nil {
return err
}
// 哈希新密码
newPasswordHash, err := ps.hashPassword(newPassword)
if err != nil {
return err
}
// 更新用户密码
user.Password = string(newPasswordHash)
user.PasswordChangedAt = &time.Time{}
user.ForceChangePassword = false
if err := ps.userRepo.Update(ctx, user); err != nil {
return err
}
// 记录密码历史
passwordHistory := &model.PasswordHistory{
UserID: userID,
Password: string(newPasswordHash),
ChangedAt: time.Now(),
}
return ps.passwordHistoryRepo.Create(passwordHistory)
}
// ResetPassword 重置密码
func (ps *PasswordService) ResetPassword(ctx context.Context, userID uint) error {
// 生成临时密码
tempPassword := ps.generateTempPassword()
// 哈希临时密码
hashedPassword, err := ps.hashPassword(tempPassword)
if err != nil {
return err
}
// 更新用户密码
user, err := ps.userRepo.GetByID(ctx, userID)
if err != nil {
return err
}
user.Password = string(hashedPassword)
now := time.Now()
user.PasswordChangedAt = &now
user.ForceChangePassword = true
return ps.userRepo.Update(ctx, user)
}
// CheckPasswordExpiration 检查密码是否过期
func (ps *PasswordService) CheckPasswordExpiration(user *model.User) (bool, error) {
if user.PasswordChangedAt == nil {
return false, nil
}
policy, err := ps.GetPasswordPolicy()
if err != nil {
return false, err
}
expirationDate := user.PasswordChangedAt.Add(time.Duration(policy.ExpirationDays) * 24 * time.Hour)
return time.Now().After(expirationDate), nil
}
// CheckForceChangePassword 检查是否需要强制修改密码
func (ps *PasswordService) CheckForceChangePassword(user *model.User) bool {
return user.ForceChangePassword
}
// generateTempPassword 生成临时密码
func (ps *PasswordService) generateTempPassword() string {
const charset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
password := make([]byte, 8)
for i := range password {
password[i] = charset[rand.Intn(len(charset))]
}
return string(password)
}
// hashPassword 哈希密码
func (ps *PasswordService) hashPassword(password string) ([]byte, error) {
return bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
}
// verifyPassword 验证密码
func (ps *PasswordService) verifyPassword(password, hashedPassword string) bool {
return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password)) == nil
}
// CheckPasswordStatus 检查密码状态
func (ps *PasswordService) CheckPasswordStatus(ctx context.Context, userID uint) (*model.PasswordStatus, error) {
user, err := ps.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, err
}
isExpired, err := ps.CheckPasswordExpiration(user)
if err != nil {
return nil, err
}
forceChange := ps.CheckForceChangePassword(user)
return &model.PasswordStatus{
ForceChangePassword: forceChange,
PasswordExpired: isExpired,
PasswordChangedAt: user.PasswordChangedAt,
}, nil
}
// UpdatePasswordPolicy 更新密码策略
func (ps *PasswordService) UpdatePasswordPolicy(policy *model.PasswordPolicy) error {
return ps.passwordPolicyRepo.Update(policy)
}