mirror of https://github.com/hwchase17/langchain
core[minor], openai[minor], langchain[patch]: BaseLanguageModel.with_structured_output #17302)
```python class Foo(BaseModel): bar: str structured_llm = ChatOpenAI().with_structured_output(Foo) ``` --------- Co-authored-by: Erick Friis <erick@langchain.dev>pull/17296/head^2
parent
f685d2f50c
commit
b5f8cf9509
@ -0,0 +1,62 @@
|
||||
import json
|
||||
from typing import Any, List, Type
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
|
||||
class PydanticOutputParser(JsonOutputParser):
|
||||
"""Parse an output using a pydantic model."""
|
||||
|
||||
pydantic_object: Type[BaseModel]
|
||||
"""The pydantic model to parse.
|
||||
|
||||
Attention: To avoid potential compatibility issues, it's recommended to use
|
||||
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
json_object = super().parse_result(result)
|
||||
try:
|
||||
return self.pydantic_object.parse_obj(json_object)
|
||||
except ValidationError as e:
|
||||
name = self.pydantic_object.__name__
|
||||
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
|
||||
raise OutputParserException(msg, llm_output=json_object)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
# Copy schema to avoid altering original Pydantic schema.
|
||||
schema = {k: v for k, v in self.pydantic_object.schema().items()}
|
||||
|
||||
# Remove extraneous fields.
|
||||
reduced_schema = schema
|
||||
if "title" in reduced_schema:
|
||||
del reduced_schema["title"]
|
||||
if "type" in reduced_schema:
|
||||
del reduced_schema["type"]
|
||||
# Ensure json in context is well-formed with double quotes.
|
||||
schema_str = json.dumps(reduced_schema)
|
||||
|
||||
return _PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "pydantic"
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[BaseModel]:
|
||||
"""Return the pydantic model."""
|
||||
return self.pydantic_object
|
||||
|
||||
|
||||
_PYDANTIC_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
|
||||
|
||||
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
|
||||
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
|
||||
|
||||
Here is the output schema:
|
||||
```
|
||||
{schema}
|
||||
```""" # noqa: E501
|
@ -1,53 +1,3 @@
|
||||
import json
|
||||
from typing import Any, List, Type
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
from langchain_core.outputs import Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel, ValidationError
|
||||
|
||||
from langchain.output_parsers.format_instructions import PYDANTIC_FORMAT_INSTRUCTIONS
|
||||
|
||||
|
||||
class PydanticOutputParser(JsonOutputParser):
|
||||
"""Parse an output using a pydantic model."""
|
||||
|
||||
pydantic_object: Type[BaseModel]
|
||||
"""The pydantic model to parse.
|
||||
|
||||
Attention: To avoid potential compatibility issues, it's recommended to use
|
||||
pydantic <2 or leverage the v1 namespace in pydantic >= 2.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
json_object = super().parse_result(result)
|
||||
try:
|
||||
return self.pydantic_object.parse_obj(json_object)
|
||||
except ValidationError as e:
|
||||
name = self.pydantic_object.__name__
|
||||
msg = f"Failed to parse {name} from completion {json_object}. Got: {e}"
|
||||
raise OutputParserException(msg, llm_output=json_object)
|
||||
|
||||
def get_format_instructions(self) -> str:
|
||||
# Copy schema to avoid altering original Pydantic schema.
|
||||
schema = {k: v for k, v in self.pydantic_object.schema().items()}
|
||||
|
||||
# Remove extraneous fields.
|
||||
reduced_schema = schema
|
||||
if "title" in reduced_schema:
|
||||
del reduced_schema["title"]
|
||||
if "type" in reduced_schema:
|
||||
del reduced_schema["type"]
|
||||
# Ensure json in context is well-formed with double quotes.
|
||||
schema_str = json.dumps(reduced_schema)
|
||||
|
||||
return PYDANTIC_FORMAT_INSTRUCTIONS.format(schema=schema_str)
|
||||
|
||||
@property
|
||||
def _type(self) -> str:
|
||||
return "pydantic"
|
||||
|
||||
@property
|
||||
def OutputType(self) -> Type[BaseModel]:
|
||||
"""Return the pydantic model."""
|
||||
return self.pydantic_object
|
||||
__all__ = ["PydanticOutputParser"]
|
||||
|
@ -0,0 +1,11 @@
|
||||
from langchain_openai.output_parsers.tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
JsonOutputToolsParser,
|
||||
PydanticToolsParser,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JsonOutputToolsParser",
|
||||
"JsonOutputKeyToolsParser",
|
||||
"PydanticToolsParser",
|
||||
]
|
@ -0,0 +1,123 @@
|
||||
import copy
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, Type
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||
from langchain_core.output_parsers.json import parse_partial_json
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
|
||||
|
||||
class JsonOutputToolsParser(BaseGenerationOutputParser[Any]):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
strict: bool = False
|
||||
"""Whether to allow non-JSON-compliant strings.
|
||||
|
||||
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
|
||||
|
||||
Useful when the parsed output may include unicode characters or new lines.
|
||||
"""
|
||||
return_id: bool = False
|
||||
"""Whether to return the tool call id."""
|
||||
first_tool_only: bool = False
|
||||
"""Whether to return only the first tool call.
|
||||
|
||||
If False, the result will be a list of tool calls, or an empty list
|
||||
if no tool calls are found.
|
||||
|
||||
If true, and multiple tool calls are found, only the first one will be returned,
|
||||
and the other tool calls will be ignored.
|
||||
If no tool calls are found, None will be returned.
|
||||
"""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise OutputParserException(
|
||||
"This output parser can only be used with a chat generation."
|
||||
)
|
||||
message = generation.message
|
||||
try:
|
||||
tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
||||
except KeyError:
|
||||
return []
|
||||
|
||||
final_tools = []
|
||||
exceptions = []
|
||||
for tool_call in tool_calls:
|
||||
if "function" not in tool_call:
|
||||
continue
|
||||
try:
|
||||
if partial:
|
||||
function_args = parse_partial_json(
|
||||
tool_call["function"]["arguments"], strict=self.strict
|
||||
)
|
||||
else:
|
||||
function_args = json.loads(
|
||||
tool_call["function"]["arguments"], strict=self.strict
|
||||
)
|
||||
except JSONDecodeError as e:
|
||||
exceptions.append(
|
||||
f"Function {tool_call['function']['name']} arguments:\n\n"
|
||||
f"{tool_call['function']['arguments']}\n\nare not valid JSON. "
|
||||
f"Received JSONDecodeError {e}"
|
||||
)
|
||||
continue
|
||||
parsed = {
|
||||
"type": tool_call["function"]["name"],
|
||||
"args": function_args,
|
||||
}
|
||||
if self.return_id:
|
||||
parsed["id"] = tool_call["id"]
|
||||
final_tools.append(parsed)
|
||||
if exceptions:
|
||||
raise OutputParserException("\n\n".join(exceptions))
|
||||
if self.first_tool_only:
|
||||
return final_tools[0] if final_tools else None
|
||||
return final_tools
|
||||
|
||||
|
||||
class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
key_name: str
|
||||
"""The type of tools to return."""
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
parsed_result = super().parse_result(result, partial=partial)
|
||||
if self.first_tool_only:
|
||||
single_result = (
|
||||
parsed_result
|
||||
if parsed_result and parsed_result["type"] == self.key_name
|
||||
else None
|
||||
)
|
||||
if self.return_id:
|
||||
return single_result
|
||||
elif single_result:
|
||||
return single_result["args"]
|
||||
else:
|
||||
return None
|
||||
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
|
||||
if not self.return_id:
|
||||
parsed_result = [res["args"] for res in parsed_result]
|
||||
return parsed_result
|
||||
|
||||
|
||||
class PydanticToolsParser(JsonOutputToolsParser):
|
||||
"""Parse tools from OpenAI response."""
|
||||
|
||||
tools: List[Type[BaseModel]]
|
||||
|
||||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
|
||||
parsed_result = super().parse_result(result, partial=partial)
|
||||
name_dict = {tool.__name__: tool for tool in self.tools}
|
||||
if self.first_tool_only:
|
||||
return (
|
||||
name_dict[parsed_result["type"]](**parsed_result["args"])
|
||||
if parsed_result
|
||||
else None
|
||||
)
|
||||
return [name_dict[res["type"]](**res["args"]) for res in parsed_result]
|
Loading…
Reference in New Issue