diff --git a/cmd/root.go b/cmd/root.go index ebf331c..6ed647c 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -2,6 +2,7 @@ package cmd import ( "NetworkAuth/config" + "NetworkAuth/utils" "NetworkAuth/utils/logger" "io" "os" @@ -67,7 +68,7 @@ func setupLogrusForNonHTTP() { if cfgFile != "" { config.Init(cfgFile) } else { - config.Init("./config.json") + config.Init("config.json") } // 根据配置文件进一步配置logrus @@ -105,7 +106,11 @@ func setupLogrusFromConfig() { logFile := viper.GetString("log.file") if logFile != "" { // 确保日志目录存在 - logDir := filepath.Dir(logFile) + path := logFile + if !filepath.IsAbs(path) { + path = filepath.Join(utils.GetRootDir(), path) + } + logDir := filepath.Dir(path) if err := os.MkdirAll(logDir, 0755); err != nil { logrus.WithError(err).Error("创建日志目录失败") return @@ -113,7 +118,7 @@ func setupLogrusFromConfig() { // 配置lumberjack日志轮转 lumberjackLogger := &lumberjack.Logger{ - Filename: logFile, + Filename: path, MaxSize: viper.GetInt("log.max_size"), // MB MaxBackups: viper.GetInt("log.max_backups"), // 保留的旧日志文件数量 MaxAge: viper.GetInt("log.max_age"), // 天数 diff --git a/config/config.go b/config/config.go index 387ed89..09b56db 100644 --- a/config/config.go +++ b/config/config.go @@ -6,6 +6,8 @@ import ( "os" "path/filepath" + "NetworkAuth/utils" + log "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -107,7 +109,7 @@ func GetDefaultAppConfig() *AppConfig { MaxOpenConns: 100, }, SQLite: SQLiteConfig{ - Path: "./database.db", + Path: "database.db", }, }, Redis: RedisConfig{ @@ -118,7 +120,7 @@ func GetDefaultAppConfig() *AppConfig { }, Log: LogConfig{ Level: "info", - File: "./logs/app.log", + File: "logs/app.log", MaxSize: 100, MaxBackups: 5, MaxAge: 30, @@ -128,6 +130,9 @@ func GetDefaultAppConfig() *AppConfig { // Init 初始化配置文件 func Init(cfgFilePath string) { + if !filepath.IsAbs(cfgFilePath) { + cfgFilePath = filepath.Join(utils.GetRootDir(), cfgFilePath) + } currentConfigFilePath = cfgFilePath viper.SetConfigFile(cfgFilePath) viper.SetConfigType("json") @@ -204,7 +209,10 @@ func SaveConfig(appConfig *AppConfig) error { return err } if currentConfigFilePath == "" { - currentConfigFilePath = "./config.json" + currentConfigFilePath = "config.json" + } + if !filepath.IsAbs(currentConfigFilePath) { + currentConfigFilePath = filepath.Join(utils.GetRootDir(), currentConfigFilePath) } if err := os.MkdirAll(filepath.Dir(currentConfigFilePath), 0755); err != nil { return err diff --git a/config/validator.go b/config/validator.go index 50e3943..06764b5 100644 --- a/config/validator.go +++ b/config/validator.go @@ -9,6 +9,8 @@ import ( "strconv" "strings" + "NetworkAuth/utils" + log "github.com/sirupsen/logrus" "github.com/spf13/viper" ) @@ -125,8 +127,13 @@ func validateSQLiteConfig(config *SQLiteConfig) error { return errors.New("SQLite数据库路径不能为空") } + path := config.Path + if !filepath.IsAbs(path) { + path = filepath.Join(utils.GetRootDir(), path) + } + // 检查目录是否存在,不存在则创建 - dir := filepath.Dir(config.Path) + dir := filepath.Dir(path) if _, err := os.Stat(dir); os.IsNotExist(err) { if err := os.MkdirAll(dir, 0755); err != nil { return fmt.Errorf("创建SQLite数据库目录失败: %w", err) @@ -160,7 +167,11 @@ func validateLogConfig(config *LogConfig) error { // 检查日志文件目录(仅当日志文件路径不为空时) if config.File != "" { - dir := filepath.Dir(config.File) + path := config.File + if !filepath.IsAbs(path) { + path = filepath.Join(utils.GetRootDir(), path) + } + dir := filepath.Dir(path) if _, err := os.Stat(dir); os.IsNotExist(err) { if err := os.MkdirAll(dir, 0755); err != nil { return fmt.Errorf("创建日志目录失败: %w", err) diff --git a/database/database.go b/database/database.go index 80ec360..81f5083 100644 --- a/database/database.go +++ b/database/database.go @@ -7,6 +7,7 @@ import ( "fmt" "log" "os" + "path/filepath" "sync" "time" @@ -103,7 +104,10 @@ func performInitFromViper() error { case "sqlite": dbPath := cfg.Database.SQLite.Path if dbPath == "" { - dbPath = "./database.db" + dbPath = "database.db" + } + if !filepath.IsAbs(dbPath) { + dbPath = filepath.Join(utils.GetRootDir(), dbPath) } if _, err := os.Stat(dbPath); os.IsNotExist(err) { logrus.Info("SQLite 数据库文件不存在,系统尚未安装,跳过数据库连接") @@ -209,7 +213,10 @@ func buildGormLogger(level string) gLogger.Interface { func initSQLite(sqliteConfig *appconfig.SQLiteConfig, logLevel string) error { path := sqliteConfig.Path if path == "" { - path = "./database.db" + path = "database.db" + } + if !filepath.IsAbs(path) { + path = filepath.Join(utils.GetRootDir(), path) } dsn := fmt.Sprintf("file:%s?cache=shared&_busy_timeout=5000&_fk=1", path) db, err := gorm.Open(sqlite.Open(dsn), &gorm.Config{Logger: buildGormLogger(logLevel)}) diff --git a/server/routes.go b/server/routes.go index 27e4239..70bc642 100644 --- a/server/routes.go +++ b/server/routes.go @@ -2,12 +2,14 @@ package server import ( "NetworkAuth/public" + "NetworkAuth/utils" "io" "io/fs" "net/http" "net/http/httputil" "net/url" "os" + "path/filepath" "strings" "github.com/gin-gonic/gin" @@ -48,6 +50,9 @@ func registerFrontendRoutes(r *gin.Engine) { return // 反向代理接管了所有非 API 路由,直接返回 } else { // 使用本地外部目录 + if !filepath.IsAbs(distConfig) { + distConfig = filepath.Join(utils.GetRootDir(), distConfig) + } fileServer = http.FileServer(http.Dir(distConfig)) // 拦截并处理静态资源请求 diff --git a/services/request/resty_client.go b/services/request/resty_client.go index 55bd246..5816cb5 100644 --- a/services/request/resty_client.go +++ b/services/request/resty_client.go @@ -1,326 +1,343 @@ -package request - -import ( - "bytes" - "compress/flate" - "compress/gzip" - "encoding/json" - "io" - "net/http" - "net/http/cookiejar" - "reflect" - "strings" - "time" - "unsafe" - - "github.com/andybalholm/brotli" - "github.com/go-resty/resty/v2" - "github.com/skycheung803/go-bypasser" -) - -type RestyClient struct { - client *resty.Client -} - -func (request *RestyClient) Resty() *resty.Client { - return request.client -} - -// NewClient 创建一个基于 uTLS 指纹与 HTTP/2 指纹的 Resty 客户端 -// baseURL 不为空则设置默认 BaseURL;proxyStr 不为空则启用 HTTP 代理(仅 HTTP/1.1) -// persistCookies 启用持久化 Cookie;followRedirect 启用重定向跟随;timeout 设置超时时间(秒,0 或负数则默认 60 秒) -func NewClient(baseURL string, proxyStr string, persistCookies bool, timeout int) *RestyClient { - rc := resty.New() - - if baseURL != "" { - rc.SetBaseURL(baseURL) - } - - if persistCookies { - jar, _ := cookiejar.New(nil) - rc.SetCookieJar(jar) - } - - // 设置请求超时时间,如果传入 0 或负数则默认 60 秒 - if timeout <= 0 { - timeout = 60 - } - rc.SetTimeout(time.Duration(timeout) * time.Second) - - // 统一设置客户端默认请求头(调用级 headers 可覆盖),字段按字母顺序排列 - rc.SetHeader("accept", "*/*") - rc.SetHeader("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6") - rc.SetHeader("connection", "keep-alive") - rc.SetHeader("pragma", "no-cache") - rc.SetHeader("priority", "u=1,i") - rc.SetHeader("sec-ch-ua", "\"Chromium\";v=\"146\", \"Not-A.Brand\";v=\"24\", \"Google Chrome\";v=\"146\"") - rc.SetHeader("sec-ch-ua-mobile", "?0") - rc.SetHeader("sec-ch-ua-platform", "\"macOS\"") - rc.SetHeader("sec-fetch-dest", "empty") - rc.SetHeader("sec-fetch-mode", "cors") - rc.SetHeader("sec-fetch-site", "same-origin") - rc.SetHeader("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Safari/537.36") - - // 初始化 go-bypasser 替代原有的 spoofed-round-tripper - opts := []bypasser.BypasserOption{ - bypasser.WithInsecureSkipVerify(true), - } - if proxyStr != "" { - opts = append(opts, bypasser.WithProxy(proxyStr)) - } - - bypass, err := bypasser.NewBypasser(opts...) - if err != nil { - panic(err) - } - - rc.SetTransport(bypass.Transport) - - return &RestyClient{client: rc} -} - -// fillResponseBody 使用反射强制填充响应体 -// 当 Resty 因为重定向策略错误而提前返回时,它可能不会读取 Body -// 此方法手动读取 RawResponse.Body 并回填到 resty.Response 的私有 body 字段中 -func (request *RestyClient) fillResponseBody(resp *resty.Response) { - if resp == nil || resp.RawResponse == nil { - return - } - // 如果已经有 body 内容,则不处理 - if len(resp.Body()) > 0 { - return - } - - // 读取底层 Body - bodyBytes, err := io.ReadAll(resp.RawResponse.Body) - if err != nil { - return - } - resp.RawResponse.Body.Close() - // 重置 Body 以便后续可能得读取 - resp.RawResponse.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - - // 使用反射设置私有字段 body - v := reflect.ValueOf(resp).Elem() - f := v.FieldByName("body") - if f.IsValid() { - // 必须使用 UnsafeAddr 获取未导出字段的地址 - rf := reflect.NewAt(f.Type(), unsafe.Pointer(f.UnsafeAddr())).Elem() - rf.SetBytes(bodyBytes) - } - - // 设置 size 字段 - s := v.FieldByName("size") - if s.IsValid() { - rs := reflect.NewAt(s.Type(), unsafe.Pointer(s.UnsafeAddr())).Elem() - rs.SetInt(int64(len(bodyBytes))) - } -} - -// makeReq 构造带可选请求头的 resty.Request -// 功能:基于客户端创建请求对象,并在传入 headers 时进行设置 -// 返回:带有请求头的请求对象 -func (request *RestyClient) makeReq(headers map[string]string, cookies []*http.Cookie) *resty.Request { - req := request.client.R() - if len(headers) > 0 { - req = req.SetHeaders(headers) - } - if len(cookies) > 0 { - req = req.SetCookies(cookies) - } - return req -} - -// doWithEncodingFallback 封装请求发送并在出现压缩相关错误时进行一次降级重试 -// 逻辑:首次请求失败且错误包含 gzip/zstd/brotli/magic number mismatch 时,设置 accept-encoding 为 identity 重试一次 -func (request *RestyClient) doWithEncodingFallback(headers map[string]string, cookies []*http.Cookie, allowRedirect bool, do func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) { - req := request.makeReq(headers, cookies) - if allowRedirect { - request.client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(10)) - } else { - // 使用 http.ErrUseLastResponse 确保 302 响应被返回且 Body 可读,而不是报错 - request.client.SetRedirectPolicy(resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - })) - } - resp, err := do(req) - - // 尝试补救响应体(特别是当重定向被禁用导致报错时) - request.fillResponseBody(resp) - - if err == nil { - return resp, nil - } - s := err.Error() - if strings.Contains(s, "gzip: invalid header") || strings.Contains(s, "magic number mismatch") || strings.Contains(s, "zstd") || strings.Contains(s, "brotli") { - h2 := map[string]string{} - for k, v := range headers { - if strings.ToLower(k) != "accept-encoding" { - h2[k] = v - } - } - h2["Accept-Encoding"] = "identity" - req2 := request.makeReq(h2, cookies) - resp2, err2 := do(req2) - request.fillResponseBody(resp2) - if err2 == nil { - return resp2, nil - } - } - return resp, err -} - -// decodeResponse 处理响应解压与 JSON 解析 -// 功能:自动识别 gzip 压缩并解压;在 result 非空时按 JSON 解析到 result -// 返回:解析错误(成功时为 nil) -func (request *RestyClient) decodeResponse(resp *resty.Response, result interface{}) error { - if resp == nil { - return nil - } - ct := strings.ToLower(resp.Header().Get("Content-Type")) - ce := strings.ToLower(resp.Header().Get("Content-Encoding")) - body := resp.Body() - if strings.Contains(ce, "gzip") && len(body) > 0 { - gr, gerr := gzip.NewReader(bytes.NewReader(body)) - if gerr == nil { - defer gr.Close() - if dec, derr := io.ReadAll(gr); derr == nil { - body = dec - resp.SetBody(body) - } - } - } else if strings.Contains(ce, "deflate") && len(body) > 0 { - // 处理 deflate 压缩 - dr := flate.NewReader(bytes.NewReader(body)) - defer dr.Close() - if dec, derr := io.ReadAll(dr); derr == nil { - body = dec - resp.SetBody(body) - } - } else if strings.Contains(ce, "br") && len(body) > 0 { - // 处理 brotli 压缩 - br := brotli.NewReader(bytes.NewReader(body)) - if dec, derr := io.ReadAll(br); derr == nil { - body = dec - resp.SetBody(body) // 将解压后的 body 写回 response - } - } - if result != nil && (strings.Contains(ct, "application/json") || json.Valid(body)) { - if err := json.Unmarshal(body, result); err != nil { - return err - } - } - return nil -} - -// RestyGet 发送 GET 请求 -func (request *RestyClient) RestyGet(path string, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { - resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { - return r.Get(path) - }) - if resp == nil && err != nil { - return nil, err - } - - if err := request.decodeResponse(resp, result); err != nil { - return nil, err - } - - return resp, err -} - -// RestyPost 发送 POST 请求 -func (request *RestyClient) RestyPost(path string, data any, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { - resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { - return r.SetBody(data).Post(path) - }) - if resp == nil && err != nil { - return nil, err - } - - if err := request.decodeResponse(resp, result); err != nil { - return nil, err - } - - return resp, err -} - -// RestyPut 发送 PUT 请求 -// 功能:发送 PUT,支持请求级 headers 覆盖客户端默认,自动识别 gzip 并解析 JSON -// 返回:响应对象与错误信息 -func (request *RestyClient) RestyPut(path string, data any, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { - resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { - return r.SetBody(data).Put(path) - }) - if resp == nil && err != nil { - return nil, err - } - - if err := request.decodeResponse(resp, result); err != nil { - return nil, err - } - - return resp, err -} - -// RestyPatch 发送 PATCH 请求 -// 功能:发送 PATCH,支持请求级 headers 覆盖客户端默认,自动识别 gzip 并解析 JSON -// 返回:响应对象与错误信息 -func (request *RestyClient) RestyPatch(path string, data any, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { - resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { - return r.SetBody(data).Patch(path) - }) - if resp == nil && err != nil { - return nil, err - } - - if err := request.decodeResponse(resp, result); err != nil { - return nil, err - } - - return resp, err -} - -// RestyDelete 发送 DELETE 请求 -// 功能:发送 DELETE,支持请求级 headers 覆盖客户端默认,自动识别 gzip 并解析 JSON -// 返回:响应对象与错误信息 -func (request *RestyClient) RestyDelete(path string, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { - resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { - return r.Delete(path) - }) - if resp == nil && err != nil { - return nil, err - } - - if err := request.decodeResponse(resp, result); err != nil { - return nil, err - } - - return resp, err -} - -// RestyHead 发送 HEAD 请求 -// 功能:发送 HEAD,支持请求级 headers 覆盖客户端默认;HEAD 通常无正文 -// 返回:响应对象与错误信息 -func (request *RestyClient) RestyHead(path string, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { - resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { - return r.Head(path) - }) - if resp == nil && err != nil { - return nil, err - } - return resp, err -} - -// RestyOptions 发送 OPTIONS 请求 -// 功能:发送 OPTIONS,支持请求级 headers 覆盖客户端默认 -// 返回:响应对象与错误信息 -func (request *RestyClient) RestyOptions(path string, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { - resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { - return r.Options(path) - }) - if resp == nil && err != nil { - return nil, err - } - return resp, err -} +package request + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "encoding/json" + "io" + "net/http" + "net/http/cookiejar" + "reflect" + "strings" + "time" + "unsafe" + + "github.com/andybalholm/brotli" + "github.com/go-resty/resty/v2" + "github.com/skycheung803/go-bypasser" +) + +type RestyClient struct { + client *resty.Client +} + +func (request *RestyClient) Resty() *resty.Client { + return request.client +} + +// NewClient 创建一个基于 uTLS 指纹与 HTTP/2 指纹的 Resty 客户端 +// baseURL 不为空则设置默认 BaseURL;proxyStr 不为空则启用 HTTP 代理(仅 HTTP/1.1) +// persistCookies 启用持久化 Cookie;followRedirect 启用重定向跟随;timeout 设置超时时间(秒,0 或负数则默认 60 秒) +func NewClient(baseURL string, proxyStr string, persistCookies bool, timeout int) *RestyClient { + rc := resty.New() + + if baseURL != "" { + rc.SetBaseURL(baseURL) + } + + if persistCookies { + jar, _ := cookiejar.New(nil) + rc.SetCookieJar(jar) + } + + // 设置请求超时时间,如果传入 0 或负数则默认 60 秒 + if timeout <= 0 { + timeout = 60 + } + rc.SetTimeout(time.Duration(timeout) * time.Second) + + // 统一设置客户端默认请求头(调用级 headers 可覆盖),字段按字母顺序排列 + rc.SetHeader("accept", "*/*") + rc.SetHeader("accept-language", "zh-CN,zh;q=0.9,en;q=0.8,en-GB;q=0.7,en-US;q=0.6") + rc.SetHeader("connection", "keep-alive") + rc.SetHeader("pragma", "no-cache") + rc.SetHeader("priority", "u=1,i") + rc.SetHeader("sec-ch-ua", "\"Chromium\";v=\"146\", \"Not-A.Brand\";v=\"24\", \"Google Chrome\";v=\"146\"") + rc.SetHeader("sec-ch-ua-mobile", "?0") + rc.SetHeader("sec-ch-ua-platform", "\"macOS\"") + rc.SetHeader("sec-fetch-dest", "empty") + rc.SetHeader("sec-fetch-mode", "cors") + rc.SetHeader("sec-fetch-site", "same-origin") + rc.SetHeader("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/146.0.0.0 Safari/537.36") + + // 初始化 go-bypasser 替代原有的 spoofed-round-tripper + opts := []bypasser.BypasserOption{ + bypasser.WithInsecureSkipVerify(true), + } + if proxyStr != "" { + opts = append(opts, bypasser.WithProxy(proxyStr)) + } + + bypass, err := bypasser.NewBypasser(opts...) + if err != nil { + panic(err) + } + + rc.SetTransport(&sanitizeTransport{t: bypass.Transport}) + + return &RestyClient{client: rc} +} + +// sanitizeTransport 包装 http.RoundTripper 以修复底层库可能违背 Go 接口约定的行为 +type sanitizeTransport struct { + t http.RoundTripper +} + +func (s *sanitizeTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := s.t.RoundTrip(req) + // net/http 规定 RoundTripper 要么返回有效的 resp 和 nil error,要么返回 nil resp 和有效的 error。 + // 某些第三方库(如部分 tls-client 封装)在遇到网络小问题时会同时返回 resp 和 err。 + // 这会导致 net/http 打印 "RoundTripper returned a response & error; ignoring response" 并强制丢弃响应。 + // 在这里我们进行修正:如果已经拿到了响应(哪怕是不完整的),我们优先保留响应并将 err 置空,让上层通过读取 Body 自行发现错误。 + if resp != nil && err != nil { + err = nil + } + return resp, err +} + +// fillResponseBody 使用反射强制填充响应体 +// 当 Resty 因为重定向策略错误而提前返回时,它可能不会读取 Body +// 此方法手动读取 RawResponse.Body 并回填到 resty.Response 的私有 body 字段中 +func (request *RestyClient) fillResponseBody(resp *resty.Response) { + if resp == nil || resp.RawResponse == nil { + return + } + // 如果已经有 body 内容,则不处理 + if len(resp.Body()) > 0 { + return + } + + // 读取底层 Body + bodyBytes, err := io.ReadAll(resp.RawResponse.Body) + if err != nil { + return + } + resp.RawResponse.Body.Close() + // 重置 Body 以便后续可能得读取 + resp.RawResponse.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // 使用反射设置私有字段 body + v := reflect.ValueOf(resp).Elem() + f := v.FieldByName("body") + if f.IsValid() { + // 必须使用 UnsafeAddr 获取未导出字段的地址 + rf := reflect.NewAt(f.Type(), unsafe.Pointer(f.UnsafeAddr())).Elem() + rf.SetBytes(bodyBytes) + } + + // 设置 size 字段 + s := v.FieldByName("size") + if s.IsValid() { + rs := reflect.NewAt(s.Type(), unsafe.Pointer(s.UnsafeAddr())).Elem() + rs.SetInt(int64(len(bodyBytes))) + } +} + +// makeReq 构造带可选请求头的 resty.Request +// 功能:基于客户端创建请求对象,并在传入 headers 时进行设置 +// 返回:带有请求头的请求对象 +func (request *RestyClient) makeReq(headers map[string]string, cookies []*http.Cookie) *resty.Request { + req := request.client.R() + if len(headers) > 0 { + req = req.SetHeaders(headers) + } + if len(cookies) > 0 { + req = req.SetCookies(cookies) + } + return req +} + +// doWithEncodingFallback 封装请求发送并在出现压缩相关错误时进行一次降级重试 +// 逻辑:首次请求失败且错误包含 gzip/zstd/brotli/magic number mismatch 时,设置 accept-encoding 为 identity 重试一次 +func (request *RestyClient) doWithEncodingFallback(headers map[string]string, cookies []*http.Cookie, allowRedirect bool, do func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) { + req := request.makeReq(headers, cookies) + if allowRedirect { + request.client.SetRedirectPolicy(resty.FlexibleRedirectPolicy(10)) + } else { + // 使用 http.ErrUseLastResponse 确保 302 响应被返回且 Body 可读,而不是报错 + request.client.SetRedirectPolicy(resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + })) + } + resp, err := do(req) + + // 尝试补救响应体(特别是当重定向被禁用导致报错时) + request.fillResponseBody(resp) + + if err == nil { + return resp, nil + } + s := err.Error() + if strings.Contains(s, "gzip: invalid header") || strings.Contains(s, "magic number mismatch") || strings.Contains(s, "zstd") || strings.Contains(s, "brotli") { + h2 := map[string]string{} + for k, v := range headers { + if strings.ToLower(k) != "accept-encoding" { + h2[k] = v + } + } + h2["Accept-Encoding"] = "identity" + req2 := request.makeReq(h2, cookies) + resp2, err2 := do(req2) + request.fillResponseBody(resp2) + if err2 == nil { + return resp2, nil + } + } + return resp, err +} + +// decodeResponse 处理响应解压与 JSON 解析 +// 功能:自动识别 gzip 压缩并解压;在 result 非空时按 JSON 解析到 result +// 返回:解析错误(成功时为 nil) +func (request *RestyClient) decodeResponse(resp *resty.Response, result interface{}) error { + if resp == nil { + return nil + } + ct := strings.ToLower(resp.Header().Get("Content-Type")) + ce := strings.ToLower(resp.Header().Get("Content-Encoding")) + body := resp.Body() + if strings.Contains(ce, "gzip") && len(body) > 0 { + gr, gerr := gzip.NewReader(bytes.NewReader(body)) + if gerr == nil { + defer gr.Close() + if dec, derr := io.ReadAll(gr); derr == nil { + body = dec + resp.SetBody(body) + } + } + } else if strings.Contains(ce, "deflate") && len(body) > 0 { + // 处理 deflate 压缩 + dr := flate.NewReader(bytes.NewReader(body)) + defer dr.Close() + if dec, derr := io.ReadAll(dr); derr == nil { + body = dec + resp.SetBody(body) + } + } else if strings.Contains(ce, "br") && len(body) > 0 { + // 处理 brotli 压缩 + br := brotli.NewReader(bytes.NewReader(body)) + if dec, derr := io.ReadAll(br); derr == nil { + body = dec + resp.SetBody(body) // 将解压后的 body 写回 response + } + } + if result != nil && (strings.Contains(ct, "application/json") || json.Valid(body)) { + if err := json.Unmarshal(body, result); err != nil { + return err + } + } + return nil +} + +// RestyGet 发送 GET 请求 +func (request *RestyClient) RestyGet(path string, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { + resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { + return r.Get(path) + }) + if resp == nil && err != nil { + return nil, err + } + + if err := request.decodeResponse(resp, result); err != nil { + return nil, err + } + + return resp, err +} + +// RestyPost 发送 POST 请求 +func (request *RestyClient) RestyPost(path string, data any, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { + resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(data).Post(path) + }) + if resp == nil && err != nil { + return nil, err + } + + if err := request.decodeResponse(resp, result); err != nil { + return nil, err + } + + return resp, err +} + +// RestyPut 发送 PUT 请求 +// 功能:发送 PUT,支持请求级 headers 覆盖客户端默认,自动识别 gzip 并解析 JSON +// 返回:响应对象与错误信息 +func (request *RestyClient) RestyPut(path string, data any, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { + resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(data).Put(path) + }) + if resp == nil && err != nil { + return nil, err + } + + if err := request.decodeResponse(resp, result); err != nil { + return nil, err + } + + return resp, err +} + +// RestyPatch 发送 PATCH 请求 +// 功能:发送 PATCH,支持请求级 headers 覆盖客户端默认,自动识别 gzip 并解析 JSON +// 返回:响应对象与错误信息 +func (request *RestyClient) RestyPatch(path string, data any, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { + resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { + return r.SetBody(data).Patch(path) + }) + if resp == nil && err != nil { + return nil, err + } + + if err := request.decodeResponse(resp, result); err != nil { + return nil, err + } + + return resp, err +} + +// RestyDelete 发送 DELETE 请求 +// 功能:发送 DELETE,支持请求级 headers 覆盖客户端默认,自动识别 gzip 并解析 JSON +// 返回:响应对象与错误信息 +func (request *RestyClient) RestyDelete(path string, result interface{}, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { + resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { + return r.Delete(path) + }) + if resp == nil && err != nil { + return nil, err + } + + if err := request.decodeResponse(resp, result); err != nil { + return nil, err + } + + return resp, err +} + +// RestyHead 发送 HEAD 请求 +// 功能:发送 HEAD,支持请求级 headers 覆盖客户端默认;HEAD 通常无正文 +// 返回:响应对象与错误信息 +func (request *RestyClient) RestyHead(path string, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { + resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { + return r.Head(path) + }) + if resp == nil && err != nil { + return nil, err + } + return resp, err +} + +// RestyOptions 发送 OPTIONS 请求 +// 功能:发送 OPTIONS,支持请求级 headers 覆盖客户端默认 +// 返回:响应对象与错误信息 +func (request *RestyClient) RestyOptions(path string, headers map[string]string, cookies []*http.Cookie, allowRedirect bool) (*resty.Response, error) { + resp, err := request.doWithEncodingFallback(headers, cookies, allowRedirect, func(r *resty.Request) (*resty.Response, error) { + return r.Options(path) + }) + if resp == nil && err != nil { + return nil, err + } + return resp, err +} diff --git a/utils/path.go b/utils/path.go new file mode 100644 index 0000000..9460310 --- /dev/null +++ b/utils/path.go @@ -0,0 +1,64 @@ +package utils + +import ( + "os" + "path/filepath" + "strings" +) + +// GetRootDir 获取当前程序运行的真实根目录 +// 能够智能、跨平台地识别是编译后的可执行文件运行,还是通过 `go run` 运行(通常在临时目录下) +func GetRootDir() string { + var baseDir string + + // 首先尝试获取当前工作目录 + workDir, err := os.Getwd() + if err != nil { + workDir = "." + } + + // 获取程序可执行文件所在目录 + execPath, err := os.Executable() + if err != nil { + // 如果获取可执行文件路径失败,使用当前工作目录 + return workDir + } + + // 解析软链接,获取真实物理路径(macOS 下 /tmp 经常是 /private/tmp 的软链) + realExecPath, err := filepath.EvalSymlinks(execPath) + if err == nil { + execPath = realExecPath + } + execDir := filepath.Dir(execPath) + + realTempDir, err := filepath.EvalSymlinks(os.TempDir()) + if err != nil { + realTempDir = os.TempDir() + } + + // 跨平台安全地判断 execDir 是否在 realTempDir 内部 + // 使用 filepath.Rel 可以避免直接 HasPrefix 带来的大小写、路径分隔符以及部分目录名重合的问题 + rel, err := filepath.Rel(realTempDir, execDir) + isGoRun := false + if err == nil { + // 如果 rel 不以 ".." 开头,说明 execDir 在 TempDir 内部,即为 go run 模式 + if rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator)) { + isGoRun = true + } + } else { + // fallback: 如果 Rel 失败(例如跨盘符),则退回简单的 HasPrefix 判断(带上分隔符防误判) + cleanTemp := filepath.Clean(realTempDir) + string(os.PathSeparator) + cleanExec := filepath.Clean(execDir) + string(os.PathSeparator) + if strings.HasPrefix(strings.ToLower(cleanExec), strings.ToLower(cleanTemp)) { + isGoRun = true + } + } + + if isGoRun { + baseDir = workDir + } else { + baseDir = execDir + } + + return baseDir +}