infra: test chat prompt ser/des (#25557)

This commit is contained in:
Bagatur 2024-08-19 15:27:36 -07:00 committed by GitHub
parent c5bf114c0f
commit 6b98207eda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 918 additions and 5 deletions

View File

@ -1451,3 +1451,867 @@
'type': 'constructor',
})
# ---
# name: test_chat_tmpl_serdes
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'chat',
'ChatPromptTemplate',
]),
'name': 'ChatPromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ChatPromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'chat',
'ChatPromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'foo',
'more_history',
'my_image',
'my_other_image',
'name',
]),
'messages': list([
dict({
'id': list([
'langchain',
'prompts',
'chat',
'SystemMessagePromptTemplate',
]),
'kwargs': dict({
'prompt': dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'name',
]),
'template': 'You are an AI assistant named {name}.',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'SystemMessagePromptTemplate',
]),
'kwargs': dict({
'prompt': list([
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'name',
]),
'template': 'You are an AI assistant named {name}.',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
]),
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'SystemMessagePromptTemplate',
]),
'kwargs': dict({
'prompt': dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'foo',
]),
'template': 'you are {foo}',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'HumanMessagePromptTemplate',
]),
'kwargs': dict({
'prompt': list([
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': 'hello',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': "What's in this image?",
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': "What's in this image?",
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'name': 'ImagePromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ImagePromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'my_image',
]),
'template': dict({
'url': 'data:image/jpeg;base64,{my_image}',
}),
}),
'lc': 1,
'name': 'ImagePromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'name': 'ImagePromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ImagePromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'my_image',
]),
'template': dict({
'url': 'data:image/jpeg;base64,{my_image}',
}),
}),
'lc': 1,
'name': 'ImagePromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'name': 'ImagePromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ImagePromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'my_other_image',
]),
'template': dict({
'url': '{my_other_image}',
}),
}),
'lc': 1,
'name': 'ImagePromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'name': 'ImagePromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ImagePromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'my_other_image',
]),
'template': dict({
'detail': 'medium',
'url': '{my_other_image}',
}),
}),
'lc': 1,
'name': 'ImagePromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'name': 'ImagePromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ImagePromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': dict({
'url': 'https://www.langchain.com/image.png',
}),
}),
'lc': 1,
'name': 'ImagePromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'name': 'ImagePromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ImagePromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': dict({
'url': '',
}),
}),
'lc': 1,
'name': 'ImagePromptTemplate',
'type': 'constructor',
}),
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'name': 'ImagePromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ImagePromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'image',
'ImagePromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': dict({
'url': '',
}),
}),
'lc': 1,
'name': 'ImagePromptTemplate',
'type': 'constructor',
}),
]),
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'MessagesPlaceholder',
]),
'kwargs': dict({
'optional': True,
'variable_name': 'chat_history',
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'MessagesPlaceholder',
]),
'kwargs': dict({
'variable_name': 'more_history',
}),
'lc': 1,
'type': 'constructor',
}),
]),
'optional_variables': list([
'chat_history',
]),
'partial_variables': dict({
'chat_history': list([
]),
}),
}),
'lc': 1,
'name': 'ChatPromptTemplate',
'type': 'constructor',
})
# ---

View File

@ -1,7 +1,7 @@
import base64
import tempfile
from pathlib import Path
from typing import Any, List, Union
from typing import Any, List, Tuple, Union, cast
import pytest
from syrupy import SnapshotAssertion
@ -565,7 +565,7 @@ async def test_chat_tmpl_from_messages_multipart_text_with_template() -> None:
async def test_chat_tmpl_from_messages_multipart_image() -> None:
"""Test multipart image URL formatting."""
base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
other_base64_image = "iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
other_base64_image = "other_iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAA"
template = ChatPromptTemplate.from_messages(
[
("system", "You are an AI assistant named {name}."),
@ -609,9 +609,7 @@ async def test_chat_tmpl_from_messages_multipart_image() -> None:
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{other_base64_image}"
},
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"},
},
{
"type": "image_url",
@ -814,3 +812,54 @@ def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) ->
assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar")
assert dumpd(prompt) == snapshot(name="chat_prompt")
assert load(dumpd(prompt)) == prompt
async def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
"""Test chat prompt template ser/des."""
template = ChatPromptTemplate(
[
("system", "You are an AI assistant named {name}."),
("system", [{"text": "You are an AI assistant named {name}."}]),
SystemMessagePromptTemplate.from_template("you are {foo}"),
cast(
Tuple,
(
"human",
[
"hello",
{"text": "What's in this image?"},
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": "data:image/jpeg;base64,{my_image}",
},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{my_image}"},
},
{"type": "image_url", "image_url": "{my_other_image}"},
{
"type": "image_url",
"image_url": {
"url": "{my_other_image}",
"detail": "medium",
},
},
{
"type": "image_url",
"image_url": {"url": "https://www.langchain.com/image.png"},
},
{
"type": "image_url",
"image_url": {"url": ""},
},
{"image_url": {"url": ""}},
],
),
),
("placeholder", "{chat_history}"),
MessagesPlaceholder("more_history", optional=False),
]
)
assert dumpd(template) == snapshot()
assert load(dumpd(template)) == template