Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions av/container/input.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from typing import Any, Iterator, overload
from av.audio.frame import AudioFrame
from av.audio.stream import AudioStream
from av.packet import Packet
from av.stream import Stream
from av.stream import AttachmentStream, DataStream, Stream
from av.subtitles.stream import SubtitleStream
from av.subtitles.subtitle import SubtitleSet
from av.video.frame import VideoFrame
Expand All @@ -19,7 +19,48 @@ class InputContainer(Container):

def __enter__(self) -> InputContainer: ...
def close(self) -> None: ...
def demux(self, *args: Any, **kwargs: Any) -> Iterator[Packet]: ...
@overload
def demux(self, video_stream: VideoStream) -> Iterator[Packet[VideoStream]]: ...
@overload
def demux(
self, video_streams: tuple[VideoStream, ...]
) -> Iterator[Packet[VideoStream]]: ...
@overload
def demux(self, *, video: Any) -> Iterator[Packet[VideoStream]]: ...
@overload
def demux(self, audio_stream: AudioStream) -> Iterator[Packet[AudioStream]]: ...
@overload
def demux(
self, audio_streams: tuple[AudioStream, ...]
) -> Iterator[Packet[AudioStream]]: ...
@overload
def demux(self, *, audio: Any) -> Iterator[Packet[AudioStream]]: ...
@overload
def demux(
self, subtitle_stream: SubtitleStream
) -> Iterator[Packet[SubtitleStream]]: ...
@overload
def demux(
self, subtitle_streams: tuple[SubtitleStream, ...]
) -> Iterator[Packet[SubtitleStream]]: ...
@overload
def demux(self, data_stream: DataStream) -> Iterator[Packet[DataStream]]: ...
@overload
def demux(
self, data_streams: tuple[DataStream, ...]
) -> Iterator[Packet[DataStream]]: ...
@overload
def demux(self, *, data: Any) -> Iterator[Packet[DataStream]]: ...
@overload
def demux(
self, attachment_stream: AttachmentStream
) -> Iterator[Packet[AttachmentStream]]: ...
@overload
def demux(
self, attachment_streams: tuple[AttachmentStream, ...]
) -> Iterator[Packet[AttachmentStream]]: ...
@overload
def demux(self, *args: Any, **kwargs: Any) -> Iterator[Packet[Stream]]: ...
@overload
def decode(self, video: int) -> Iterator[VideoFrame]: ...
@overload
Expand Down
38 changes: 30 additions & 8 deletions av/packet.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from fractions import Fraction
from typing import Iterator, Literal
from typing import Generic, Iterator, Literal, TypeVar, overload

from av.subtitles.subtitle import SubtitleSet
from av.audio.frame import AudioFrame
from av.audio.stream import AudioStream
from av.stream import AttachmentStream, DataStream, Stream
from av.subtitles.stream import SubtitleStream
from av.subtitles.subtitle import AssSubtitle, BitmapSubtitle, SubtitleSet
from av.video.frame import VideoFrame
from av.video.stream import VideoStream

from .buffer import Buffer
from .stream import Stream
Expand Down Expand Up @@ -52,8 +58,8 @@ PktSideDataT = Literal[

class PacketSideData(Buffer):
@staticmethod
def from_packet(packet: Packet, dtype: PktSideDataT) -> PacketSideData: ...
def to_packet(self, packet: Packet, move: bool = False): ...
def from_packet(packet: Packet[Stream], dtype: PktSideDataT) -> PacketSideData: ...
def to_packet(self, packet: Packet[Stream], move: bool = False): ...
@property
def data_type(self) -> str: ...
@property
Expand All @@ -65,8 +71,11 @@ class PacketSideData(Buffer):
def packet_sidedata_type_to_literal(dtype: int) -> PktSideDataT: ...
def packet_sidedata_type_from_literal(dtype: PktSideDataT) -> int: ...

class Packet(Buffer):
stream: Stream
# TypeVar for stream types - bound to Stream so it can be any stream type
StreamT = TypeVar("StreamT", bound=Stream)

class Packet(Buffer, Generic[StreamT]):
stream: StreamT
stream_index: int
time_base: Fraction
pts: int | None
Expand All @@ -81,8 +90,21 @@ class Packet(Buffer):
is_trusted: bool
is_disposable: bool

def __init__(self, input: int | bytes | None = None) -> None: ...
def decode(self) -> list[SubtitleSet]: ...
def __init__(self: Packet[Stream], input: int | bytes | None = None) -> None: ...

# Overloads that return the same type as the stream's decode method
@overload
def decode(self: Packet[VideoStream]) -> list[VideoFrame]: ...
@overload
def decode(self: Packet[AudioStream]) -> list[AudioFrame]: ...
@overload
def decode(
self: Packet[SubtitleStream],
) -> list[AssSubtitle] | list[BitmapSubtitle]: ...
@overload
def decode(
self,
) -> list[VideoFrame | AudioFrame | AssSubtitle | BitmapSubtitle]: ...
def has_sidedata(self, dtype: PktSideDataT) -> bool: ...
def get_sidedata(self, dtype: PktSideDataT) -> PacketSideData: ...
def set_sidedata(self, sidedata: PacketSideData, move: bool = False) -> None: ...
Expand Down
5 changes: 5 additions & 0 deletions tests/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ def test_decode_audio_corrupt(self) -> None:

packet_count = 0
frame_count = 0
audio_frame: av.AudioFrame | None = None

with av.open(path) as container:
for packet in container.demux(audio=0):
for frame in packet.decode():
frame_count += 1
audio_frame = frame
packet_count += 1

assert packet_count == 1
Expand Down Expand Up @@ -109,8 +111,11 @@ def test_decoded_time_base(self) -> None:

assert stream.time_base == Fraction(1, 25)

video_frame: av.VideoFrame | None = None

for packet in container.demux(stream):
for frame in packet.decode():
video_frame = frame
assert not isinstance(frame, SubtitleSet)
assert packet.time_base == frame.time_base
assert stream.time_base == frame.time_base
Expand Down
20 changes: 12 additions & 8 deletions tests/test_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,15 @@ def test_data_stream_from_template(self) -> None:
input_data_stream
)

for packet in input_container.demux(input_data_stream):
payload = bytes(packet)
for data_packet in input_container.demux(input_data_stream):
payload = bytes(data_packet)
if not payload:
continue
copied_payloads.append(payload)
clone = av.Packet(payload)
clone.pts = packet.pts
clone.dts = packet.dts
clone.time_base = packet.time_base
clone.pts = data_packet.pts
clone.dts = data_packet.dts
clone.time_base = data_packet.time_base
clone.stream = output_data_stream
output_container.mux(clone)

Expand All @@ -196,8 +196,8 @@ def test_data_stream_from_template(self) -> None:
assert output_stream.codec_context is None

remuxed_payloads: list[bytes] = []
for packet in remuxed.demux(output_stream):
payload = bytes(packet)
for data_packet in remuxed.demux(output_stream):
payload = bytes(data_packet)
if payload:
remuxed_payloads.append(payload)

Expand Down Expand Up @@ -287,7 +287,11 @@ def test_attachment_stream(self) -> None:
for packet in ic.demux(ic.streams.video):
if packet.dts is None:
continue
packet.stream = stream_map[packet.stream.index]
updated_stream = stream_map.get(packet.stream.index)
if isinstance(updated_stream, av.video.stream.VideoStream):
packet.stream = updated_stream
else:
raise ValueError("Expected a VideoStream")
oc.mux(packet)

with av.open(out2_path) as c:
Expand Down
Loading