core[patch]: support additional kwargs on StructuredPrompt (#25645)

pull/22120/head^2
Bagatur 2 weeks ago committed by GitHub
parent 51dae57357
commit 933bc0d6ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -7,7 +7,6 @@ from typing import (
Mapping,
Optional,
Sequence,
Set,
Type,
Union,
)
@ -15,20 +14,17 @@ from typing import (
from langchain_core._api.beta_decorator import beta
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
MessageLikeRepresentation,
MessagesPlaceholder,
_convert_to_message,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.base import (
Other,
Runnable,
RunnableSequence,
RunnableSerializable,
)
from langchain_core.utils import get_pydantic_field_names
@beta()
@ -37,6 +33,26 @@ class StructuredPrompt(ChatPromptTemplate):
schema_: Union[Dict, Type[BaseModel]]
"""Schema for the structured prompt."""
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
def __init__(
self,
messages: Sequence[MessageLikeRepresentation],
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
structured_output_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
schema_ = schema_ or kwargs.pop("schema")
structured_output_kwargs = structured_output_kwargs or {}
for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
structured_output_kwargs[k] = kwargs.pop(k)
super().__init__(
messages=messages,
schema_=schema_,
structured_output_kwargs=structured_output_kwargs,
**kwargs,
)
@classmethod
def get_lc_namespace(cls) -> List[str]:
@ -52,6 +68,7 @@ class StructuredPrompt(ChatPromptTemplate):
cls,
messages: Sequence[MessageLikeRepresentation],
schema: Union[Dict, Type[BaseModel]],
**kwargs: Any,
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
@ -61,11 +78,13 @@ class StructuredPrompt(ChatPromptTemplate):
.. code-block:: python
from langchain_core.prompts import StructuredPrompt
class OutputSchema(BaseModel):
name: str
value: int
template = ChatPromptTemplate.from_messages(
template = StructuredPrompt(
[
("human", "Hello, how are you?"),
("ai", "I'm doing well, thanks!"),
@ -82,29 +101,13 @@ class StructuredPrompt(ChatPromptTemplate):
(4) 2-tuple of (message class, template), (5) a string which is
shorthand for ("human", template); e.g., "{user_input}"
schema: a dictionary representation of function call, or a Pydantic model.
kwargs: Any additional kwargs to pass through to
``ChatModel.with_structured_output(schema, **kwargs)``.
Returns:
a structured prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
# Automatically infer input variables from messages
input_vars: Set[str] = set()
partial_vars: Dict[str, Any] = {}
for _message in _messages:
if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
elif isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(
input_variables=sorted(input_vars),
messages=_messages,
partial_variables=partial_vars,
schema_=schema,
)
return cls(messages, schema, **kwargs)
def __or__(
self,
@ -115,27 +118,16 @@ class StructuredPrompt(ChatPromptTemplate):
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSerializable[Dict, Other]:
if isinstance(other, BaseLanguageModel) or hasattr(
other, "with_structured_output"
):
try:
return RunnableSequence(
self, other.with_structured_output(self.schema_)
)
except NotImplementedError as e:
raise NotImplementedError(
"Structured prompts must be piped to a language model that "
"implements with_structured_output."
) from e
else:
raise NotImplementedError(
"Structured prompts must be piped to a language model that "
"implements with_structured_output."
)
return self.pipe(other)
def pipe(
self,
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
*others: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
name: Optional[str] = None,
) -> RunnableSerializable[Dict, Other]:
"""Pipe the structured prompt to a language model.
@ -158,7 +150,9 @@ class StructuredPrompt(ChatPromptTemplate):
):
return RunnableSequence(
self,
others[0].with_structured_output(self.schema_),
others[0].with_structured_output(
self.schema_, **self.structured_output_kwargs
),
*others[1:],
name=name,
)

@ -12,13 +12,13 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable(
schema: Union[Dict, Type[BaseModel]], _: Any
input: Any, *, schema: Union[Dict, Type[BaseModel]], value: Any = 42, **_: Any
) -> Union[BaseModel, Dict]:
if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=42)
return schema(name="yo", value=value)
else:
params = cast(Dict, schema)["parameters"]
return {k: 1 for k, v in params.items()}
return {k: 1 if k != "value" else value for k, v in params.items()}
class FakeStructuredChatModel(FakeListChatModel):
@ -27,7 +27,7 @@ class FakeStructuredChatModel(FakeListChatModel):
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable:
return RunnableLambda(partial(_fake_runnable, schema))
return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs))
@property
def _llm_type(self) -> str:
@ -39,7 +39,7 @@ def test_structured_prompt_pydantic() -> None:
name: str
value: int
prompt = StructuredPrompt.from_messages_and_schema(
prompt = StructuredPrompt(
[
("human", "I'm very structured, how about you?"),
],
@ -54,7 +54,7 @@ def test_structured_prompt_pydantic() -> None:
def test_structured_prompt_dict() -> None:
prompt = StructuredPrompt.from_messages_and_schema(
prompt = StructuredPrompt(
[
("human", "I'm very structured, how about you?"),
],
@ -72,10 +72,47 @@ def test_structured_prompt_dict() -> None:
chain = prompt | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1}
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
assert loads(dumps(prompt)) == prompt
chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1}
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
def test_structured_prompt_kwargs() -> None:
prompt = StructuredPrompt(
[
("human", "I'm very structured, how about you?"),
],
{
"name": "yo",
"description": "a structured output",
"parameters": {
"name": {"type": "string"},
"value": {"type": "integer"},
},
},
value=7,
)
model = FakeStructuredChatModel(responses=[])
chain = prompt | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
assert loads(dumps(prompt)) == prompt
chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
class OutputSchema(BaseModel):
name: str
value: int
prompt = StructuredPrompt(
[("human", "I'm very structured, how about you?")], OutputSchema, value=7
)
model = FakeStructuredChatModel(responses=[])
chain = prompt | model
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)

Loading…
Cancel
Save