Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions taskiq_nats/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
from nats.errors import TimeoutError as NatsTimeoutError
from nats.js import JetStreamContext
from nats.js.api import ConsumerConfig, StreamConfig
from nats.js.errors import NotFoundError
from taskiq import AckableMessage, AsyncBroker, AsyncResultBackend, BrokerMessage

_T = typing.TypeVar("_T") # (Too short)


JetStreamConsumerType = typing.TypeVar(
"JetStreamConsumerType",
)


logger = getLogger("taskiq_nats")


Expand Down Expand Up @@ -138,6 +137,23 @@ def __init__(

self.consumer: JetStreamConsumerType

async def _ensure_stream_exists(self) -> None:
"""Ensure stream exists, create if it doesn't."""
if self.stream_config.name is None:
self.stream_config.name = self.stream_name
if not self.stream_config.subjects:
self.stream_config.subjects = [self.subject]

try:
# Check if stream already exists
await self.js.stream_info(self.stream_config.name)
logger.debug("Stream %s already exists", self.stream_config.name)
except NotFoundError:
logger.debug("stream %s does not exist", self.stream_config.name)
# Stream doesn't exist, create it
await self.js.add_stream(config=self.stream_config)
logger.info("Created stream %s", self.stream_config.name)

async def startup(self) -> None:
"""
Startup event handler.
Expand All @@ -148,11 +164,9 @@ async def startup(self) -> None:
await super().startup()
await self.client.connect(self.servers, **self.connection_kwargs)
self.js = self.client.jetstream()
if self.stream_config.name is None:
self.stream_config.name = self.stream_name
if not self.stream_config.subjects:
self.stream_config.subjects = [self.subject]
await self.js.add_stream(config=self.stream_config)

# Ensure stream exists (won't recreate if it exists)
await self._ensure_stream_exists()
await self._startup_consumer()

async def shutdown(self) -> None:
Expand Down