Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/19310.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add cancel_task API to the task scheduler.
72 changes: 70 additions & 2 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,16 @@
SynapseError,
)
from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
from synapse.types import JsonDict, JsonValue, Requester, UserID, create_requester
from synapse.types import (
JsonDict,
JsonMapping,
JsonValue,
Requester,
ScheduledTask,
TaskStatus,
UserID,
create_requester,
)
from synapse.util.caches.descriptors import cached
from synapse.util.duration import Duration
from synapse.util.stringutils import parse_and_validate_mxc_uri
Expand All @@ -46,6 +55,8 @@
MAX_AVATAR_URL_LEN = 1000
# Field name length is specced at 255 bytes.
MAX_CUSTOM_FIELD_LEN = 255
UPDATE_JOIN_STATES_ACTION_NAME = "update_join_states"
UPDATE_JOIN_STATES_LOCK_NAME = "update_join_states_lock"


class ProfileHandler:
Expand Down Expand Up @@ -78,6 +89,12 @@ def __init__(self, hs: "HomeServer"):

self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules

self._task_scheduler = hs.get_task_scheduler()
self._task_scheduler.register_action(
self._update_join_states_task, UPDATE_JOIN_STATES_ACTION_NAME
)
self._worker_locks = hs.get_worker_locks_handler()

async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict:
"""
Get a user's profile as a JSON dictionary.
Expand Down Expand Up @@ -587,7 +604,53 @@ async def _update_join_states(
await self.clock.sleep(Duration(seconds=random.randint(1, 10)))
return

room_ids = await self.store.get_rooms_for_user(target_user.to_string())
target_user_str = target_user.to_string()

async with self._worker_locks.acquire_read_write_lock(
UPDATE_JOIN_STATES_LOCK_NAME,
target_user_str,
write=True,
):
tasks_to_cancel = await self._task_scheduler.get_tasks(
actions=[UPDATE_JOIN_STATES_ACTION_NAME],
resource_id=target_user_str,
statuses=[TaskStatus.ACTIVE, TaskStatus.SCHEDULED],
)

for task in tasks_to_cancel:
await self._task_scheduler.cancel_task(task.id)

await self._task_scheduler.schedule_task(
UPDATE_JOIN_STATES_ACTION_NAME,
resource_id=target_user_str,
params={
"requester_authenticated_entity": requester.authenticated_entity,
},
)

async def _update_join_states_task(
self,
task: ScheduledTask,
) -> tuple[TaskStatus, JsonMapping | None, str | None]:
assert task.resource_id
assert task.params

target_user = UserID.from_string(task.resource_id)
room_ids = sorted(await self.store.get_rooms_for_user(target_user.to_string()))

last_room_id = task.result.get("last_room_id", None) if task.result else None

if last_room_id:
unhandled_room_ids = []
for room_id in room_ids:
if room_id > last_room_id:
unhandled_room_ids.append(room_id)
room_ids = unhandled_room_ids

requester = create_requester(
user_id=target_user,
authenticated_entity=task.params.get("requester_authenticated_entity"),
)

for room_id in room_ids:
handler = self.hs.get_room_member_handler()
Expand All @@ -605,6 +668,11 @@ async def _update_join_states(
logger.warning(
"Failed to update join event for room %s - %s", room_id, str(e)
)
await self._task_scheduler.update_task(
task.id, result={"last_room_id": last_room_id}
)

return TaskStatus.COMPLETE, None, None

async def check_profile_query_allowed(
self, target_user: UserID, requester: UserID | None = None
Expand Down
12 changes: 12 additions & 0 deletions synapse/replication/tcp/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,17 @@ class NewActiveTaskCommand(_SimpleCommand):
NAME = "NEW_ACTIVE_TASK"


class CancelTaskCommand(_SimpleCommand):
"""Sent to inform instance handling background tasks that a task has been cancelled and should be terminated.

Format::

