Files
NetworkAuth/utils/csrf.go
skyle1995 c93ee377fe Enhance user authentication and authentication
Fix the modification of personal information
Fix the formatted page template
2025-10-26 03:05:27 +08:00

168 lines
4.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package utils
import (
"crypto/rand"
"crypto/subtle"
"encoding/base64"
"net/http"
)
const (
CSRFTokenLength = 32
CSRFCookieName = "csrf_token"
CSRFHeaderName = "X-CSRF-Token"
CSRFFormField = "csrf_token"
)
// generateRandomBytes 生成指定长度的随机字节
func generateRandomBytes(length int) ([]byte, error) {
bytes := make([]byte, length)
_, err := rand.Read(bytes)
if err != nil {
return nil, err
}
return bytes, nil
}
// GenerateCSRFToken 生成CSRF令牌
func GenerateCSRFToken() (string, error) {
bytes, err := generateRandomBytes(CSRFTokenLength)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(bytes), nil
}
// SetCSRFToken 设置CSRF令牌到Cookie和响应头
func SetCSRFToken(w http.ResponseWriter, token string) {
// 设置CSRF令牌到Cookie
cookie := CreateSecureCookie(CSRFCookieName, token, 3600) // 1小时过期
http.SetCookie(w, cookie)
// 设置CSRF令牌到响应头方便JavaScript获取
w.Header().Set("X-CSRF-Token", token)
}
// GetCSRFTokenFromRequest 从请求中获取CSRF令牌
// 优先级Header > Form > Cookie
func GetCSRFTokenFromRequest(r *http.Request) string {
// 1. 从Header获取
if token := r.Header.Get(CSRFHeaderName); token != "" {
return token
}
// 2. 从Form获取
if token := r.FormValue(CSRFFormField); token != "" {
return token
}
// 3. 从Cookie获取作为备选
if cookie, err := r.Cookie(CSRFCookieName); err == nil {
return cookie.Value
}
return ""
}
// GetCSRFTokenFromCookie 从Cookie中获取CSRF令牌
func GetCSRFTokenFromCookie(r *http.Request) string {
cookie, err := r.Cookie(CSRFCookieName)
if err != nil {
return ""
}
return cookie.Value
}
// ValidateCSRFToken 验证CSRF令牌
func ValidateCSRFToken(r *http.Request) bool {
// 获取Cookie中的令牌服务器端存储的
cookieToken := GetCSRFTokenFromCookie(r)
if cookieToken == "" {
return false
}
// 获取请求中的令牌(客户端提交的)
requestToken := GetCSRFTokenFromRequest(r)
if requestToken == "" {
return false
}
// 使用常量时间比较防止时序攻击
return subtle.ConstantTimeCompare([]byte(cookieToken), []byte(requestToken)) == 1
}
// CSRFProtection CSRF保护中间件
func CSRFProtection(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 对于GET、HEAD、OPTIONS请求只生成令牌不验证
if r.Method == http.MethodGet || r.Method == http.MethodHead || r.Method == http.MethodOptions {
// 生成新的CSRF令牌
token, err := GenerateCSRFToken()
if err != nil {
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
SetCSRFToken(w, token)
next(w, r)
return
}
// 对于POST、PUT、DELETE等修改性请求验证CSRF令牌
if !ValidateCSRFToken(r) {
JsonResponse(w, http.StatusForbidden, false, "CSRF令牌验证失败", nil)
return
}
// 验证通过,继续处理请求
next(w, r)
}
}
// RequireCSRFToken 要求CSRF令牌的中间件用于特定路由
func RequireCSRFToken(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !ValidateCSRFToken(r) {
JsonResponse(w, http.StatusForbidden, false, "CSRF令牌验证失败", nil)
return
}
next(w, r)
}
}
// GetCSRFTokenForTemplate 获取用于模板的CSRF令牌
func GetCSRFTokenForTemplate(r *http.Request) string {
// 尝试从Cookie获取现有令牌
if token := GetCSRFTokenFromCookie(r); token != "" {
return token
}
// 如果没有现有令牌,生成新的(但不设置到响应中)
token, err := GenerateCSRFToken()
if err != nil {
return ""
}
return token
}
// CSRFTokenHandler 专门用于获取CSRF令牌的API端点
func CSRFTokenHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
JsonResponse(w, http.StatusMethodNotAllowed, false, "只支持GET请求", nil)
return
}
// 生成新的CSRF令牌
token, err := GenerateCSRFToken()
if err != nil {
JsonResponse(w, http.StatusInternalServerError, false, "生成CSRF令牌失败", nil)
return
}
// 设置令牌到Cookie和响应头
SetCSRFToken(w, token)
// 返回令牌给前端
JsonResponse(w, http.StatusOK, true, "CSRF令牌获取成功", map[string]interface{}{
"csrf_token": token,
})
}