diff --git a/internal/mocks.go b/internal/mocks.go index 1b8b984..90512fb 100644 --- a/internal/mocks.go +++ b/internal/mocks.go @@ -183,16 +183,15 @@ func (m *MockCacheInvalidator) InvalidateCache( var _ ValidationResponseHandler = (*MockValidationResponseHandler)(nil) type MockValidationResponseHandler struct { - HandleValidationResponseFunc func(ctx RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) + HandleValidationResponseFunc func(ctx RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) } func (m *MockValidationResponseHandler) HandleValidationResponse( ctx RevalidationContext, req *http.Request, resp *http.Response, - err error, ) (*http.Response, error) { - return m.HandleValidationResponseFunc(ctx, req, resp, err) + return m.HandleValidationResponseFunc(ctx, req, resp) } var _ VaryMatcher = (*MockVaryMatcher)(nil) diff --git a/internal/validationresponsehandler.go b/internal/validationresponsehandler.go index 292e705..5824d47 100644 --- a/internal/validationresponsehandler.go +++ b/internal/validationresponsehandler.go @@ -24,7 +24,6 @@ type ValidationResponseHandler interface { ctx RevalidationContext, req *http.Request, resp *http.Response, - err error, ) (*http.Response, error) } @@ -75,9 +74,8 @@ func (r *validationResponseHandler) HandleValidationResponse( ctx RevalidationContext, req *http.Request, resp *http.Response, - err error, ) (*http.Response, error) { - if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { + if req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified { // RFC 9111 §4.3.3 Handling Validation Responses (304 Not Modified) // RFC 9111 §4.3.4 Freshening Stored Responses upon Validation updateStoredHeaders(ctx.Stored.Data, resp) @@ -90,7 +88,7 @@ func (r *validationResponseHandler) HandleValidationResponse( ccResp CCResponseDirectives ccRespOnce bool ) - if (err != nil || isStaleErrorAllowed(resp.StatusCode)) && req.Method == http.MethodGet { + if isStaleErrorAllowed(resp.StatusCode) && req.Method == http.MethodGet { ccResp = ParseCCResponseDirectives(resp.Header) ccRespOnce = true if r.siep.CanStaleOnError(ctx.Freshness, ccResp) { @@ -103,10 +101,6 @@ func (r *validationResponseHandler) HandleValidationResponse( } } - if err != nil { - return nil, err - } - if !ccRespOnce { ccResp = ParseCCResponseDirectives(resp.Header) } diff --git a/internal/validationresponsehandler_test.go b/internal/validationresponsehandler_test.go index e2e8740..998e669 100644 --- a/internal/validationresponsehandler_test.go +++ b/internal/validationresponsehandler_test.go @@ -15,7 +15,6 @@ package internal import ( - "errors" "log/slog" "net/http" "net/url" @@ -48,9 +47,8 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) { } type args struct { - req *http.Request - resp *http.Response - inputErr error + req *http.Request + resp *http.Response } tests := []struct { @@ -99,31 +97,7 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) { }, }, { - name: "GET with error, stale allowed", - handler: &validationResponseHandler{ - l: noopLogger, - siep: &MockStaleIfErrorPolicy{ - CanStaleOnErrorFunc: func(*Freshness, ...StaleIfErrorer) bool { return true }, - }, - clock: &MockClock{NowResult: base}, - }, - setup: func(tt *testing.T, handler *validationResponseHandler) args { - return args{ - req: &http.Request{Method: http.MethodGet}, - resp: &http.Response{ - StatusCode: http.StatusInternalServerError, - Header: http.Header{"Cache-Control": {"stale-if-error=60"}}, - }, - inputErr: errors.New("network error"), - } - }, - assert: func(tt *testing.T, got *http.Response, err error) { - testutil.RequireNoError(tt, err) - testutil.AssertEqual(tt, http.StatusOK, got.StatusCode) - }, - }, - { - name: "GET with error, stale not allowed", + name: "GET with error status, stale not allowed", handler: &validationResponseHandler{ l: noopLogger, siep: &MockStaleIfErrorPolicy{ @@ -132,18 +106,20 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) { clock: &MockClock{NowResult: base}, }, setup: func(tt *testing.T, handler *validationResponseHandler) args { + handler.ce = CacheabilityEvaluatorFunc( + func(*http.Response, CCRequestDirectives, CCResponseDirectives) bool { return false }, + ) return args{ req: &http.Request{Method: http.MethodGet}, resp: &http.Response{ StatusCode: http.StatusInternalServerError, Header: http.Header{}, }, - inputErr: errors.New("network error"), } }, assert: func(tt *testing.T, got *http.Response, err error) { - testutil.RequireError(tt, err) - testutil.AssertNil(tt, got) + testutil.RequireNoError(tt, err) + testutil.AssertEqual(tt, "BYPASS", got.Header.Get(CacheStatusHeader)) }, }, { @@ -202,7 +178,7 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := tt.setup(t, tt.handler) - got, err := tt.handler.HandleValidationResponse(ctx, a.req, a.resp, a.inputErr) + got, err := tt.handler.HandleValidationResponse(ctx, a.req, a.resp) tt.assert(t, got, err) }) } diff --git a/roundtripper.go b/roundtripper.go index f2d9ab1..dba6579 100644 --- a/roundtripper.go +++ b/roundtripper.go @@ -340,6 +340,9 @@ func (r *transport) handleCacheHit( revalidate: req = withConditionalHeaders(req, stored.Data.Header) resp, start, end, err := r.roundTripTimed(req) + if err != nil { + return nil, err + } ctx := internal.RevalidationContext{ URLKey: urlKey, Start: start, @@ -350,7 +353,7 @@ revalidate: RefIndex: refIndex, Freshness: freshness, } - return r.vrh.HandleValidationResponse(ctx, req, resp, err) + return r.vrh.HandleValidationResponse(ctx, req, resp) } func (r *transport) serveFromCache( @@ -441,7 +444,7 @@ func (r *transport) backgroundRevalidate( Freshness: freshness, } //nolint:bodyclose // The response is not used, so we don't need to close it. - _, err = r.vrh.HandleValidationResponse(revalCtx, req, resp, nil) + _, err = r.vrh.HandleValidationResponse(revalCtx, req, resp) errc <- err }() diff --git a/roundtripper_test.go b/roundtripper_test.go index 69771b1..9feaebd 100644 --- a/roundtripper_test.go +++ b/roundtripper_test.go @@ -64,8 +64,8 @@ func mockTransport(fields func(rt *transport)) *transport { ci: &internal.MockCacheInvalidator{}, rs: &internal.MockResponseStorer{}, vrh: &internal.MockValidationResponseHandler{ - HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) { - return resp, err + HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) { + return resp, nil }, }, clock: &internal.MockClock{NowResult: time.Now()}, @@ -224,9 +224,9 @@ func Test_transport_CacheHit_MustRevalidate_Stale(t *testing.T) { }, } rt.vrh = &internal.MockValidationResponseHandler{ - HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) { + HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) { mockVHCalled = true - return resp, err + return resp, nil }, } }) @@ -259,9 +259,9 @@ func Test_transport_CacheHit_NoCacheUnqualified(t *testing.T) { }, } rt.vrh = &internal.MockValidationResponseHandler{ - HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) { + HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) { mockVHCalled = true - return resp, err + return resp, nil }, } }) @@ -539,10 +539,10 @@ func Test_transport_RevalidationPath(t *testing.T) { }, } rt.vrh = &internal.MockValidationResponseHandler{ - HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) { + HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) { mockVHCalled = true internal.CacheStatusRevalidated.ApplyTo(resp.Header) - return resp, err + return resp, nil }, } }) @@ -596,9 +596,9 @@ func Test_transport_SWR_NormalPath(t *testing.T) { rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0} rt.siep = &internal.MockStaleIfErrorPolicy{} rt.vrh = &internal.MockValidationResponseHandler{ - HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) { + HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) { revalidateCalled <- struct{}{} // Signal that revalidation was called - return resp, err + return resp, nil }, } rt.swrTimeout = DefaultSWRTimeout @@ -669,7 +669,7 @@ func Test_transport_SWR_NormalPathAndError(t *testing.T) { rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0} rt.swrTimeout = swrTimeout rt.vrh = &internal.MockValidationResponseHandler{ - HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) { + HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) { defer func() { revalidateCalled <- struct{}{} }() // Signal that revalidation was called return nil, errors.New("revalidation error") }, @@ -736,9 +736,9 @@ func Test_transport_SWR_Timeout(t *testing.T) { rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0} rt.swrTimeout = swrTimeout rt.vrh = &internal.MockValidationResponseHandler{ - HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) { + HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) { revalidateCalled <- struct{}{} // Signal that revalidation was called - return resp, err + return resp, nil }, } rt.upstream = &internal.MockRoundTripper{