Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 33 additions & 105 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ def default_retry_policy(self, policy):
Using ssl_options without ssl_context is deprecated and will be removed in the
next major release.

An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket``
when new sockets are created. This should be used when client encryption is enabled
An optional dict which will be used as kwargs for ``ssl.SSLContext.wrap_socket``
when new sockets are created. This should be used when client encryption is enabled
in Cassandra.

The following documentation only applies when ssl_options is used without ssl_context.
Expand Down Expand Up @@ -1086,10 +1086,10 @@ def default_retry_policy(self, policy):
"""
Specifies a server-side timeout (in seconds) for all internal driver queries,
such as schema metadata lookups and cluster topology requests.

The timeout is enforced by appending `USING TIMEOUT <timeout>` to queries
executed by the driver.

- A value of `0` disables explicit timeout enforcement. In this case,
the driver does not add `USING TIMEOUT`, and the timeout is determined
by the server's defaults.
Expand Down Expand Up @@ -1683,14 +1683,7 @@ def protocol_downgrade(self, host_endpoint, previous_version):
"http://datastax.github.io/python-driver/api/cassandra/cluster.html#cassandra.cluster.Cluster.protocol_version", self.protocol_version, new_version, host_endpoint)
self.protocol_version = new_version

def _add_resolved_hosts(self):
for endpoint in self.endpoints_resolved:
host, new = self.add_host(endpoint, signal=False)
if new:
host.set_up()
for listener in self.listeners:
listener.on_add(host)

def _populate_hosts(self):
self.profile_manager.populate(
weakref.proxy(self), self.metadata.all_hosts())
self.load_balancing_policy.populate(
Expand All @@ -1717,17 +1710,10 @@ def connect(self, keyspace=None, wait_for_all_pools=False):
self.contact_points, self.protocol_version)
self.connection_class.initialize_reactor()
_register_cluster_shutdown(self)

self._add_resolved_hosts()

try:
self.control_connection.connect()

# we set all contact points up for connecting, but we won't infer state after this
for endpoint in self.endpoints_resolved:
h = self.metadata.get_host(endpoint)
if h and self.profile_manager.distance(h) == HostDistance.IGNORED:
h.is_up = None
self._populate_hosts()

log.debug("Control connection created")
except Exception:
Expand Down Expand Up @@ -2016,14 +2002,14 @@ def on_add(self, host, refresh_nodes=True):

log.debug("Handling new host %r and notifying listeners", host)

self.profile_manager.on_add(host)
self.control_connection.on_add(host, refresh_nodes)

distance = self.profile_manager.distance(host)
if distance != HostDistance.IGNORED:
self._prepare_all_queries(host)
log.debug("Done preparing queries for new host %r", host)

self.profile_manager.on_add(host)
self.control_connection.on_add(host, refresh_nodes)

if distance == HostDistance.IGNORED:
log.debug("Not adding connection pool for new host %r because the "
"load balancing policy has marked it as IGNORED", host)
Expand Down Expand Up @@ -3534,24 +3520,22 @@ def _set_new_connection(self, conn):
if old:
log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn)
old.close()
def _connect_host_in_lbp(self):

def _try_connect_to_hosts(self):
errors = {}
lbp = (
self._cluster.load_balancing_policy
if self._cluster._config_mode == _ConfigMode.LEGACY else
self._cluster._default_load_balancing_policy
)

for host in lbp.make_query_plan():
lbp = self._cluster.load_balancing_policy \
if self._cluster._config_mode == _ConfigMode.LEGACY else self._cluster._default_load_balancing_policy

for endpoint in chain((host.endpoint for host in lbp.make_query_plan()), self._cluster.endpoints_resolved):
try:
return (self._try_connect(host), None)
return (self._try_connect(endpoint), None)
except Exception as exc:
errors[str(host.endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", host, exc_info=True)
errors[str(endpoint)] = exc
log.warning("[control connection] Error connecting to %s:", endpoint, exc_info=True)
if self._is_shutdown:
raise DriverException("[control connection] Reconnection in progress during shutdown")

return (None, errors)

def _reconnect_internal(self):
Expand All @@ -3563,43 +3547,43 @@ def _reconnect_internal(self):
to the exception that was raised when an attempt was made to open
a connection to that host.
"""
(conn, _) = self._connect_host_in_lbp()
(conn, _) = self._try_connect_to_hosts()
if conn is not None:
return conn

