55import threading
66import time
77from abc import ABC , abstractmethod
8- from typing import TYPE_CHECKING , List , Literal , Optional , Union
8+ from typing import TYPE_CHECKING , Dict , List , Literal , Optional , Union
99
1010from redis .typing import Number
1111
@@ -463,31 +463,26 @@ class OSSNodeMigratedNotification(MaintenanceNotification):
463463
464464 Args:
465465 id (int): Unique identifier for this notification
466- node_address (Optional[str]): Address of the node that has completed migration
467- in the format "host:port"
468- slots (Optional[List[int]]): List of slots that have been migrated
466+ nodes_to_slots_mapping (Dict[str, str]): Mapping of node addresses to slots
469467 """
470468
471469 DEFAULT_TTL = 30
472470
473471 def __init__ (
474472 self ,
475473 id : int ,
476- node_address : str ,
477- slots : Optional [List [int ]] = None ,
474+ nodes_to_slots_mapping : Dict [str , str ],
478475 ):
479476 super ().__init__ (id , OSSNodeMigratedNotification .DEFAULT_TTL )
480- self .node_address = node_address
481- self .slots = slots
477+ self .nodes_to_slots_mapping = nodes_to_slots_mapping
482478
483479 def __repr__ (self ) -> str :
484480 expiry_time = self .creation_time + self .ttl
485481 remaining = max (0 , expiry_time - time .monotonic ())
486482 return (
487483 f"{ self .__class__ .__name__ } ("
488484 f"id={ self .id } , "
489- f"node_address={ self .node_address } , "
490- f"slots={ self .slots } , "
485+ f"nodes_to_slots_mapping={ self .nodes_to_slots_mapping } , "
491486 f"ttl={ self .ttl } , "
492487 f"creation_time={ self .creation_time } , "
493488 f"expires_at={ expiry_time } , "
@@ -899,12 +894,14 @@ def handle_notification(self, notification: MaintenanceNotification):
899894 return
900895
901896 if notification_type :
902- self .handle_maintenance_start_notification (MaintenanceState .MAINTENANCE )
897+ self .handle_maintenance_start_notification (
898+ MaintenanceState .MAINTENANCE , notification
899+ )
903900 else :
904901 self .handle_maintenance_completed_notification ()
905902
906903 def handle_maintenance_start_notification (
907- self , maintenance_state : MaintenanceState
904+ self , maintenance_state : MaintenanceState , notification : MaintenanceNotification
908905 ):
909906 if (
910907 self .connection .maintenance_state == MaintenanceState .MOVING
@@ -918,6 +915,11 @@ def handle_maintenance_start_notification(
918915 )
919916 # extend the timeout for all created connections
920917 self .connection .update_current_socket_timeout (self .config .relaxed_timeout )
918+ if isinstance (notification , OSSNodeMigratingNotification ):
919+ # add the notification id to the set of processed start maint notifications
920+ # this is used to skip the unrelaxing of the timeouts if we have received more than
921+ # one start notification before the the final end notification
922+ self .connection .add_maint_start_notification (notification .id )
921923
922924 def handle_maintenance_completed_notification (self ):
923925 # Only reset timeouts if state is not MOVING and relaxed timeouts are enabled
@@ -931,6 +933,9 @@ def handle_maintenance_completed_notification(self):
931933 # timeouts by providing -1 as the relaxed timeout
932934 self .connection .update_current_socket_timeout (- 1 )
933935 self .connection .maintenance_state = MaintenanceState .NONE
936+ # reset the sets that keep track of received start maint
937+ # notifications and skipped end maint notifications
938+ self .connection .reset_received_notifications ()
934939
935940
936941class OSSMaintNotificationsHandler :
@@ -999,40 +1004,55 @@ def handle_oss_maintenance_completed_notification(
9991004
10001005 # Updates the cluster slots cache with the new slots mapping
10011006 # This will also update the nodes cache with the new nodes mapping
1002- new_node_host , new_node_port = notification .node_address .split (":" )
1007+ additional_startup_nodes_info = []
1008+ for node_address , _ in notification .nodes_to_slots_mapping .items ():
1009+ new_node_host , new_node_port = node_address .split (":" )
1010+ additional_startup_nodes_info .append (
1011+ (new_node_host , int (new_node_port ))
1012+ )
10031013 self .cluster_client .nodes_manager .initialize (
10041014 disconnect_startup_nodes_pools = False ,
1005- additional_startup_nodes_info = [( new_node_host , int ( new_node_port ))] ,
1015+ additional_startup_nodes_info = additional_startup_nodes_info ,
10061016 )
1007- # mark for reconnect all in use connections to the node - this will force them to
1008- # disconnect after they complete their current commands
1009- # Some of them might be used by sub sub and we don't know which ones - so we disconnect
1010- # all in flight connections after they are done with current command execution
1011- for conn in (
1012- current_node .redis_connection .connection_pool ._get_in_use_connections ()
1013- ):
1014- conn .mark_for_reconnect ()
1017+ with current_node .redis_connection .connection_pool ._lock :
1018+ # mark for reconnect all in use connections to the node - this will force them to
1019+ # disconnect after they complete their current commands
1020+ # Some of them might be used by sub sub and we don't know which ones - so we disconnect
1021+ # all in flight connections after they are done with current command execution
1022+ for conn in current_node .redis_connection .connection_pool ._get_in_use_connections ():
1023+ conn .mark_for_reconnect ()
10151024
1016- if (
1017- current_node
1018- not in self .cluster_client .nodes_manager .nodes_cache .values ()
1019- ):
1020- # disconnect all free connections to the node - this node will be dropped
1021- # from the cluster, so we don't need to revert the timeouts
1022- for conn in current_node .redis_connection .connection_pool ._get_free_connections ():
1023- conn .disconnect ()
1024- else :
1025- if self .config .is_relaxed_timeouts_enabled ():
1026- # reset the timeouts for the node to which the connection is connected
1027- # TODO: add check if other maintenance ops are in progress for the same node - CAE-1038
1028- # and if so, don't reset the timeouts
1029- for conn in (
1030- * current_node .redis_connection .connection_pool ._get_in_use_connections (),
1031- * current_node .redis_connection .connection_pool ._get_free_connections (),
1032- ):
1033- conn .reset_tmp_settings (reset_relaxed_timeout = True )
1034- conn .update_current_socket_timeout (relaxed_timeout = - 1 )
1035- conn .maintenance_state = MaintenanceState .NONE
1025+ if (
1026+ current_node
1027+ not in self .cluster_client .nodes_manager .nodes_cache .values ()
1028+ ):
1029+ # disconnect all free connections to the node - this node will be dropped
1030+ # from the cluster, so we don't need to revert the timeouts
1031+ for conn in current_node .redis_connection .connection_pool ._get_free_connections ():
1032+ conn .disconnect ()
1033+ else :
1034+ if self .config .is_relaxed_timeouts_enabled ():
1035+ # reset the timeouts for the node to which the connection is connected
1036+ # Perform check if other maintenance ops are in progress for the same node
1037+ # and if so, don't reset the timeouts and wait for the last maintenance
1038+ # to complete
1039+ for conn in (
1040+ * current_node .redis_connection .connection_pool ._get_in_use_connections (),
1041+ * current_node .redis_connection .connection_pool ._get_free_connections (),
1042+ ):
1043+ if (
1044+ len (conn .get_processed_start_notifications ())
1045+ > len (conn .get_skipped_end_notifications ()) + 1
1046+ ):
1047+ # we have received more start notifications than end notifications
1048+ # for this connection - we should not reset the timeouts
1049+ # and add the notification id to the set of skipped end notifications
1050+ conn .add_skipped_end_notification (notification .id )
1051+ else :
1052+ conn .reset_tmp_settings (reset_relaxed_timeout = True )
1053+ conn .update_current_socket_timeout (relaxed_timeout = - 1 )
1054+ conn .maintenance_state = MaintenanceState .NONE
1055+ conn .reset_received_notifications ()
10361056
10371057 # mark the notification as processed
10381058 self ._processed_notifications .add (notification )
0 commit comments