Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from collections import OrderedDict
import inspect
import sys
import typing
import typing_extensions
from dataclasses import MISSING, Field, dataclass
Expand Down Expand Up @@ -79,6 +80,15 @@ class Dataclassish(Protocol):
class TypeOfTypedDict(Protocol):
__total__: bool

if sys.version_info >= (3, 12) and typing.TypeAliasType is not typing_extensions.TypeAliasType:
# Sometimes typing_extensions aliases TypeAliasType,
# sometimes it's its own declaration.
def is_type_alias_type(py_type: object) -> TypeGuard[TypeAliasType]:
return isinstance(py_type, typing.TypeAliasType | typing_extensions.TypeAliasType)
else:
def is_type_alias_type(py_type: object) -> TypeGuard[TypeAliasType]:
return isinstance(py_type, typing_extensions.TypeAliasType)


def is_generic(py_type: object) -> TypeGuard[GenericAliasish]:
return hasattr(py_type, "__origin__") and hasattr(py_type, "__args__")
Expand All @@ -88,9 +98,8 @@ def is_dataclass(py_type: object) -> TypeGuard[Dataclassish]:

TypeReferenceTarget: TypeAlias = type | TypeAliasType | TypeVar | GenericAliasish


def is_python_type_or_alias(origin: object) -> TypeGuard[type | TypeAliasType]:
return isinstance(origin, TypeAliasType | type)
return isinstance(origin, type) or is_type_alias_type(origin)


_KNOWN_GENERIC_SPECIAL_FORMS: frozenset[Any] = frozenset(
Expand Down Expand Up @@ -393,7 +402,7 @@ def declare_type(py_type: object):
reserve_name(py_type)

return InterfaceDeclarationNode(py_type.__name__, None, "", None, [])
if isinstance(py_type, TypeAliasType):
if is_type_alias_type(py_type):
type_params = [TypeParameterDeclarationNode(type_param.__name__) for type_param in py_type.__type_params__]

reserve_name(py_type)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Entry point is: 'FirstOrSecond'

type FirstOrSecond<T> = First<T> | Second<T>

interface Second<T> {
kind: "second";
second_attr: T;
}

interface First<T> {
kind: "first";
first_attr: T;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
// Entry point is: 'StrOrInt'

type StrOrInt = string | number
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Entry point is: 'Nested'

interface Nested {
item: FirstOrSecond<string>;
}

type FirstOrSecond<T> = First<T> | Second<T>

interface Second<T> {
kind: "second";
second_attr: T;
}

interface First<T> {
kind: "first";
first_attr: T;
}
1 change: 0 additions & 1 deletion python/tests/test_generic_alias_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

T = TypeVar("T", covariant=True)


class First(Generic[T], TypedDict):
kind: Literal["first"]
first_attr: T
Expand Down
4 changes: 1 addition & 3 deletions python/tests/test_generic_alias_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,5 @@ class Nested(TypedDict):
item: FirstOrSecond[str]



def test_generic_alias1(snapshot: Any):
def test_generic_alias2(snapshot: Any):
assert(python_type_to_typescript_schema(Nested) == snapshot(extension_class=TypeScriptSchemaSnapshotExtension))

20 changes: 20 additions & 0 deletions python/tests/test_generic_alias_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Any
from .utilities import check_snapshot_for_module_string_if_3_12_plus

module_str = """
from typing import Literal, TypedDict
class First[T](TypedDict):
kind: Literal["first"]
first_attr: T


class Second[T](TypedDict):
kind: Literal["second"]
second_attr: T


type FirstOrSecond[T] = First[T] | Second[T]
"""

def test_generic_alias3(snapshot: Any):
check_snapshot_for_module_string_if_3_12_plus(snapshot, input_type_str="FirstOrSecond", module_str=module_str)
23 changes: 23 additions & 0 deletions python/tests/test_generic_alias_4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any
from .utilities import check_snapshot_for_module_string_if_3_12_plus

module_str = """
from typing import Literal, TypedDict
class First[T](TypedDict):
kind: Literal["first"]
first_attr: T


class Second[T](TypedDict):
kind: Literal["second"]
second_attr: T


type FirstOrSecond[T] = First[T] | Second[T]

class Nested(TypedDict):
item: FirstOrSecond[str]
"""

def test_generic_alias4(snapshot: Any):
check_snapshot_for_module_string_if_3_12_plus(snapshot, input_type_str="Nested", module_str=module_str)
7 changes: 7 additions & 0 deletions python/tests/test_type_alias_syntax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import Any
from .utilities import check_snapshot_for_module_string_if_3_12_plus

module_str = "type StrOrInt = str | int"

def test_type_alias_union1(snapshot: Any):
check_snapshot_for_module_string_if_3_12_plus(snapshot, "StrOrInt", module_str)
16 changes: 15 additions & 1 deletion python/tests/utilities.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pathlib import Path
import sys
import types

from typing_extensions import Any, override
import pytest

from syrupy.extensions.single_file import SingleFileSnapshotExtension, WriteMode
from syrupy.location import PyTestLocation

from typechat._internal.ts_conversion import TypeScriptSchemaConversionResult
from typechat._internal.ts_conversion import TypeScriptSchemaConversionResult, python_type_to_typescript_schema

class TypeScriptSchemaSnapshotExtension(SingleFileSnapshotExtension):
_write_mode = WriteMode.TEXT
Expand Down Expand Up @@ -41,6 +42,19 @@ def dirname(cls, *, test_location: PyTestLocation) -> str:
)
return str(result)

class PyVersioned3_12_PlusSnapshotExtension(PyVersionedTypeScriptSchemaSnapshotExtension):
py_ver_dir: str = f"__py3.12+_snapshots__"

def check_snapshot_for_module_string_if_3_12_plus(snapshot: Any, input_type_str: str, module_str: str):
if sys.version_info < (3, 12):
pytest.skip("requires python 3.12 or higher")

module = types.ModuleType("test_module")
exec(module_str, module.__dict__)
type_obj = eval(input_type_str, globals(), module.__dict__)

assert(python_type_to_typescript_schema(type_obj) == snapshot(extension_class=PyVersioned3_12_PlusSnapshotExtension))

@pytest.fixture
def snapshot_schema(snapshot: Any):
return snapshot.with_defaults(extension_class=TypeScriptSchemaSnapshotExtension)