core[patch]: support additional kwargs on StructuredPrompt (#25645)

This commit is contained in:
Bagatur 2024-09-02 14:55:26 -07:00 committed by GitHub
parent 51dae57357
commit 933bc0d6ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 84 additions and 53 deletions

View File

@ -7,7 +7,6 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Set,
Type, Type,
Union, Union,
) )
@ -15,20 +14,17 @@ from typing import (
from langchain_core._api.beta_decorator import beta from langchain_core._api.beta_decorator import beta
from langchain_core.language_models.base import BaseLanguageModel from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.prompts.chat import ( from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate, ChatPromptTemplate,
MessageLikeRepresentation, MessageLikeRepresentation,
MessagesPlaceholder,
_convert_to_message,
) )
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.base import ( from langchain_core.runnables.base import (
Other, Other,
Runnable, Runnable,
RunnableSequence, RunnableSequence,
RunnableSerializable, RunnableSerializable,
) )
from langchain_core.utils import get_pydantic_field_names
@beta() @beta()
@ -37,6 +33,26 @@ class StructuredPrompt(ChatPromptTemplate):
schema_: Union[Dict, Type[BaseModel]] schema_: Union[Dict, Type[BaseModel]]
"""Schema for the structured prompt.""" """Schema for the structured prompt."""
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
def __init__(
self,
messages: Sequence[MessageLikeRepresentation],
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
structured_output_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None:
schema_ = schema_ or kwargs.pop("schema")
structured_output_kwargs = structured_output_kwargs or {}
for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
structured_output_kwargs[k] = kwargs.pop(k)
super().__init__(
messages=messages,
schema_=schema_,
structured_output_kwargs=structured_output_kwargs,
**kwargs,
)
@classmethod @classmethod
def get_lc_namespace(cls) -> List[str]: def get_lc_namespace(cls) -> List[str]:
@ -52,6 +68,7 @@ class StructuredPrompt(ChatPromptTemplate):
cls, cls,
messages: Sequence[MessageLikeRepresentation], messages: Sequence[MessageLikeRepresentation],
schema: Union[Dict, Type[BaseModel]], schema: Union[Dict, Type[BaseModel]],
**kwargs: Any,
) -> ChatPromptTemplate: ) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats. """Create a chat prompt template from a variety of message formats.
@ -61,11 +78,13 @@ class StructuredPrompt(ChatPromptTemplate):
.. code-block:: python .. code-block:: python
from langchain_core.prompts import StructuredPrompt
class OutputSchema(BaseModel): class OutputSchema(BaseModel):
name: str name: str
value: int value: int
template = ChatPromptTemplate.from_messages( template = StructuredPrompt(
[ [
("human", "Hello, how are you?"), ("human", "Hello, how are you?"),
("ai", "I'm doing well, thanks!"), ("ai", "I'm doing well, thanks!"),
@ -82,29 +101,13 @@ class StructuredPrompt(ChatPromptTemplate):
(4) 2-tuple of (message class, template), (5) a string which is (4) 2-tuple of (message class, template), (5) a string which is
shorthand for ("human", template); e.g., "{user_input}" shorthand for ("human", template); e.g., "{user_input}"
schema: a dictionary representation of function call, or a Pydantic model. schema: a dictionary representation of function call, or a Pydantic model.
kwargs: Any additional kwargs to pass through to
``ChatModel.with_structured_output(schema, **kwargs)``.
Returns: Returns:
a structured prompt template a structured prompt template
""" """
_messages = [_convert_to_message(message) for message in messages] return cls(messages, schema, **kwargs)
# Automatically infer input variables from messages
input_vars: Set[str] = set()
partial_vars: Dict[str, Any] = {}
for _message in _messages:
if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
elif isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)
return cls(
input_variables=sorted(input_vars),
messages=_messages,
partial_variables=partial_vars,
schema_=schema,
)
def __or__( def __or__(
self, self,
@ -115,27 +118,16 @@ class StructuredPrompt(ChatPromptTemplate):
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
], ],
) -> RunnableSerializable[Dict, Other]: ) -> RunnableSerializable[Dict, Other]:
if isinstance(other, BaseLanguageModel) or hasattr( return self.pipe(other)
other, "with_structured_output"
):
try:
return RunnableSequence(
self, other.with_structured_output(self.schema_)
)
except NotImplementedError as e:
raise NotImplementedError(
"Structured prompts must be piped to a language model that "
"implements with_structured_output."
) from e
else:
raise NotImplementedError(
"Structured prompts must be piped to a language model that "
"implements with_structured_output."
)
def pipe( def pipe(
self, self,
*others: Union[Runnable[Any, Other], Callable[[Any], Other]], *others: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
name: Optional[str] = None, name: Optional[str] = None,
) -> RunnableSerializable[Dict, Other]: ) -> RunnableSerializable[Dict, Other]:
"""Pipe the structured prompt to a language model. """Pipe the structured prompt to a language model.
@ -158,7 +150,9 @@ class StructuredPrompt(ChatPromptTemplate):
): ):
return RunnableSequence( return RunnableSequence(
self, self,
others[0].with_structured_output(self.schema_), others[0].with_structured_output(
self.schema_, **self.structured_output_kwargs
),
*others[1:], *others[1:],
name=name, name=name,
) )

View File

@ -12,13 +12,13 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable( def _fake_runnable(
schema: Union[Dict, Type[BaseModel]], _: Any input: Any, *, schema: Union[Dict, Type[BaseModel]], value: Any = 42, **_: Any
) -> Union[BaseModel, Dict]: ) -> Union[BaseModel, Dict]:
if isclass(schema) and is_basemodel_subclass(schema): if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=42) return schema(name="yo", value=value)
else: else:
params = cast(Dict, schema)["parameters"] params = cast(Dict, schema)["parameters"]
return {k: 1 for k, v in params.items()} return {k: 1 if k != "value" else value for k, v in params.items()}
class FakeStructuredChatModel(FakeListChatModel): class FakeStructuredChatModel(FakeListChatModel):
@ -27,7 +27,7 @@ class FakeStructuredChatModel(FakeListChatModel):
def with_structured_output( def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable: ) -> Runnable:
return RunnableLambda(partial(_fake_runnable, schema)) return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs))
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
@ -39,7 +39,7 @@ def test_structured_prompt_pydantic() -> None:
name: str name: str
value: int value: int
prompt = StructuredPrompt.from_messages_and_schema( prompt = StructuredPrompt(
[ [
("human", "I'm very structured, how about you?"), ("human", "I'm very structured, how about you?"),
], ],
@ -54,7 +54,7 @@ def test_structured_prompt_pydantic() -> None:
def test_structured_prompt_dict() -> None: def test_structured_prompt_dict() -> None:
prompt = StructuredPrompt.from_messages_and_schema( prompt = StructuredPrompt(
[ [
("human", "I'm very structured, how about you?"), ("human", "I'm very structured, how about you?"),
], ],
@ -72,10 +72,47 @@ def test_structured_prompt_dict() -> None:
chain = prompt | model chain = prompt | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1} assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
assert loads(dumps(prompt)) == prompt assert loads(dumps(prompt)) == prompt
chain = loads(dumps(prompt)) | model chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1} assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
def test_structured_prompt_kwargs() -> None:
prompt = StructuredPrompt(
[
("human", "I'm very structured, how about you?"),
],
{
"name": "yo",
"description": "a structured output",
"parameters": {
"name": {"type": "string"},
"value": {"type": "integer"},
},
},
value=7,
)
model = FakeStructuredChatModel(responses=[])
chain = prompt | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
assert loads(dumps(prompt)) == prompt
chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
class OutputSchema(BaseModel):
name: str
value: int
prompt = StructuredPrompt(
[("human", "I'm very structured, how about you?")], OutputSchema, value=7
)
model = FakeStructuredChatModel(responses=[])
chain = prompt | model
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)