From 92ad8ab9dc351bea14894a89a220933f59514321 Mon Sep 17 00:00:00 2001 From: abhishek-singla-97 <“abhishek.singla@angelbroking.com”> Date: Fri, 19 Dec 2025 12:59:04 +0530 Subject: [PATCH 1/2] feat: add RemoveKeys #17 --- batch_test.go | 251 +++++++++++++++++++++++++++++++ group.go | 78 ++++++++++ instance_test.go | 4 +- stats.go | 30 ++++ stats_test.go | 4 + transport/http_transport.go | 118 ++++++++++++++- transport/http_transport_test.go | 7 + transport/mock_transport.go | 6 + transport/pb/groupcache.pb.go | 162 ++++++++++++++++++-- transport/pb/groupcache.proto | 9 ++ transport/peer/client.go | 5 + transport/types.go | 1 + 12 files changed, 655 insertions(+), 20 deletions(-) create mode 100644 batch_test.go diff --git a/batch_test.go b/batch_test.go new file mode 100644 index 0000000..cc7da7b --- /dev/null +++ b/batch_test.go @@ -0,0 +1,251 @@ +/* +Copyright 2024 Groupcache Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package groupcache_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/groupcache/groupcache-go/v3" + "github.com/groupcache/groupcache-go/v3/cluster" + "github.com/groupcache/groupcache-go/v3/transport" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRemoveKeys(t *testing.T) { + ctx := context.Background() + + err := cluster.Start(ctx, 3, groupcache.Options{}) + require.NoError(t, err) + defer func() { _ = cluster.Shutdown(ctx) }() + + callCount := make(map[string]int) + getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error { + callCount[key]++ + return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5)) + }) + + // Register the group on ALL daemons (required for broadcast) + group, err := cluster.DaemonAt(0).NewGroup("test-remove-keys", 3000000, getter) + require.NoError(t, err) + for i := 1; i < 3; i++ { + _, err := cluster.DaemonAt(i).NewGroup("test-remove-keys", 3000000, getter) + require.NoError(t, err) + } + + keys := []string{"key1", "key2", "key3"} + + // First, populate the cache by getting each key + for _, key := range keys { + var value string + err := group.Get(ctx, key, transport.StringSink(&value)) + require.NoError(t, err) + assert.Equal(t, fmt.Sprintf("value-%s", key), value) + } + + // Verify getter was called for each key + for _, key := range keys { + assert.Equal(t, 1, callCount[key], "getter should be called once for %s", key) + } + + // Now remove all keys using variadic signature + err = group.RemoveKeys(ctx, "key1", "key2", "key3") + require.NoError(t, err) + + // Fetch again - getter should be called again since keys were removed + for _, key := range keys { + var value string + err := group.Get(ctx, key, transport.StringSink(&value)) + require.NoError(t, err) + } + + // Verify getter was called again for each key + for _, key := range keys { + assert.Equal(t, 2, callCount[key], "getter should be called twice for %s after removal", key) + } +} + +func TestRemoveKeysEmpty(t *testing.T) { + ctx := context.Background() + + err := cluster.Start(ctx, 2, groupcache.Options{}) + require.NoError(t, err) + defer func() { _ = cluster.Shutdown(ctx) }() + + getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error { + return dest.SetString("value", time.Now().Add(time.Minute)) + }) + + // Register the group on ALL daemons + group, err := cluster.DaemonAt(0).NewGroup("test-remove-empty", 3000000, getter) + require.NoError(t, err) + _, err = cluster.DaemonAt(1).NewGroup("test-remove-empty", 3000000, getter) + require.NoError(t, err) + + // Test RemoveKeys with no keys - should not error + err = group.RemoveKeys(ctx) + require.NoError(t, err) +} + +func TestRemoveKeysWithSlice(t *testing.T) { + ctx := context.Background() + + err := cluster.Start(ctx, 2, groupcache.Options{}) + require.NoError(t, err) + defer func() { _ = cluster.Shutdown(ctx) }() + + getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error { + return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5)) + }) + + // Register the group on ALL daemons + group, err := cluster.DaemonAt(0).NewGroup("test-remove-slice", 3000000, getter) + require.NoError(t, err) + _, err = cluster.DaemonAt(1).NewGroup("test-remove-slice", 3000000, getter) + require.NoError(t, err) + + keys := []string{"key1", "key2", "key3"} + + // Populate cache + for _, key := range keys { + var value string + err := group.Get(ctx, key, transport.StringSink(&value)) + require.NoError(t, err) + } + + // Test RemoveKeys with slice expansion + err = group.RemoveKeys(ctx, keys...) + require.NoError(t, err) +} + +func TestRemoveKeysStats(t *testing.T) { + ctx := context.Background() + + err := cluster.Start(ctx, 2, groupcache.Options{}) + require.NoError(t, err) + defer func() { _ = cluster.Shutdown(ctx) }() + + getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error { + return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5)) + }) + + // Register the group on ALL daemons + group, err := cluster.DaemonAt(0).NewGroup("test-remove-stats", 3000000, getter) + require.NoError(t, err) + _, err = cluster.DaemonAt(1).NewGroup("test-remove-stats", 3000000, getter) + require.NoError(t, err) + + // Call RemoveKeys + err = group.RemoveKeys(ctx, "key1", "key2", "key3") + require.NoError(t, err) + + // Note: Stats are internal to the group implementation + // The batch stats are incremented but not directly accessible via interface + // This test verifies that the operation completes without error +} + +func BenchmarkRemoveKeys(b *testing.B) { + ctx := context.Background() + + err := cluster.Start(ctx, 3, groupcache.Options{}) + if err != nil { + b.Fatal(err) + } + defer func() { _ = cluster.Shutdown(ctx) }() + + getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error { + return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5)) + }) + + // Register the group on ALL daemons + group, err := cluster.DaemonAt(0).NewGroup("bench-remove", 3000000, getter) + if err != nil { + b.Fatal(err) + } + for i := 1; i < 3; i++ { + _, err := cluster.DaemonAt(i).NewGroup("bench-remove", 3000000, getter) + if err != nil { + b.Fatal(err) + } + } + + // Prepare keys + keys := make([]string, 100) + for i := 0; i < 100; i++ { + keys[i] = fmt.Sprintf("key-%d", i) + } + + // Populate cache first + for _, key := range keys { + var value string + _ = group.Get(ctx, key, transport.StringSink(&value)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = group.RemoveKeys(ctx, keys...) + } +} + +func BenchmarkRemoveKeysVsLoop(b *testing.B) { + ctx := context.Background() + + err := cluster.Start(ctx, 3, groupcache.Options{}) + if err != nil { + b.Fatal(err) + } + defer func() { _ = cluster.Shutdown(ctx) }() + + getter := groupcache.GetterFunc(func(ctx context.Context, key string, dest transport.Sink) error { + return dest.SetString(fmt.Sprintf("value-%s", key), time.Now().Add(time.Minute*5)) + }) + + // Register the group on ALL daemons + group, err := cluster.DaemonAt(0).NewGroup("bench-compare", 3000000, getter) + if err != nil { + b.Fatal(err) + } + for i := 1; i < 3; i++ { + _, err := cluster.DaemonAt(i).NewGroup("bench-compare", 3000000, getter) + if err != nil { + b.Fatal(err) + } + } + + // Prepare keys + keys := make([]string, 50) + for i := 0; i < 50; i++ { + keys[i] = fmt.Sprintf("key-%d", i) + } + + b.Run("RemoveKeys", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = group.RemoveKeys(ctx, keys...) + } + }) + + b.Run("LoopRemove", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, key := range keys { + _ = group.Remove(ctx, key) + } + } + }) +} diff --git a/group.go b/group.go index 83a23ad..be039d7 100644 --- a/group.go +++ b/group.go @@ -42,6 +42,7 @@ type Group interface { Remove(context.Context, string) error UsedBytes() (int64, int64) Name() string + RemoveKeys(ctx context.Context, keys ...string) error } // A Getter loads data for a key. @@ -443,6 +444,79 @@ func (g *group) LocalRemove(key string) { }) } +func (g *group) RemoveKeys(ctx context.Context, keys ...string) error { + if len(keys) == 0 { + return nil + } + + g.Stats.BatchRemoves.Add(1) + g.Stats.BatchKeysRemoved.Add(int64(len(keys))) + + keysByOwner := make(map[peer.Client][]string) + var localKeys []string + + for _, key := range keys { + owner, isRemote := g.instance.PickPeer(key) + if isRemote { + keysByOwner[owner] = append(keysByOwner[owner], key) + } else { + localKeys = append(localKeys, key) + } + } + + for _, key := range localKeys { + g.LocalRemove(key) + } + + multiErr := &MultiError{} + errCh := make(chan error) + + // Step 3: Send batch requests to owners (parallel) + var wg sync.WaitGroup + for owner, ownerKeys := range keysByOwner { + wg.Add(1) + go func(p peer.Client, k []string) { + errCh <- p.RemoveKeys(ctx, &pb.RemoveMultiRequest{ + Group: &g.name, + Keys: k, + }) + wg.Done() + }(owner, ownerKeys) + } + + allPeers := g.instance.getAllPeers() + for _, p := range allPeers { + if p.PeerInfo().IsSelf { + continue + } + if _, isOwner := keysByOwner[p]; isOwner { + continue + } + + wg.Add(1) + go func(peer peer.Client) { + errCh <- peer.RemoveKeys(ctx, &pb.RemoveMultiRequest{ + Group: &g.name, + Keys: keys, + }) + wg.Done() + }(p) + } + + go func() { + wg.Wait() + close(errCh) + }() + + for err := range errCh { + if err != nil { + multiErr.Add(err) + } + } + + return multiErr.NilOrError() +} + func (g *group) populateCache(key string, value transport.ByteView, cache Cache) { if g.maxCacheBytes <= 0 { return @@ -524,6 +598,8 @@ func (g *group) registerInstruments(meter otelmetric.Meter) error { o.ObserveInt64(instruments.LocalLoadsCounter(), g.Stats.LocalLoads.Get(), observeOptions...) o.ObserveInt64(instruments.LocalLoadErrsCounter(), g.Stats.LocalLoadErrs.Get(), observeOptions...) o.ObserveInt64(instruments.GetFromPeersLatencyMaxGauge(), g.Stats.GetFromPeersLatencyLower.Get(), observeOptions...) + o.ObserveInt64(instruments.BatchRemovesCounter(), g.Stats.BatchRemoves.Get(), observeOptions...) + o.ObserveInt64(instruments.BatchKeysRemovedCounter(), g.Stats.BatchKeysRemoved.Get(), observeOptions...) return nil }, @@ -536,6 +612,8 @@ func (g *group) registerInstruments(meter otelmetric.Meter) error { instruments.LocalLoadsCounter(), instruments.LocalLoadErrsCounter(), instruments.GetFromPeersLatencyMaxGauge(), + instruments.BatchRemovesCounter(), + instruments.BatchKeysRemovedCounter(), ) return err diff --git a/instance_test.go b/instance_test.go index 4107c32..a00a5a3 100644 --- a/instance_test.go +++ b/instance_test.go @@ -523,11 +523,13 @@ func TestNewGroupRegistersMetricsWithMeterProvider(t *testing.T) { "groupcache.group.loads.deduped", "groupcache.group.local.loads", "groupcache.group.local.load_errors", + "groupcache.group.batch.removes", + "groupcache.group.batch.keys_removed", } assert.Equal(t, expectedCounters, recMeter.counterNames) assert.Equal(t, []string{"groupcache.group.peer.latency_max_ms"}, recMeter.updownNames) assert.True(t, recMeter.callbackRegistered, "expected callback registration for metrics") - assert.Equal(t, 9, recMeter.instrumentCount) + assert.Equal(t, 11, recMeter.instrumentCount) } func TestNewGroupFailsWhenMetricRegistrationFails(t *testing.T) { diff --git a/stats.go b/stats.go index 8d97363..43df6b8 100644 --- a/stats.go +++ b/stats.go @@ -80,6 +80,8 @@ type GroupStats struct { LoadsDeduped AtomicInt // after singleflight LocalLoads AtomicInt // total good local loads LocalLoadErrs AtomicInt // total bad local loads + BatchRemoves AtomicInt // total RemoveKeys requests + BatchKeysRemoved AtomicInt // total keys removed via RemoveKeys } type MeterProviderOption func(*MeterProvider) @@ -124,6 +126,8 @@ type groupInstruments struct { localLoadsCounter metric.Int64ObservableCounter localLoadErrsCounter metric.Int64ObservableCounter getFromPeersLatencyMaxGauge metric.Int64ObservableUpDownCounter + batchRemovesCounter metric.Int64ObservableCounter + batchKeysRemovedCounter metric.Int64ObservableCounter } // newGroupInstruments registers all instruments that map to GroupStats counters. @@ -200,6 +204,22 @@ func newGroupInstruments(meter metric.Meter) (*groupInstruments, error) { return nil, err } + batchRemovesCounter, err := meter.Int64ObservableCounter( + "groupcache.group.batch.removes", + metric.WithDescription("Total RemoveKeys requests"), + ) + if err != nil { + return nil, err + } + + batchKeysRemovedCounter, err := meter.Int64ObservableCounter( + "groupcache.group.batch.keys_removed", + metric.WithDescription("Total keys removed via RemoveKeys"), + ) + if err != nil { + return nil, err + } + return &groupInstruments{ getsCounter: getsCounter, hitsCounter: hitsCounter, @@ -210,6 +230,8 @@ func newGroupInstruments(meter metric.Meter) (*groupInstruments, error) { localLoadsCounter: localLoadsCounter, localLoadErrsCounter: localLoadErrsCounter, getFromPeersLatencyMaxGauge: getFromPeersLatencyMaxGauge, + batchRemovesCounter: batchRemovesCounter, + batchKeysRemovedCounter: batchKeysRemovedCounter, }, nil } @@ -249,6 +271,14 @@ func (gm *groupInstruments) GetFromPeersLatencyMaxGauge() metric.Int64Observable return gm.getFromPeersLatencyMaxGauge } +func (gm *groupInstruments) BatchRemovesCounter() metric.Int64ObservableCounter { + return gm.batchRemovesCounter +} + +func (gm *groupInstruments) BatchKeysRemovedCounter() metric.Int64ObservableCounter { + return gm.batchKeysRemovedCounter +} + type cacheInstruments struct { rejectedCounter metric.Int64Counter bytesGauge metric.Int64UpDownCounter diff --git a/stats_test.go b/stats_test.go index db4ee2e..990bceb 100644 --- a/stats_test.go +++ b/stats_test.go @@ -67,6 +67,8 @@ func TestNewGroupInstrumentsRegistersAllCounters(t *testing.T) { "groupcache.group.loads.deduped", "groupcache.group.local.loads", "groupcache.group.local.load_errors", + "groupcache.group.batch.removes", + "groupcache.group.batch.keys_removed", } assert.Equal(t, expectedCounters, meter.counterNames) assert.Equal(t, []string{"groupcache.group.peer.latency_max_ms"}, meter.updownNames) @@ -80,6 +82,8 @@ func TestNewGroupInstrumentsRegistersAllCounters(t *testing.T) { assert.NotNil(t, inst.LocalLoadsCounter()) assert.NotNil(t, inst.LocalLoadErrsCounter()) assert.NotNil(t, inst.GetFromPeersLatencyMaxGauge()) + assert.NotNil(t, inst.BatchRemovesCounter()) + assert.NotNil(t, inst.BatchKeysRemovedCounter()) } func TestNewGroupInstrumentsErrorsOnCounterFailure(t *testing.T) { diff --git a/transport/http_transport.go b/transport/http_transport.go index 191569e..0032803 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -22,6 +22,7 @@ import ( "bytes" "context" "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -93,6 +94,12 @@ type Transport interface { ListenAddress() string } +type transportMethods interface { + Get(ctx context.Context, key string, dest Sink) error + RemoteSet(string, []byte, time.Time) + LocalRemove(string) +} + // HttpTransportOptions options for creating a new HttpTransport type HttpTransportOptions struct { // Context (Optional) specifies a context for the server to use when it @@ -307,12 +314,6 @@ func (t *HttpTransport) ServeHTTP(w http.ResponseWriter, r *http.Request) { groupName := parts[0] key := parts[1] - type transportMethods interface { - Get(ctx context.Context, key string, dest Sink) error - RemoteSet(string, []byte, time.Time) - LocalRemove(string) - } - // Fetch the value for this group/key. group := t.instance.GetGroup(groupName).(transportMethods) if group == nil { @@ -321,6 +322,11 @@ func (t *HttpTransport) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if strings.HasPrefix(key, "_batch/") { + t.handleBatchRequest(ctx, w, r, groupName, key, group, recordError) + return + } + // Delete the key and return 200 if r.Method == http.MethodDelete { group.LocalRemove(key) @@ -401,6 +407,51 @@ func (t *HttpTransport) ServeHTTP(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(body) } +func (t *HttpTransport) handleBatchRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, groupName string, path string, group transportMethods, recordError func()) { + if r.Method != http.MethodPost { + http.Error(w, "batch operations require POST method", http.StatusMethodNotAllowed) + recordError() + return + } + + defer r.Body.Close() + + b := bufferPool.Get().(*bytes.Buffer) + b.Reset() + defer bufferPool.Put(b) + _, err := io.Copy(b, r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + recordError() + return + } + + switch path { + case "_batch/remove": + t.handleBatchRemove(ctx, w, b.Bytes(), group, recordError) + default: + http.Error(w, "unknown batch operation: "+path, http.StatusBadRequest) + recordError() + } +} + +// handleBatchRemove handles a batch remove request +func (t *HttpTransport) handleBatchRemove(ctx context.Context, w http.ResponseWriter, body []byte, group transportMethods, recordError func()) { + var req pb.RemoveMultiRequest + if err := json.Unmarshal(body, &req); err != nil { + http.Error(w, "invalid request: "+err.Error(), http.StatusBadRequest) + recordError() + return + } + + // Remove each key locally + for _, key := range req.Keys { + group.LocalRemove(key) + } + + w.WriteHeader(http.StatusOK) +} + // NewClient creates a new http client for the provided peer func (t *HttpTransport) NewClient(_ context.Context, p peer.Info) (peer.Client, error) { return &HttpClient{ @@ -581,6 +632,33 @@ func (h *HttpClient) Remove(ctx context.Context, in *pb.GetRequest) error { return nil } +func (h *HttpClient) RemoveKeys(ctx context.Context, in *pb.RemoveMultiRequest) error { + ctx, span, endSpan := h.startSpan(ctx, "GroupCache.RemoveKeys") + defer endSpan() + + body, err := json.Marshal(in) + if err != nil { + werr := fmt.Errorf("while marshaling RemoveMultiRequest body: %w", err) + return recordSpanError(span, werr) + } + + var res http.Response + if err := h.makeBatchRequest(ctx, http.MethodPost, in.GetGroup(), "_batch/remove", bytes.NewReader(body), &res); err != nil { + return recordSpanError(span, err) + } + + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + msg, _ := io.ReadAll(res.Body) + err := fmt.Errorf("server returned status %d: %s", res.StatusCode, msg) + return recordSpanError(span, err) + } + + markSpanOK(span) + return nil +} + func (h *HttpClient) PeerInfo() peer.Info { return h.info } @@ -594,6 +672,34 @@ type request interface { GetKey() string } +func (h *HttpClient) makeBatchRequest(ctx context.Context, method string, group string, path string, body io.Reader, out *http.Response) error { + u := fmt.Sprintf( + "%v%v/%v", + h.endpoint, + url.PathEscape(group), + path, + ) + + var bodyBytes []byte + if body != nil { + bodyBytes, _ = io.ReadAll(body) + body = bytes.NewReader(bodyBytes) + } + + req, err := http.NewRequestWithContext(ctx, method, u, body) + if err != nil { + return err + } + + res, err := h.client.Do(req) + if err != nil { + return err + } + + *out = *res + return nil +} + func (h *HttpClient) makeRequest(ctx context.Context, m string, in request, b io.Reader, out *http.Response) error { u := fmt.Sprintf( "%v%v/%v", diff --git a/transport/http_transport_test.go b/transport/http_transport_test.go index 6b13b91..7c52ff7 100644 --- a/transport/http_transport_test.go +++ b/transport/http_transport_test.go @@ -335,6 +335,13 @@ func (t *tracingGroup) Name() string { return groupName } +func (t *tracingGroup) RemoveKeys(ctx context.Context, keys ...string) error { + for _, key := range keys { + t.LocalRemove(key) + } + return nil +} + func spanHasAttribute(span sdktrace.ReadOnlySpan, key attribute.Key, expected string) bool { for _, attr := range span.Attributes() { if attr.Key == key && attr.Value.AsString() == expected { diff --git a/transport/mock_transport.go b/transport/mock_transport.go index e927bed..27aaeea 100644 --- a/transport/mock_transport.go +++ b/transport/mock_transport.go @@ -144,6 +144,12 @@ func (c *MockClient) Set(ctx context.Context, in *pb.SetRequest) error { return nil } +func (c *MockClient) RemoveKeys(ctx context.Context, in *pb.RemoveMultiRequest) error { + c.addCall("RemoveKeys", len(in.Keys)) + // TODO: Implement when needed + return nil +} + func (c *MockClient) PeerInfo() peer.Info { c.addCall("PeerInfo", 1) return c.peer diff --git a/transport/pb/groupcache.pb.go b/transport/pb/groupcache.pb.go index 6958cbc..6988153 100644 --- a/transport/pb/groupcache.pb.go +++ b/transport/pb/groupcache.pb.go @@ -16,7 +16,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.32.0 -// protoc (unknown) +// protoc v6.33.2 // source: transport/pb/groupcache.proto package pb @@ -224,6 +224,108 @@ func (x *SetRequest) GetExpire() int64 { return 0 } +type RemoveMultiRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Group *string `protobuf:"bytes,1,req,name=group" json:"group,omitempty"` + Keys []string `protobuf:"bytes,2,rep,name=keys" json:"keys,omitempty"` +} + +func (x *RemoveMultiRequest) Reset() { + *x = RemoveMultiRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_transport_pb_groupcache_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RemoveMultiRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RemoveMultiRequest) ProtoMessage() {} + +func (x *RemoveMultiRequest) ProtoReflect() protoreflect.Message { + mi := &file_transport_pb_groupcache_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RemoveMultiRequest.ProtoReflect.Descriptor instead. +func (*RemoveMultiRequest) Descriptor() ([]byte, []int) { + return file_transport_pb_groupcache_proto_rawDescGZIP(), []int{3} +} + +func (x *RemoveMultiRequest) GetGroup() string { + if x != nil && x.Group != nil { + return *x.Group + } + return "" +} + +func (x *RemoveMultiRequest) GetKeys() []string { + if x != nil { + return x.Keys + } + return nil +} + +type RemoveMultiResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + FailedKeys []string `protobuf:"bytes,1,rep,name=failed_keys,json=failedKeys" json:"failed_keys,omitempty"` +} + +func (x *RemoveMultiResponse) Reset() { + *x = RemoveMultiResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_transport_pb_groupcache_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *RemoveMultiResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RemoveMultiResponse) ProtoMessage() {} + +func (x *RemoveMultiResponse) ProtoReflect() protoreflect.Message { + mi := &file_transport_pb_groupcache_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RemoveMultiResponse.ProtoReflect.Descriptor instead. +func (*RemoveMultiResponse) Descriptor() ([]byte, []int) { + return file_transport_pb_groupcache_proto_rawDescGZIP(), []int{4} +} + +func (x *RemoveMultiResponse) GetFailedKeys() []string { + if x != nil { + return x.FailedKeys + } + return nil +} + var File_transport_pb_groupcache_proto protoreflect.FileDescriptor var file_transport_pb_groupcache_proto_rawDesc = []byte{ @@ -244,13 +346,21 @@ var file_transport_pb_groupcache_proto_rawDesc = []byte{ 0x18, 0x02, 0x20, 0x02, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x03, 0x52, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x32, 0x36, 0x0a, 0x0a, 0x47, 0x72, 0x6f, - 0x75, 0x70, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, 0x28, 0x0a, 0x03, 0x47, 0x65, 0x74, 0x12, 0x0e, - 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, - 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x42, 0x28, 0x5a, 0x26, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, - 0x67, 0x72, 0x6f, 0x75, 0x70, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2f, 0x67, 0x72, 0x6f, 0x75, 0x70, - 0x63, 0x61, 0x63, 0x68, 0x65, 0x2d, 0x67, 0x6f, 0x2f, 0x70, 0x62, + 0x03, 0x52, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x22, 0x3e, 0x0a, 0x12, 0x52, 0x65, 0x6d, + 0x6f, 0x76, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, + 0x14, 0x0a, 0x05, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x18, 0x01, 0x20, 0x02, 0x28, 0x09, 0x52, 0x05, + 0x67, 0x72, 0x6f, 0x75, 0x70, 0x12, 0x12, 0x0a, 0x04, 0x6b, 0x65, 0x79, 0x73, 0x18, 0x02, 0x20, + 0x03, 0x28, 0x09, 0x52, 0x04, 0x6b, 0x65, 0x79, 0x73, 0x22, 0x36, 0x0a, 0x13, 0x52, 0x65, 0x6d, + 0x6f, 0x76, 0x65, 0x4d, 0x75, 0x6c, 0x74, 0x69, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x1f, 0x0a, 0x0b, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x5f, 0x6b, 0x65, 0x79, 0x73, 0x18, + 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x66, 0x61, 0x69, 0x6c, 0x65, 0x64, 0x4b, 0x65, 0x79, + 0x73, 0x32, 0x36, 0x0a, 0x0a, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x43, 0x61, 0x63, 0x68, 0x65, 0x12, + 0x28, 0x0a, 0x03, 0x47, 0x65, 0x74, 0x12, 0x0e, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x0f, 0x2e, 0x70, 0x62, 0x2e, 0x47, 0x65, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x28, 0x5a, 0x26, 0x67, 0x69, 0x74, + 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x63, 0x61, 0x63, + 0x68, 0x65, 0x2f, 0x67, 0x72, 0x6f, 0x75, 0x70, 0x63, 0x61, 0x63, 0x68, 0x65, 0x2d, 0x67, 0x6f, + 0x2f, 0x70, 0x62, } var ( @@ -265,11 +375,13 @@ func file_transport_pb_groupcache_proto_rawDescGZIP() []byte { return file_transport_pb_groupcache_proto_rawDescData } -var file_transport_pb_groupcache_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_transport_pb_groupcache_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_transport_pb_groupcache_proto_goTypes = []interface{}{ - (*GetRequest)(nil), // 0: pb.GetRequest - (*GetResponse)(nil), // 1: pb.GetResponse - (*SetRequest)(nil), // 2: pb.SetRequest + (*GetRequest)(nil), // 0: pb.GetRequest + (*GetResponse)(nil), // 1: pb.GetResponse + (*SetRequest)(nil), // 2: pb.SetRequest + (*RemoveMultiRequest)(nil), // 3: pb.RemoveMultiRequest + (*RemoveMultiResponse)(nil), // 4: pb.RemoveMultiResponse } var file_transport_pb_groupcache_proto_depIdxs = []int32{ 0, // 0: pb.GroupCache.Get:input_type -> pb.GetRequest @@ -323,6 +435,30 @@ func file_transport_pb_groupcache_proto_init() { return nil } } + file_transport_pb_groupcache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RemoveMultiRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_transport_pb_groupcache_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*RemoveMultiResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -330,7 +466,7 @@ func file_transport_pb_groupcache_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_transport_pb_groupcache_proto_rawDesc, NumEnums: 0, - NumMessages: 3, + NumMessages: 5, NumExtensions: 0, NumServices: 1, }, diff --git a/transport/pb/groupcache.proto b/transport/pb/groupcache.proto index 6dc5185..c9820f8 100644 --- a/transport/pb/groupcache.proto +++ b/transport/pb/groupcache.proto @@ -37,6 +37,15 @@ message SetRequest { optional int64 expire = 4; } +message RemoveMultiRequest { + required string group = 1; + repeated string keys = 2; +} + +message RemoveMultiResponse { + repeated string failed_keys = 1; +} + service GroupCache { rpc Get(GetRequest) returns (GetResponse) { }; diff --git a/transport/peer/client.go b/transport/peer/client.go index d909b44..518e549 100644 --- a/transport/peer/client.go +++ b/transport/peer/client.go @@ -27,6 +27,7 @@ type Client interface { Get(context context.Context, in *pb.GetRequest, out *pb.GetResponse) error Remove(context context.Context, in *pb.GetRequest) error Set(context context.Context, in *pb.SetRequest) error + RemoveKeys(context context.Context, in *pb.RemoveMultiRequest) error PeerInfo() Info HashKey() string } @@ -49,6 +50,10 @@ func (e *NoOpClient) Set(context context.Context, in *pb.SetRequest) error { return nil } +func (e *NoOpClient) RemoveKeys(context context.Context, in *pb.RemoveMultiRequest) error { + return nil +} + func (e *NoOpClient) PeerInfo() Info { return e.Info } diff --git a/transport/types.go b/transport/types.go index 38c4b3f..bd0f5e1 100644 --- a/transport/types.go +++ b/transport/types.go @@ -27,4 +27,5 @@ type Group interface { Remove(context.Context, string) error UsedBytes() (int64, int64) Name() string + RemoveKeys(ctx context.Context, keys ...string) error } From 110178d19cafabd2e12b3732424abab396aff287 Mon Sep 17 00:00:00 2001 From: abhishek-singla-97 <“abhishek.singla@angelbroking.com”> Date: Sat, 20 Dec 2025 13:34:55 +0530 Subject: [PATCH 2/2] Rename batch methods to removeKeys for consistency, add groupStats, minor refactoring --- group.go | 24 +++++++----- instance_test.go | 4 +- batch_test.go => remove_keys_test.go | 19 +++++++--- stats.go | 28 +++++++------- stats_test.go | 8 ++-- transport/http_transport.go | 55 +++++++++------------------ transport/mock_transport.go | 2 +- transport/pb/groupcache.pb.go | 57 ++++++++++++++-------------- transport/pb/groupcache.proto | 4 +- transport/peer/client.go | 4 +- 10 files changed, 101 insertions(+), 104 deletions(-) rename batch_test.go => remove_keys_test.go (89%) diff --git a/group.go b/group.go index be039d7..404e39d 100644 --- a/group.go +++ b/group.go @@ -43,6 +43,7 @@ type Group interface { UsedBytes() (int64, int64) Name() string RemoveKeys(ctx context.Context, keys ...string) error + GroupStats() GroupStats } // A Getter loads data for a key. @@ -109,6 +110,11 @@ func (g *group) Name() string { return g.name } +// GroupStats returns the stats for this group. +func (g *group) GroupStats() GroupStats { + return g.Stats +} + // UsedBytes returns the total number of bytes used by the main and hot caches func (g *group) UsedBytes() (mainCache int64, hotCache int64) { return g.mainCache.Bytes(), g.hotCache.Bytes() @@ -449,8 +455,8 @@ func (g *group) RemoveKeys(ctx context.Context, keys ...string) error { return nil } - g.Stats.BatchRemoves.Add(1) - g.Stats.BatchKeysRemoved.Add(int64(len(keys))) + g.Stats.RemoveKeysRequests.Add(1) + g.Stats.RemovedKeys.Add(int64(len(keys))) keysByOwner := make(map[peer.Client][]string) var localKeys []string @@ -471,12 +477,12 @@ func (g *group) RemoveKeys(ctx context.Context, keys ...string) error { multiErr := &MultiError{} errCh := make(chan error) - // Step 3: Send batch requests to owners (parallel) + // Send removeKeys requests to owners (parallel) var wg sync.WaitGroup for owner, ownerKeys := range keysByOwner { wg.Add(1) go func(p peer.Client, k []string) { - errCh <- p.RemoveKeys(ctx, &pb.RemoveMultiRequest{ + errCh <- p.RemoveKeys(ctx, &pb.RemoveKeysRequest{ Group: &g.name, Keys: k, }) @@ -495,7 +501,7 @@ func (g *group) RemoveKeys(ctx context.Context, keys ...string) error { wg.Add(1) go func(peer peer.Client) { - errCh <- peer.RemoveKeys(ctx, &pb.RemoveMultiRequest{ + errCh <- peer.RemoveKeys(ctx, &pb.RemoveKeysRequest{ Group: &g.name, Keys: keys, }) @@ -598,8 +604,8 @@ func (g *group) registerInstruments(meter otelmetric.Meter) error { o.ObserveInt64(instruments.LocalLoadsCounter(), g.Stats.LocalLoads.Get(), observeOptions...) o.ObserveInt64(instruments.LocalLoadErrsCounter(), g.Stats.LocalLoadErrs.Get(), observeOptions...) o.ObserveInt64(instruments.GetFromPeersLatencyMaxGauge(), g.Stats.GetFromPeersLatencyLower.Get(), observeOptions...) - o.ObserveInt64(instruments.BatchRemovesCounter(), g.Stats.BatchRemoves.Get(), observeOptions...) - o.ObserveInt64(instruments.BatchKeysRemovedCounter(), g.Stats.BatchKeysRemoved.Get(), observeOptions...) + o.ObserveInt64(instruments.RemoveKeysRequestsCounter(), g.Stats.RemoveKeysRequests.Get(), observeOptions...) + o.ObserveInt64(instruments.RemovedKeysCounter(), g.Stats.RemovedKeys.Get(), observeOptions...) return nil }, @@ -612,8 +618,8 @@ func (g *group) registerInstruments(meter otelmetric.Meter) error { instruments.LocalLoadsCounter(), instruments.LocalLoadErrsCounter(), instruments.GetFromPeersLatencyMaxGauge(), - instruments.BatchRemovesCounter(), - instruments.BatchKeysRemovedCounter(), + instruments.RemoveKeysRequestsCounter(), + instruments.RemovedKeysCounter(), ) return err diff --git a/instance_test.go b/instance_test.go index a00a5a3..03b90e5 100644 --- a/instance_test.go +++ b/instance_test.go @@ -523,8 +523,8 @@ func TestNewGroupRegistersMetricsWithMeterProvider(t *testing.T) { "groupcache.group.loads.deduped", "groupcache.group.local.loads", "groupcache.group.local.load_errors", - "groupcache.group.batch.removes", - "groupcache.group.batch.keys_removed", + "groupcache.group.remove_keys.requests", + "groupcache.group.removed_keys", } assert.Equal(t, expectedCounters, recMeter.counterNames) assert.Equal(t, []string{"groupcache.group.peer.latency_max_ms"}, recMeter.updownNames) diff --git a/batch_test.go b/remove_keys_test.go similarity index 89% rename from batch_test.go rename to remove_keys_test.go index cc7da7b..c20b088 100644 --- a/batch_test.go +++ b/remove_keys_test.go @@ -147,18 +147,27 @@ func TestRemoveKeysStats(t *testing.T) { }) // Register the group on ALL daemons - group, err := cluster.DaemonAt(0).NewGroup("test-remove-stats", 3000000, getter) + transportGroup, err := cluster.DaemonAt(0).NewGroup("test-remove-stats", 3000000, getter) require.NoError(t, err) _, err = cluster.DaemonAt(1).NewGroup("test-remove-stats", 3000000, getter) require.NoError(t, err) - // Call RemoveKeys + // Cast to groupcache.Group to access GroupStats() + group, ok := transportGroup.(groupcache.Group) + require.True(t, ok, "expected transportGroup to implement groupcache.Group") + + // Capture stats before RemoveKeys + statsBefore := group.GroupStats() + removeKeysRequestsBefore := statsBefore.RemoveKeysRequests.Get() + removedKeysBefore := statsBefore.RemovedKeys.Get() + err = group.RemoveKeys(ctx, "key1", "key2", "key3") require.NoError(t, err) - // Note: Stats are internal to the group implementation - // The batch stats are incremented but not directly accessible via interface - // This test verifies that the operation completes without error + // Verify stats were incremented correctly + statsAfter := group.GroupStats() + assert.Equal(t, removeKeysRequestsBefore+1, statsAfter.RemoveKeysRequests.Get(), "RemoveKeysRequests should be incremented by 1") + assert.Equal(t, removedKeysBefore+3, statsAfter.RemovedKeys.Get(), "RemovedKeys should be incremented by 3") } func BenchmarkRemoveKeys(b *testing.B) { diff --git a/stats.go b/stats.go index 43df6b8..959eb46 100644 --- a/stats.go +++ b/stats.go @@ -80,8 +80,8 @@ type GroupStats struct { LoadsDeduped AtomicInt // after singleflight LocalLoads AtomicInt // total good local loads LocalLoadErrs AtomicInt // total bad local loads - BatchRemoves AtomicInt // total RemoveKeys requests - BatchKeysRemoved AtomicInt // total keys removed via RemoveKeys + RemoveKeysRequests AtomicInt // total RemoveKeys requests + RemovedKeys AtomicInt // total keys removed via RemoveKeys } type MeterProviderOption func(*MeterProvider) @@ -126,8 +126,8 @@ type groupInstruments struct { localLoadsCounter metric.Int64ObservableCounter localLoadErrsCounter metric.Int64ObservableCounter getFromPeersLatencyMaxGauge metric.Int64ObservableUpDownCounter - batchRemovesCounter metric.Int64ObservableCounter - batchKeysRemovedCounter metric.Int64ObservableCounter + removeKeysRequestsCounter metric.Int64ObservableCounter + removedKeysCounter metric.Int64ObservableCounter } // newGroupInstruments registers all instruments that map to GroupStats counters. @@ -204,16 +204,16 @@ func newGroupInstruments(meter metric.Meter) (*groupInstruments, error) { return nil, err } - batchRemovesCounter, err := meter.Int64ObservableCounter( - "groupcache.group.batch.removes", + removeKeysRequestsCounter, err := meter.Int64ObservableCounter( + "groupcache.group.remove_keys.requests", metric.WithDescription("Total RemoveKeys requests"), ) if err != nil { return nil, err } - batchKeysRemovedCounter, err := meter.Int64ObservableCounter( - "groupcache.group.batch.keys_removed", + removedKeysCounter, err := meter.Int64ObservableCounter( + "groupcache.group.removed_keys", metric.WithDescription("Total keys removed via RemoveKeys"), ) if err != nil { @@ -230,8 +230,8 @@ func newGroupInstruments(meter metric.Meter) (*groupInstruments, error) { localLoadsCounter: localLoadsCounter, localLoadErrsCounter: localLoadErrsCounter, getFromPeersLatencyMaxGauge: getFromPeersLatencyMaxGauge, - batchRemovesCounter: batchRemovesCounter, - batchKeysRemovedCounter: batchKeysRemovedCounter, + removeKeysRequestsCounter: removeKeysRequestsCounter, + removedKeysCounter: removedKeysCounter, }, nil } @@ -271,12 +271,12 @@ func (gm *groupInstruments) GetFromPeersLatencyMaxGauge() metric.Int64Observable return gm.getFromPeersLatencyMaxGauge } -func (gm *groupInstruments) BatchRemovesCounter() metric.Int64ObservableCounter { - return gm.batchRemovesCounter +func (gm *groupInstruments) RemoveKeysRequestsCounter() metric.Int64ObservableCounter { + return gm.removeKeysRequestsCounter } -func (gm *groupInstruments) BatchKeysRemovedCounter() metric.Int64ObservableCounter { - return gm.batchKeysRemovedCounter +func (gm *groupInstruments) RemovedKeysCounter() metric.Int64ObservableCounter { + return gm.removedKeysCounter } type cacheInstruments struct { diff --git a/stats_test.go b/stats_test.go index 990bceb..3a5ecab 100644 --- a/stats_test.go +++ b/stats_test.go @@ -67,8 +67,8 @@ func TestNewGroupInstrumentsRegistersAllCounters(t *testing.T) { "groupcache.group.loads.deduped", "groupcache.group.local.loads", "groupcache.group.local.load_errors", - "groupcache.group.batch.removes", - "groupcache.group.batch.keys_removed", + "groupcache.group.remove_keys.requests", + "groupcache.group.removed_keys", } assert.Equal(t, expectedCounters, meter.counterNames) assert.Equal(t, []string{"groupcache.group.peer.latency_max_ms"}, meter.updownNames) @@ -82,8 +82,8 @@ func TestNewGroupInstrumentsRegistersAllCounters(t *testing.T) { assert.NotNil(t, inst.LocalLoadsCounter()) assert.NotNil(t, inst.LocalLoadErrsCounter()) assert.NotNil(t, inst.GetFromPeersLatencyMaxGauge()) - assert.NotNil(t, inst.BatchRemovesCounter()) - assert.NotNil(t, inst.BatchKeysRemovedCounter()) + assert.NotNil(t, inst.RemoveKeysRequestsCounter()) + assert.NotNil(t, inst.RemovedKeysCounter()) } func TestNewGroupInstrumentsErrorsOnCounterFailure(t *testing.T) { diff --git a/transport/http_transport.go b/transport/http_transport.go index 0032803..2e45b5c 100644 --- a/transport/http_transport.go +++ b/transport/http_transport.go @@ -322,11 +322,6 @@ func (t *HttpTransport) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if strings.HasPrefix(key, "_batch/") { - t.handleBatchRequest(ctx, w, r, groupName, key, group, recordError) - return - } - // Delete the key and return 200 if r.Method == http.MethodDelete { group.LocalRemove(key) @@ -364,8 +359,18 @@ func (t *HttpTransport) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if r.Method == http.MethodPost { + if strings.HasPrefix(key, "_remove-keys/") { + t.handleRemoveKeysRequest(ctx, w, r, group, recordError) + return + } + http.Error(w, "invalid path for POST method", http.StatusNotFound) + recordError() + return + } + if r.Method != http.MethodGet { - http.Error(w, "Only GET, DELETE, PUT are supported", http.StatusMethodNotAllowed) + http.Error(w, "Only GET, DELETE, PUT, POST are supported", http.StatusMethodNotAllowed) recordError() return } @@ -407,13 +412,7 @@ func (t *HttpTransport) ServeHTTP(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(body) } -func (t *HttpTransport) handleBatchRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, groupName string, path string, group transportMethods, recordError func()) { - if r.Method != http.MethodPost { - http.Error(w, "batch operations require POST method", http.StatusMethodNotAllowed) - recordError() - return - } - +func (t *HttpTransport) handleRemoveKeysRequest(ctx context.Context, w http.ResponseWriter, r *http.Request, group transportMethods, recordError func()) { defer r.Body.Close() b := bufferPool.Get().(*bytes.Buffer) @@ -426,25 +425,13 @@ func (t *HttpTransport) handleBatchRequest(ctx context.Context, w http.ResponseW return } - switch path { - case "_batch/remove": - t.handleBatchRemove(ctx, w, b.Bytes(), group, recordError) - default: - http.Error(w, "unknown batch operation: "+path, http.StatusBadRequest) - recordError() - } -} - -// handleBatchRemove handles a batch remove request -func (t *HttpTransport) handleBatchRemove(ctx context.Context, w http.ResponseWriter, body []byte, group transportMethods, recordError func()) { - var req pb.RemoveMultiRequest - if err := json.Unmarshal(body, &req); err != nil { + var req pb.RemoveKeysRequest + if err := json.Unmarshal(b.Bytes(), &req); err != nil { http.Error(w, "invalid request: "+err.Error(), http.StatusBadRequest) recordError() return } - // Remove each key locally for _, key := range req.Keys { group.LocalRemove(key) } @@ -632,18 +619,18 @@ func (h *HttpClient) Remove(ctx context.Context, in *pb.GetRequest) error { return nil } -func (h *HttpClient) RemoveKeys(ctx context.Context, in *pb.RemoveMultiRequest) error { +func (h *HttpClient) RemoveKeys(ctx context.Context, in *pb.RemoveKeysRequest) error { ctx, span, endSpan := h.startSpan(ctx, "GroupCache.RemoveKeys") defer endSpan() body, err := json.Marshal(in) if err != nil { - werr := fmt.Errorf("while marshaling RemoveMultiRequest body: %w", err) + werr := fmt.Errorf("while marshaling RemoveKeysRequest body: %w", err) return recordSpanError(span, werr) } var res http.Response - if err := h.makeBatchRequest(ctx, http.MethodPost, in.GetGroup(), "_batch/remove", bytes.NewReader(body), &res); err != nil { + if err := h.makeRemoveKeysRequest(ctx, http.MethodPost, in.GetGroup(), "_remove-keys/", bytes.NewReader(body), &res); err != nil { return recordSpanError(span, err) } @@ -672,7 +659,7 @@ type request interface { GetKey() string } -func (h *HttpClient) makeBatchRequest(ctx context.Context, method string, group string, path string, body io.Reader, out *http.Response) error { +func (h *HttpClient) makeRemoveKeysRequest(ctx context.Context, method string, group string, path string, body io.Reader, out *http.Response) error { u := fmt.Sprintf( "%v%v/%v", h.endpoint, @@ -680,12 +667,6 @@ func (h *HttpClient) makeBatchRequest(ctx context.Context, method string, group path, ) - var bodyBytes []byte - if body != nil { - bodyBytes, _ = io.ReadAll(body) - body = bytes.NewReader(bodyBytes) - } - req, err := http.NewRequestWithContext(ctx, method, u, body) if err != nil { return err diff --git a/transport/mock_transport.go b/transport/mock_transport.go index 27aaeea..f72ecc9 100644 --- a/transport/mock_transport.go +++ b/transport/mock_transport.go @@ -144,7 +144,7 @@ func (c *MockClient) Set(ctx context.Context, in *pb.SetRequest) error { return nil } -func (c *MockClient) RemoveKeys(ctx context.Context, in *pb.RemoveMultiRequest) error { +func (c *MockClient) RemoveKeys(ctx context.Context, in *pb.RemoveKeysRequest) error { c.addCall("RemoveKeys", len(in.Keys)) // TODO: Implement when needed return nil diff --git a/transport/pb/groupcache.pb.go b/transport/pb/groupcache.pb.go index 6988153..69561aa 100644 --- a/transport/pb/groupcache.pb.go +++ b/transport/pb/groupcache.pb.go @@ -22,10 +22,11 @@ package pb import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -224,7 +225,7 @@ func (x *SetRequest) GetExpire() int64 { return 0 } -type RemoveMultiRequest struct { +type RemoveKeysRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields @@ -233,8 +234,8 @@ type RemoveMultiRequest struct { Keys []string `protobuf:"bytes,2,rep,name=keys" json:"keys,omitempty"` } -func (x *RemoveMultiRequest) Reset() { - *x = RemoveMultiRequest{} +func (x *RemoveKeysRequest) Reset() { + *x = RemoveKeysRequest{} if protoimpl.UnsafeEnabled { mi := &file_transport_pb_groupcache_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -242,13 +243,13 @@ func (x *RemoveMultiRequest) Reset() { } } -func (x *RemoveMultiRequest) String() string { +func (x *RemoveKeysRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*RemoveMultiRequest) ProtoMessage() {} +func (*RemoveKeysRequest) ProtoMessage() {} -func (x *RemoveMultiRequest) ProtoReflect() protoreflect.Message { +func (x *RemoveKeysRequest) ProtoReflect() protoreflect.Message { mi := &file_transport_pb_groupcache_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -260,26 +261,26 @@ func (x *RemoveMultiRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use RemoveMultiRequest.ProtoReflect.Descriptor instead. -func (*RemoveMultiRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use RemoveKeysRequest.ProtoReflect.Descriptor instead. +func (*RemoveKeysRequest) Descriptor() ([]byte, []int) { return file_transport_pb_groupcache_proto_rawDescGZIP(), []int{3} } -func (x *RemoveMultiRequest) GetGroup() string { +func (x *RemoveKeysRequest) GetGroup() string { if x != nil && x.Group != nil { return *x.Group } return "" } -func (x *RemoveMultiRequest) GetKeys() []string { +func (x *RemoveKeysRequest) GetKeys() []string { if x != nil { return x.Keys } return nil } -type RemoveMultiResponse struct { +type RemoveKeysResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields @@ -287,8 +288,8 @@ type RemoveMultiResponse struct { FailedKeys []string `protobuf:"bytes,1,rep,name=failed_keys,json=failedKeys" json:"failed_keys,omitempty"` } -func (x *RemoveMultiResponse) Reset() { - *x = RemoveMultiResponse{} +func (x *RemoveKeysResponse) Reset() { + *x = RemoveKeysResponse{} if protoimpl.UnsafeEnabled { mi := &file_transport_pb_groupcache_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -296,13 +297,13 @@ func (x *RemoveMultiResponse) Reset() { } } -func (x *RemoveMultiResponse) String() string { +func (x *RemoveKeysResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*RemoveMultiResponse) ProtoMessage() {} +func (*RemoveKeysResponse) ProtoMessage() {} -func (x *RemoveMultiResponse) ProtoReflect() protoreflect.Message { +func (x *RemoveKeysResponse) ProtoReflect() protoreflect.Message { mi := &file_transport_pb_groupcache_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -314,12 +315,12 @@ func (x *RemoveMultiResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use RemoveMultiResponse.ProtoReflect.Descriptor instead. -func (*RemoveMultiResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use RemoveKeysResponse.ProtoReflect.Descriptor instead. +func (*RemoveKeysResponse) Descriptor() ([]byte, []int) { return file_transport_pb_groupcache_proto_rawDescGZIP(), []int{4} } -func (x *RemoveMultiResponse) GetFailedKeys() []string { +func (x *RemoveKeysResponse) GetFailedKeys() []string { if x != nil { return x.FailedKeys } @@ -377,11 +378,11 @@ func file_transport_pb_groupcache_proto_rawDescGZIP() []byte { var file_transport_pb_groupcache_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_transport_pb_groupcache_proto_goTypes = []interface{}{ - (*GetRequest)(nil), // 0: pb.GetRequest - (*GetResponse)(nil), // 1: pb.GetResponse - (*SetRequest)(nil), // 2: pb.SetRequest - (*RemoveMultiRequest)(nil), // 3: pb.RemoveMultiRequest - (*RemoveMultiResponse)(nil), // 4: pb.RemoveMultiResponse + (*GetRequest)(nil), // 0: pb.GetRequest + (*GetResponse)(nil), // 1: pb.GetResponse + (*SetRequest)(nil), // 2: pb.SetRequest + (*RemoveKeysRequest)(nil), // 3: pb.RemoveKeysRequest + (*RemoveKeysResponse)(nil), // 4: pb.RemoveKeysResponse } var file_transport_pb_groupcache_proto_depIdxs = []int32{ 0, // 0: pb.GroupCache.Get:input_type -> pb.GetRequest @@ -436,7 +437,7 @@ func file_transport_pb_groupcache_proto_init() { } } file_transport_pb_groupcache_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemoveMultiRequest); i { + switch v := v.(*RemoveKeysRequest); i { case 0: return &v.state case 1: @@ -448,7 +449,7 @@ func file_transport_pb_groupcache_proto_init() { } } file_transport_pb_groupcache_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RemoveMultiResponse); i { + switch v := v.(*RemoveKeysResponse); i { case 0: return &v.state case 1: diff --git a/transport/pb/groupcache.proto b/transport/pb/groupcache.proto index c9820f8..e28ee59 100644 --- a/transport/pb/groupcache.proto +++ b/transport/pb/groupcache.proto @@ -37,12 +37,12 @@ message SetRequest { optional int64 expire = 4; } -message RemoveMultiRequest { +message RemoveKeysRequest { required string group = 1; repeated string keys = 2; } -message RemoveMultiResponse { +message RemoveKeysResponse { repeated string failed_keys = 1; } diff --git a/transport/peer/client.go b/transport/peer/client.go index 518e549..8227326 100644 --- a/transport/peer/client.go +++ b/transport/peer/client.go @@ -27,7 +27,7 @@ type Client interface { Get(context context.Context, in *pb.GetRequest, out *pb.GetResponse) error Remove(context context.Context, in *pb.GetRequest) error Set(context context.Context, in *pb.SetRequest) error - RemoveKeys(context context.Context, in *pb.RemoveMultiRequest) error + RemoveKeys(context context.Context, in *pb.RemoveKeysRequest) error PeerInfo() Info HashKey() string } @@ -50,7 +50,7 @@ func (e *NoOpClient) Set(context context.Context, in *pb.SetRequest) error { return nil } -func (e *NoOpClient) RemoveKeys(context context.Context, in *pb.RemoveMultiRequest) error { +func (e *NoOpClient) RemoveKeys(context context.Context, in *pb.RemoveKeysRequest) error { return nil }