diff --git a/cache.go b/cache.go index 2a2be18..6c6fa76 100644 --- a/cache.go +++ b/cache.go @@ -48,6 +48,9 @@ type Response struct { // Header is the cached response header. Header http.Header + // StatusCode is the cached response status code. + StatusCode int + // Expiration is the cached response expiration date. Expiration time.Time @@ -66,6 +69,7 @@ type Client struct { ttl time.Duration refreshKey string methods []string + statusCodeFilter func(int) bool writeExpiresHeader bool } @@ -120,13 +124,13 @@ func (c *Client) Middleware(next http.Handler) http.Handler { response.Frequency++ c.adapter.Set(key, response.Bytes(), response.Expiration) - //w.WriteHeader(http.StatusNotModified) for k, v := range response.Header { w.Header().Set(k, strings.Join(v, ",")) } if c.writeExpiresHeader { w.Header().Set("Expires", response.Expiration.UTC().Format(http.TimeFormat)) } + w.WriteHeader(response.StatusCode) w.Write(response.Value) return } @@ -143,10 +147,11 @@ func (c *Client) Middleware(next http.Handler) http.Handler { value := rec.Body.Bytes() now := time.Now() expires := now.Add(c.ttl) - if statusCode < 400 { + if c.statusCodeFilter(statusCode) { response := Response{ Value: value, Header: result.Header, + StatusCode: statusCode, Expiration: expires, LastAccess: now, Frequency: 1, @@ -244,6 +249,9 @@ func NewClient(opts ...ClientOption) (*Client, error) { if c.methods == nil { c.methods = []string{http.MethodGet} } + if c.statusCodeFilter == nil { + c.statusCodeFilter = func(code int) bool { return code < 400 } + } return c, nil } @@ -293,6 +301,15 @@ func ClientWithMethods(methods []string) ClientOption { } } +// ClientWithStatusCodeFilter sets the acceptable status codes to be cached. +// Optional setting. If not set, default filter allows caching of every response with status code below 400. +func ClientWithStatusCodeFilter(filter func(int) bool) ClientOption { + return func(c *Client) error { + c.statusCodeFilter = filter + return nil + } +} + // ClientWithExpiresHeader enables middleware to add an Expires header to responses. // Optional setting. If not set, default is false. func ClientWithExpiresHeader() ClientOption { diff --git a/cache_test.go b/cache_test.go index 8ebca53..a223a30 100644 --- a/cache_test.go +++ b/cache_test.go @@ -55,18 +55,22 @@ func TestMiddleware(t *testing.T) { store: map[uint64][]byte{ 14974843192121052621: Response{ Value: []byte("value 1"), + StatusCode: 200, Expiration: time.Now().Add(1 * time.Minute), }.Bytes(), 14974839893586167988: Response{ Value: []byte("value 2"), + StatusCode: 200, Expiration: time.Now().Add(1 * time.Minute), }.Bytes(), 14974840993097796199: Response{ Value: []byte("value 3"), + StatusCode: 200, Expiration: time.Now().Add(-1 * time.Minute), }.Bytes(), 10956846073361780255: Response{ Value: []byte("value 4"), + StatusCode: 200, Expiration: time.Now().Add(-1 * time.Minute), }.Bytes(), }, @@ -473,9 +477,63 @@ func TestNewClient(t *testing.T) { t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.want != nil { + got.statusCodeFilter = nil + tt.want.statusCodeFilter = nil + } + if !reflect.DeepEqual(got, tt.want) { t.Errorf("NewClient() = %v, want %v", got, tt.want) } }) } } + +func TestNewClientWithStatusCodeFilter(t *testing.T) { + adapter := &adapterMock{} + + tests := []struct { + name string + opts []ClientOption + wantCache []int + wantSkip []int + }{ + { + "returns new client with status code filter", + []ClientOption{ + ClientWithAdapter(adapter), + ClientWithTTL(1 * time.Millisecond), + }, + []int{200, 300}, + []int{400, 500}, + }, + { + "returns new client with status code filter", + []ClientOption{ + ClientWithAdapter(adapter), + ClientWithTTL(1 * time.Millisecond), + ClientWithStatusCodeFilter(func(code int) bool { return code < 350 || code > 450 }), + }, + []int{200, 300, 500}, + []int{400}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, _ := NewClient(tt.opts...) + + for _, c := range tt.wantCache { + if got.statusCodeFilter(c) == false { + t.Errorf("NewClient() allows caching of status code %v, don't want it to", c) + } + } + + for _, c := range tt.wantSkip { + if got.statusCodeFilter(c) == true { + t.Errorf("NewClient() doesn't allow caching of status code %v, want it to", c) + } + } + }) + } +}