From 6b98207edaefdae997dda521f8a2e2eed3646c80 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 19 Aug 2024 15:27:36 -0700 Subject: [PATCH] infra: test chat prompt ser/des (#25557) --- .../prompts/__snapshots__/test_chat.ambr | 864 ++++++++++++++++++ .../tests/unit_tests/prompts/test_chat.py | 59 +- 2 files changed, 918 insertions(+), 5 deletions(-) diff --git a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr index 28ef4b9f0f..901b35e75c 100644 --- a/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr +++ b/libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr @@ -1451,3 +1451,867 @@ 'type': 'constructor', }) # --- +# name: test_chat_tmpl_serdes + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'ChatPromptTemplate', + ]), + 'name': 'ChatPromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ChatPromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'ChatPromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'foo', + 'more_history', + 'my_image', + 'my_other_image', + 'name', + ]), + 'messages': list([ + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'SystemMessagePromptTemplate', + ]), + 'kwargs': dict({ + 'prompt': dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'name': 'PromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'PromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'name', + ]), + 'template': 'You are an AI assistant named {name}.', + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'PromptTemplate', + 'type': 'constructor', + }), + }), + 'lc': 1, + 'type': 'constructor', + }), + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'SystemMessagePromptTemplate', + ]), + 'kwargs': dict({ + 'prompt': list([ + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'name': 'PromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'PromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'name', + ]), + 'template': 'You are an AI assistant named {name}.', + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'PromptTemplate', + 'type': 'constructor', + }), + ]), + }), + 'lc': 1, + 'type': 'constructor', + }), + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'SystemMessagePromptTemplate', + ]), + 'kwargs': dict({ + 'prompt': dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'name': 'PromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'PromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'foo', + ]), + 'template': 'you are {foo}', + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'PromptTemplate', + 'type': 'constructor', + }), + }), + 'lc': 1, + 'type': 'constructor', + }), + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'HumanMessagePromptTemplate', + ]), + 'kwargs': dict({ + 'prompt': list([ + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'name': 'PromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'PromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + ]), + 'template': 'hello', + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'PromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'name': 'PromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'PromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + ]), + 'template': "What's in this image?", + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'PromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'name': 'PromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'PromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'prompt', + 'PromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + ]), + 'template': "What's in this image?", + 'template_format': 'f-string', + }), + 'lc': 1, + 'name': 'PromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'name': 'ImagePromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ImagePromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'my_image', + ]), + 'template': dict({ + 'url': 'data:image/jpeg;base64,{my_image}', + }), + }), + 'lc': 1, + 'name': 'ImagePromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'name': 'ImagePromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ImagePromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'my_image', + ]), + 'template': dict({ + 'url': 'data:image/jpeg;base64,{my_image}', + }), + }), + 'lc': 1, + 'name': 'ImagePromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'name': 'ImagePromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ImagePromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'my_other_image', + ]), + 'template': dict({ + 'url': '{my_other_image}', + }), + }), + 'lc': 1, + 'name': 'ImagePromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'name': 'ImagePromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ImagePromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + 'my_other_image', + ]), + 'template': dict({ + 'detail': 'medium', + 'url': '{my_other_image}', + }), + }), + 'lc': 1, + 'name': 'ImagePromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'name': 'ImagePromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ImagePromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + ]), + 'template': dict({ + 'url': 'https://www.langchain.com/image.png', + }), + }), + 'lc': 1, + 'name': 'ImagePromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'name': 'ImagePromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ImagePromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + ]), + 'template': dict({ + 'url': '', + }), + }), + 'lc': 1, + 'name': 'ImagePromptTemplate', + 'type': 'constructor', + }), + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'PromptInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'name': 'ImagePromptTemplate', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ImagePromptTemplateOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'prompts', + 'image', + 'ImagePromptTemplate', + ]), + 'kwargs': dict({ + 'input_variables': list([ + ]), + 'template': dict({ + 'url': '', + }), + }), + 'lc': 1, + 'name': 'ImagePromptTemplate', + 'type': 'constructor', + }), + ]), + }), + 'lc': 1, + 'type': 'constructor', + }), + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'MessagesPlaceholder', + ]), + 'kwargs': dict({ + 'optional': True, + 'variable_name': 'chat_history', + }), + 'lc': 1, + 'type': 'constructor', + }), + dict({ + 'id': list([ + 'langchain', + 'prompts', + 'chat', + 'MessagesPlaceholder', + ]), + 'kwargs': dict({ + 'variable_name': 'more_history', + }), + 'lc': 1, + 'type': 'constructor', + }), + ]), + 'optional_variables': list([ + 'chat_history', + ]), + 'partial_variables': dict({ + 'chat_history': list([ + ]), + }), + }), + 'lc': 1, + 'name': 'ChatPromptTemplate', + 'type': 'constructor', + }) +# --- diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index d6e3837883..a280ff1dbb 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,7 +1,7 @@ import base64 import tempfile from pathlib import Path -from typing import Any, List, Union +from typing import Any, List, Tuple, Union, cast import pytest from syrupy import SnapshotAssertion @@ -565,7 +565,7 @@ async def test_chat_tmpl_from_messages_multipart_text_with_template() -> None: async def test_chat_tmpl_from_messages_multipart_image() -> None: """Test multipart image URL formatting.""" base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" - other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" + other_base64_image = "other_iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA" template = ChatPromptTemplate.from_messages( [ ("system", "You are an AI assistant named {name}."), @@ -609,9 +609,7 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None: }, { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{other_base64_image}" - }, + "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}, }, { "type": "image_url", @@ -814,3 +812,54 @@ def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar") assert dumpd(prompt) == snapshot(name="chat_prompt") assert load(dumpd(prompt)) == prompt + + +async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None: + """Test chat prompt template ser/des.""" + template = ChatPromptTemplate( + [ + ("system", "You are an AI assistant named {name}."), + ("system", [{"text": "You are an AI assistant named {name}."}]), + SystemMessagePromptTemplate.from_template("you are {foo}"), + cast( + Tuple, + ( + "human", + [ + "hello", + {"text": "What's in this image?"}, + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "data:image/jpeg;base64,{my_image}", + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{my_image}"}, + }, + {"type": "image_url", "image_url": "{my_other_image}"}, + { + "type": "image_url", + "image_url": { + "url": "{my_other_image}", + "detail": "medium", + }, + }, + { + "type": "image_url", + "image_url": {"url": "https://www.langchain.com/image.png"}, + }, + { + "type": "image_url", + "image_url": {"url": ""}, + }, + {"image_url": {"url": ""}}, + ], + ), + ), + ("placeholder", "{chat_history}"), + MessagesPlaceholder("more_history", optional=False), + ] + ) + assert dumpd(template) == snapshot() + assert load(dumpd(template)) == template