# Try to re-resolve hostnames as a fallback when all hosts are unreachable
self._cluster._resolve_hostnames()

self._cluster._add_resolved_hosts()
self._cluster._populate_hosts()

(conn, errors) = self._connect_host_in_lbp()
(conn, errors) = self._try_connect_to_hosts()
if conn is not None:
return conn

raise NoHostAvailable("Unable to connect to any servers", errors)

def _try_connect(self, host):
def _try_connect(self, endpoint):
"""
Creates a new Connection, registers for pushed events, and refreshes
node/token and schema metadata.
"""
log.debug("[control connection] Opening new connection to %s", host)
log.debug("[control connection] Opening new connection to %s", endpoint)

while True:
try:
connection = self._cluster.connection_factory(host.endpoint, is_control_connection=True)
connection = self._cluster.connection_factory(endpoint, is_control_connection=True)
if self._is_shutdown:
connection.close()
raise DriverException("Reconnecting during shutdown")
break
except ProtocolVersionUnsupported as e:
self._cluster.protocol_downgrade(host.endpoint, e.startup_version)
self._cluster.protocol_downgrade(endpoint, e.startup_version)
except ProtocolException as e:
# protocol v5 is out of beta in C* >=4.0-beta5 and is now the default driver
# protocol version. If the protocol version was not explicitly specified,
# and that the server raises a beta protocol error, we should downgrade.
if not self._cluster._protocol_version_explicit and e.is_beta_protocol_error:
self._cluster.protocol_downgrade(host.endpoint, self._cluster.protocol_version)
self._cluster.protocol_downgrade(endpoint, self._cluster.protocol_version)
else:
raise

Expand Down Expand Up @@ -3814,67 +3798,10 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None,
self._cluster.metadata.cluster_name = cluster_name

partitioner = local_row.get("partitioner")
tokens = local_row.get("tokens")

host = self._cluster.metadata.get_host(connection.original_endpoint)
if host:
datacenter = local_row.get("data_center")
rack = local_row.get("rack")
self._update_location_info(host, datacenter, rack)

# support the use case of connecting only with public address
if isinstance(self._cluster.endpoint_factory, SniEndPointFactory):
new_endpoint = self._cluster.endpoint_factory.create(local_row)

if new_endpoint.address:
host.endpoint = new_endpoint

host.host_id = local_row.get("host_id")

found_host_ids.add(host.host_id)
found_endpoints.add(host.endpoint)

host.listen_address = local_row.get("listen_address")
host.listen_port = local_row.get("listen_port")
host.broadcast_address = _NodeInfo.get_broadcast_address(local_row)
host.broadcast_port = _NodeInfo.get_broadcast_port(local_row)

host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(local_row)
host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(local_row)
if host.broadcast_rpc_address is None:
if self._token_meta_enabled:
# local rpc_address is not available, use the connection endpoint
host.broadcast_rpc_address = connection.endpoint.address
host.broadcast_rpc_port = connection.endpoint.port
else:
# local rpc_address has not been queried yet, try to fetch it
# separately, which might fail because C* < 2.1.6 doesn't have rpc_address
# in system.local. See CASSANDRA-9436.
local_rpc_address_query = QueryMessage(
query=maybe_add_timeout_to_query(self._SELECT_LOCAL_NO_TOKENS_RPC_ADDRESS, self._metadata_request_timeout),
consistency_level=ConsistencyLevel.ONE)
success, local_rpc_address_result = connection.wait_for_response(
local_rpc_address_query, timeout=self._timeout, fail_on_error=False)
if success:
row = dict_factory(
local_rpc_address_result.column_names,
local_rpc_address_result.parsed_rows)
host.broadcast_rpc_address = _NodeInfo.get_broadcast_rpc_address(row[0])
host.broadcast_rpc_port = _NodeInfo.get_broadcast_rpc_port(row[0])
else:
host.broadcast_rpc_address = connection.endpoint.address
host.broadcast_rpc_port = connection.endpoint.port

