From 933bc0d6ff58b2ed7f6cd4b8cce5b61065889d19 Mon Sep 17 00:00:00 2001 From: Bagatur <22008038+baskaryan@users.noreply.github.com> Date: Mon, 2 Sep 2024 14:55:26 -0700 Subject: [PATCH] core[patch]: support additional kwargs on StructuredPrompt (#25645) --- .../core/langchain_core/prompts/structured.py | 84 +++++++++---------- .../unit_tests/prompts/test_structured.py | 53 ++++++++++-- 2 files changed, 84 insertions(+), 53 deletions(-) diff --git a/libs/core/langchain_core/prompts/structured.py b/libs/core/langchain_core/prompts/structured.py index 5176b483d9..8ccb177338 100644 --- a/libs/core/langchain_core/prompts/structured.py +++ b/libs/core/langchain_core/prompts/structured.py @@ -7,7 +7,6 @@ from typing import ( Mapping, Optional, Sequence, - Set, Type, Union, ) @@ -15,20 +14,17 @@ from typing import ( from langchain_core._api.beta_decorator import beta from langchain_core.language_models.base import BaseLanguageModel from langchain_core.prompts.chat import ( - BaseChatPromptTemplate, - BaseMessagePromptTemplate, ChatPromptTemplate, MessageLikeRepresentation, - MessagesPlaceholder, - _convert_to_message, ) -from langchain_core.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables.base import ( Other, Runnable, RunnableSequence, RunnableSerializable, ) +from langchain_core.utils import get_pydantic_field_names @beta() @@ -37,6 +33,26 @@ class StructuredPrompt(ChatPromptTemplate): schema_: Union[Dict, Type[BaseModel]] """Schema for the structured prompt.""" + structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict) + + def __init__( + self, + messages: Sequence[MessageLikeRepresentation], + schema_: Optional[Union[Dict, Type[BaseModel]]] = None, + *, + structured_output_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + schema_ = schema_ or kwargs.pop("schema") + structured_output_kwargs = structured_output_kwargs or {} + for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)): + structured_output_kwargs[k] = kwargs.pop(k) + super().__init__( + messages=messages, + schema_=schema_, + structured_output_kwargs=structured_output_kwargs, + **kwargs, + ) @classmethod def get_lc_namespace(cls) -> List[str]: @@ -52,6 +68,7 @@ class StructuredPrompt(ChatPromptTemplate): cls, messages: Sequence[MessageLikeRepresentation], schema: Union[Dict, Type[BaseModel]], + **kwargs: Any, ) -> ChatPromptTemplate: """Create a chat prompt template from a variety of message formats. @@ -61,11 +78,13 @@ class StructuredPrompt(ChatPromptTemplate): .. code-block:: python + from langchain_core.prompts import StructuredPrompt + class OutputSchema(BaseModel): name: str value: int - template = ChatPromptTemplate.from_messages( + template = StructuredPrompt( [ ("human", "Hello, how are you?"), ("ai", "I'm doing well, thanks!"), @@ -82,29 +101,13 @@ class StructuredPrompt(ChatPromptTemplate): (4) 2-tuple of (message class, template), (5) a string which is shorthand for ("human", template); e.g., "{user_input}" schema: a dictionary representation of function call, or a Pydantic model. + kwargs: Any additional kwargs to pass through to + ``ChatModel.with_structured_output(schema, **kwargs)``. Returns: a structured prompt template """ - _messages = [_convert_to_message(message) for message in messages] - - # Automatically infer input variables from messages - input_vars: Set[str] = set() - partial_vars: Dict[str, Any] = {} - for _message in _messages: - if isinstance(_message, MessagesPlaceholder) and _message.optional: - partial_vars[_message.variable_name] = [] - elif isinstance( - _message, (BaseChatPromptTemplate, BaseMessagePromptTemplate) - ): - input_vars.update(_message.input_variables) - - return cls( - input_variables=sorted(input_vars), - messages=_messages, - partial_variables=partial_vars, - schema_=schema, - ) + return cls(messages, schema, **kwargs) def __or__( self, @@ -115,27 +118,16 @@ class StructuredPrompt(ChatPromptTemplate): Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], ) -> RunnableSerializable[Dict, Other]: - if isinstance(other, BaseLanguageModel) or hasattr( - other, "with_structured_output" - ): - try: - return RunnableSequence( - self, other.with_structured_output(self.schema_) - ) - except NotImplementedError as e: - raise NotImplementedError( - "Structured prompts must be piped to a language model that " - "implements with_structured_output." - ) from e - else: - raise NotImplementedError( - "Structured prompts must be piped to a language model that " - "implements with_structured_output." - ) + return self.pipe(other) def pipe( self, - *others: Union[Runnable[Any, Other], Callable[[Any], Other]], + *others: Union[ + Runnable[Any, Other], + Callable[[Any], Other], + Callable[[Iterator[Any]], Iterator[Other]], + Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], + ], name: Optional[str] = None, ) -> RunnableSerializable[Dict, Other]: """Pipe the structured prompt to a language model. @@ -158,7 +150,9 @@ class StructuredPrompt(ChatPromptTemplate): ): return RunnableSequence( self, - others[0].with_structured_output(self.schema_), + others[0].with_structured_output( + self.schema_, **self.structured_output_kwargs + ), *others[1:], name=name, ) diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index 17fd52b79f..923a69e97d 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -12,13 +12,13 @@ from langchain_core.utils.pydantic import is_basemodel_subclass def _fake_runnable( - schema: Union[Dict, Type[BaseModel]], _: Any + input: Any, *, schema: Union[Dict, Type[BaseModel]], value: Any = 42, **_: Any ) -> Union[BaseModel, Dict]: if isclass(schema) and is_basemodel_subclass(schema): - return schema(name="yo", value=42) + return schema(name="yo", value=value) else: params = cast(Dict, schema)["parameters"] - return {k: 1 for k, v in params.items()} + return {k: 1 if k != "value" else value for k, v in params.items()} class FakeStructuredChatModel(FakeListChatModel): @@ -27,7 +27,7 @@ class FakeStructuredChatModel(FakeListChatModel): def with_structured_output( self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any ) -> Runnable: - return RunnableLambda(partial(_fake_runnable, schema)) + return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs)) @property def _llm_type(self) -> str: @@ -39,7 +39,7 @@ def test_structured_prompt_pydantic() -> None: name: str value: int - prompt = StructuredPrompt.from_messages_and_schema( + prompt = StructuredPrompt( [ ("human", "I'm very structured, how about you?"), ], @@ -54,7 +54,7 @@ def test_structured_prompt_pydantic() -> None: def test_structured_prompt_dict() -> None: - prompt = StructuredPrompt.from_messages_and_schema( + prompt = StructuredPrompt( [ ("human", "I'm very structured, how about you?"), ], @@ -72,10 +72,47 @@ def test_structured_prompt_dict() -> None: chain = prompt | model - assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1} + assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} assert loads(dumps(prompt)) == prompt chain = loads(dumps(prompt)) | model - assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1} + assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} + + +def test_structured_prompt_kwargs() -> None: + prompt = StructuredPrompt( + [ + ("human", "I'm very structured, how about you?"), + ], + { + "name": "yo", + "description": "a structured output", + "parameters": { + "name": {"type": "string"}, + "value": {"type": "integer"}, + }, + }, + value=7, + ) + model = FakeStructuredChatModel(responses=[]) + chain = prompt | model + assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} + assert loads(dumps(prompt)) == prompt + chain = loads(dumps(prompt)) | model + assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} + + class OutputSchema(BaseModel): + name: str + value: int + + prompt = StructuredPrompt( + [("human", "I'm very structured, how about you?")], OutputSchema, value=7 + ) + + model = FakeStructuredChatModel(responses=[]) + + chain = prompt | model + + assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)