Skip to content

Commit d5d6a80

Browse files
committed
Applying the changes in the SMIGRATED notification format (#3868)
* Refactoring the logic related to SMIGRATED notification format. Applying the new format. * Add handling for parallel slot migrations with maintenance notifications flow for OSS Cluster API (#3869) * Adding handling of parallel slot migrations when OSS cluster api is used * Applying review comments
1 parent 10522fc commit d5d6a80

File tree

6 files changed

+278
-148
lines changed

6 files changed

+278
-148
lines changed

redis/_parsers/base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,14 @@ def parse_oss_maintenance_start_msg(response):
193193
@staticmethod
194194
def parse_oss_maintenance_completed_msg(response):
195195
# Expected message format is:
196-
# SMIGRATED <seq_number> <host:port> <slot, range1-range2,...>
196+
# SMIGRATED <seq_number> [<host:port> <slot, range1-range2,...>, ...]
197197
id = response[1]
198-
node_address = safe_str(response[2])
199-
slots = response[3]
198+
nodes_to_slots_mapping_data = response[2]
199+
nodes_to_slots_mapping = {}
200+
for node, slots in nodes_to_slots_mapping_data:
201+
nodes_to_slots_mapping[safe_str(node)] = safe_str(slots)
200202

201-
return OSSNodeMigratedNotification(id, node_address, slots)
203+
return OSSNodeMigratedNotification(id, nodes_to_slots_mapping)
202204

203205
@staticmethod
204206
def parse_maintenance_start_msg(response, notification_type):

redis/connection.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ def __init__(
319319
oss_cluster_maint_notifications_handler,
320320
parser,
321321
)
322+
self._processed_start_maint_notifications = set()
323+
self._skipped_end_maint_notifications = set()
322324

323325
@abstractmethod
324326
def _get_parser(self) -> Union[_HiredisParser, _RESP3Parser]:
@@ -667,6 +669,22 @@ def maintenance_state(self) -> MaintenanceState:
667669
def maintenance_state(self, state: "MaintenanceState"):
668670
self._maintenance_state = state
669671

672+
def add_maint_start_notification(self, id: int):
673+
self._processed_start_maint_notifications.add(id)
674+
675+
def get_processed_start_notifications(self) -> set:
676+
return self._processed_start_maint_notifications
677+
678+
def add_skipped_end_notification(self, id: int):
679+
self._skipped_end_maint_notifications.add(id)
680+
681+
def get_skipped_end_notifications(self) -> set:
682+
return self._skipped_end_maint_notifications
683+
684+
def reset_received_notifications(self):
685+
self._processed_start_maint_notifications.clear()
686+
self._skipped_end_maint_notifications.clear()
687+
670688
def getpeername(self):
671689
"""
672690
Returns the peer name of the connection.

redis/maint_notifications.py

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import threading
66
import time
77
from abc import ABC, abstractmethod
8-
from typing import TYPE_CHECKING, List, Literal, Optional, Union
8+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
99

1010
from redis.typing import Number
1111

@@ -463,31 +463,26 @@ class OSSNodeMigratedNotification(MaintenanceNotification):
463463
464464
Args:
465465
id (int): Unique identifier for this notification
466-
node_address (Optional[str]): Address of the node that has completed migration
467-
in the format "host:port"
468-
slots (Optional[List[int]]): List of slots that have been migrated
466+
nodes_to_slots_mapping (Dict[str, str]): Mapping of node addresses to slots
469467
"""
470468

471469
DEFAULT_TTL = 30
472470

473471
def __init__(
474472
self,
475473
id: int,
476-
node_address: str,
477-
slots: Optional[List[int]] = None,
474+
nodes_to_slots_mapping: Dict[str, str],
478475
):
479476
super().__init__(id, OSSNodeMigratedNotification.DEFAULT_TTL)
480-
self.node_address = node_address
481-
self.slots = slots
477+
self.nodes_to_slots_mapping = nodes_to_slots_mapping
482478

483479
def __repr__(self) -> str:
484480
expiry_time = self.creation_time + self.ttl
485481
remaining = max(0, expiry_time - time.monotonic())
486482
return (
487483
f"{self.__class__.__name__}("
488484
f"id={self.id}, "
489-
f"node_address={self.node_address}, "
490-
f"slots={self.slots}, "
485+
f"nodes_to_slots_mapping={self.nodes_to_slots_mapping}, "
491486
f"ttl={self.ttl}, "
492487
f"creation_time={self.creation_time}, "
493488
f"expires_at={expiry_time}, "
@@ -899,12 +894,14 @@ def handle_notification(self, notification: MaintenanceNotification):
899894
return
900895

901896
if notification_type:
902-
self.handle_maintenance_start_notification(MaintenanceState.MAINTENANCE)
897+
self.handle_maintenance_start_notification(
898+
MaintenanceState.MAINTENANCE, notification
899+
)
903900
else:
904901
self.handle_maintenance_completed_notification()
905902

906903
def handle_maintenance_start_notification(
907-
self, maintenance_state: MaintenanceState
904+
self, maintenance_state: MaintenanceState, notification: MaintenanceNotification
908905
):
909906
if (
910907
self.connection.maintenance_state == MaintenanceState.MOVING
@@ -918,6 +915,11 @@ def handle_maintenance_start_notification(
918915
)
919916
# extend the timeout for all created connections
920917
self.connection.update_current_socket_timeout(self.config.relaxed_timeout)
918+
if isinstance(notification, OSSNodeMigratingNotification):
919+
# add the notification id to the set of processed start maint notifications
920+
# this is used to skip the unrelaxing of the timeouts if we have received more than
921+
# one start notification before the the final end notification
922+
self.connection.add_maint_start_notification(notification.id)
921923

922924
def handle_maintenance_completed_notification(self):
923925
# Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
@@ -931,6 +933,9 @@ def handle_maintenance_completed_notification(self):
931933
# timeouts by providing -1 as the relaxed timeout
932934
self.connection.update_current_socket_timeout(-1)
933935
self.connection.maintenance_state = MaintenanceState.NONE
936+
# reset the sets that keep track of received start maint
937+
# notifications and skipped end maint notifications
938+
self.connection.reset_received_notifications()
934939

935940

936941
class OSSMaintNotificationsHandler:
@@ -999,40 +1004,55 @@ def handle_oss_maintenance_completed_notification(
9991004

10001005
# Updates the cluster slots cache with the new slots mapping
10011006
# This will also update the nodes cache with the new nodes mapping
1002-
new_node_host, new_node_port = notification.node_address.split(":")
1007+
additional_startup_nodes_info = []
1008+
for node_address, _ in notification.nodes_to_slots_mapping.items():
1009+
new_node_host, new_node_port = node_address.split(":")
1010+
additional_startup_nodes_info.append(
1011+
(new_node_host, int(new_node_port))
1012+
)
10031013
self.cluster_client.nodes_manager.initialize(
10041014
disconnect_startup_nodes_pools=False,
1005-
additional_startup_nodes_info=[(new_node_host, int(new_node_port))],
1015+
additional_startup_nodes_info=additional_startup_nodes_info,
10061016
)
1007-
# mark for reconnect all in use connections to the node - this will force them to
1008-
# disconnect after they complete their current commands
1009-
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
1010-
# all in flight connections after they are done with current command execution
1011-
for conn in (
1012-
current_node.redis_connection.connection_pool._get_in_use_connections()
1013-
):
1014-
conn.mark_for_reconnect()
1017+
with current_node.redis_connection.connection_pool._lock:
1018+
# mark for reconnect all in use connections to the node - this will force them to
1019+
# disconnect after they complete their current commands
1020+
# Some of them might be used by sub sub and we don't know which ones - so we disconnect
1021+
# all in flight connections after they are done with current command execution
1022+
for conn in current_node.redis_connection.connection_pool._get_in_use_connections():
1023+
conn.mark_for_reconnect()
10151024

1016-
if (
1017-
current_node
1018-
not in self.cluster_client.nodes_manager.nodes_cache.values()
1019-
):
1020-
# disconnect all free connections to the node - this node will be dropped
1021-
# from the cluster, so we don't need to revert the timeouts
1022-
for conn in current_node.redis_connection.connection_pool._get_free_connections():
1023-
conn.disconnect()
1024-
else:
1025-
if self.config.is_relaxed_timeouts_enabled():
1026-
# reset the timeouts for the node to which the connection is connected
1027-
# TODO: add check if other maintenance ops are in progress for the same node - CAE-1038
1028-
# and if so, don't reset the timeouts
1029-
for conn in (
1030-
*current_node.redis_connection.connection_pool._get_in_use_connections(),
1031-
*current_node.redis_connection.connection_pool._get_free_connections(),
1032-
):
1033-
conn.reset_tmp_settings(reset_relaxed_timeout=True)
1034-
conn.update_current_socket_timeout(relaxed_timeout=-1)
1035-
conn.maintenance_state = MaintenanceState.NONE
1025+
if (
1026+
current_node
1027+
not in self.cluster_client.nodes_manager.nodes_cache.values()
1028+
):
1029+
# disconnect all free connections to the node - this node will be dropped
1030+
# from the cluster, so we don't need to revert the timeouts
1031+
for conn in current_node.redis_connection.connection_pool._get_free_connections():
1032+
conn.disconnect()
1033+
else:
1034+
if self.config.is_relaxed_timeouts_enabled():
1035+
# reset the timeouts for the node to which the connection is connected
1036+
# Perform check if other maintenance ops are in progress for the same node
1037+
# and if so, don't reset the timeouts and wait for the last maintenance
1038+
# to complete
1039+
for conn in (
1040+
*current_node.redis_connection.connection_pool._get_in_use_connections(),
1041+
*current_node.redis_connection.connection_pool._get_free_connections(),
1042+
):
1043+
if (
1044+
len(conn.get_processed_start_notifications())
1045+
> len(conn.get_skipped_end_notifications()) + 1
1046+
):
1047+
# we have received more start notifications than end notifications
1048+
# for this connection - we should not reset the timeouts
1049+
# and add the notification id to the set of skipped end notifications
1050+
conn.add_skipped_end_notification(notification.id)
1051+
else:
1052+
conn.reset_tmp_settings(reset_relaxed_timeout=True)
1053+
conn.update_current_socket_timeout(relaxed_timeout=-1)
1054+
conn.maintenance_state = MaintenanceState.NONE
1055+
conn.reset_received_notifications()
10361056

10371057
# mark the notification as processed
10381058
self._processed_notifications.add(notification)

tests/maint_notifications/proxy_server_helpers.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,51 @@ class RespTranslator:
1111
"""Helper class to translate between RESP and other encodings."""
1212

1313
@staticmethod
14-
def str_or_list_to_resp(txt: str) -> str:
15-
"""
16-
Convert specific string or list to RESP format.
17-
"""
18-
if re.match(r"^<.*>$", txt):
19-
items = txt[1:-1].split(",")
20-
return f"*{len(items)}\r\n" + "\r\n".join(
21-
f"${len(x)}\r\n{x}" for x in items
14+
def oss_maint_notification_to_resp(txt: str) -> str:
15+
"""Convert query to RESP format."""
16+
if txt.startswith("SMIGRATED"):
17+
# Format: SMIGRATED SeqID host:port slot1,range1-range2 host1:port1 slot2,range3-range4
18+
# SMIGRATED 93923 abc.com:6789 123,789-1000 abc.com:4545 1000-2000 abc.com:4323 900,910,920
19+
# SMIGRATED - simple string
20+
# SeqID - integer
21+
# host and slots info are provided as array of arrays
22+
# host:port - simple string
23+
# slots - simple string
24+
25+
parts = txt.split()
26+
notification = parts[0]
27+
seq_id = parts[1]
28+
hosts_and_slots = parts[2:]
29+
resp = (
30+
">3\r\n" # Push message with 3 elements
31+
f"+{notification}\r\n" # Element 1: Command
32+
f":{seq_id}\r\n" # Element 2: SeqID
33+
f"*{len(hosts_and_slots) // 2}\r\n" # Element 3: Array of host:port, slots pairs
2234
)
35+
for i in range(0, len(hosts_and_slots), 2):
36+
resp += "*2\r\n"
37+
resp += f"+{hosts_and_slots[i]}\r\n"
38+
resp += f"+{hosts_and_slots[i + 1]}\r\n"
2339
else:
24-
return f"${len(txt)}\r\n{txt}"
25-
26-
@staticmethod
27-
def cluster_slots_to_resp(resp: str) -> str:
28-
"""Convert query to RESP format."""
29-
return (
30-
f"*{len(resp.split())}\r\n"
31-
+ "\r\n".join(f"${len(x)}\r\n{x}" for x in resp.split())
32-
+ "\r\n"
33-
)
34-
35-
@staticmethod
36-
def oss_maint_notification_to_resp(resp: str) -> str:
37-
"""Convert query to RESP format."""
38-
return (
39-
f">{len(resp.split())}\r\n"
40-
+ "\r\n".join(
41-
f"{RespTranslator.str_or_list_to_resp(x)}" for x in resp.split()
40+
# SMIGRATING
41+
# Format: SMIGRATING SeqID slot,range1-range2
42+
# SMIGRATING 93923 123,789-1000
43+
# SMIGRATING - simple string
44+
# SeqID - integer
45+
# slots - simple string
46+
47+
parts = txt.split()
48+
notification = parts[0]
49+
seq_id = parts[1]
50+
slots = parts[2]
51+
52+
resp = (
53+
">3\r\n" # Push message with 3 elements
54+
f"+{notification}\r\n" # Element 1: Command
55+
f":{seq_id}\r\n" # Element 2: SeqID
56+
f"+{slots}\r\n" # Element 3: Array of [host:port, slots] pairs
4257
)
43-
+ "\r\n"
44-
)
58+
return resp
4559

4660

4761
@dataclass

0 commit comments

Comments
 (0)