host.release_version = local_row.get("release_version")
host.dse_version = local_row.get("dse_version")
host.dse_workload = local_row.get("workload")
host.dse_workloads = local_row.get("workloads")
tokens = local_row.get("tokens", None)

if partitioner and tokens:
token_map[host] = tokens
peers_result.insert(0, local_row)

self._cluster.metadata.update_host(host, old_endpoint=connection.endpoint)
connection.original_endpoint = connection.endpoint = host.endpoint
# Check metadata.partitioner to see if we haven't built anything yet. If
# every node in the cluster was in the contact points, we won't discover
# any new nodes, so we need this additional check. (See PYTHON-90)
Expand Down Expand Up @@ -4173,8 +4100,9 @@ def _get_peers_query(self, peers_query_type, connection=None):
query_template = (self._SELECT_SCHEMA_PEERS_TEMPLATE
if peers_query_type == self.PeersQueryType.PEERS_SCHEMA
else self._SELECT_PEERS_NO_TOKENS_TEMPLATE)
host_release_version = self._cluster.metadata.get_host(connection.original_endpoint).release_version
host_dse_version = self._cluster.metadata.get_host(connection.original_endpoint).dse_version
original_endpoint_host = self._cluster.metadata.get_host(connection.original_endpoint)
host_release_version = None if original_endpoint_host is None else original_endpoint_host.release_version
host_dse_version = None if original_endpoint_host is None else original_endpoint_host.dse_version
uses_native_address_query = (
host_dse_version and Version(host_dse_version) >= self._MINIMUM_NATIVE_ADDRESS_DSE_VERSION)

Expand Down
2 changes: 1 addition & 1 deletion cassandra/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3481,7 +3481,7 @@ def group_keys_by_replica(session, keyspace, table, keys):
:class:`~.NO_VALID_REPLICA`
Example usage::
>>> result = group_keys_by_replica(
... session, "system", "peers",
... (("127.0.0.1", ), ("127.0.0.2", )))
Expand Down
18 changes: 5 additions & 13 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def __init__(self, local_dc='', used_hosts_per_remote_dc=0):
self.used_hosts_per_remote_dc = used_hosts_per_remote_dc
self._dc_live_hosts = {}
self._position = 0
self._endpoints = []
LoadBalancingPolicy.__init__(self)

def _dc(self, host):
Expand All @@ -255,11 +254,6 @@ def populate(self, cluster, hosts):
for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)):
self._dc_live_hosts[dc] = tuple({*dc_hosts, *self._dc_live_hosts.get(dc, [])})

if not self.local_dc:
self._endpoints = [
endpoint
for endpoint in cluster.endpoints_resolved]

self._position = randint(0, len(hosts) - 1) if hosts else 0

def distance(self, host):
Expand Down Expand Up @@ -301,13 +295,11 @@ def on_up(self, host):
# not worrying about threads because this will happen during
# control connection startup/refresh
if not self.local_dc and host.datacenter:
if host.endpoint in self._endpoints:
self.local_dc = host.datacenter
log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
"if incorrect, please specify a local_dc to the constructor, "
"or limit contact points to local cluster nodes" %
(self.local_dc, host.endpoint))
del self._endpoints
self.local_dc = host.datacenter
log.info("Using datacenter '%s' for DCAwareRoundRobinPolicy (via host '%s'); "
"if incorrect, please specify a local_dc to the constructor, "
"or limit contact points to local cluster nodes" %
(self.local_dc, host.endpoint))

dc = self._dc(host)
with self._hosts_lock:
Expand Down
2 changes: 1 addition & 1 deletion cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=No
self.endpoint = endpoint if isinstance(endpoint, EndPoint) else DefaultEndPoint(endpoint)
self.conviction_policy = conviction_policy_factory(self)
if not host_id:
host_id = uuid.uuid4()
raise ValueError("host_id may not be None")
self.host_id = host_id
Comment on lines 177 to 180

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commit: "Don't create Host instances with random host_id"

