diff --git a/av/container/input.pyi b/av/container/input.pyi index 90154c331..b9d93c6b4 100644 --- a/av/container/input.pyi +++ b/av/container/input.pyi @@ -3,7 +3,7 @@ from typing import Any, Iterator, overload from av.audio.frame import AudioFrame from av.audio.stream import AudioStream from av.packet import Packet -from av.stream import Stream +from av.stream import AttachmentStream, DataStream, Stream from av.subtitles.stream import SubtitleStream from av.subtitles.subtitle import SubtitleSet from av.video.frame import VideoFrame @@ -19,7 +19,48 @@ class InputContainer(Container): def __enter__(self) -> InputContainer: ... def close(self) -> None: ... - def demux(self, *args: Any, **kwargs: Any) -> Iterator[Packet]: ... + @overload + def demux(self, video_stream: VideoStream) -> Iterator[Packet[VideoStream]]: ... + @overload + def demux( + self, video_streams: tuple[VideoStream, ...] + ) -> Iterator[Packet[VideoStream]]: ... + @overload + def demux(self, *, video: Any) -> Iterator[Packet[VideoStream]]: ... + @overload + def demux(self, audio_stream: AudioStream) -> Iterator[Packet[AudioStream]]: ... + @overload + def demux( + self, audio_streams: tuple[AudioStream, ...] + ) -> Iterator[Packet[AudioStream]]: ... + @overload + def demux(self, *, audio: Any) -> Iterator[Packet[AudioStream]]: ... + @overload + def demux( + self, subtitle_stream: SubtitleStream + ) -> Iterator[Packet[SubtitleStream]]: ... + @overload + def demux( + self, subtitle_streams: tuple[SubtitleStream, ...] + ) -> Iterator[Packet[SubtitleStream]]: ... + @overload + def demux(self, data_stream: DataStream) -> Iterator[Packet[DataStream]]: ... + @overload + def demux( + self, data_streams: tuple[DataStream, ...] + ) -> Iterator[Packet[DataStream]]: ... + @overload + def demux(self, *, data: Any) -> Iterator[Packet[DataStream]]: ... + @overload + def demux( + self, attachment_stream: AttachmentStream + ) -> Iterator[Packet[AttachmentStream]]: ... + @overload + def demux( + self, attachment_streams: tuple[AttachmentStream, ...] + ) -> Iterator[Packet[AttachmentStream]]: ... + @overload + def demux(self, *args: Any, **kwargs: Any) -> Iterator[Packet[Stream]]: ... @overload def decode(self, video: int) -> Iterator[VideoFrame]: ... @overload diff --git a/av/packet.pyi b/av/packet.pyi index 6b1a271c2..d129e5e7f 100644 --- a/av/packet.pyi +++ b/av/packet.pyi @@ -1,7 +1,13 @@ from fractions import Fraction -from typing import Iterator, Literal +from typing import Generic, Iterator, Literal, TypeVar, overload -from av.subtitles.subtitle import SubtitleSet +from av.audio.frame import AudioFrame +from av.audio.stream import AudioStream +from av.stream import AttachmentStream, DataStream, Stream +from av.subtitles.stream import SubtitleStream +from av.subtitles.subtitle import AssSubtitle, BitmapSubtitle, SubtitleSet +from av.video.frame import VideoFrame +from av.video.stream import VideoStream from .buffer import Buffer from .stream import Stream @@ -52,8 +58,8 @@ PktSideDataT = Literal[ class PacketSideData(Buffer): @staticmethod - def from_packet(packet: Packet, dtype: PktSideDataT) -> PacketSideData: ... - def to_packet(self, packet: Packet, move: bool = False): ... + def from_packet(packet: Packet[Stream], dtype: PktSideDataT) -> PacketSideData: ... + def to_packet(self, packet: Packet[Stream], move: bool = False): ... @property def data_type(self) -> str: ... @property @@ -65,8 +71,11 @@ class PacketSideData(Buffer): def packet_sidedata_type_to_literal(dtype: int) -> PktSideDataT: ... def packet_sidedata_type_from_literal(dtype: PktSideDataT) -> int: ... -class Packet(Buffer): - stream: Stream +# TypeVar for stream types - bound to Stream so it can be any stream type +StreamT = TypeVar("StreamT", bound=Stream) + +class Packet(Buffer, Generic[StreamT]): + stream: StreamT stream_index: int time_base: Fraction pts: int | None @@ -81,8 +90,21 @@ class Packet(Buffer): is_trusted: bool is_disposable: bool - def __init__(self, input: int | bytes | None = None) -> None: ... - def decode(self) -> list[SubtitleSet]: ... + def __init__(self: Packet[Stream], input: int | bytes | None = None) -> None: ... + + # Overloads that return the same type as the stream's decode method + @overload + def decode(self: Packet[VideoStream]) -> list[VideoFrame]: ... + @overload + def decode(self: Packet[AudioStream]) -> list[AudioFrame]: ... + @overload + def decode( + self: Packet[SubtitleStream], + ) -> list[AssSubtitle] | list[BitmapSubtitle]: ... + @overload + def decode( + self, + ) -> list[VideoFrame | AudioFrame | AssSubtitle | BitmapSubtitle]: ... def has_sidedata(self, dtype: PktSideDataT) -> bool: ... def get_sidedata(self, dtype: PktSideDataT) -> PacketSideData: ... def set_sidedata(self, sidedata: PacketSideData, move: bool = False) -> None: ... diff --git a/tests/test_decode.py b/tests/test_decode.py index d7fffbd4c..3277dd776 100644 --- a/tests/test_decode.py +++ b/tests/test_decode.py @@ -72,11 +72,13 @@ def test_decode_audio_corrupt(self) -> None: packet_count = 0 frame_count = 0 + audio_frame: av.AudioFrame | None = None with av.open(path) as container: for packet in container.demux(audio=0): for frame in packet.decode(): frame_count += 1 + audio_frame = frame packet_count += 1 assert packet_count == 1 @@ -109,8 +111,11 @@ def test_decoded_time_base(self) -> None: assert stream.time_base == Fraction(1, 25) + video_frame: av.VideoFrame | None = None + for packet in container.demux(stream): for frame in packet.decode(): + video_frame = frame assert not isinstance(frame, SubtitleSet) assert packet.time_base == frame.time_base assert stream.time_base == frame.time_base diff --git a/tests/test_streams.py b/tests/test_streams.py index 9387d68cc..4f7defef0 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -179,15 +179,15 @@ def test_data_stream_from_template(self) -> None: input_data_stream ) - for packet in input_container.demux(input_data_stream): - payload = bytes(packet) + for data_packet in input_container.demux(input_data_stream): + payload = bytes(data_packet) if not payload: continue copied_payloads.append(payload) clone = av.Packet(payload) - clone.pts = packet.pts - clone.dts = packet.dts - clone.time_base = packet.time_base + clone.pts = data_packet.pts + clone.dts = data_packet.dts + clone.time_base = data_packet.time_base clone.stream = output_data_stream output_container.mux(clone) @@ -196,8 +196,8 @@ def test_data_stream_from_template(self) -> None: assert output_stream.codec_context is None remuxed_payloads: list[bytes] = [] - for packet in remuxed.demux(output_stream): - payload = bytes(packet) + for data_packet in remuxed.demux(output_stream): + payload = bytes(data_packet) if payload: remuxed_payloads.append(payload) @@ -287,7 +287,11 @@ def test_attachment_stream(self) -> None: for packet in ic.demux(ic.streams.video): if packet.dts is None: continue - packet.stream = stream_map[packet.stream.index] + updated_stream = stream_map.get(packet.stream.index) + if isinstance(updated_stream, av.video.stream.VideoStream): + packet.stream = updated_stream + else: + raise ValueError("Expected a VideoStream") oc.mux(packet) with av.open(out2_path) as c: