From c945078efa03492c292ddb06c4555847fd4995c9 Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 00:20:31 +0000 Subject: [PATCH 01/10] Add tests for preamble. --- python/tests/test_translator.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/python/tests/test_translator.py b/python/tests/test_translator.py index 1245ef34..28488000 100644 --- a/python/tests/test_translator.py +++ b/python/tests/test_translator.py @@ -50,4 +50,31 @@ def test_translator_with_single_failure(snapshot: Any): t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) asyncio.run(t.translate("Get me stuff.")) - assert m.conversation == snapshot \ No newline at end of file + assert m.conversation == snapshot + +def test_translator_with_single_failure_and_str_preamble(snapshot: Any): + m = FixedModel([ + '{ "a": "hello", "b": true }', + '{ "a": "hello", "b": true, "c": 1234 }', + ]) + t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) + asyncio.run(t.translate( + "Get me stuff.", + prompt_preamble="Just so you know, I need some stuff.", + )) + + assert m.conversation == snapshot + +def test_translator_with_single_failure_and_list_preamble_1(snapshot: Any): + m = FixedModel([ + '{ "a": "hello", "b": true }', + '{ "a": "hello", "b": true, "c": 1234 }', + ]) + t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) + asyncio.run(t.translate("Get me stuff.", prompt_preamble=[ + {"role": "user", "content": "Hey, I need some stuff."}, + {"role": "assistant", "content": "Okay, what kind of stuff?"}, + ])) + + assert m.conversation == snapshot + From b0a3e1f649a66f29ccb2cd5ea93338995d96b565 Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 00:31:45 +0000 Subject: [PATCH 02/10] Update snapshots. --- .../tests/__snapshots__/test_translator.ambr | 212 ++++++++++++++++++ 1 file changed, 212 insertions(+) diff --git a/python/tests/__snapshots__/test_translator.ambr b/python/tests/__snapshots__/test_translator.ambr index 594786fb..604bd8d2 100644 --- a/python/tests/__snapshots__/test_translator.ambr +++ b/python/tests/__snapshots__/test_translator.ambr @@ -135,3 +135,215 @@ }), ]) # --- +# name: test_translator_with_single_failure_and_list_preamble_1 + list([ + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': 'Get me stuff.', + 'role': 'user', + }), + dict({ + 'content': 'Hey, I need some stuff.', + 'role': 'user', + }), + dict({ + 'content': 'Okay, what kind of stuff?', + 'role': 'assistant', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': ''' + + The above JSON object is invalid for the following reason: + ''' + Validation path `c` failed for value `{"a": "hello", "b": true}` because: + Field required + ''' + The following is a revised JSON object: + + ''', + 'role': 'user', + }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true }', + }), + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': 'Get me stuff.', + 'role': 'user', + }), + dict({ + 'content': 'Hey, I need some stuff.', + 'role': 'user', + }), + dict({ + 'content': 'Okay, what kind of stuff?', + 'role': 'assistant', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': ''' + + The above JSON object is invalid for the following reason: + ''' + Validation path `c` failed for value `{"a": "hello", "b": true}` because: + Field required + ''' + The following is a revised JSON object: + + ''', + 'role': 'user', + }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true, "c": 1234 }', + }), + ]) +# --- +# name: test_translator_with_single_failure_and_str_preamble + list([ + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': 'Get me stuff.', + 'role': 'user', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': ''' + + The above JSON object is invalid for the following reason: + ''' + Validation path `c` failed for value `{"a": "hello", "b": true}` because: + Field required + ''' + The following is a revised JSON object: + + ''', + 'role': 'user', + }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true }', + }), + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': 'Get me stuff.', + 'role': 'user', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': ''' + + The above JSON object is invalid for the following reason: + ''' + Validation path `c` failed for value `{"a": "hello", "b": true}` because: + Field required + ''' + The following is a revised JSON object: + + ''', + 'role': 'user', + }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true, "c": 1234 }', + }), + ]) +# --- From 9a944311296ea154a7c3aff54a8d318606df45a7 Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 00:20:38 +0000 Subject: [PATCH 03/10] Fix translator - correct ordering and include assistant response on recovery. --- python/src/typechat/_internal/translator.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/src/typechat/_internal/translator.py b/python/src/typechat/_internal/translator.py index 6c649c1e..09dbf50b 100644 --- a/python/src/typechat/_internal/translator.py +++ b/python/src/typechat/_internal/translator.py @@ -64,12 +64,10 @@ async def translate(self, input: str, *, prompt_preamble: str | list[PromptSecti messages: list[PromptSection] = [] - messages.append({"role": "user", "content": input}) if prompt_preamble: if isinstance(prompt_preamble, str): prompt_preamble = [{"role": "user", "content": prompt_preamble}] - else: - messages.extend(prompt_preamble) + messages.extend(prompt_preamble) messages.append({"role": "user", "content": self._create_request_prompt(input)}) @@ -95,6 +93,7 @@ async def translate(self, input: str, *, prompt_preamble: str | list[PromptSecti if num_repairs_attempted >= self._max_repair_attempts: return Failure(error_message) num_repairs_attempted += 1 + messages.append({"role": "assistant", "content": text_response}) messages.append({"role": "user", "content": self._create_repair_prompt(error_message)}) def _create_request_prompt(self, intent: str) -> str: From e16eb2783f50af9a72450f70f956afd0b2206e82 Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 00:32:17 +0000 Subject: [PATCH 04/10] Update snapshots. --- .../tests/__snapshots__/test_translator.ambr | 48 ++++++++++--------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/python/tests/__snapshots__/test_translator.ambr b/python/tests/__snapshots__/test_translator.ambr index 604bd8d2..3773cadd 100644 --- a/python/tests/__snapshots__/test_translator.ambr +++ b/python/tests/__snapshots__/test_translator.ambr @@ -4,10 +4,6 @@ dict({ 'kind': 'CLIENT REQUEST', 'payload': list([ - dict({ - 'content': 'Get me stuff.', - 'role': 'user', - }), dict({ 'content': ''' @@ -42,10 +38,6 @@ dict({ 'kind': 'CLIENT REQUEST', 'payload': list([ - dict({ - 'content': 'Get me stuff.', - 'role': 'user', - }), dict({ 'content': ''' @@ -67,6 +59,10 @@ ''', 'role': 'user', }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), dict({ 'content': ''' @@ -89,10 +85,6 @@ dict({ 'kind': 'CLIENT REQUEST', 'payload': list([ - dict({ - 'content': 'Get me stuff.', - 'role': 'user', - }), dict({ 'content': ''' @@ -114,6 +106,10 @@ ''', 'role': 'user', }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), dict({ 'content': ''' @@ -140,10 +136,6 @@ dict({ 'kind': 'CLIENT REQUEST', 'payload': list([ - dict({ - 'content': 'Get me stuff.', - 'role': 'user', - }), dict({ 'content': 'Hey, I need some stuff.', 'role': 'user', @@ -173,6 +165,10 @@ ''', 'role': 'user', }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), dict({ 'content': ''' @@ -195,10 +191,6 @@ dict({ 'kind': 'CLIENT REQUEST', 'payload': list([ - dict({ - 'content': 'Get me stuff.', - 'role': 'user', - }), dict({ 'content': 'Hey, I need some stuff.', 'role': 'user', @@ -228,6 +220,10 @@ ''', 'role': 'user', }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), dict({ 'content': ''' @@ -255,7 +251,7 @@ 'kind': 'CLIENT REQUEST', 'payload': list([ dict({ - 'content': 'Get me stuff.', + 'content': 'Just so you know, I need some stuff.', 'role': 'user', }), dict({ @@ -279,6 +275,10 @@ ''', 'role': 'user', }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), dict({ 'content': ''' @@ -302,7 +302,7 @@ 'kind': 'CLIENT REQUEST', 'payload': list([ dict({ - 'content': 'Get me stuff.', + 'content': 'Just so you know, I need some stuff.', 'role': 'user', }), dict({ @@ -326,6 +326,10 @@ ''', 'role': 'user', }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), dict({ 'content': ''' From 564b4cbe0f4cf8055e7e78c0ad4baac4948b060b Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 00:56:34 +0000 Subject: [PATCH 05/10] Catch JSON parse errors. --- python/src/typechat/_internal/translator.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/python/src/typechat/_internal/translator.py b/python/src/typechat/_internal/translator.py index 09dbf50b..fb45ad35 100644 --- a/python/src/typechat/_internal/translator.py +++ b/python/src/typechat/_internal/translator.py @@ -83,11 +83,20 @@ async def translate(self, input: str, *, prompt_preamble: str | list[PromptSecti error_message: str if 0 <= first_curly < last_curly: trimmed_response = text_response[first_curly:last_curly] - parsed_response = pydantic_core.from_json(trimmed_response, allow_inf_nan=False, cache_strings=False) - result = self.validator.validate_object(parsed_response) - if isinstance(result, Success): - return result - error_message = result.message + try: + parsed_response = pydantic_core.from_json(trimmed_response, allow_inf_nan=False, cache_strings=False) + except ValueError as e: + error_message = f""" +Error: {e} + +Attempted to parse: +{trimmed_response} +""" + else: + result = self.validator.validate_object(parsed_response) + if isinstance(result, Success): + return result + error_message = result.message else: error_message = "Response did not contain any text resembling JSON." if num_repairs_attempted >= self._max_repair_attempts: From 0e6329f2f7c9323d2b649e9a9c19db22b22b096a Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 00:58:24 +0000 Subject: [PATCH 06/10] Add test for invalid JSON responses. --- python/tests/test_translator.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/tests/test_translator.py b/python/tests/test_translator.py index 28488000..adb2dd8f 100644 --- a/python/tests/test_translator.py +++ b/python/tests/test_translator.py @@ -52,6 +52,16 @@ def test_translator_with_single_failure(snapshot: Any): assert m.conversation == snapshot +def test_translator_with_invalid_json(snapshot: Any): + m = FixedModel([ + '{ "a": "hello" "b": true }', + '{ "a": "hello" "b": true, "c": 1234 }', + ]) + t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) + asyncio.run(t.translate("Get me stuff.")) + + assert m.conversation == snapshot + def test_translator_with_single_failure_and_str_preamble(snapshot: Any): m = FixedModel([ '{ "a": "hello", "b": true }', From e943312de262ae202b84c79fdc60ec7d9a7e32cf Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 01:16:21 +0000 Subject: [PATCH 07/10] Copy lists from the translator so that output is correctly snapshotted/not duplicated. --- python/tests/test_translator.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/tests/test_translator.py b/python/tests/test_translator.py index adb2dd8f..a86502b3 100644 --- a/python/tests/test_translator.py +++ b/python/tests/test_translator.py @@ -20,6 +20,11 @@ def __init__(self, responses: list[str]) -> None: @override async def complete(self, prompt: str | list[typechat.PromptSection]) -> typechat.Result[str]: + # Capture a snapshot because the translator + # can choose to pass in the same underlying list. + if isinstance(prompt, list): + prompt = prompt.copy() + self.conversation.append({ "kind": "CLIENT REQUEST", "payload": prompt }) response = next(self.responses) self.conversation.append({ "kind": "MODEL RESPONSE", "payload": response }) From 5e656143633f0190664a1725b411068d68e8d96c Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 01:16:32 +0000 Subject: [PATCH 08/10] Updated snapshot. --- .../tests/__snapshots__/test_translator.ambr | 110 ++++++++++++------ 1 file changed, 72 insertions(+), 38 deletions(-) diff --git a/python/tests/__snapshots__/test_translator.ambr b/python/tests/__snapshots__/test_translator.ambr index 3773cadd..634ed3d4 100644 --- a/python/tests/__snapshots__/test_translator.ambr +++ b/python/tests/__snapshots__/test_translator.ambr @@ -33,7 +33,7 @@ }), ]) # --- -# name: test_translator_with_single_failure +# name: test_translator_with_invalid_json list([ dict({ 'kind': 'CLIENT REQUEST', @@ -59,8 +59,38 @@ ''', 'role': 'user', }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello" "b": true }', + }), + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ dict({ - 'content': '{ "a": "hello", "b": true }', + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': '{ "a": "hello" "b": true }', 'role': 'assistant', }), dict({ @@ -68,8 +98,12 @@ The above JSON object is invalid for the following reason: ''' - Validation path `c` failed for value `{"a": "hello", "b": true}` because: - Field required + + Error: expected `,` or `}` at line 1 column 16 + + Attempted to parse: + { "a": "hello" "b": true } + ''' The following is a revised JSON object: @@ -78,6 +112,40 @@ }), ]), }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello" "b": true, "c": 1234 }', + }), + ]) +# --- +# name: test_translator_with_single_failure + list([ + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + ]), + }), dict({ 'kind': 'MODEL RESPONSE', 'payload': '{ "a": "hello", "b": true }', @@ -165,23 +233,6 @@ ''', 'role': 'user', }), - dict({ - 'content': '{ "a": "hello", "b": true }', - 'role': 'assistant', - }), - dict({ - 'content': ''' - - The above JSON object is invalid for the following reason: - ''' - Validation path `c` failed for value `{"a": "hello", "b": true}` because: - Field required - ''' - The following is a revised JSON object: - - ''', - 'role': 'user', - }), ]), }), dict({ @@ -275,23 +326,6 @@ ''', 'role': 'user', }), - dict({ - 'content': '{ "a": "hello", "b": true }', - 'role': 'assistant', - }), - dict({ - 'content': ''' - - The above JSON object is invalid for the following reason: - ''' - Validation path `c` failed for value `{"a": "hello", "b": true}` because: - Field required - ''' - The following is a revised JSON object: - - ''', - 'role': 'user', - }), ]), }), dict({ From 194e3b04f395aee9e9dc76c43812a21c1d31240f Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 01:30:14 +0000 Subject: [PATCH 09/10] Flatten string, give more context. --- python/src/typechat/_internal/translator.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/src/typechat/_internal/translator.py b/python/src/typechat/_internal/translator.py index fb45ad35..86b57126 100644 --- a/python/src/typechat/_internal/translator.py +++ b/python/src/typechat/_internal/translator.py @@ -86,19 +86,14 @@ async def translate(self, input: str, *, prompt_preamble: str | list[PromptSecti try: parsed_response = pydantic_core.from_json(trimmed_response, allow_inf_nan=False, cache_strings=False) except ValueError as e: - error_message = f""" -Error: {e} - -Attempted to parse: -{trimmed_response} -""" + error_message = f"Error: {e}\n\nAttempted to parse:\n\n{trimmed_response}" else: result = self.validator.validate_object(parsed_response) if isinstance(result, Success): return result error_message = result.message else: - error_message = "Response did not contain any text resembling JSON." + error_message = f"Response did not contain any text resembling JSON.\nResponse was\n\n{text_response}" if num_repairs_attempted >= self._max_repair_attempts: return Failure(error_message) num_repairs_attempted += 1 From 0a8fdb7d829bc3d2ba61a28f132b2aee0052ed5b Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Tue, 30 Apr 2024 01:34:47 +0000 Subject: [PATCH 10/10] Update snapshots. --- python/tests/__snapshots__/test_translator.ambr | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tests/__snapshots__/test_translator.ambr b/python/tests/__snapshots__/test_translator.ambr index 634ed3d4..84e22594 100644 --- a/python/tests/__snapshots__/test_translator.ambr +++ b/python/tests/__snapshots__/test_translator.ambr @@ -98,12 +98,11 @@ The above JSON object is invalid for the following reason: ''' - Error: expected `,` or `}` at line 1 column 16 Attempted to parse: - { "a": "hello" "b": true } + { "a": "hello" "b": true } ''' The following is a revised JSON object: