mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[patch]: support additional kwargs on StructuredPrompt (#25645)
This commit is contained in:
parent
51dae57357
commit
933bc0d6ff
@ -7,7 +7,6 @@ from typing import (
|
|||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -15,20 +14,17 @@ from typing import (
|
|||||||
from langchain_core._api.beta_decorator import beta
|
from langchain_core._api.beta_decorator import beta
|
||||||
from langchain_core.language_models.base import BaseLanguageModel
|
from langchain_core.language_models.base import BaseLanguageModel
|
||||||
from langchain_core.prompts.chat import (
|
from langchain_core.prompts.chat import (
|
||||||
BaseChatPromptTemplate,
|
|
||||||
BaseMessagePromptTemplate,
|
|
||||||
ChatPromptTemplate,
|
ChatPromptTemplate,
|
||||||
MessageLikeRepresentation,
|
MessageLikeRepresentation,
|
||||||
MessagesPlaceholder,
|
|
||||||
_convert_to_message,
|
|
||||||
)
|
)
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||||
from langchain_core.runnables.base import (
|
from langchain_core.runnables.base import (
|
||||||
Other,
|
Other,
|
||||||
Runnable,
|
Runnable,
|
||||||
RunnableSequence,
|
RunnableSequence,
|
||||||
RunnableSerializable,
|
RunnableSerializable,
|
||||||
)
|
)
|
||||||
|
from langchain_core.utils import get_pydantic_field_names
|
||||||
|
|
||||||
|
|
||||||
@beta()
|
@beta()
|
||||||
@ -37,6 +33,26 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
|
|
||||||
schema_: Union[Dict, Type[BaseModel]]
|
schema_: Union[Dict, Type[BaseModel]]
|
||||||
"""Schema for the structured prompt."""
|
"""Schema for the structured prompt."""
|
||||||
|
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
messages: Sequence[MessageLikeRepresentation],
|
||||||
|
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||||
|
*,
|
||||||
|
structured_output_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
schema_ = schema_ or kwargs.pop("schema")
|
||||||
|
structured_output_kwargs = structured_output_kwargs or {}
|
||||||
|
for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
|
||||||
|
structured_output_kwargs[k] = kwargs.pop(k)
|
||||||
|
super().__init__(
|
||||||
|
messages=messages,
|
||||||
|
schema_=schema_,
|
||||||
|
structured_output_kwargs=structured_output_kwargs,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_lc_namespace(cls) -> List[str]:
|
def get_lc_namespace(cls) -> List[str]:
|
||||||
@ -52,6 +68,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
cls,
|
cls,
|
||||||
messages: Sequence[MessageLikeRepresentation],
|
messages: Sequence[MessageLikeRepresentation],
|
||||||
schema: Union[Dict, Type[BaseModel]],
|
schema: Union[Dict, Type[BaseModel]],
|
||||||
|
**kwargs: Any,
|
||||||
) -> ChatPromptTemplate:
|
) -> ChatPromptTemplate:
|
||||||
"""Create a chat prompt template from a variety of message formats.
|
"""Create a chat prompt template from a variety of message formats.
|
||||||
|
|
||||||
@ -61,11 +78,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_core.prompts import StructuredPrompt
|
||||||
|
|
||||||
class OutputSchema(BaseModel):
|
class OutputSchema(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
value: int
|
value: int
|
||||||
|
|
||||||
template = ChatPromptTemplate.from_messages(
|
template = StructuredPrompt(
|
||||||
[
|
[
|
||||||
("human", "Hello, how are you?"),
|
("human", "Hello, how are you?"),
|
||||||
("ai", "I'm doing well, thanks!"),
|
("ai", "I'm doing well, thanks!"),
|
||||||
@ -82,29 +101,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
(4) 2-tuple of (message class, template), (5) a string which is
|
(4) 2-tuple of (message class, template), (5) a string which is
|
||||||
shorthand for ("human", template); e.g., "{user_input}"
|
shorthand for ("human", template); e.g., "{user_input}"
|
||||||
schema: a dictionary representation of function call, or a Pydantic model.
|
schema: a dictionary representation of function call, or a Pydantic model.
|
||||||
|
kwargs: Any additional kwargs to pass through to
|
||||||
|
``ChatModel.with_structured_output(schema, **kwargs)``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a structured prompt template
|
a structured prompt template
|
||||||
"""
|
"""
|
||||||
_messages = [_convert_to_message(message) for message in messages]
|
return cls(messages, schema, **kwargs)
|
||||||
|
|
||||||
# Automatically infer input variables from messages
|
|
||||||
input_vars: Set[str] = set()
|
|
||||||
partial_vars: Dict[str, Any] = {}
|
|
||||||
for _message in _messages:
|
|
||||||
if isinstance(_message, MessagesPlaceholder) and _message.optional:
|
|
||||||
partial_vars[_message.variable_name] = []
|
|
||||||
elif isinstance(
|
|
||||||
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
|
||||||
):
|
|
||||||
input_vars.update(_message.input_variables)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
input_variables=sorted(input_vars),
|
|
||||||
messages=_messages,
|
|
||||||
partial_variables=partial_vars,
|
|
||||||
schema_=schema,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __or__(
|
def __or__(
|
||||||
self,
|
self,
|
||||||
@ -115,27 +118,16 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||||
],
|
],
|
||||||
) -> RunnableSerializable[Dict, Other]:
|
) -> RunnableSerializable[Dict, Other]:
|
||||||
if isinstance(other, BaseLanguageModel) or hasattr(
|
return self.pipe(other)
|
||||||
other, "with_structured_output"
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
return RunnableSequence(
|
|
||||||
self, other.with_structured_output(self.schema_)
|
|
||||||
)
|
|
||||||
except NotImplementedError as e:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Structured prompts must be piped to a language model that "
|
|
||||||
"implements with_structured_output."
|
|
||||||
) from e
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Structured prompts must be piped to a language model that "
|
|
||||||
"implements with_structured_output."
|
|
||||||
)
|
|
||||||
|
|
||||||
def pipe(
|
def pipe(
|
||||||
self,
|
self,
|
||||||
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
|
*others: Union[
|
||||||
|
Runnable[Any, Other],
|
||||||
|
Callable[[Any], Other],
|
||||||
|
Callable[[Iterator[Any]], Iterator[Other]],
|
||||||
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||||
|
],
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
) -> RunnableSerializable[Dict, Other]:
|
) -> RunnableSerializable[Dict, Other]:
|
||||||
"""Pipe the structured prompt to a language model.
|
"""Pipe the structured prompt to a language model.
|
||||||
@ -158,7 +150,9 @@ class StructuredPrompt(ChatPromptTemplate):
|
|||||||
):
|
):
|
||||||
return RunnableSequence(
|
return RunnableSequence(
|
||||||
self,
|
self,
|
||||||
others[0].with_structured_output(self.schema_),
|
others[0].with_structured_output(
|
||||||
|
self.schema_, **self.structured_output_kwargs
|
||||||
|
),
|
||||||
*others[1:],
|
*others[1:],
|
||||||
name=name,
|
name=name,
|
||||||
)
|
)
|
||||||
|
@ -12,13 +12,13 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
|||||||
|
|
||||||
|
|
||||||
def _fake_runnable(
|
def _fake_runnable(
|
||||||
schema: Union[Dict, Type[BaseModel]], _: Any
|
input: Any, *, schema: Union[Dict, Type[BaseModel]], value: Any = 42, **_: Any
|
||||||
) -> Union[BaseModel, Dict]:
|
) -> Union[BaseModel, Dict]:
|
||||||
if isclass(schema) and is_basemodel_subclass(schema):
|
if isclass(schema) and is_basemodel_subclass(schema):
|
||||||
return schema(name="yo", value=42)
|
return schema(name="yo", value=value)
|
||||||
else:
|
else:
|
||||||
params = cast(Dict, schema)["parameters"]
|
params = cast(Dict, schema)["parameters"]
|
||||||
return {k: 1 for k, v in params.items()}
|
return {k: 1 if k != "value" else value for k, v in params.items()}
|
||||||
|
|
||||||
|
|
||||||
class FakeStructuredChatModel(FakeListChatModel):
|
class FakeStructuredChatModel(FakeListChatModel):
|
||||||
@ -27,7 +27,7 @@ class FakeStructuredChatModel(FakeListChatModel):
|
|||||||
def with_structured_output(
|
def with_structured_output(
|
||||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||||
) -> Runnable:
|
) -> Runnable:
|
||||||
return RunnableLambda(partial(_fake_runnable, schema))
|
return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
@ -39,7 +39,7 @@ def test_structured_prompt_pydantic() -> None:
|
|||||||
name: str
|
name: str
|
||||||
value: int
|
value: int
|
||||||
|
|
||||||
prompt = StructuredPrompt.from_messages_and_schema(
|
prompt = StructuredPrompt(
|
||||||
[
|
[
|
||||||
("human", "I'm very structured, how about you?"),
|
("human", "I'm very structured, how about you?"),
|
||||||
],
|
],
|
||||||
@ -54,7 +54,7 @@ def test_structured_prompt_pydantic() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_structured_prompt_dict() -> None:
|
def test_structured_prompt_dict() -> None:
|
||||||
prompt = StructuredPrompt.from_messages_and_schema(
|
prompt = StructuredPrompt(
|
||||||
[
|
[
|
||||||
("human", "I'm very structured, how about you?"),
|
("human", "I'm very structured, how about you?"),
|
||||||
],
|
],
|
||||||
@ -72,10 +72,47 @@ def test_structured_prompt_dict() -> None:
|
|||||||
|
|
||||||
chain = prompt | model
|
chain = prompt | model
|
||||||
|
|
||||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1}
|
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
||||||
|
|
||||||
assert loads(dumps(prompt)) == prompt
|
assert loads(dumps(prompt)) == prompt
|
||||||
|
|
||||||
chain = loads(dumps(prompt)) | model
|
chain = loads(dumps(prompt)) | model
|
||||||
|
|
||||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1}
|
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
||||||
|
|
||||||
|
|
||||||
|
def test_structured_prompt_kwargs() -> None:
|
||||||
|
prompt = StructuredPrompt(
|
||||||
|
[
|
||||||
|
("human", "I'm very structured, how about you?"),
|
||||||
|
],
|
||||||
|
{
|
||||||
|
"name": "yo",
|
||||||
|
"description": "a structured output",
|
||||||
|
"parameters": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"value": {"type": "integer"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
value=7,
|
||||||
|
)
|
||||||
|
model = FakeStructuredChatModel(responses=[])
|
||||||
|
chain = prompt | model
|
||||||
|
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
|
||||||
|
assert loads(dumps(prompt)) == prompt
|
||||||
|
chain = loads(dumps(prompt)) | model
|
||||||
|
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
|
||||||
|
|
||||||
|
class OutputSchema(BaseModel):
|
||||||
|
name: str
|
||||||
|
value: int
|
||||||
|
|
||||||
|
prompt = StructuredPrompt(
|
||||||
|
[("human", "I'm very structured, how about you?")], OutputSchema, value=7
|
||||||
|
)
|
||||||
|
|
||||||
|
model = FakeStructuredChatModel(responses=[])
|
||||||
|
|
||||||
|
chain = prompt | model
|
||||||
|
|
||||||
|
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)
|
||||||
|
Loading…
Reference in New Issue
Block a user