CANCEL_TASK "<task_id>"
"""

NAME = "CANCEL_TASK"


_COMMANDS: tuple[type[Command], ...] = (
ServerCommand,
RdataCommand,
Expand All @@ -520,6 +531,7 @@ class NewActiveTaskCommand(_SimpleCommand):
ClearUserSyncsCommand,
LockReleasedCommand,
NewActiveTaskCommand,
CancelTaskCommand,
)

# Map of command name to command type.
Expand Down
29 changes: 28 additions & 1 deletion synapse/replication/tcp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from synapse.metrics import SERVER_NAME_LABEL, LaterGauge
from synapse.replication.tcp.commands import (
CancelTaskCommand,
ClearUserSyncsCommand,
Command,
FederationAckCommand,
Expand Down Expand Up @@ -269,6 +270,14 @@ def __init__(self, hs: "HomeServer"):
for stream_name in self._streams
}

self._global_command_queue = BackgroundQueue[
tuple[CancelTaskCommand, IReplicationConnection]
](
hs,
"process-replication-data",
self._process_global_command,
)

# For each connection, the incoming stream names that have received a POSITION
# from that connection.
self._streams_by_connection: dict[IReplicationConnection, set[str]] = {}
Expand Down Expand Up @@ -376,6 +385,14 @@ async def _process_command(
# This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)

async def _process_global_command(
self, item: tuple[CancelTaskCommand, IReplicationConnection]
) -> None:
cmd, conn = item
if isinstance(cmd, CancelTaskCommand):
if self._task_scheduler:
await self._task_scheduler.on_cancel_task(cmd.data)

def start_replication(self, hs: "HomeServer") -> None:
"""Helper method to start replication."""
from synapse.replication.tcp.redis import RedisDirectTcpReplicationClientFactory
Expand Down Expand Up @@ -746,10 +763,16 @@ def on_LOCK_RELEASED(
def on_NEW_ACTIVE_TASK(
self, conn: IReplicationConnection, cmd: NewActiveTaskCommand
) -> None:
"""Called when get a new NEW_ACTIVE_TASK command."""
"""Called when we get a new NEW_ACTIVE_TASK command."""
if self._task_scheduler:
self._task_scheduler.on_new_task(cmd.data)

async def on_CANCEL_TASK(
self, conn: IReplicationConnection, cmd: CancelTaskCommand
) -> None:
"""Called when we get a new CANCEL_TASK command."""
self._global_command_queue.add((cmd, conn))

def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection."""
self._connections.append(connection)
Expand Down Expand Up @@ -872,6 +895,10 @@ def send_new_active_task(self, task_id: str) -> None:
"""Called when a new task has been scheduled for immediate launch and is ACTIVE."""
self.send_command(NewActiveTaskCommand(task_id))

def send_cancel_task(self, task_id: str) -> None:
"""Called when a scheduled task has been cancelled annd should be terminated."""
self.send_command(CancelTaskCommand(task_id))


UpdateToken = TypeVar("UpdateToken")
UpdateRow = TypeVar("UpdateRow")
Expand Down
1 change: 1 addition & 0 deletions synapse/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,7 @@ class TaskStatus(str, Enum):
COMPLETE = "complete"
# Task is over and either returned a failed status, or had an exception
FAILED = "failed"
CANCELLED = "cancelled"


@attr.s(auto_attribs=True, frozen=True, slots=True)
Expand Down
38 changes: 34 additions & 4 deletions synapse/util/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Callable

from twisted.internet import defer
from twisted.python.failure import Failure

