core[patch]: don't serialize BasePromptTemplate.input_types (#24516)

Candidate fix for #24513
This commit is contained in:
Bagatur 2024-07-22 13:30:16 -07:00 committed by GitHub
parent df357f82ca
commit 8a140ee77c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 230 additions and 4 deletions

View File

@ -47,7 +47,7 @@ class BasePromptTemplate(
prompt."""
optional_variables: List[str] = Field(default=[])
"""A list of the names of the variables that are optional in the prompt."""
input_types: Dict[str, Any] = Field(default_factory=dict)
input_types: Dict[str, Any] = Field(default_factory=dict, exclude=True)
"""A dictionary of the types of the variables the prompt template expects.
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None

View File

@ -1216,3 +1216,220 @@
'type': 'object',
})
# ---
# name: test_chat_prompt_w_msgs_placeholder_ser_des[chat_prompt]
dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'chat',
'ChatPromptTemplate',
]),
'name': 'ChatPromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'ChatPromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'chat',
'ChatPromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
'bar',
]),
'messages': list([
dict({
'id': list([
'langchain',
'prompts',
'chat',
'SystemMessagePromptTemplate',
]),
'kwargs': dict({
'prompt': dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': 'foo',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'MessagesPlaceholder',
]),
'kwargs': dict({
'variable_name': 'bar',
}),
'lc': 1,
'type': 'constructor',
}),
dict({
'id': list([
'langchain',
'prompts',
'chat',
'HumanMessagePromptTemplate',
]),
'kwargs': dict({
'prompt': dict({
'graph': dict({
'edges': list([
dict({
'source': 0,
'target': 1,
}),
dict({
'source': 1,
'target': 2,
}),
]),
'nodes': list([
dict({
'data': 'PromptInput',
'id': 0,
'type': 'schema',
}),
dict({
'data': dict({
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'name': 'PromptTemplate',
}),
'id': 1,
'type': 'runnable',
}),
dict({
'data': 'PromptTemplateOutput',
'id': 2,
'type': 'schema',
}),
]),
}),
'id': list([
'langchain',
'prompts',
'prompt',
'PromptTemplate',
]),
'kwargs': dict({
'input_variables': list([
]),
'template': 'baz',
'template_format': 'f-string',
}),
'lc': 1,
'name': 'PromptTemplate',
'type': 'constructor',
}),
}),
'lc': 1,
'type': 'constructor',
}),
]),
}),
'lc': 1,
'name': 'ChatPromptTemplate',
'type': 'constructor',
})
# ---
# name: test_chat_prompt_w_msgs_placeholder_ser_des[placholder]
dict({
'id': list([
'langchain',
'prompts',
'chat',
'MessagesPlaceholder',
]),
'kwargs': dict({
'variable_name': 'bar',
}),
'lc': 1,
'type': 'constructor',
})
# ---

View File

@ -6,9 +6,8 @@ from typing import Any, List, Union
import pytest
from syrupy import SnapshotAssertion
from langchain_core._api.deprecation import (
LangChainPendingDeprecationWarning,
)
from langchain_core._api.deprecation import LangChainPendingDeprecationWarning
from langchain_core.load import dumpd, load
from langchain_core.messages import (
AIMessage,
BaseMessage,
@ -806,3 +805,13 @@ def test_chat_input_schema(snapshot: SnapshotAssertion) -> None:
assert set(prompt_optional.input_variables) == {"input"}
prompt_optional.input_schema(input="") # won't raise error
assert prompt_optional.input_schema.schema() == snapshot(name="partial")
def test_chat_prompt_w_msgs_placeholder_ser_des(snapshot: SnapshotAssertion) -> None:
prompt = ChatPromptTemplate.from_messages(
[("system", "foo"), MessagesPlaceholder("bar"), ("human", "baz")]
)
assert dumpd(MessagesPlaceholder("bar")) == snapshot(name="placholder")
assert load(dumpd(MessagesPlaceholder("bar"))) == MessagesPlaceholder("bar")
assert dumpd(prompt) == snapshot(name="chat_prompt")
assert load(dumpd(prompt)) == prompt