diff --git a/README.md b/README.md index 0d9526f..bcdf9c8 100644 --- a/README.md +++ b/README.md @@ -42,10 +42,14 @@ func main() { os.Exit(1) } + examplePathRegex := regexp.MustCompile(`^/api/v1/.*`) + cacheClient, err := cache.NewClient( cache.ClientWithAdapter(memcached), cache.ClientWithTTL(10 * time.Minute), cache.ClientWithRefreshKey("opn"), + cache.ClientWithSkipCacheResponseHeader("x-skip-example"), + cache.ClientWithSkipCacheUriPathRegex(examplePathRegex) ) if err != nil { fmt.Println(err) @@ -73,15 +77,44 @@ import ( "server": ":6379", }, } + + examplePathRegex := regexp.MustCompile(`^/api/v1/.*`) + cacheClient, err := cache.NewClient( cache.ClientWithAdapter(redis.NewAdapter(ringOpt)), cache.ClientWithTTL(10 * time.Minute), cache.ClientWithRefreshKey("opn"), + cache.ClientWithSkipCacheResponseHeader("x-skip-example"), + cache.ClientWithSkipCacheUriPathRegex(examplePathRegex) ) ... ``` +Example of handler func skipping cache using response header +```go +... + cacheClient, err := cache.NewClient( + cache.ClientWithAdapter(memcached), + cache.ClientWithTTL(10 * time.Minute), + cache.ClientWithSkipCacheResponseHeader("x-skip-example"), + ) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + example := func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Skip-Example", "1") + w.Write([]byte(fmt.Sprintf("response value at %s", time.Now().UTC().String()))) + } + + handler := http.HandlerFunc(example) +... +``` + +``` + ## Benchmarks The benchmarks were based on [allegro/bigache](https://github.com/allegro/bigcache) tests and used to compare it with the http-cache memory adapter.
The tests were run using an Intel i5-2410M with 8GB RAM on Arch Linux 64bits.
diff --git a/adapter/redis/redis_test.go b/adapter/redis/redis_test.go index 635344c..b098159 100644 --- a/adapter/redis/redis_test.go +++ b/adapter/redis/redis_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/victorspringer/http-cache" + cache "github.com/victorspringer/http-cache" ) var a cache.Adapter diff --git a/cache.go b/cache.go index 0bb2638..8d249b6 100644 --- a/cache.go +++ b/cache.go @@ -34,6 +34,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "regexp" "sort" "strconv" "strings" @@ -62,11 +63,13 @@ type Response struct { // Client data structure for HTTP cache middleware. type Client struct { - adapter Adapter - ttl time.Duration - refreshKey string - methods []string - writeExpiresHeader bool + adapter Adapter + ttl time.Duration + refreshKey string + skipCacheResponseHeader string + skipCacheUriPathRegex *regexp.Regexp + methods []string + writeExpiresHeader bool } // ClientOption is used to set Client settings. @@ -88,7 +91,8 @@ type Adapter interface { // Middleware is the HTTP cache middleware handler. func (c *Client) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if c.cacheableMethod(r.Method) { + + if c.cacheableUriPath(r.URL) && c.cacheableMethod(r.Method) { sortURLParams(r.URL) key := generateKey(r.URL.String()) if r.Method == http.MethodPost && r.Body != nil { @@ -138,27 +142,36 @@ func (c *Client) Middleware(next http.Handler) http.Handler { rec := httptest.NewRecorder() next.ServeHTTP(rec, r) result := rec.Result() + headers := result.Header statusCode := result.StatusCode value := rec.Body.Bytes() - now := time.Now() - expires := now.Add(c.ttl) - if statusCode < 400 { - response := Response{ - Value: value, - Header: result.Header, - Expiration: expires, - LastAccess: now, - Frequency: 1, + + skipCachingResponse := headers.Get(c.skipCacheResponseHeader) != "" + + if !skipCachingResponse { + + now := time.Now() + expires := now.Add(c.ttl) + if statusCode < 400 { + response := Response{ + Value: value, + Header: result.Header, + Expiration: expires, + LastAccess: now, + Frequency: 1, + } + c.adapter.Set(key, response.Bytes(), response.Expiration) + } + if c.writeExpiresHeader { + w.Header().Set("Expires", expires.UTC().Format(http.TimeFormat)) } - c.adapter.Set(key, response.Bytes(), response.Expiration) + } + for k, v := range result.Header { w.Header().Set(k, strings.Join(v, ",")) } - if c.writeExpiresHeader { - w.Header().Set("Expires", expires.UTC().Format(http.TimeFormat)) - } w.WriteHeader(statusCode) w.Write(value) return @@ -176,6 +189,20 @@ func (c *Client) cacheableMethod(method string) bool { return false } +// cacheableUriPath takes the request url and see if it +// matches regex used for skipping cache based on request +// path +func (c *Client) cacheableUriPath(requestUrl *url.URL) bool { + + if c.skipCacheUriPathRegex == nil { + return true + } + + foundMatchingUriPath := c.skipCacheUriPathRegex.FindString(requestUrl.Path) + + return foundMatchingUriPath == "" +} + // BytesToResponse converts bytes array into Response data structure. func BytesToResponse(b []byte) Response { var r Response @@ -279,6 +306,27 @@ func ClientWithRefreshKey(refreshKey string) ClientOption { } } +// ClientWithSkipCacheResponseHeader sets the name of the response header +// that will be used to ensure a response does not get cached. +// Optional setting. +func ClientWithSkipCacheResponseHeader(headerName string) ClientOption { + return func(c *Client) error { + c.skipCacheResponseHeader = headerName + return nil + } +} + +// ClientWithSkipCacheUriPathRegex sets the regex that will be +// used to ensure that both request/response of matching path +// is free of cache. +// Optional setting. +func ClientWithSkipCacheUriPathRegex(uriPathRegex *regexp.Regexp) ClientOption { + return func(c *Client) error { + c.skipCacheUriPathRegex = uriPathRegex + return nil + } +} + // ClientWithMethods sets the acceptable HTTP methods to be cached. // Optional setting. If not set, default is "GET". func ClientWithMethods(methods []string) ClientOption { diff --git a/cache_test.go b/cache_test.go index 8ebca53..d98e76a 100644 --- a/cache_test.go +++ b/cache_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/url" "reflect" + "regexp" "sync" "testing" "time" @@ -48,6 +49,10 @@ func (errReader) Read(p []byte) (n int, err error) { func TestMiddleware(t *testing.T) { counter := 0 httpTestHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if q := r.URL.Query()["set-skip-header"]; len(q) > 0 { + w.Header().Add("X-Skip", "1") + } + w.Write([]byte(fmt.Sprintf("new value %v", counter))) }) @@ -72,28 +77,42 @@ func TestMiddleware(t *testing.T) { }, } + exampleRegex := regexp.MustCompile("^/test-4$") + client, _ := NewClient( ClientWithAdapter(adapter), ClientWithTTL(1*time.Minute), ClientWithRefreshKey("rk"), ClientWithMethods([]string{http.MethodGet, http.MethodPost}), + ClientWithSkipCacheResponseHeader("X-Skip"), + ClientWithSkipCacheUriPathRegex(exampleRegex), ) - handler := client.Middleware(httpTestHandler) + handlers := http.ServeMux{} + handlers.Handle("/test-1", httpTestHandler) + handlers.Handle("/test-2", httpTestHandler) + handlers.Handle("/test-3", httpTestHandler) + handlers.Handle("/test-4", httpTestHandler) + + handler := client.Middleware(&handlers) tests := []struct { - name string - url string - method string - body []byte - wantBody string - wantCode int + name string + url string + method string + body []byte + setSkipHeader bool + skipPath string + wantBody string + wantCode int }{ { "returns cached response", "http://foo.bar/test-1", "GET", nil, + false, + "", "value 1", 200, }, @@ -102,6 +121,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "PUT", nil, + false, + "", "new value 2", 200, }, @@ -110,6 +131,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "GET", nil, + false, + "", "value 2", 200, }, @@ -118,6 +141,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-3?zaz=baz&baz=zaz", "GET", nil, + false, + "", "new value 4", 200, }, @@ -126,6 +151,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-3?baz=zaz&zaz=baz", "GET", nil, + false, + "", "new value 4", 200, }, @@ -134,6 +161,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-3", "GET", nil, + false, + "", "new value 6", 200, }, @@ -142,6 +171,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2?rk=true", "GET", nil, + false, + "", "new value 7", 200, }, @@ -150,6 +181,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "GET", nil, + false, + "", "new value 7", 200, }, @@ -158,6 +191,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "POST", []byte(`{"foo": "bar"}`), + false, + "", "new value 9", 200, }, @@ -166,6 +201,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "POST", []byte(`{"foo": "bar"}`), + false, + "", "new value 9", 200, }, @@ -174,6 +211,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "GET", []byte(`{"foo": "bar"}`), + false, + "", "new value 7", 200, }, @@ -182,9 +221,61 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "POST", []byte(`{"foo": "bar"}`), + false, + "", "new value 12", 200, }, + { + "skip cached using header - new uncached response", + "http://foo.bar/test-2?set-skip-header=1", + "GET", + nil, + false, + "", + "new value 13", + 200, + }, + { + "skip cached using header - new uncached response (confirm)", + "http://foo.bar/test-2?set-skip-header=1", + "GET", + nil, + false, + "", + "new value 14", + 200, + }, + { + "skip cached using header - confirm didn't change cached value", + "http://foo.bar/test-2", + "GET", + nil, + false, + "", + "new value 7", + 200, + }, + { + "skip cache by regex path - returns new uncached response", + "http://foo.bar/test-4", + "GET", + nil, + false, + "", + "new value 16", + 200, + }, + { + "skip cache by regex path - returns new uncached response (confirm)", + "http://foo.bar/test-4", + "GET", + nil, + false, + "", + "new value 17", + 200, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {