core[patch]: fix chat prompt partial messages placeholder var (#16918)

pull/16956/head
Bagatur 4 months ago committed by GitHub
parent 3b0fa9079d
commit c29e9b6412
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -47,9 +47,7 @@ class BasePromptTemplate(
If not provided, all variables are assumed to be strings."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
partial_variables: Mapping[str, Any] = Field(default_factory=dict)
@classmethod
def get_lc_namespace(cls) -> List[str]:
@ -143,8 +141,7 @@ class BasePromptTemplate(
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
k: v if not callable(v) else v() for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}

@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Optional,
@ -130,13 +129,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
f"variable {self.variable_name} should be a list of base messages, "
f"got {value}"
)
for v in convert_to_messages(value):
if not isinstance(v, BaseMessage):
raise ValueError(
f"variable {self.variable_name} should be a list of base messages,"
f" got {value}"
)
return value
return convert_to_messages(value)
@property
def input_variables(self) -> List[str]:
@ -755,13 +748,20 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
# Automatically infer input variables from messages
input_vars: Set[str] = set()
partial_vars: Dict[str, Any] = {}
for _message in _messages:
if isinstance(
if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
elif isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(input_variables=sorted(input_vars), messages=_messages)
return cls(
input_variables=sorted(input_vars),
messages=_messages,
partial_variables=partial_vars,
)
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.
@ -799,7 +799,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate):
raise ValueError(f"Unexpected input: {message_template}")
return result
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> ChatPromptTemplate:
def partial(self, **kwargs: Any) -> ChatPromptTemplate:
"""Get a new ChatPromptTemplate with some input variables already filled in.
Args:

@ -503,9 +503,27 @@ def test_messages_placeholder() -> None:
prompt.format_messages()
prompt = MessagesPlaceholder("history", optional=True)
assert prompt.format_messages() == []
prompt.format_messages(
assert prompt.format_messages(
history=[("system", "You are an AI assistant."), "Hello!"]
) == [
SystemMessage(content="You are an AI assistant."),
HumanMessage(content="Hello!"),
]
def test_chat_prompt_message_placeholder_partial() -> None:
prompt = ChatPromptTemplate.from_messages([MessagesPlaceholder("history")])
prompt = prompt.partial(history=[("system", "foo")])
assert prompt.format_messages() == [SystemMessage(content="foo")]
assert prompt.format_messages(history=[("system", "bar")]) == [
SystemMessage(content="bar")
]
prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder("history", optional=True),
]
)
assert prompt.format_messages() == []
prompt = prompt.partial(history=[("system", "foo")])
assert prompt.format_messages() == [SystemMessage(content="foo")]

@ -1544,7 +1544,8 @@
}
}
}
]
],
"partial_variables": {}
}
},
"middle": [],
@ -1617,7 +1618,8 @@
}
}
}
]
],
"partial_variables": {}
}
},
"middle": [],

@ -191,7 +191,8 @@
}
}
}
]
],
"partial_variables": {}
}
}
}

Loading…
Cancel
Save