from synapse.logging.context import (
Expand Down Expand Up @@ -111,7 +112,7 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
self._running_tasks: set[str] = set()
self._running_tasks: dict[str, defer.Deferred] = {}
# A map between action names and their registered function
self._actions: dict[
str,
Expand Down Expand Up @@ -325,6 +326,35 @@ async def delete_task(self, id: str) -> None:
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted")
await self._store.delete_scheduled_task(id)

async def cancel_task(self, id: str) -> None:
"""Cancel an ACTIVE or SCHEDULED task.

Args:
id: id of the task to cancel
"""
task = await self.get_task(id)
if not task:
return

if task.status == TaskStatus.SCHEDULED:
await self.update_task(id, status=TaskStatus.CANCELLED)
return

if not task.status == TaskStatus.ACTIVE:
return

if self._run_background_tasks:
await self.on_cancel_task(id)
else:
self.hs.get_replication_command_handler().send_cancel_task(id)

async def on_cancel_task(self, id: str) -> None:
if id in self._running_tasks:
defer = self._running_tasks[id]
defer.cancel()
self._running_tasks.pop(id)
await self.update_task(id, status=TaskStatus.CANCELLED)

def on_new_task(self, task_id: str) -> None:
"""Handle a notification that a new ready-to-run task has been added to the queue"""
# Just run the scheduler
Expand Down Expand Up @@ -458,7 +488,7 @@ async def wrapper() -> None:
result=result,
error=error,
)
self._running_tasks.remove(task.id)
self._running_tasks.pop(task.id)

current_time = self._clock.time()
usage = log_context.get_resource_usage()
Expand Down Expand Up @@ -489,6 +519,6 @@ async def wrapper() -> None:
if task.id in self._running_tasks:
return

self._running_tasks.add(task.id)
await self.update_task(task.id, status=TaskStatus.ACTIVE)
self.hs.run_as_background_process(f"task-{task.action}", wrapper)
defer = self.hs.run_as_background_process(f"task-{task.action}", wrapper)
self._running_tasks[task.id] = defer
72 changes: 70 additions & 2 deletions tests/util/test_task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.task_scheduler.register_action(self._sleeping_task, "_sleeping_task")
self.task_scheduler.register_action(self._raising_task, "_raising_task")
self.task_scheduler.register_action(self._resumable_task, "_resumable_task")
self.task_scheduler.register_action(
self._incrementing_task, "_incrementing_task"
)

async def _test_task(
self, task: ScheduledTask
Expand Down Expand Up @@ -187,15 +190,80 @@ def test_schedule_resumable_task(self) -> None:
self.assertEqual(task.status, TaskStatus.ACTIVE)

# Simulate a synapse restart by emptying the list of running tasks
self.task_scheduler._running_tasks = set()
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL.as_secs()))
self.task_scheduler._running_tasks = {}
self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL.as_secs())

task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None
self.assertEqual(task.status, TaskStatus.COMPLETE)
assert task.result is not None
self.assertTrue(task.result.get("success"))

async def _incrementing_task(
self, task: ScheduledTask
) -> tuple[TaskStatus, JsonMapping | None, str | None]:
current_counter = 0
if task.result and "counter" in task.result:
current_counter = int(task.result["counter"])

return TaskStatus.ACTIVE, {"counter": current_counter + 1}, None

def test_cancel_task(self) -> None:
"""Schedule and then cancel a long running task that increments a counter."""

task_id = self.get_success(
self.task_scheduler.schedule_task(
"_incrementing_task",
)
)

task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None
assert task.status == TaskStatus.ACTIVE

assert task.result and "counter" in task.result
current_counter = int(task.result["counter"])

self.reactor.advance(1)

task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None
assert task.status == TaskStatus.ACTIVE

# At this point the task should have run at least one more time, let's check the counter
assert task.result and "counter" in task.result
new_counter = int(task.result["counter"])
assert new_counter > current_counter
current_counter = new_counter

# Cancelling active task
self.get_success(self.task_scheduler.cancel_task(task_id))

self.reactor.advance(1)

# Task should be marked as cancelled
task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None
self.assertEqual(task.status, TaskStatus.CANCELLED)

# Task should be in the running tasks
assert task_id not in self.task_scheduler._running_tasks

# Counter should not increase anymore and stay the same
assert task.result and "counter" in task.result
new_counter = int(task.result["counter"])
assert new_counter == current_counter
current_counter = new_counter

# Let's check one more time to be sure that it is not increasing
self.reactor.advance(100)

task = self.get_success(self.task_scheduler.get_task(task_id))
assert task is not None
assert task.result and "counter" in task.result
new_counter = int(task.result["counter"])
assert new_counter == current_counter


class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
Expand Down
Loading