|
|
|
@ -7,7 +7,6 @@ from typing import (
|
|
|
|
|
Mapping,
|
|
|
|
|
Optional,
|
|
|
|
|
Sequence,
|
|
|
|
|
Set,
|
|
|
|
|
Type,
|
|
|
|
|
Union,
|
|
|
|
|
)
|
|
|
|
@ -15,20 +14,17 @@ from typing import (
|
|
|
|
|
from langchain_core._api.beta_decorator import beta
|
|
|
|
|
from langchain_core.language_models.base import BaseLanguageModel
|
|
|
|
|
from langchain_core.prompts.chat import (
|
|
|
|
|
BaseChatPromptTemplate,
|
|
|
|
|
BaseMessagePromptTemplate,
|
|
|
|
|
ChatPromptTemplate,
|
|
|
|
|
MessageLikeRepresentation,
|
|
|
|
|
MessagesPlaceholder,
|
|
|
|
|
_convert_to_message,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel
|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
|
from langchain_core.runnables.base import (
|
|
|
|
|
Other,
|
|
|
|
|
Runnable,
|
|
|
|
|
RunnableSequence,
|
|
|
|
|
RunnableSerializable,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.utils import get_pydantic_field_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@beta()
|
|
|
|
@ -37,6 +33,26 @@ class StructuredPrompt(ChatPromptTemplate):
|
|
|
|
|
|
|
|
|
|
schema_: Union[Dict, Type[BaseModel]]
|
|
|
|
|
"""Schema for the structured prompt."""
|
|
|
|
|
structured_output_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
messages: Sequence[MessageLikeRepresentation],
|
|
|
|
|
schema_: Optional[Union[Dict, Type[BaseModel]]] = None,
|
|
|
|
|
*,
|
|
|
|
|
structured_output_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> None:
|
|
|
|
|
schema_ = schema_ or kwargs.pop("schema")
|
|
|
|
|
structured_output_kwargs = structured_output_kwargs or {}
|
|
|
|
|
for k in set(kwargs).difference(get_pydantic_field_names(self.__class__)):
|
|
|
|
|
structured_output_kwargs[k] = kwargs.pop(k)
|
|
|
|
|
super().__init__(
|
|
|
|
|
messages=messages,
|
|
|
|
|
schema_=schema_,
|
|
|
|
|
structured_output_kwargs=structured_output_kwargs,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_lc_namespace(cls) -> List[str]:
|
|
|
|
@ -52,6 +68,7 @@ class StructuredPrompt(ChatPromptTemplate):
|
|
|
|
|
cls,
|
|
|
|
|
messages: Sequence[MessageLikeRepresentation],
|
|
|
|
|
schema: Union[Dict, Type[BaseModel]],
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ChatPromptTemplate:
|
|
|
|
|
"""Create a chat prompt template from a variety of message formats.
|
|
|
|
|
|
|
|
|
@ -61,11 +78,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
from langchain_core.prompts import StructuredPrompt
|
|
|
|
|
|
|
|
|
|
class OutputSchema(BaseModel):
|
|
|
|
|
name: str
|
|
|
|
|
value: int
|
|
|
|
|
|
|
|
|
|
template = ChatPromptTemplate.from_messages(
|
|
|
|
|
template = StructuredPrompt(
|
|
|
|
|
[
|
|
|
|
|
("human", "Hello, how are you?"),
|
|
|
|
|
("ai", "I'm doing well, thanks!"),
|
|
|
|
@ -82,29 +101,13 @@ class StructuredPrompt(ChatPromptTemplate):
|
|
|
|
|
(4) 2-tuple of (message class, template), (5) a string which is
|
|
|
|
|
shorthand for ("human", template); e.g., "{user_input}"
|
|
|
|
|
schema: a dictionary representation of function call, or a Pydantic model.
|
|
|
|
|
kwargs: Any additional kwargs to pass through to
|
|
|
|
|
``ChatModel.with_structured_output(schema, **kwargs)``.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
a structured prompt template
|
|
|
|
|
"""
|
|
|
|
|
_messages = [_convert_to_message(message) for message in messages]
|
|
|
|
|
|
|
|
|
|
# Automatically infer input variables from messages
|
|
|
|
|
input_vars: Set[str] = set()
|
|
|
|
|
partial_vars: Dict[str, Any] = {}
|
|
|
|
|
for _message in _messages:
|
|
|
|
|
if isinstance(_message, MessagesPlaceholder) and _message.optional:
|
|
|
|
|
partial_vars[_message.variable_name] = []
|
|
|
|
|
elif isinstance(
|
|
|
|
|
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
|
|
|
|
|
):
|
|
|
|
|
input_vars.update(_message.input_variables)
|
|
|
|
|
|
|
|
|
|
return cls(
|
|
|
|
|
input_variables=sorted(input_vars),
|
|
|
|
|
messages=_messages,
|
|
|
|
|
partial_variables=partial_vars,
|
|
|
|
|
schema_=schema,
|
|
|
|
|
)
|
|
|
|
|
return cls(messages, schema, **kwargs)
|
|
|
|
|
|
|
|
|
|
def __or__(
|
|
|
|
|
self,
|
|
|
|
@ -115,27 +118,16 @@ class StructuredPrompt(ChatPromptTemplate):
|
|
|
|
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
|
|
|
|
],
|
|
|
|
|
) -> RunnableSerializable[Dict, Other]:
|
|
|
|
|
if isinstance(other, BaseLanguageModel) or hasattr(
|
|
|
|
|
other, "with_structured_output"
|
|
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
return RunnableSequence(
|
|
|
|
|
self, other.with_structured_output(self.schema_)
|
|
|
|
|
)
|
|
|
|
|
except NotImplementedError as e:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Structured prompts must be piped to a language model that "
|
|
|
|
|
"implements with_structured_output."
|
|
|
|
|
) from e
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Structured prompts must be piped to a language model that "
|
|
|
|
|
"implements with_structured_output."
|
|
|
|
|
)
|
|
|
|
|
return self.pipe(other)
|
|
|
|
|
|
|
|
|
|
def pipe(
|
|
|
|
|
self,
|
|
|
|
|
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
|
|
|
|
|
*others: Union[
|
|
|
|
|
Runnable[Any, Other],
|
|
|
|
|
Callable[[Any], Other],
|
|
|
|
|
Callable[[Iterator[Any]], Iterator[Other]],
|
|
|
|
|
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
|
|
|
|
|
],
|
|
|
|
|
name: Optional[str] = None,
|
|
|
|
|
) -> RunnableSerializable[Dict, Other]:
|
|
|
|
|
"""Pipe the structured prompt to a language model.
|
|
|
|
@ -158,7 +150,9 @@ class StructuredPrompt(ChatPromptTemplate):
|
|
|
|
|
):
|
|
|
|
|
return RunnableSequence(
|
|
|
|
|
self,
|
|
|
|
|
others[0].with_structured_output(self.schema_),
|
|
|
|
|
others[0].with_structured_output(
|
|
|
|
|
self.schema_, **self.structured_output_kwargs
|
|
|
|
|
),
|
|
|
|
|
*others[1:],
|
|
|
|
|
name=name,
|
|
|
|
|
)
|
|
|
|
|