Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions internal/mocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,16 +183,15 @@ func (m *MockCacheInvalidator) InvalidateCache(
var _ ValidationResponseHandler = (*MockValidationResponseHandler)(nil)

type MockValidationResponseHandler struct {
HandleValidationResponseFunc func(ctx RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error)
HandleValidationResponseFunc func(ctx RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error)
}

func (m *MockValidationResponseHandler) HandleValidationResponse(
ctx RevalidationContext,
req *http.Request,
resp *http.Response,
err error,
) (*http.Response, error) {
return m.HandleValidationResponseFunc(ctx, req, resp, err)
return m.HandleValidationResponseFunc(ctx, req, resp)
}

var _ VaryMatcher = (*MockVaryMatcher)(nil)
10 changes: 2 additions & 8 deletions internal/validationresponsehandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type ValidationResponseHandler interface {
ctx RevalidationContext,
req *http.Request,
resp *http.Response,
err error,
) (*http.Response, error)
}

Expand Down Expand Up @@ -75,9 +74,8 @@ func (r *validationResponseHandler) HandleValidationResponse(
ctx RevalidationContext,
req *http.Request,
resp *http.Response,
err error,
) (*http.Response, error) {
if err == nil && req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified {
if req.Method == http.MethodGet && resp.StatusCode == http.StatusNotModified {
// RFC 9111 §4.3.3 Handling Validation Responses (304 Not Modified)
// RFC 9111 §4.3.4 Freshening Stored Responses upon Validation
updateStoredHeaders(ctx.Stored.Data, resp)
Expand All @@ -90,7 +88,7 @@ func (r *validationResponseHandler) HandleValidationResponse(
ccResp CCResponseDirectives
ccRespOnce bool
)
if (err != nil || isStaleErrorAllowed(resp.StatusCode)) && req.Method == http.MethodGet {
if isStaleErrorAllowed(resp.StatusCode) && req.Method == http.MethodGet {
ccResp = ParseCCResponseDirectives(resp.Header)
ccRespOnce = true
if r.siep.CanStaleOnError(ctx.Freshness, ccResp) {
Expand All @@ -103,10 +101,6 @@ func (r *validationResponseHandler) HandleValidationResponse(
}
}

if err != nil {
return nil, err
}

if !ccRespOnce {
ccResp = ParseCCResponseDirectives(resp.Header)
}
Expand Down
42 changes: 9 additions & 33 deletions internal/validationresponsehandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package internal

import (
"errors"
"log/slog"
"net/http"
"net/url"
Expand Down Expand Up @@ -48,9 +47,8 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) {
}

type args struct {
req *http.Request
resp *http.Response
inputErr error
req *http.Request
resp *http.Response
}

tests := []struct {
Expand Down Expand Up @@ -99,31 +97,7 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) {
},
},
{
name: "GET with error, stale allowed",
handler: &validationResponseHandler{
l: noopLogger,
siep: &MockStaleIfErrorPolicy{
CanStaleOnErrorFunc: func(*Freshness, ...StaleIfErrorer) bool { return true },
},
clock: &MockClock{NowResult: base},
},
setup: func(tt *testing.T, handler *validationResponseHandler) args {
return args{
req: &http.Request{Method: http.MethodGet},
resp: &http.Response{
StatusCode: http.StatusInternalServerError,
Header: http.Header{"Cache-Control": {"stale-if-error=60"}},
},
inputErr: errors.New("network error"),
}
},
assert: func(tt *testing.T, got *http.Response, err error) {
testutil.RequireNoError(tt, err)
testutil.AssertEqual(tt, http.StatusOK, got.StatusCode)
},
},
{
name: "GET with error, stale not allowed",
name: "GET with error status, stale not allowed",
handler: &validationResponseHandler{
l: noopLogger,
siep: &MockStaleIfErrorPolicy{
Expand All @@ -132,18 +106,20 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) {
clock: &MockClock{NowResult: base},
},
setup: func(tt *testing.T, handler *validationResponseHandler) args {
handler.ce = CacheabilityEvaluatorFunc(
func(*http.Response, CCRequestDirectives, CCResponseDirectives) bool { return false },
)
return args{
req: &http.Request{Method: http.MethodGet},
resp: &http.Response{
StatusCode: http.StatusInternalServerError,
Header: http.Header{},
},
inputErr: errors.New("network error"),
}
},
assert: func(tt *testing.T, got *http.Response, err error) {
testutil.RequireError(tt, err)
testutil.AssertNil(tt, got)
testutil.RequireNoError(tt, err)
testutil.AssertEqual(tt, "BYPASS", got.Header.Get(CacheStatusHeader))
},
},
{
Expand Down Expand Up @@ -202,7 +178,7 @@ func Test_validationResponseHandler_HandleValidationResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := tt.setup(t, tt.handler)
got, err := tt.handler.HandleValidationResponse(ctx, a.req, a.resp, a.inputErr)
got, err := tt.handler.HandleValidationResponse(ctx, a.req, a.resp)
tt.assert(t, got, err)
})
}
Expand Down
7 changes: 5 additions & 2 deletions roundtripper.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ func (r *transport) handleCacheHit(
revalidate:
req = withConditionalHeaders(req, stored.Data.Header)
resp, start, end, err := r.roundTripTimed(req)
if err != nil {
return nil, err
}
ctx := internal.RevalidationContext{
URLKey: urlKey,
Start: start,
Expand All @@ -350,7 +353,7 @@ revalidate:
RefIndex: refIndex,
Freshness: freshness,
}
return r.vrh.HandleValidationResponse(ctx, req, resp, err)
return r.vrh.HandleValidationResponse(ctx, req, resp)
}

func (r *transport) serveFromCache(
Expand Down Expand Up @@ -441,7 +444,7 @@ func (r *transport) backgroundRevalidate(
Freshness: freshness,
}
//nolint:bodyclose // The response is not used, so we don't need to close it.
_, err = r.vrh.HandleValidationResponse(revalCtx, req, resp, nil)
_, err = r.vrh.HandleValidationResponse(revalCtx, req, resp)
errc <- err
}()

Expand Down
26 changes: 13 additions & 13 deletions roundtripper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func mockTransport(fields func(rt *transport)) *transport {
ci: &internal.MockCacheInvalidator{},
rs: &internal.MockResponseStorer{},
vrh: &internal.MockValidationResponseHandler{
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
return resp, err
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
return resp, nil
},
},
clock: &internal.MockClock{NowResult: time.Now()},
Expand Down Expand Up @@ -224,9 +224,9 @@ func Test_transport_CacheHit_MustRevalidate_Stale(t *testing.T) {
},
}
rt.vrh = &internal.MockValidationResponseHandler{
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
mockVHCalled = true
return resp, err
return resp, nil
},
}
})
Expand Down Expand Up @@ -259,9 +259,9 @@ func Test_transport_CacheHit_NoCacheUnqualified(t *testing.T) {
},
}
rt.vrh = &internal.MockValidationResponseHandler{
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
mockVHCalled = true
return resp, err
return resp, nil
},
}
})
Expand Down Expand Up @@ -539,10 +539,10 @@ func Test_transport_RevalidationPath(t *testing.T) {
},
}
rt.vrh = &internal.MockValidationResponseHandler{
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
mockVHCalled = true
internal.CacheStatusRevalidated.ApplyTo(resp.Header)
return resp, err
return resp, nil
},
}
})
Expand Down Expand Up @@ -596,9 +596,9 @@ func Test_transport_SWR_NormalPath(t *testing.T) {
rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0}
rt.siep = &internal.MockStaleIfErrorPolicy{}
rt.vrh = &internal.MockValidationResponseHandler{
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
revalidateCalled <- struct{}{} // Signal that revalidation was called
return resp, err
return resp, nil
},
}
rt.swrTimeout = DefaultSWRTimeout
Expand Down Expand Up @@ -669,7 +669,7 @@ func Test_transport_SWR_NormalPathAndError(t *testing.T) {
rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0}
rt.swrTimeout = swrTimeout
rt.vrh = &internal.MockValidationResponseHandler{
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
defer func() { revalidateCalled <- struct{}{} }() // Signal that revalidation was called
return nil, errors.New("revalidation error")
},
Expand Down Expand Up @@ -736,9 +736,9 @@ func Test_transport_SWR_Timeout(t *testing.T) {
rt.clock = &internal.MockClock{NowResult: base.Add(5 * time.Second), SinceResult: 0}
rt.swrTimeout = swrTimeout
rt.vrh = &internal.MockValidationResponseHandler{
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response, err error) (*http.Response, error) {
HandleValidationResponseFunc: func(ctx internal.RevalidationContext, req *http.Request, resp *http.Response) (*http.Response, error) {
revalidateCalled <- struct{}{} // Signal that revalidation was called
return resp, err
return resp, nil
},
}
rt.upstream = &internal.MockRoundTripper{
Expand Down