The change here is the one that the commit message explains. Perhaps the chain((host.endpoint for host in lbp.make_query_plan()), self._cluster.endpoints_resolved) line is also explained. Other changes are not explained, and are not at all obvious to me.

When writing commits, please assume that a reader won't be as familiar with the relevant code as you are. It is almost always true - even if reviewer is an active maintainer, there is high chance they did not work with this specific area recently.

self.set_location_info(datacenter, rack)
self.lock = RLock()
Expand Down
36 changes: 24 additions & 12 deletions tests/integration/standard/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,9 @@ def test_profile_lb_swap(self):
"""
Tests that profile load balancing policies are not shared

Creates two LBP, runs a few queries, and validates that each LBP is execised
seperately between EP's
Creates two LBP, runs a few queries, and validates that each LBP is exercised
separately between EP's. Each RoundRobinPolicy starts from its own random
position and maintains independent round-robin ordering.

@since 3.5
@jira_ticket PYTHON-569
Expand All @@ -916,17 +917,28 @@ def test_profile_lb_swap(self):
with TestCluster(execution_profiles=exec_profiles) as cluster:
session = cluster.connect(wait_for_all_pools=True)

# default is DCA RR for all hosts
expected_hosts = set(cluster.metadata.all_hosts())
rr1_queried_hosts = set()
rr2_queried_hosts = set()

rs = session.execute(query, execution_profile='rr1')
rr1_queried_hosts.add(rs.response_future._current_host)
rs = session.execute(query, execution_profile='rr2')
rr2_queried_hosts.add(rs.response_future._current_host)

assert rr2_queried_hosts == rr1_queried_hosts
num_hosts = len(expected_hosts)
assert num_hosts > 1, "Need at least 2 hosts for this test"

rr1_queried_hosts = []
rr2_queried_hosts = []

for _ in range(num_hosts * 2):
rs = session.execute(query, execution_profile='rr1')
rr1_queried_hosts.append(rs.response_future._current_host)
rs = session.execute(query, execution_profile='rr2')
rr2_queried_hosts.append(rs.response_future._current_host)

# Both policies should have queried all hosts
assert set(rr1_queried_hosts) == expected_hosts
assert set(rr2_queried_hosts) == expected_hosts

# The order of hosts should demonstrate round-robin behavior
# After num_hosts queries, the pattern should repeat
for i in range(num_hosts):
assert rr1_queried_hosts[i] == rr1_queried_hosts[i + num_hosts]
assert rr2_queried_hosts[i] == rr2_queried_hosts[i + num_hosts]

def test_ta_lbp(self):
"""
Expand Down
8 changes: 6 additions & 2 deletions tests/integration/standard/test_control_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ def test_get_control_connection_host(self):

# reconnect and make sure that the new host is reflected correctly
self.cluster.control_connection._reconnect()
new_host = self.cluster.get_control_connection_host()
assert host != new_host
new_host1 = self.cluster.get_control_connection_host()

self.cluster.control_connection._reconnect()
new_host2 = self.cluster.get_control_connection_host()

assert new_host1 != new_host2

# TODO: enable after https://github.com/scylladb/python-driver/issues/121 is fixed
@unittest.skip('Fails on scylla due to the broadcast_rpc_port is None')
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/standard/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def test_metrics_per_cluster(self):
try:
# Test write
query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL)
with pytest.raises(WriteTimeout):
with pytest.raises((WriteTimeout, Unavailable)):
self.session.execute(query, timeout=None)
finally:
get_node(1).resume()
Expand All @@ -230,7 +230,7 @@ def test_metrics_per_cluster(self):
stats_cluster2 = cluster2.metrics.get_stats()

# Test direct access to stats
assert 1 == self.cluster.metrics.stats.write_timeouts
assert (1 == self.cluster.metrics.stats.write_timeouts or 1 == self.cluster.metrics.stats.unavailables)
assert 0 == cluster2.metrics.stats.write_timeouts

# Test direct access to a child stats
Expand Down
Loading
Loading