forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, List
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
|
|
from langchain.output_parsers.json import parse_and_check_json_markdown
|
|
from langchain.schema import BaseOutputParser
|
|
|
|
line_template = '\t"{name}": {type} // {description}'
|
|
|
|
|
|
class ResponseSchema(BaseModel):
|
|
name: str
|
|
description: str
|
|
type: str = "string"
|
|
|
|
|
|
def _get_sub_string(schema: ResponseSchema) -> str:
|
|
return line_template.format(
|
|
name=schema.name, description=schema.description, type=schema.type
|
|
)
|
|
|
|
|
|
class StructuredOutputParser(BaseOutputParser):
|
|
response_schemas: List[ResponseSchema]
|
|
|
|
@classmethod
|
|
def from_response_schemas(
|
|
cls, response_schemas: List[ResponseSchema]
|
|
) -> StructuredOutputParser:
|
|
return cls(response_schemas=response_schemas)
|
|
|
|
def get_format_instructions(self) -> str:
|
|
schema_str = "\n".join(
|
|
[_get_sub_string(schema) for schema in self.response_schemas]
|
|
)
|
|
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
|
|
|
def parse(self, text: str) -> Any:
|
|
expected_keys = [rs.name for rs in self.response_schemas]
|
|
return parse_and_check_json_markdown(text, expected_keys)
|
|
|
|
@property
|
|
def _type(self) -> str:
|
|
return "structured"
|