Files
NetworkAuth/config/validator.go

270 lines
7.0 KiB
Go
Raw Normal View History

2025-10-24 00:09:45 +08:00
package config
import (
"errors"
"fmt"
"net"
"os"
"path/filepath"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
2025-10-27 23:12:15 +08:00
// ============================================================================
// 公共函数
// ============================================================================
// ValidateConfig 验证配置
func ValidateConfig() (*AppConfig, error) {
2025-10-24 00:09:45 +08:00
var config AppConfig
// 解析配置到结构体
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
// 验证配置
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err)
}
log.Info("配置验证通过")
return &config, nil
}
2025-10-27 23:12:15 +08:00
// ============================================================================
// 私有函数
// ============================================================================
2025-10-24 00:09:45 +08:00
// validateConfig 验证配置
func validateConfig(config *AppConfig) error {
// 验证服务器配置
if err := validateServerConfig(&config.Server); err != nil {
return fmt.Errorf("服务器配置错误: %w", err)
}
// 验证数据库配置
if err := validateDatabaseConfig(&config.Database); err != nil {
return fmt.Errorf("数据库配置错误: %w", err)
}
// 验证Redis配置
if err := validateRedisConfig(&config.Redis); err != nil {
return fmt.Errorf("redis配置错误: %w", err)
}
// 验证日志配置
if err := validateLogConfig(&config.Log); err != nil {
return fmt.Errorf("日志配置错误: %w", err)
}
// 验证安全配置
if err := validateSecurityConfig(&config.Security); err != nil {
return fmt.Errorf("安全配置错误: %w", err)
}
return nil
}
// validateServerConfig 验证服务器配置
func validateServerConfig(config *ServerConfig) error {
// 验证主机地址
if config.Host != "" {
if ip := net.ParseIP(config.Host); ip == nil && config.Host != "localhost" {
return fmt.Errorf("无效的主机地址: %s", config.Host)
}
}
// 验证端口
if config.Port < 1 || config.Port > 65535 {
return fmt.Errorf("无效的端口号: %d端口号必须在1-65535之间", config.Port)
}
return nil
}
// validateDatabaseConfig 验证数据库配置
func validateDatabaseConfig(config *DatabaseConfig) error {
// 验证数据库类型
validTypes := []string{"mysql", "sqlite"}
if !contains(validTypes, config.Type) {
return fmt.Errorf("不支持的数据库类型: %s支持的类型: %s", config.Type, strings.Join(validTypes, ", "))
}
// 根据类型验证具体配置
switch config.Type {
case "mysql":
return validateMySQLConfig(&config.MySQL)
case "sqlite":
return validateSQLiteConfig(&config.SQLite)
}
return nil
}
// validateMySQLConfig 验证MySQL配置
func validateMySQLConfig(config *MySQLConfig) error {
if config.Host == "" {
return errors.New("MySQL主机地址不能为空")
}
if config.Port < 1 || config.Port > 65535 {
return fmt.Errorf("无效的MySQL端口号: %d", config.Port)
}
if config.Username == "" {
return errors.New("MySQL用户名不能为空")
}
if config.Database == "" {
return errors.New("MySQL数据库名不能为空")
}
if config.MaxIdleConns < 0 {
return errors.New("MySQL最大空闲连接数不能为负数")
}
if config.MaxOpenConns < 0 {
return errors.New("MySQL最大打开连接数不能为负数")
}
return nil
}
// validateSQLiteConfig 验证SQLite配置
func validateSQLiteConfig(config *SQLiteConfig) error {
if config.Path == "" {
return errors.New("SQLite数据库路径不能为空")
}
// 检查目录是否存在,不存在则创建
dir := filepath.Dir(config.Path)
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("创建SQLite数据库目录失败: %w", err)
}
}
return nil
}
// validateRedisConfig 验证Redis配置
func validateRedisConfig(config *RedisConfig) error {
if config.Host == "" {
return errors.New("Redis主机地址不能为空")
}
if config.Port < 1 || config.Port > 65535 {
return fmt.Errorf("无效的Redis端口号: %d", config.Port)
}
if config.DB < 0 || config.DB > 15 {
return fmt.Errorf("无效的Redis数据库索引: %d必须在0-15之间", config.DB)
}
return nil
}
// validateLogConfig 验证日志配置
func validateLogConfig(config *LogConfig) error {
// 验证日志级别
validLevels := []string{"trace", "debug", "info", "warn", "error", "fatal", "panic"}
if !contains(validLevels, config.Level) {
return fmt.Errorf("无效的日志级别: %s支持的级别: %s", config.Level, strings.Join(validLevels, ", "))
}
// 检查日志文件目录(仅当日志文件路径不为空时)
if config.File != "" {
dir := filepath.Dir(config.File)
if _, err := os.Stat(dir); os.IsNotExist(err) {
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("创建日志目录失败: %w", err)
}
}
}
// 当日志文件路径为空时,不进行目录检查和创建
// 验证日志轮转配置
if config.MaxSize <= 0 {
return errors.New("日志文件最大大小必须大于0")
}
if config.MaxBackups < 0 {
return errors.New("日志备份文件数量不能为负数")
}
if config.MaxAge < 0 {
return errors.New("日志文件保留天数不能为负数")
}
return nil
}
// validateSecurityConfig 验证安全配置
func validateSecurityConfig(config *SecurityConfig) error {
if len(config.JWTSecret) < 16 {
return errors.New("JWT密钥长度不能少于16个字符")
}
if len(config.EncryptionKey) < 16 {
return errors.New("加密密钥长度不能少于16个字符")
}
2025-10-26 11:57:31 +08:00
if config.JWTRefresh < 1 || config.JWTRefresh > 23 {
2025-10-24 00:09:45 +08:00
return errors.New("JWT令牌刷新阈值必须在1-23小时之间")
}
// 检查是否使用默认值(生产环境警告)
if strings.Contains(config.JWTSecret, "default") {
log.Warn("检测到使用默认JWT密钥生产环境请更换为安全的密钥")
}
if strings.Contains(config.EncryptionKey, "default") {
log.Warn("检测到使用默认加密密钥,生产环境请更换为安全的密钥")
}
return nil
}
// contains 检查切片是否包含指定元素
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
// GetConfigValue 获取配置值,支持类型转换和默认值
func GetConfigValue[T any](key string, defaultValue T) T {
if !viper.IsSet(key) {
return defaultValue
}
value := viper.Get(key)
if result, ok := value.(T); ok {
return result
}
// 尝试类型转换
if converted, err := convertValue[T](value); err == nil {
return converted
}
return defaultValue
}
// convertValue 尝试类型转换
func convertValue[T any](value interface{}) (T, error) {
var zero T
str := fmt.Sprintf("%v", value)
switch any(zero).(type) {
case int:
if i, err := strconv.Atoi(str); err == nil {
return any(i).(T), nil
}
case string:
return any(str).(T), nil
case bool:
if b, err := strconv.ParseBool(str); err == nil {
return any(b).(T), nil
}
}
return zero, fmt.Errorf("无法转换类型")
}