mirror of
https://github.com/skyle1995/NetworkAuth.git
synced 2026-05-25 02:24:05 +08:00
Use the gin framework
This commit is contained in:
@@ -1,97 +0,0 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"networkDev/web"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// JsonResponse 通用JSON响应函数
|
||||
// 将 success 转换为 code:true -> 0, false -> 1,并输出 data
|
||||
func JsonResponse(w http.ResponseWriter, status int, success bool, message string, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
|
||||
// 将success转换为code格式:true -> 0, false -> 1
|
||||
code := 1
|
||||
if success {
|
||||
code = 0
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"code": code,
|
||||
"msg": message,
|
||||
"data": data,
|
||||
})
|
||||
}
|
||||
|
||||
// RenderTemplate 通用模板渲染函数
|
||||
// templateName: 模板文件名
|
||||
// data: 模板数据
|
||||
// w: HTTP响应写入器
|
||||
func RenderTemplate(w http.ResponseWriter, templateName string, data map[string]interface{}) error {
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
|
||||
tmpl, err := web.ParseTemplates()
|
||||
if err != nil {
|
||||
http.Error(w, "模板解析失败", http.StatusInternalServerError)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tmpl.ExecuteTemplate(w, templateName, data); err != nil {
|
||||
http.Error(w, "模板渲染失败", http.StatusInternalServerError)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetDefaultTemplateData 获取默认模板数据
|
||||
// 返回包含系统基础信息的数据映射
|
||||
func GetDefaultTemplateData() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"SystemName": "网络验证系统",
|
||||
"FooterText": "© 2025 凌动技术 保留所有权利",
|
||||
}
|
||||
}
|
||||
|
||||
// GetTemplateDataWithCSRF 获取包含CSRF令牌的模板数据
|
||||
// 合并默认数据和CSRF令牌,用于需要CSRF保护的页面
|
||||
func GetTemplateDataWithCSRF(r *http.Request, additionalData map[string]interface{}) map[string]interface{} {
|
||||
// 获取默认模板数据
|
||||
data := GetDefaultTemplateData()
|
||||
|
||||
// 添加CSRF令牌
|
||||
data["CSRFToken"] = GetCSRFTokenForTemplate(r)
|
||||
|
||||
// 合并额外数据
|
||||
for key, value := range additionalData {
|
||||
data[key] = value
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// GetClientIP 获取客户端IP地址
|
||||
// 优先从 X-Forwarded-For 和 X-Real-IP 头部获取,否则使用 RemoteAddr
|
||||
func GetClientIP(r *http.Request) string {
|
||||
// 检查 X-Forwarded-For 头部
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// X-Forwarded-For 可能包含多个IP,取第一个
|
||||
if idx := strings.Index(xff, ","); idx != -1 {
|
||||
return strings.TrimSpace(xff[:idx])
|
||||
}
|
||||
return strings.TrimSpace(xff)
|
||||
}
|
||||
|
||||
// 检查 X-Real-IP 头部
|
||||
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// 使用 RemoteAddr
|
||||
if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 {
|
||||
return r.RemoteAddr[:idx]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
@@ -121,15 +121,6 @@ func DecryptString(enc string) (string, error) {
|
||||
return string(plain), nil
|
||||
}
|
||||
|
||||
// ResetCrypto 重置加密管理器(用于配置更新后重新初始化)
|
||||
func ResetCrypto() {
|
||||
cryptoManager.mutex.Lock()
|
||||
defer cryptoManager.mutex.Unlock()
|
||||
cryptoManager.inited = false
|
||||
cryptoManager.key = nil
|
||||
cryptoManager.gcm = nil
|
||||
}
|
||||
|
||||
// EncryptStringBatch 批量加密字符串
|
||||
// 减少锁竞争,提高批量处理性能
|
||||
func EncryptStringBatch(plains []string) ([]string, error) {
|
||||
@@ -281,12 +272,12 @@ func DecryptStringWithSalt(enc, salt string) (string, error) {
|
||||
if len(combined) < len(salt) {
|
||||
return "", errors.New("decrypted data too short")
|
||||
}
|
||||
|
||||
|
||||
// 验证盐值是否匹配
|
||||
if combined[len(combined)-len(salt):] != salt {
|
||||
return "", errors.New("salt mismatch")
|
||||
}
|
||||
|
||||
|
||||
return combined[:len(combined)-len(salt)], nil
|
||||
}
|
||||
|
||||
|
||||
109
utils/csrf.go
109
utils/csrf.go
@@ -5,6 +5,8 @@ import (
|
||||
"crypto/subtle"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -34,55 +36,51 @@ func GenerateCSRFToken() (string, error) {
|
||||
}
|
||||
|
||||
// SetCSRFToken 设置CSRF令牌到Cookie和响应头
|
||||
func SetCSRFToken(w http.ResponseWriter, token string) {
|
||||
// 设置CSRF令牌到Cookie
|
||||
cookie := CreateSecureCookie(CSRFCookieName, token, 3600) // 1小时过期
|
||||
http.SetCookie(w, cookie)
|
||||
|
||||
// 设置CSRF令牌到响应头,方便JavaScript获取
|
||||
w.Header().Set("X-CSRF-Token", token)
|
||||
func SetCSRFToken(c *gin.Context, token string) {
|
||||
c.SetCookie(CSRFCookieName, token, 3600*24, "/", "", false, true)
|
||||
c.Header(CSRFHeaderName, token)
|
||||
}
|
||||
|
||||
// GetCSRFTokenFromRequest 从请求中获取CSRF令牌
|
||||
// GetCSRFTokenFromRequest 从Gin请求中获取CSRF令牌
|
||||
// 优先级:Header > Form > Cookie
|
||||
func GetCSRFTokenFromRequest(r *http.Request) string {
|
||||
func GetCSRFTokenFromRequest(c *gin.Context) string {
|
||||
// 1. 从Header获取
|
||||
if token := r.Header.Get(CSRFHeaderName); token != "" {
|
||||
if token := c.GetHeader(CSRFHeaderName); token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// 2. 从Form获取
|
||||
if token := r.FormValue(CSRFFormField); token != "" {
|
||||
if token := c.PostForm(CSRFFormField); token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
// 3. 从Cookie获取(作为备选)
|
||||
if cookie, err := r.Cookie(CSRFCookieName); err == nil {
|
||||
return cookie.Value
|
||||
if cookie, err := c.Cookie(CSRFCookieName); err == nil {
|
||||
return cookie
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetCSRFTokenFromCookie 从Cookie中获取CSRF令牌
|
||||
func GetCSRFTokenFromCookie(r *http.Request) string {
|
||||
cookie, err := r.Cookie(CSRFCookieName)
|
||||
func GetCSRFTokenFromCookie(c *gin.Context) string {
|
||||
cookie, err := c.Cookie(CSRFCookieName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return cookie.Value
|
||||
return cookie
|
||||
}
|
||||
|
||||
// ValidateCSRFToken 验证CSRF令牌
|
||||
func ValidateCSRFToken(r *http.Request) bool {
|
||||
func ValidateCSRFToken(c *gin.Context) bool {
|
||||
// 获取Cookie中的令牌(服务器端存储的)
|
||||
cookieToken := GetCSRFTokenFromCookie(r)
|
||||
cookieToken := GetCSRFTokenFromCookie(c)
|
||||
if cookieToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取请求中的令牌(客户端提交的)
|
||||
requestToken := GetCSRFTokenFromRequest(r)
|
||||
requestToken := GetCSRFTokenFromRequest(c)
|
||||
if requestToken == "" {
|
||||
return false
|
||||
}
|
||||
@@ -92,47 +90,62 @@ func ValidateCSRFToken(r *http.Request) bool {
|
||||
}
|
||||
|
||||
// CSRFProtection CSRF保护中间件
|
||||
func CSRFProtection(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
func CSRFProtection() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 对于GET、HEAD、OPTIONS请求,只生成令牌,不验证
|
||||
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
|
||||
if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead || c.Request.Method == http.MethodOptions {
|
||||
// 生成新的CSRF令牌
|
||||
token, err := GenerateCSRFToken()
|
||||
if err != nil {
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 1,
|
||||
"msg": "Internal Server Error",
|
||||
"data": nil,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
SetCSRFToken(w, token)
|
||||
next(w, r)
|
||||
SetCSRFToken(c, token)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 对于POST、PUT、DELETE等修改性请求,验证CSRF令牌
|
||||
if !ValidateCSRFToken(r) {
|
||||
JsonResponse(w, http.StatusForbidden, false, "CSRF令牌验证失败", nil)
|
||||
if !ValidateCSRFToken(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"code": 1,
|
||||
"msg": "CSRF令牌验证失败",
|
||||
"data": nil,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 验证通过,继续处理请求
|
||||
next(w, r)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireCSRFToken 要求CSRF令牌的中间件(用于特定路由)
|
||||
func RequireCSRFToken(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if !ValidateCSRFToken(r) {
|
||||
JsonResponse(w, http.StatusForbidden, false, "CSRF令牌验证失败", nil)
|
||||
func RequireCSRFToken() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !ValidateCSRFToken(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"code": 1,
|
||||
"msg": "CSRF令牌验证失败",
|
||||
"data": nil,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
next(w, r)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetCSRFTokenForTemplate 获取用于模板的CSRF令牌
|
||||
func GetCSRFTokenForTemplate(r *http.Request) string {
|
||||
func GetCSRFTokenForTemplate(c *gin.Context) string {
|
||||
// 尝试从Cookie获取现有令牌
|
||||
if token := GetCSRFTokenFromCookie(r); token != "" {
|
||||
if token := GetCSRFTokenFromCookie(c); token != "" {
|
||||
return token
|
||||
}
|
||||
|
||||
@@ -145,24 +158,36 @@ func GetCSRFTokenForTemplate(r *http.Request) string {
|
||||
}
|
||||
|
||||
// CSRFTokenHandler 专门用于获取CSRF令牌的API端点
|
||||
func CSRFTokenHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodGet {
|
||||
JsonResponse(w, http.StatusMethodNotAllowed, false, "只支持GET请求", nil)
|
||||
func CSRFTokenHandler(c *gin.Context) {
|
||||
if c.Request.Method != http.MethodGet {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"code": 1,
|
||||
"msg": "只支持GET请求",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新的CSRF令牌
|
||||
token, err := GenerateCSRFToken()
|
||||
if err != nil {
|
||||
JsonResponse(w, http.StatusInternalServerError, false, "生成CSRF令牌失败", nil)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 1,
|
||||
"msg": "生成CSRF令牌失败",
|
||||
"data": nil,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 设置令牌到Cookie和响应头
|
||||
SetCSRFToken(w, token)
|
||||
SetCSRFToken(c, token)
|
||||
|
||||
// 返回令牌给前端
|
||||
JsonResponse(w, http.StatusOK, true, "CSRF令牌获取成功", map[string]interface{}{
|
||||
"csrf_token": token,
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"msg": "CSRF令牌生成成功",
|
||||
"data": gin.H{
|
||||
"csrf_token": token,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -4,10 +4,10 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -68,27 +68,13 @@ type LogEntry struct {
|
||||
Line int `json:"line"` // 源文件行号
|
||||
}
|
||||
|
||||
// WriteJSONResponse 写入JSON响应
|
||||
// w: HTTP响应写入器
|
||||
// statusCode: HTTP状态码
|
||||
// response: 响应数据
|
||||
func WriteJSONResponse(w http.ResponseWriter, statusCode int, response interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
LogError("Failed to encode JSON response", err, nil)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteErrorResponse 写入错误响应
|
||||
// w: HTTP响应写入器
|
||||
// c: Gin上下文
|
||||
// statusCode: HTTP状态码
|
||||
// message: 错误消息
|
||||
// errorCode: 错误代码
|
||||
// data: 附加数据
|
||||
func WriteErrorResponse(w http.ResponseWriter, statusCode int, message, errorCode string, data interface{}) {
|
||||
func WriteErrorResponse(c *gin.Context, statusCode int, message, errorCode string, data interface{}) {
|
||||
response := ErrorResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
@@ -97,15 +83,15 @@ func WriteErrorResponse(w http.ResponseWriter, statusCode int, message, errorCod
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
WriteJSONResponse(w, statusCode, response)
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// WriteSuccessResponse 写入成功响应
|
||||
// w: HTTP响应写入器
|
||||
// c: Gin上下文
|
||||
// statusCode: HTTP状态码
|
||||
// message: 成功消息
|
||||
// data: 响应数据
|
||||
func WriteSuccessResponse(w http.ResponseWriter, statusCode int, message string, data interface{}) {
|
||||
func WriteSuccessResponse(c *gin.Context, statusCode int, message string, data interface{}) {
|
||||
response := SuccessResponse{
|
||||
Success: true,
|
||||
Message: message,
|
||||
@@ -113,57 +99,57 @@ func WriteSuccessResponse(w http.ResponseWriter, statusCode int, message string,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
WriteJSONResponse(w, statusCode, response)
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// HandleDatabaseError 处理数据库错误
|
||||
// w: HTTP响应写入器
|
||||
// c: Gin上下文
|
||||
// err: 数据库错误
|
||||
// operation: 操作描述
|
||||
func HandleDatabaseError(w http.ResponseWriter, err error, operation string) {
|
||||
func HandleDatabaseError(c *gin.Context, err error, operation string) {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
LogWarn(fmt.Sprintf("Record not found during %s", operation), map[string]interface{}{
|
||||
"operation": operation,
|
||||
"error": err.Error(),
|
||||
})
|
||||
WriteErrorResponse(w, http.StatusNotFound, "记录不存在", ErrCodeNotFound, nil)
|
||||
WriteErrorResponse(c, 404, "记录不存在", ErrCodeNotFound, nil)
|
||||
return
|
||||
}
|
||||
|
||||
LogError(fmt.Sprintf("Database error during %s", operation), err, map[string]interface{}{
|
||||
"operation": operation,
|
||||
})
|
||||
WriteErrorResponse(w, http.StatusInternalServerError, "数据库操作失败", ErrCodeDatabaseError, nil)
|
||||
WriteErrorResponse(c, 500, "数据库操作失败", ErrCodeDatabaseError, nil)
|
||||
}
|
||||
|
||||
// HandleValidationError 处理验证错误
|
||||
// w: HTTP响应写入器
|
||||
// c: Gin上下文
|
||||
// message: 验证错误消息
|
||||
// details: 验证错误详情
|
||||
func HandleValidationError(w http.ResponseWriter, message string, details interface{}) {
|
||||
func HandleValidationError(c *gin.Context, message string, details interface{}) {
|
||||
LogWarn("Validation error: "+message, map[string]interface{}{
|
||||
"details": details,
|
||||
})
|
||||
WriteErrorResponse(w, http.StatusBadRequest, message, ErrCodeValidationError, details)
|
||||
WriteErrorResponse(c, 400, message, ErrCodeValidationError, details)
|
||||
}
|
||||
|
||||
// HandleUnauthorizedError 处理未授权错误
|
||||
// w: HTTP响应写入器
|
||||
// c: Gin上下文
|
||||
// message: 错误消息
|
||||
func HandleUnauthorizedError(w http.ResponseWriter, message string) {
|
||||
func HandleUnauthorizedError(c *gin.Context, message string) {
|
||||
LogWarn("Unauthorized access: "+message, nil)
|
||||
WriteErrorResponse(w, http.StatusUnauthorized, message, ErrCodeUnauthorized, nil)
|
||||
WriteErrorResponse(c, 401, message, ErrCodeUnauthorized, nil)
|
||||
}
|
||||
|
||||
// HandleInternalError 处理内部错误
|
||||
// w: HTTP响应写入器
|
||||
// c: Gin上下文
|
||||
// err: 错误
|
||||
// operation: 操作描述
|
||||
func HandleInternalError(w http.ResponseWriter, err error, operation string) {
|
||||
func HandleInternalError(c *gin.Context, err error, operation string) {
|
||||
LogError(fmt.Sprintf("Internal error during %s", operation), err, map[string]interface{}{
|
||||
"operation": operation,
|
||||
})
|
||||
WriteErrorResponse(w, http.StatusInternalServerError, "服务器内部错误", ErrCodeInternalError, nil)
|
||||
WriteErrorResponse(c, 500, "服务器内部错误", ErrCodeInternalError, nil)
|
||||
}
|
||||
|
||||
// LogInfo 记录信息日志
|
||||
|
||||
Reference in New Issue
Block a user