forked from Archives/langchain
ce5d97bcb3
Co-authored-by: jerwelborn <jeremy.welborn@gmail.com>
50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import List
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
|
|
from langchain.schema import BaseOutputParser, OutputParserException
|
|
|
|
line_template = '\t"{name}": {type} // {description}'
|
|
|
|
|
|
class ResponseSchema(BaseModel):
|
|
name: str
|
|
description: str
|
|
|
|
|
|
def _get_sub_string(schema: ResponseSchema) -> str:
|
|
return line_template.format(
|
|
name=schema.name, description=schema.description, type="string"
|
|
)
|
|
|
|
|
|
class StructuredOutputParser(BaseOutputParser):
|
|
response_schemas: List[ResponseSchema]
|
|
|
|
@classmethod
|
|
def from_response_schemas(
|
|
cls, response_schemas: List[ResponseSchema]
|
|
) -> StructuredOutputParser:
|
|
return cls(response_schemas=response_schemas)
|
|
|
|
def get_format_instructions(self) -> str:
|
|
schema_str = "\n".join(
|
|
[_get_sub_string(schema) for schema in self.response_schemas]
|
|
)
|
|
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
|
|
|
|
def parse(self, text: str) -> BaseModel:
|
|
json_string = text.split("```json")[1].strip().strip("```").strip()
|
|
json_obj = json.loads(json_string)
|
|
for schema in self.response_schemas:
|
|
if schema.name not in json_obj:
|
|
raise OutputParserException(
|
|
f"Got invalid return object. Expected key `{schema.name}` "
|
|
f"to be present, but got {json_obj}"
|
|
)
|
|
return json_obj
|