langchain/langchain/output_parsers/structured.py
Harrison Chase ce5d97bcb3
Harrison/guarded output parser (#1804)
Co-authored-by: jerwelborn <jeremy.welborn@gmail.com>
2023-03-21 22:07:23 -07:00

50 lines
1.5 KiB
Python

from __future__ import annotations
import json
from typing import List
from pydantic import BaseModel
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
from langchain.schema import BaseOutputParser, OutputParserException
line_template = '\t"{name}": {type} // {description}'
class ResponseSchema(BaseModel):
name: str
description: str
def _get_sub_string(schema: ResponseSchema) -> str:
return line_template.format(
name=schema.name, description=schema.description, type="string"
)
class StructuredOutputParser(BaseOutputParser):
response_schemas: List[ResponseSchema]
@classmethod
def from_response_schemas(
cls, response_schemas: List[ResponseSchema]
) -> StructuredOutputParser:
return cls(response_schemas=response_schemas)
def get_format_instructions(self) -> str:
schema_str = "\n".join(
[_get_sub_string(schema) for schema in self.response_schemas]
)
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> BaseModel:
json_string = text.split("```json")[1].strip().strip("```").strip()
json_obj = json.loads(json_string)
for schema in self.response_schemas:
if schema.name not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}"
)
return json_obj