mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
core[patch]: support additional kwargs on StructuredPrompt (#25645)
This commit is contained in:
parent
51dae57357
commit
933bc0d6ff
@ -7,7 +7,6 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
@ -15,20 +14,17 @@ from typing import (
|
||||
from langchain_core._api.beta_decorator import beta
|
||||
from langchain_core.language_models.base import BaseLanguageModel
|
||||
from langchain_core.prompts.chat import (
|
||||
BaseChatPromptTemplate,
|
||||
BaseMessagePromptTemplate,
|
||||
ChatPromptTemplate,
|
||||
MessageLikeRepresentation,
|
||||
MessagesPlaceholder,
|
||||
_convert_to_message,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
from langchain_core.runnables.base import (
|
||||
Other,
|
||||
Runnable,
|
||||
RunnableSequence,
|
||||
RunnableSerializable,
|
||||
)
|
||||
from langchain_core.utils import get_pydantic_field_names
|
||||
|
||||
|
||||
@beta()
|
||||
@ -37,6 +33,26 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
|
||||
schema_: Union[Dict, Type[BaseModel]]
|
||||
"""Schema for the structured prompt."""
|
||||
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
|
||||
*,
|
||||
structured_output_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
schema_ = schema_ or kwargs.pop("schema")
|
||||
structured_output_kwargs = structured_output_kwargs or {}
|
||||
for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
|
||||
structured_output_kwargs[k] = kwargs.pop(k)
|
||||
super().__init__(
|
||||
messages=messages,
|
||||
schema_=schema_,
|
||||
structured_output_kwargs=structured_output_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
@ -52,6 +68,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
cls,
|
||||
messages: Sequence[MessageLikeRepresentation],
|
||||
schema: Union[Dict, Type[BaseModel]],
|
||||
**kwargs: Any,
|
||||
) -> ChatPromptTemplate:
|
||||
"""Create a chat prompt template from a variety of message formats.
|
||||
|
||||
@ -61,11 +78,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_core.prompts import StructuredPrompt
|
||||
|
||||
class OutputSchema(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
template = ChatPromptTemplate.from_messages(
|
||||
template = StructuredPrompt(
|
||||
[
|
||||
("human", "Hello, how are you?"),
|
||||
("ai", "I'm doing well, thanks!"),
|
||||
@ -82,29 +101,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
(4) 2-tuple of (message class, template), (5) a string which is
|
||||
shorthand for ("human", template); e.g., "{user_input}"
|
||||
schema: a dictionary representation of function call, or a Pydantic model.
|
||||
kwargs: Any additional kwargs to pass through to
|
||||
``ChatModel.with_structured_output(schema, **kwargs)``.
|
||||
|
||||
Returns:
|
||||
a structured prompt template
|
||||
"""
|
||||
_messages = [_convert_to_message(message) for message in messages]
|
||||
|
||||
# Automatically infer input variables from messages
|
||||
input_vars: Set[str] = set()
|
||||
partial_vars: Dict[str, Any] = {}
|
||||
for _message in _messages:
|
||||
if isinstance(_message, MessagesPlaceholder) and _message.optional:
|
||||
partial_vars[_message.variable_name] = []
|
||||
elif isinstance(
|
||||
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
||||
):
|
||||
input_vars.update(_message.input_variables)
|
||||
|
||||
return cls(
|
||||
input_variables=sorted(input_vars),
|
||||
messages=_messages,
|
||||
partial_variables=partial_vars,
|
||||
schema_=schema,
|
||||
)
|
||||
return cls(messages, schema, **kwargs)
|
||||
|
||||
def __or__(
|
||||
self,
|
||||
@ -115,27 +118,16 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
) -> RunnableSerializable[Dict, Other]:
|
||||
if isinstance(other, BaseLanguageModel) or hasattr(
|
||||
other, "with_structured_output"
|
||||
):
|
||||
try:
|
||||
return RunnableSequence(
|
||||
self, other.with_structured_output(self.schema_)
|
||||
)
|
||||
except NotImplementedError as e:
|
||||
raise NotImplementedError(
|
||||
"Structured prompts must be piped to a language model that "
|
||||
"implements with_structured_output."
|
||||
) from e
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Structured prompts must be piped to a language model that "
|
||||
"implements with_structured_output."
|
||||
)
|
||||
return self.pipe(other)
|
||||
|
||||
def pipe(
|
||||
self,
|
||||
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
|
||||
*others: Union[
|
||||
Runnable[Any, Other],
|
||||
Callable[[Any], Other],
|
||||
Callable[[Iterator[Any]], Iterator[Other]],
|
||||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
||||
],
|
||||
name: Optional[str] = None,
|
||||
) -> RunnableSerializable[Dict, Other]:
|
||||
"""Pipe the structured prompt to a language model.
|
||||
@ -158,7 +150,9 @@ class StructuredPrompt(ChatPromptTemplate):
|
||||
):
|
||||
return RunnableSequence(
|
||||
self,
|
||||
others[0].with_structured_output(self.schema_),
|
||||
others[0].with_structured_output(
|
||||
self.schema_, **self.structured_output_kwargs
|
||||
),
|
||||
*others[1:],
|
||||
name=name,
|
||||
)
|
||||
|
@ -12,13 +12,13 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
|
||||
|
||||
|
||||
def _fake_runnable(
|
||||
schema: Union[Dict, Type[BaseModel]], _: Any
|
||||
input: Any, *, schema: Union[Dict, Type[BaseModel]], value: Any = 42, **_: Any
|
||||
) -> Union[BaseModel, Dict]:
|
||||
if isclass(schema) and is_basemodel_subclass(schema):
|
||||
return schema(name="yo", value=42)
|
||||
return schema(name="yo", value=value)
|
||||
else:
|
||||
params = cast(Dict, schema)["parameters"]
|
||||
return {k: 1 for k, v in params.items()}
|
||||
return {k: 1 if k != "value" else value for k, v in params.items()}
|
||||
|
||||
|
||||
class FakeStructuredChatModel(FakeListChatModel):
|
||||
@ -27,7 +27,7 @@ class FakeStructuredChatModel(FakeListChatModel):
|
||||
def with_structured_output(
|
||||
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
|
||||
) -> Runnable:
|
||||
return RunnableLambda(partial(_fake_runnable, schema))
|
||||
return RunnableLambda(partial(_fake_runnable, schema=schema, **kwargs))
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
@ -39,7 +39,7 @@ def test_structured_prompt_pydantic() -> None:
|
||||
name: str
|
||||
value: int
|
||||
|
||||
prompt = StructuredPrompt.from_messages_and_schema(
|
||||
prompt = StructuredPrompt(
|
||||
[
|
||||
("human", "I'm very structured, how about you?"),
|
||||
],
|
||||
@ -54,7 +54,7 @@ def test_structured_prompt_pydantic() -> None:
|
||||
|
||||
|
||||
def test_structured_prompt_dict() -> None:
|
||||
prompt = StructuredPrompt.from_messages_and_schema(
|
||||
prompt = StructuredPrompt(
|
||||
[
|
||||
("human", "I'm very structured, how about you?"),
|
||||
],
|
||||
@ -72,10 +72,47 @@ def test_structured_prompt_dict() -> None:
|
||||
|
||||
chain = prompt | model
|
||||
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1}
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
||||
|
||||
assert loads(dumps(prompt)) == prompt
|
||||
|
||||
chain = loads(dumps(prompt)) | model
|
||||
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1}
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
||||
|
||||
|
||||
def test_structured_prompt_kwargs() -> None:
|
||||
prompt = StructuredPrompt(
|
||||
[
|
||||
("human", "I'm very structured, how about you?"),
|
||||
],
|
||||
{
|
||||
"name": "yo",
|
||||
"description": "a structured output",
|
||||
"parameters": {
|
||||
"name": {"type": "string"},
|
||||
"value": {"type": "integer"},
|
||||
},
|
||||
},
|
||||
value=7,
|
||||
)
|
||||
model = FakeStructuredChatModel(responses=[])
|
||||
chain = prompt | model
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
|
||||
assert loads(dumps(prompt)) == prompt
|
||||
chain = loads(dumps(prompt)) | model
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
|
||||
|
||||
class OutputSchema(BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
prompt = StructuredPrompt(
|
||||
[("human", "I'm very structured, how about you?")], OutputSchema, value=7
|
||||
)
|
||||
|
||||
model = FakeStructuredChatModel(responses=[])
|
||||
|
||||
chain = prompt | model
|
||||
|
||||
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)
|
||||
|
Loading…
Reference in New Issue
Block a user