diff --git a/README.md b/README.md index 1092e22..edee043 100644 --- a/README.md +++ b/README.md @@ -26,15 +26,18 @@ _Please note: this is not a replacement for `http.Client`, but rather a companio `go get -u github.com/go-pkgz/requester` - -## `Requester` middlewares - ## Overview +*Built-in middlewares:* + - `Header` - appends user-defined headers to all requests. +- `MaxConcurrent` - sets maximum concurrency +- `Retry` - sets retry on errors and status codes - `JSON` - sets headers `"Content-Type": "application/json"` and `"Accept": "application/json"` - `BasicAuth(user, passwd string)` - adds HTTP Basic Authentication -- `MaxConcurrent` - sets maximum concurrency + +*Interfaces for external middlewares:* + - `Repeater` - sets repeater to retry failed requests. Doesn't provide repeater implementation but wraps it. Compatible with any repeater (for example [go-pkgz/repeater](https://github.com/go-pkgz/repeater)) implementing a single method interface `Do(ctx context.Context, fun func() error, errors ...error) (err error)` interface. - `Cache` - sets any `LoadingCache` implementation to be used for request/response caching. Doesn't provide cache, but wraps it. Compatible with any cache (for example a family of caches from [go-pkgz/lcw](https://github.com/go-pkgz/lcw)) implementing a single-method interface `Get(key string, fn func() (interface{}, error)) (val interface{}, err error)` - `Logger` - sets logger, compatible with any implementation of a single-method interface `Logf(format string, args ...interface{})`, for example [go-pkgz/lgr](https://github.com/go-pkgz/lgr) @@ -45,7 +48,131 @@ Convenient functional adapter `middleware.RoundTripperFunc` provided. See examples of the usage in [_example](https://github.com/go-pkgz/requester/tree/master/_example) -### Logging +### Header middleware + +`Header` middleware adds user-defined headers to all requests. It expects a map of headers to be added. For example: + +```go +rq := requester.New(http.Client{}, middleware.Header("X-Auth", "123456789")) +``` +### MaxConcurrent middleware + +`MaxConcurrent` middleware can be used to limit the concurrency of a given requester and limit overall concurrency for multiple requesters. For the first case, `MaxConcurrent(N)` should be created in the requester chain of middlewares. For example, `rq := requester.New(http.Client{Timeout: 3 * time.Second}, middleware.MaxConcurrent(8))`. To make it global, `MaxConcurrent` should be created once, outside the chain, and passed into each requester. For example: + +```go +mc := middleware.MaxConcurrent(16) +rq1 := requester.New(http.Client{Timeout: 3 * time.Second}, mc) +rq2 := requester.New(http.Client{Timeout: 1 * time.Second}, middleware.JSON, mc) +``` +### Retry middleware + +Retry middleware provides a flexible retry mechanism with different backoff strategies. By default, it retries on network errors and 5xx responses. + +```go +// retry 3 times with exponential backoff, starting from 100ms +rq := requester.New(http.Client{}, middleware.Retry(3, 100*time.Millisecond)) + +// retry with custom options +rq := requester.New(http.Client{}, middleware.Retry(3, 100*time.Millisecond, + middleware.RetryWithBackoff(middleware.BackoffLinear), // use linear backoff + middleware.RetryMaxDelay(5*time.Second), // cap maximum delay + middleware.RetryWithJitter(0.1), // add 10% randomization + middleware.RetryOnCodes(503, 502), // retry only on specific codes + // or middleware.RetryExcludeCodes(404, 401), // alternatively, retry on all except these codes +)) +``` + +Default configuration: +- 3 attempts +- Initial delay: 100ms +- Max delay: 30s +- Exponential backoff +- 10% jitter +- Retries on 5xx status codes + +Retry Options: +- `RetryWithBackoff(t BackoffType)` - set backoff strategy (Constant, Linear, or Exponential) +- `RetryMaxDelay(d time.Duration)` - cap the maximum delay between retries +- `RetryWithJitter(f float64)` - add randomization to delays (0-1.0 factor) +- `RetryOnCodes(codes ...int)` - retry only on specific status codes +- `RetryExcludeCodes(codes ...int)` - retry on all codes except specified + +Note: `RetryOnCodes` and `RetryExcludeCodes` are mutually exclusive and can't be used together. + +### Cache middleware + +Cache middleware provides an **in-memory caching layer** for HTTP responses. It improves performance by avoiding repeated network calls for the same request. + +#### **Basic Usage** + +```go +rq := requester.New(http.Client{}, middleware.Cache()) +``` + +By default: + +- Only GET requests are cached +- TTL (Time-To-Live) is 5 minutes +- Maximum cache size is 1000 entries +- Caches only HTTP 200 responses + + +#### **Cache Configuration Options** + +```go +rq := requester.New(http.Client{}, middleware.Cache( + middleware.CacheTTL(10*time.Minute), // change TTL to 10 minutes + middleware.CacheSize(500), // limit cache to 500 entries + middleware.CacheMethods(http.MethodGet, http.MethodPost), // allow caching for GET and POST + middleware.CacheStatuses(200, 201, 204), // cache only responses with these status codes + middleware.CacheWithBody, // include request body in cache key + middleware.CacheWithHeaders("Authorization", "X-Custom-Header"), // include selected headers in cache key +)) +``` + +#### Cache Key Composition + +By default, the cache key is generated using: + +- HTTP **method** +- Full **URL** +- (Optional) **Headers** (if `CacheWithHeaders` is enabled) +- (Optional) **Body** (if `CacheWithBody` is enabled) + +For example, enabling `CacheWithHeaders("Authorization")` will cache the same URL differently **for each unique Authorization token**. + +#### Cache Eviction Strategy + +- **Entries expire** when the TTL is reached. +- **If the cache reaches its maximum size**, the **oldest entry is evicted** (FIFO order). + + +#### Cache Limitations + +- **Only caches complete HTTP responses.** Streaming responses are **not** supported. +- **Does not cache responses with status codes other than 200** (unless explicitly allowed). +- **Uses in-memory storage**, meaning the cache **resets on application restart**. + + +### JSON middleware + +`JSON` middleware sets headers `"Content-Type": "application/json"` and `"Accept": "application/json"`. + +```go +rq := requester.New(http.Client{}, middleware.JSON) +``` + +### BasicAuth middleware + +`BasicAuth` middleware adds HTTP Basic Authentication to all requests. It expects a username and password. For example: + +```go +rq := requester.New(http.Client{}, middleware.BasicAuth("user", "passwd")) +``` + +---- + +### Logging middleware interface Logger should implement `Logger` interface with a single method `Logf(format string, args ...interface{})`. For convenience, func type `LoggerFunc` is provided as an adapter to allow the use of ordinary functions as `Logger`. @@ -63,19 +190,11 @@ logging options: Note: If logging is allowed, it will log the URL, method, and may log headers and the request body. It may affect application security. For example, if a request passes some sensitive information as part of the body or header. In this case, consider turning logging off or providing your own logger to suppress all that you need to hide. -### MaxConcurrent - -MaxConcurrent middleware can be used to limit the concurrency of a given requester and limit overall concurrency for multiple requesters. For the first case, `MaxConcurrent(N)` should be created in the requester chain of middlewares. For example, `rq := requester.New(http.Client{Timeout: 3 * time.Second}, middleware.MaxConcurrent(8))`. To make it global, `MaxConcurrent` should be created once, outside the chain, and passed into each requester. For example: - -```go -mc := middleware.MaxConcurrent(16) -rq1 := requester.New(http.Client{Timeout: 3 * time.Second}, mc) -rq2 := requester.New(http.Client{Timeout: 1 * time.Second}, middleware.JSON, mc) -``` If the request is limited, it will wait till the limit is released. -### Cache +### Cache middleware interface + Cache expects the `LoadingCache` interface to implement a single method: `Get(key string, fn func() (interface{}, error)) (val interface{}, err error)`. [LCW](https://github.com/go-pkgz/lcw/) can be used directly, and in order to adopt other caches, see the provided `LoadingCacheFunc`. #### Caching Key and Allowed Requests @@ -92,12 +211,11 @@ Several options define what part of the request will be used for the key: example: `cache.New(lruCache, cache.Methods("GET", "POST"), cache.KeyFunc() {func(r *http.Request) string {return r.Host})` - #### cache and streaming response `Cache` is **not compatible** with HTTP streaming mode. Practically, this is rare and exotic, but allowing `Cache` will effectively transform the streaming response into a "get it all" typical response. This is due to the fact that the cache has to read the response body fully to save it, so technically streaming will be working, but the client will receive all the data at once. -### Repeater +### Repeater middleware interface `Repeater` expects a single method interface `Do(fn func() error, failOnCodes ...error) (err error)`. [repeater](github.com/go-pkgz/repeater) can be used directly. diff --git a/middleware/cache.go b/middleware/cache.go new file mode 100644 index 0000000..ca75773 --- /dev/null +++ b/middleware/cache.go @@ -0,0 +1,221 @@ +package middleware + +import ( + "bytes" + "crypto/sha256" + "fmt" + "io" + "net/http" + "sort" + "strings" + "sync" + "time" +) + +// CacheEntry represents a cached response with metadata +type CacheEntry struct { + body []byte + headers http.Header + status int + createdAt time.Time +} + +// CacheMiddleware implements in-memory cache for HTTP responses with TTL-based eviction +type CacheMiddleware struct { + next http.RoundTripper + ttl time.Duration + maxKeys int + includeBody bool + headers []string + allowedCodes []int + allowedMethods []string + + cache map[string]CacheEntry + keys []string // Maintains insertion order + mu sync.Mutex +} + +// RoundTrip implements http.RoundTripper +func (c *CacheMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { + // check if method is allowed + methodAllowed := false + for _, m := range c.allowedMethods { + if req.Method == m { + methodAllowed = true + break + } + } + if !methodAllowed { + return c.next.RoundTrip(req) + } + + key := c.makeKey(req) // generate cache key based on request + + c.mu.Lock() + // remove expired entries + for len(c.keys) > 0 { + oldestKey := c.keys[0] + if time.Since(c.cache[oldestKey].createdAt) < c.ttl { + break // Stop once we find a non-expired entry + } + delete(c.cache, oldestKey) + c.keys = c.keys[1:] + } + // check cache + entry, found := c.cache[key] + c.mu.Unlock() + + if found { + // cache hit - reconstruct response + return &http.Response{ + Status: fmt.Sprintf("%d %s", entry.status, http.StatusText(entry.status)), + StatusCode: entry.status, + Header: entry.headers.Clone(), + Body: io.NopCloser(bytes.NewReader(entry.body)), + Request: req, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: int64(len(entry.body)), + }, nil + } + + // fetch fresh response + resp, err := c.next.RoundTrip(req) + if err != nil { + return resp, err + } + + // check if response code is allowed for caching + if !c.shouldCache(resp.StatusCode) { + return resp, nil + } + + // read and store response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return resp, err + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(body)) + + // store in cache + c.mu.Lock() + defer c.mu.Unlock() + + // evict oldest if maxKeys reached + if len(c.cache) >= c.maxKeys { + oldestKey := c.keys[0] + delete(c.cache, oldestKey) + c.keys = c.keys[1:] + } + + // store new entry + c.cache[key] = CacheEntry{body: body, headers: resp.Header.Clone(), status: resp.StatusCode, createdAt: time.Now()} + c.keys = append(c.keys, key) // maintain order of keys for LRU eviction + + return resp, nil +} + +// makeKey generates a cache key based on the request details +func (c *CacheMiddleware) makeKey(req *http.Request) string { + var sb strings.Builder + sb.WriteString(req.Method) + sb.WriteString(":") + sb.WriteString(req.URL.String()) + + if c.includeBody && req.Body != nil { + body, err := io.ReadAll(req.Body) + if err == nil { + sb.Write(body) + req.Body = io.NopCloser(bytes.NewReader(body)) + } + } + + if len(c.headers) > 0 { + var headers []string + for _, h := range c.headers { + if vals := req.Header.Values(h); len(vals) > 0 { + headers = append(headers, fmt.Sprintf("%s:%s", h, strings.Join(vals, ","))) + } + } + sort.Strings(headers) + sb.WriteString(strings.Join(headers, "||")) + } + + hash := sha256.Sum256([]byte(sb.String())) + return fmt.Sprintf("%x", hash) +} + +func (c *CacheMiddleware) shouldCache(code int) bool { + for _, allowed := range c.allowedCodes { + if code == allowed { + return true + } + } + return false +} + +// Cache creates caching middleware with provided options +func Cache(opts ...CacheOption) RoundTripperHandler { + return func(next http.RoundTripper) http.RoundTripper { + c := &CacheMiddleware{ + next: next, + ttl: 5 * time.Minute, + maxKeys: 1000, + allowedCodes: []int{200}, + allowedMethods: []string{http.MethodGet}, + cache: make(map[string]CacheEntry), + keys: make([]string, 0, 1000), + } + + for _, opt := range opts { + opt(c) + } + + return c + } +} + +// CacheOption represents cache middleware options +type CacheOption func(c *CacheMiddleware) + +// CacheTTL sets cache TTL +func CacheTTL(ttl time.Duration) CacheOption { + return func(c *CacheMiddleware) { + c.ttl = ttl + } +} + +// CacheSize sets maximum number of cached entries +func CacheSize(size int) CacheOption { + return func(c *CacheMiddleware) { + c.maxKeys = size + } +} + +// CacheWithBody includes request body in cache key +func CacheWithBody(c *CacheMiddleware) { + c.includeBody = true +} + +// CacheWithHeaders includes specified headers in cache key +func CacheWithHeaders(headers ...string) CacheOption { + return func(c *CacheMiddleware) { + c.headers = headers + } +} + +// CacheStatuses sets which response status codes should be cached +func CacheStatuses(codes ...int) CacheOption { + return func(c *CacheMiddleware) { + c.allowedCodes = codes + } +} + +// CacheMethods sets which HTTP methods should be cached +func CacheMethods(methods ...string) CacheOption { + return func(c *CacheMiddleware) { + c.allowedMethods = methods + } +} diff --git a/middleware/cache_test.go b/middleware/cache_test.go new file mode 100644 index 0000000..b478e9b --- /dev/null +++ b/middleware/cache_test.go @@ -0,0 +1,451 @@ +package middleware + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/requester/middleware/mocks" +) + +func TestCache_BasicCaching(t *testing.T) { + t.Run("caches GET request", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Header: http.Header{"X-Test": []string{"value"}}, + Body: io.NopCloser(strings.NewReader("response body")), + }, nil + }} + + h := Cache()(rmock) + req, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) + require.NoError(t, err) + + // first request - cache miss + resp1, err := h.RoundTrip(req) + require.NoError(t, err) + body1, err := io.ReadAll(resp1.Body) + require.NoError(t, err) + _ = resp1.Body.Close() + + // second request - should be cached + resp2, err := h.RoundTrip(req) + require.NoError(t, err) + body2, err := io.ReadAll(resp2.Body) + require.NoError(t, err) + _ = resp2.Body.Close() + + assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount)) + assert.Equal(t, 200, resp1.StatusCode) + assert.Equal(t, 200, resp2.StatusCode) + assert.Equal(t, "response body", string(body1)) + assert.Equal(t, "response body", string(body2)) + assert.Equal(t, "value", resp1.Header.Get("X-Test")) + assert.Equal(t, "value", resp2.Header.Get("X-Test")) + }) + + t.Run("does not cache POST by default", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("response body")), + }, nil + }} + + h := Cache()(rmock) + req, err := http.NewRequest(http.MethodPost, "http://example.com/", http.NoBody) + require.NoError(t, err) + + // make two requests + _, err = h.RoundTrip(req) + require.NoError(t, err) + _, err = h.RoundTrip(req) + require.NoError(t, err) + + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) + }) + + t.Run("does not cache non-200 by default", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 404, + Body: io.NopCloser(strings.NewReader("not found")), + }, nil + }} + + h := Cache()(rmock) + req, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) + require.NoError(t, err) + + // make two requests + _, err = h.RoundTrip(req) + require.NoError(t, err) + _, err = h.RoundTrip(req) + require.NoError(t, err) + + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) + }) +} + +func TestCache_Options(t *testing.T) { + + t.Run("respects TTL", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("response body")), + }, nil + }} + + h := Cache(CacheTTL(50 * time.Millisecond))(rmock) + req, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) + require.NoError(t, err) + + // first request + _, err = h.RoundTrip(req) + require.NoError(t, err) + + // second request - should be cached + _, err = h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount)) + + // wait for TTL to expire + time.Sleep(100 * time.Millisecond) + + // third request - should hit the backend + _, err = h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) + }) + + t.Run("respects cache size", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Header: http.Header{"X-Test": []string{"value"}}, + Body: io.NopCloser(strings.NewReader(fmt.Sprintf("response %d", atomic.AddInt32(&requestCount, 1)))), + }, nil + }} + + h := Cache(CacheSize(2))(rmock) + + // first request - should be cached + req1, _ := http.NewRequest(http.MethodGet, "http://example.com/1", http.NoBody) + _, _ = h.RoundTrip(req1) // first call: should hit the backend + _, _ = h.RoundTrip(req1) // second call: should be served from cache + + // second request - should be cached + req2, _ := http.NewRequest(http.MethodGet, "http://example.com/2", http.NoBody) + _, _ = h.RoundTrip(req2) // First call: should hit the backend + _, _ = h.RoundTrip(req2) // Second call: should be served from cache + + // third request - triggers eviction of first request + req3, _ := http.NewRequest(http.MethodGet, "http://example.com/3", http.NoBody) + _, _ = h.RoundTrip(req3) + + // first request should be evicted, making a new backend call + _, _ = h.RoundTrip(req1) + + assert.Equal(t, int32(4), atomic.LoadInt32(&requestCount), "First request should be evicted and re-fetched") + }) + + t.Run("respects allowed methods", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Header: http.Header{"X-Test": []string{"value"}}, + Body: io.NopCloser(strings.NewReader("response body")), + }, nil + }} + + h := Cache(CacheMethods(http.MethodGet, http.MethodPost))(rmock) + + // GET request should be cached + req1, _ := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) + _, err := h.RoundTrip(req1) + require.NoError(t, err) + _, err = h.RoundTrip(req1) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount), "GET requests should be cached") + + // POST request should be cached + req2, _ := http.NewRequest(http.MethodPost, "http://example.com/", http.NoBody) + _, err = h.RoundTrip(req2) + require.NoError(t, err) + _, err = h.RoundTrip(req2) + require.NoError(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "POST requests should use different cache key") + + // PUT request should not be cached + req3, _ := http.NewRequest(http.MethodPut, "http://example.com/", http.NoBody) + _, err = h.RoundTrip(req3) + require.NoError(t, err) + _, err = h.RoundTrip(req3) + require.NoError(t, err) + assert.Equal(t, int32(4), atomic.LoadInt32(&requestCount), "PUT requests should not be cached") + }) +} + +func TestCache_Keys(t *testing.T) { + t.Run("different URLs get different cache entries", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(r.URL.Path)), + }, nil + }} + + h := Cache()(rmock) + + req1, err := http.NewRequest(http.MethodGet, "http://example.com/1", http.NoBody) + require.NoError(t, err) + resp1, err := h.RoundTrip(req1) + require.NoError(t, err) + body1, err := io.ReadAll(resp1.Body) + require.NoError(t, err) + err = resp1.Body.Close() + require.NoError(t, err) + + req2, err := http.NewRequest(http.MethodGet, "http://example.com/2", http.NoBody) + require.NoError(t, err) + resp2, err := h.RoundTrip(req2) + require.NoError(t, err) + body2, err := io.ReadAll(resp2.Body) + require.NoError(t, err) + err = resp2.Body.Close() + require.NoError(t, err) + + assert.Equal(t, "/1", string(body1)) + assert.Equal(t, "/2", string(body2)) + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) + }) + + t.Run("includes headers in cache key when configured", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(r.Header.Get("X-Test"))), + }, nil + }} + + h := Cache(CacheWithHeaders("X-Test"))(rmock) + req1, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) + require.NoError(t, err) + req1.Header.Set("X-Test", "value1") + resp1, err := h.RoundTrip(req1) + require.NoError(t, err) + body1, err := io.ReadAll(resp1.Body) + require.NoError(t, err) + err = resp1.Body.Close() + require.NoError(t, err) + + req2, err := http.NewRequest(http.MethodGet, "http://example.com/", http.NoBody) + require.NoError(t, err) + req2.Header.Set("X-Test", "value2") + resp2, err := h.RoundTrip(req2) + require.NoError(t, err) + body2, err := io.ReadAll(resp2.Body) + require.NoError(t, err) + err = resp2.Body.Close() + require.NoError(t, err) + + assert.Equal(t, "value1", string(body1)) + assert.Equal(t, "value2", string(body2)) + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) + }) + + t.Run("includes body in cache key when configured", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + body, err := io.ReadAll(r.Body) + require.NoError(t, err) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(bytes.NewReader(body)), + }, nil + }} + + h := Cache(CacheWithBody)(rmock) + req1, err := http.NewRequest(http.MethodGet, "http://example.com/", strings.NewReader("body1")) + require.NoError(t, err) + resp1, err := h.RoundTrip(req1) + require.NoError(t, err) + body1, err := io.ReadAll(resp1.Body) + require.NoError(t, err) + err = resp1.Body.Close() + require.NoError(t, err) + + req2, err := http.NewRequest(http.MethodGet, "http://example.com/", strings.NewReader("body2")) + require.NoError(t, err) + resp2, err := h.RoundTrip(req2) + require.NoError(t, err) + body2, err := io.ReadAll(resp2.Body) + require.NoError(t, err) + err = resp2.Body.Close() + require.NoError(t, err) + + assert.Equal(t, "body1", string(body1)) + assert.Equal(t, "body2", string(body2)) + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount)) + }) +} + +func TestCache_EdgeCases(t *testing.T) { + + t.Run("expired cache entry should be ignored", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("fresh response")), + }, nil + }} + + h := Cache(CacheTTL(50 * time.Millisecond))(rmock) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/expired", http.NoBody) + require.NoError(t, err, "failed to create request") + + _, err = h.RoundTrip(req) + require.NoError(t, err, "first request should not fail") + assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount), "first request should hit backend") + + time.Sleep(100 * time.Millisecond) + + _, err = h.RoundTrip(req) + require.NoError(t, err, "second request should not fail") + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "expired cache entry should not be used") + }) + + t.Run("cache size 1 should evict immediately", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader("cached response")), + }, nil + }} + + h := Cache(CacheSize(1))(rmock) + + req1, err := http.NewRequest(http.MethodGet, "http://example.com/1", http.NoBody) + require.NoError(t, err, "failed to create request 1") + req2, err := http.NewRequest(http.MethodGet, "http://example.com/2", http.NoBody) + require.NoError(t, err, "failed to create request 2") + req3, err := http.NewRequest(http.MethodGet, "http://example.com/3", http.NoBody) + require.NoError(t, err, "failed to create request 3") + + _, err = h.RoundTrip(req1) + require.NoError(t, err) + + _, err = h.RoundTrip(req2) + require.NoError(t, err) + + _, err = h.RoundTrip(req3) + require.NoError(t, err) + + _, err = h.RoundTrip(req1) + require.NoError(t, err) + + assert.Equal(t, int32(4), atomic.LoadInt32(&requestCount), "each request should evict the previous one") + }) + + t.Run("only specified status codes should be cached", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 202, // not in allowed list + Body: io.NopCloser(strings.NewReader("not cached")), + }, nil + }} + + h := Cache(CacheStatuses(200, 201, 204))(rmock) + + req, err := http.NewRequest(http.MethodGet, "http://example.com/status", http.NoBody) + require.NoError(t, err, "failed to create request") + + _, err = h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, int32(1), atomic.LoadInt32(&requestCount), "first request should hit backend") + + _, err = h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "non-allowed status codes should not be cached") + }) + + t.Run("headers should be included in cache key when configured", func(t *testing.T) { + var requestCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&requestCount, 1) + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(r.Header.Get("Authorization"))), + }, nil + }} + + h := Cache(CacheWithHeaders("Authorization"))(rmock) + + req1, err := http.NewRequest(http.MethodGet, "http://example.com/auth", http.NoBody) + require.NoError(t, err, "failed to create request 1") + req1.Header.Set("Authorization", "Bearer token1") + + resp1, err := h.RoundTrip(req1) + require.NoError(t, err) + body1, err := io.ReadAll(resp1.Body) + require.NoError(t, err) + err = resp1.Body.Close() + require.NoError(t, err) + + resp2, err := h.RoundTrip(req1) // second call should hit cache + require.NoError(t, err) + body2, err := io.ReadAll(resp2.Body) + require.NoError(t, err) + err = resp2.Body.Close() + require.NoError(t, err) + + req2, err := http.NewRequest(http.MethodGet, "http://example.com/auth", http.NoBody) + require.NoError(t, err, "failed to create request 2") + req2.Header.Set("Authorization", "Bearer token2") + + resp3, err := h.RoundTrip(req2) + require.NoError(t, err) + body3, err := io.ReadAll(resp3.Body) + require.NoError(t, err) + err = resp3.Body.Close() + require.NoError(t, err) + + assert.Equal(t, "Bearer token1", string(body1), "first request should be cached separately") + assert.Equal(t, "Bearer token1", string(body2), "second request should be served from cache") + assert.Equal(t, "Bearer token2", string(body3), "third request should be a new cache entry") + assert.Equal(t, int32(2), atomic.LoadInt32(&requestCount), "each authorization header should generate a new cache entry") + }) +} diff --git a/middleware/retry.go b/middleware/retry.go new file mode 100644 index 0000000..b02939f --- /dev/null +++ b/middleware/retry.go @@ -0,0 +1,194 @@ +package middleware + +import ( + "fmt" + "math" + "math/rand" + "net/http" + "time" +) + +// BackoffType represents backoff strategy +type BackoffType int + +const ( + // BackoffConstant is a backoff strategy with constant delay + BackoffConstant BackoffType = iota + // BackoffLinear is a backoff strategy with linear delay + BackoffLinear + // BackoffExponential is a backoff strategy with exponential delay + BackoffExponential +) + +// RetryMiddleware implements a retry mechanism for http requests with configurable backoff strategies. +// It supports constant, linear and exponential backoff with optional jitter for better load distribution. +// By default retries on network errors and 5xx responses. Can be configured to retry on specific status codes +// or to exclude specific codes from retry. +// +// Default configuration: +// - 3 attempts +// - Initial delay: 100ms +// - Max delay: 30s +// - Exponential backoff +// - 10% jitter +// - Retries on 5xx status codes +type RetryMiddleware struct { + next http.RoundTripper + attempts int + initialDelay time.Duration + maxDelay time.Duration + backoff BackoffType + jitterFactor float64 + retryCodes []int + excludeCodes []int +} + +// Retry creates retry middleware with provided options +func Retry(attempts int, initialDelay time.Duration, opts ...RetryOption) RoundTripperHandler { + return func(next http.RoundTripper) http.RoundTripper { + r := &RetryMiddleware{ + next: next, + attempts: attempts, + initialDelay: initialDelay, + maxDelay: 30 * time.Second, + backoff: BackoffExponential, + jitterFactor: 0.1, + } + + for _, opt := range opts { + opt(r) + } + + if len(r.retryCodes) > 0 && len(r.excludeCodes) > 0 { + panic("retry: cannot use both RetryOnCodes and RetryExcludeCodes") + } + + return r + } +} + +// RoundTrip implements http.RoundTripper +func (r *RetryMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { + var lastResponse *http.Response + var lastError error + + for attempt := 0; attempt < r.attempts; attempt++ { + if req.Context().Err() != nil { + return nil, req.Context().Err() + } + + if attempt > 0 { + delay := r.calcDelay(attempt) + select { + case <-req.Context().Done(): + return nil, req.Context().Err() + case <-time.After(delay): + } + } + + resp, err := r.next.RoundTrip(req) + if err != nil { + lastError = err + lastResponse = resp + continue + } + + if !r.shouldRetry(resp) { + return resp, nil + } + + lastResponse = resp + } + + if lastError != nil { + return lastResponse, fmt.Errorf("retry: transport error after %d attempts: %w", r.attempts, lastError) + } + return lastResponse, nil +} + +func (r *RetryMiddleware) calcDelay(attempt int) time.Duration { + if attempt == 0 { + return 0 + } + + var delay time.Duration + switch r.backoff { + case BackoffConstant: + delay = r.initialDelay + case BackoffLinear: + delay = r.initialDelay * time.Duration(attempt) + case BackoffExponential: + delay = r.initialDelay * time.Duration(math.Pow(2, float64(attempt-1))) + } + + if delay > r.maxDelay { + delay = r.maxDelay + } + + if r.jitterFactor > 0 { + jitter := float64(delay) * r.jitterFactor + delay = time.Duration(float64(delay) + rand.Float64()*jitter - jitter/2) //nolint:gosec // week randomness is acceptable + } + + return delay +} + +func (r *RetryMiddleware) shouldRetry(resp *http.Response) bool { + if len(r.retryCodes) > 0 { + for _, code := range r.retryCodes { + if resp.StatusCode == code { + return true + } + } + return false + } + + if len(r.excludeCodes) > 0 { + for _, code := range r.excludeCodes { + if resp.StatusCode == code { + return false + } + } + return true + } + + return resp.StatusCode >= 500 +} + +// RetryOption represents option type for retry middleware +type RetryOption func(r *RetryMiddleware) + +// RetryMaxDelay sets maximum delay between retries +func RetryMaxDelay(d time.Duration) RetryOption { + return func(r *RetryMiddleware) { + r.maxDelay = d + } +} + +// RetryWithBackoff sets backoff strategy +func RetryWithBackoff(t BackoffType) RetryOption { + return func(r *RetryMiddleware) { + r.backoff = t + } +} + +// RetryWithJitter sets how much randomness to add to delay (0-1.0) +func RetryWithJitter(f float64) RetryOption { + return func(r *RetryMiddleware) { + r.jitterFactor = f + } +} + +// RetryOnCodes sets status codes that should trigger a retry +func RetryOnCodes(codes ...int) RetryOption { + return func(r *RetryMiddleware) { + r.retryCodes = codes + } +} + +// RetryExcludeCodes sets status codes that should not trigger a retry +func RetryExcludeCodes(codes ...int) RetryOption { + return func(r *RetryMiddleware) { + r.excludeCodes = codes + } +} diff --git a/middleware/retry_test.go b/middleware/retry_test.go new file mode 100644 index 0000000..49c423a --- /dev/null +++ b/middleware/retry_test.go @@ -0,0 +1,305 @@ +package middleware + +import ( + "context" + "errors" + "net/http" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/go-pkgz/requester/middleware/mocks" +) + +func TestRetry_BasicBehavior(t *testing.T) { + t.Run("retries on network error", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + count := atomic.AddInt32(&attemptCount, 1) + if count < 3 { + return nil, errors.New("network error") + } + return &http.Response{StatusCode: 200}, nil + }} + + h := Retry(3, time.Millisecond)(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + resp, err := h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, int32(3), atomic.LoadInt32(&attemptCount)) + }) + + t.Run("retries on 5xx status by default", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + count := atomic.AddInt32(&attemptCount, 1) + if count < 3 { + return &http.Response{StatusCode: 503}, nil + } + return &http.Response{StatusCode: 200}, nil + }} + + h := Retry(3, time.Millisecond)(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + resp, err := h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, int32(3), atomic.LoadInt32(&attemptCount)) + }) + + t.Run("no retry on 4xx by default", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&attemptCount, 1) + return &http.Response{StatusCode: 404}, nil + }} + + h := Retry(3, time.Millisecond)(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + resp, err := h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 404, resp.StatusCode) + assert.Equal(t, int32(1), atomic.LoadInt32(&attemptCount)) + }) + + t.Run("fails with error after max attempts", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&attemptCount, 1) + return nil, errors.New("persistent network error") + }} + + h := Retry(3, time.Millisecond)(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + resp, err := h.RoundTrip(req) + assert.Nil(t, resp) + require.Error(t, err) + assert.Contains(t, err.Error(), "retry: transport error after 3 attempts") + assert.Equal(t, int32(3), atomic.LoadInt32(&attemptCount)) + }) + + t.Run("respects request context cancellation", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&attemptCount, 1) + return nil, errors.New("network failure") + }} + + ctx, cancel := context.WithCancel(context.Background()) + req, err := http.NewRequestWithContext(ctx, "GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + h := Retry(5, 50*time.Millisecond)(rmock) + + // Cancel request after first attempt + time.AfterFunc(20*time.Millisecond, cancel) + + _, err = h.RoundTrip(req) + require.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + assert.Equal(t, int32(1), atomic.LoadInt32(&attemptCount), "should stop retrying after context cancellation") + }) +} + +func TestRetry_BackoffStrategies(t *testing.T) { + tests := []struct { + name string + backoff BackoffType + minDuration time.Duration + maxDuration time.Duration + }{ + { + name: "constant backoff", + backoff: BackoffConstant, + minDuration: 3 * time.Millisecond, // 1ms * 3 + maxDuration: 5 * time.Millisecond, // some buffer for execution time + }, + { + name: "linear backoff", + backoff: BackoffLinear, + minDuration: 6 * time.Millisecond, // 1ms + 2ms + 3ms + maxDuration: 8 * time.Millisecond, + }, + { + name: "exponential backoff", + backoff: BackoffExponential, + minDuration: 7 * time.Millisecond, // 1ms + 2ms + 4ms + maxDuration: 9 * time.Millisecond, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + count := atomic.AddInt32(&attemptCount, 1) + if count < 4 { + return &http.Response{StatusCode: 503}, nil + } + return &http.Response{StatusCode: 200}, nil + }} + + start := time.Now() + h := Retry(4, time.Millisecond, RetryWithBackoff(tt.backoff))(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + resp, err := h.RoundTrip(req) + duration := time.Since(start) + + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, int32(4), atomic.LoadInt32(&attemptCount)) + assert.GreaterOrEqual(t, duration, tt.minDuration) + assert.LessOrEqual(t, duration, tt.maxDuration) + }) + } + + t.Run("max delay limits backoff", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + atomic.AddInt32(&attemptCount, 1) + return &http.Response{StatusCode: 503}, nil + }} + + start := time.Now() + h := Retry(3, 10*time.Millisecond, + RetryMaxDelay(15*time.Millisecond), + RetryWithBackoff(BackoffExponential), + )(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + _, _ = h.RoundTrip(req) + duration := time.Since(start) + + // With exponential backoff and 10ms initial delay, without max delay + // it would be 10ms + 20ms + 40ms = 70ms, but with max delay of 15ms + // it should be 10ms + 15ms + 15ms = 40ms + assert.Less(t, duration, 50*time.Millisecond) + }) + + t.Run("jitter factor affects delay", func(t *testing.T) { + var callTimes []time.Time + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + callTimes = append(callTimes, time.Now()) + return &http.Response{StatusCode: 503}, nil + }} + + h := Retry(3, 10*time.Millisecond, + RetryWithJitter(0.5), + RetryWithBackoff(BackoffConstant), + )(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + _, _ = h.RoundTrip(req) + + require.Greater(t, len(callTimes), 2) + delay1 := callTimes[1].Sub(callTimes[0]) + delay2 := callTimes[2].Sub(callTimes[1]) + // With 0.5 jitter and 10ms delay, delays should be different + assert.NotEqual(t, delay1, delay2) + // But both should be in range 5ms-15ms (10ms ±50%) + assert.Greater(t, delay1, 5*time.Millisecond) + assert.Less(t, delay1, 15*time.Millisecond) + assert.Greater(t, delay2, 5*time.Millisecond) + assert.Less(t, delay2, 15*time.Millisecond) + }) + + t.Run("verifies retry actually introduces delay", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + count := atomic.AddInt32(&attemptCount, 1) + if count < 4 { + return &http.Response{StatusCode: 503}, nil + } + return &http.Response{StatusCode: 200}, nil + }} + + start := time.Now() + h := Retry(4, 5*time.Millisecond, RetryWithBackoff(BackoffExponential))(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + resp, err := h.RoundTrip(req) + duration := time.Since(start) + + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, int32(4), atomic.LoadInt32(&attemptCount)) + + // expected delay: 5ms + 10ms + 20ms = 35ms (exponential backoff) + expectedMin := 30 * time.Millisecond + expectedMax := 40 * time.Millisecond + + assert.Greater(t, duration, expectedMin, "retries should introduce actual delay") + assert.LessOrEqual(t, duration, expectedMax, "delay should not exceed expected range") + }) +} + +func TestRetry_RetryConditions(t *testing.T) { + t.Run("retry specific codes", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + count := atomic.AddInt32(&attemptCount, 1) + if count < 3 { + return &http.Response{StatusCode: 418}, nil // teapot error + } + return &http.Response{StatusCode: 200}, nil + }} + + h := Retry(3, time.Millisecond, RetryOnCodes(418))(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + resp, err := h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, int32(3), atomic.LoadInt32(&attemptCount)) + }) + + t.Run("exclude codes from retry", func(t *testing.T) { + var attemptCount int32 + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + count := atomic.AddInt32(&attemptCount, 1) + if count < 3 { + return &http.Response{StatusCode: 404}, nil + } + return &http.Response{StatusCode: 200}, nil + }} + + h := Retry(3, time.Millisecond, RetryExcludeCodes(503, 404))(rmock) + req, err := http.NewRequest("GET", "http://example.com/", http.NoBody) + require.NoError(t, err) + + resp, err := h.RoundTrip(req) + require.NoError(t, err) + assert.Equal(t, 404, resp.StatusCode) + assert.Equal(t, int32(1), atomic.LoadInt32(&attemptCount)) + }) + + t.Run("cannot use both include and exclude codes", func(t *testing.T) { + assert.Panics(t, func() { + rmock := &mocks.RoundTripper{RoundTripFunc: func(r *http.Request) (*http.Response, error) { + return &http.Response{StatusCode: 200}, nil + }} + _ = Retry(3, time.Millisecond, + RetryOnCodes(503), + RetryExcludeCodes(404), + )(rmock) + }) + }) +}