Harrison/new output parser (#1617)

This commit is contained in:
Harrison Chase 2023-03-13 15:08:39 -07:00 committed by GitHub
parent 039d05c808
commit df6c33d4b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 405 additions and 0 deletions

View File

@ -0,0 +1,334 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "084ee2f0",
"metadata": {},
"source": [
"# Output Parsers\n",
"\n",
"Language models output text. But many times you may want to get more structured information than just text back. This is where output parsers come in.\n",
"\n",
"Output parsers are classes that help structure language model responses. There are two main methods an output parser must implement:\n",
"\n",
"- `get_format_instructions() -> str`: A method which returns a string containing instructions for how the output of a language model should be formatted.\n",
"- `parse(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and parses it into some structure.\n",
"\n",
"Below we go over some examples of output parsers."
]
},
{
"cell_type": "markdown",
"id": "91871002",
"metadata": {},
"source": [
"## Structured Output Parser\n",
"\n",
"This output parser can be used when you want to return multiple fields."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b492997a",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers import StructuredOutputParser, ResponseSchema"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ffb7fc57",
"metadata": {},
"outputs": [],
"source": [
"from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n",
"from langchain.llms import OpenAI\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "markdown",
"id": "09473dce",
"metadata": {},
"source": [
"Here we define the response schema we want to receive."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "432ac44a",
"metadata": {},
"outputs": [],
"source": [
"response_schemas = [\n",
" ResponseSchema(name=\"answer\", description=\"answer to the user's question\"),\n",
" ResponseSchema(name=\"source\", description=\"source used to answer the user's question, should be a website.\")\n",
"]\n",
"output_parser = StructuredOutputParser.from_response_schemas(response_schemas)"
]
},
{
"cell_type": "markdown",
"id": "7b92ce96",
"metadata": {},
"source": [
"We now get a string that contains instructions for how the response should be formatted, and we then insert that into our prompt."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "593cfc25",
"metadata": {},
"outputs": [],
"source": [
"format_instructions = output_parser.get_format_instructions()\n",
"prompt = PromptTemplate(\n",
" template=\"answer the users question as best as possible.\\n{format_instructions}\\n{question}\",\n",
" input_variables=[\"question\"],\n",
" partial_variables={\"format_instructions\": format_instructions}\n",
")"
]
},
{
"cell_type": "markdown",
"id": "0943e783",
"metadata": {},
"source": [
"We can now use this to format a prompt to send to the language model, and then parse the returned result."
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "106f1ba6",
"metadata": {},
"outputs": [],
"source": [
"model = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "86d9d24f",
"metadata": {},
"outputs": [],
"source": [
"_input = prompt.format_prompt(question=\"what's the capital of france\")\n",
"output = model(_input.to_string())"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "956bdc99",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_parser.parse(output)"
]
},
{
"cell_type": "markdown",
"id": "da639285",
"metadata": {},
"source": [
"And here's an example of using this in a chat model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8f483d7d",
"metadata": {},
"outputs": [],
"source": [
"chat_model = ChatOpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f761cbf1",
"metadata": {},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate(\n",
" messages=[\n",
" HumanMessagePromptTemplate.from_template(\"answer the users question as best as possible.\\n{format_instructions}\\n{question}\") \n",
" ],\n",
" input_variables=[\"question\"],\n",
" partial_variables={\"format_instructions\": format_instructions}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "edd73ae3",
"metadata": {},
"outputs": [],
"source": [
"_input = prompt.format_prompt(question=\"what's the capital of france\")\n",
"output = chat_model(_input.to_messages())"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "a3c8b91e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_parser.parse(output.content)"
]
},
{
"cell_type": "markdown",
"id": "9936fa27",
"metadata": {},
"source": [
"## CommaSeparatedListOutputParser\n",
"\n",
"This output parser can be used to get a list of items as output."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "872246d7",
"metadata": {},
"outputs": [],
"source": [
"from langchain.output_parsers import CommaSeparatedListOutputParser"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "c3f9aee6",
"metadata": {},
"outputs": [],
"source": [
"output_parser = CommaSeparatedListOutputParser()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "e77871b7",
"metadata": {},
"outputs": [],
"source": [
"format_instructions = output_parser.get_format_instructions()\n",
"prompt = PromptTemplate(\n",
" template=\"List five {subject}.\\n{format_instructions}\",\n",
" input_variables=[\"subject\"],\n",
" partial_variables={\"format_instructions\": format_instructions}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a71cb5d3",
"metadata": {},
"outputs": [],
"source": [
"model = OpenAI(temperature=0)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "783d7d98",
"metadata": {},
"outputs": [],
"source": [
"_input = prompt.format(subject=\"ice cream flavors\")\n",
"output = model(_input)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "fcb81344",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['Vanilla',\n",
" 'Chocolate',\n",
" 'Strawberry',\n",
" 'Mint Chocolate Chip',\n",
" 'Cookies and Cream']"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output_parser.parse(output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cba6d8e3",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -4,10 +4,13 @@ from langchain.output_parsers.list import (
ListOutputParser,
)
from langchain.output_parsers.regex import RegexParser
from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser
__all__ = [
"RegexParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"BaseOutputParser",
"StructuredOutputParser",
"ResponseSchema",
]

View File

@ -13,6 +13,9 @@ class BaseOutputParser(BaseModel, ABC):
def parse(self, text: str) -> Any:
"""Parse the output of an LLM call."""
def get_format_instructions(self) -> str:
raise NotImplementedError
@property
def _type(self) -> str:
"""Return the type key."""

View File

@ -0,0 +1,9 @@
# flake8: noqa
STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet formatted in the following schema:
```json
{{
{format}
}}
```"""

View File

@ -17,6 +17,12 @@ class ListOutputParser(BaseOutputParser):
class CommaSeparatedListOutputParser(ListOutputParser):
"""Parse out comma separated lists."""
def get_format_instructions(self) -> str:
return (
"Your response should be a list of comma separated values, "
"eg: `foo, bar, baz`"
)
def parse(self, text: str) -> List[str]:
"""Parse the output of an LLM call."""
return text.strip().split(", ")

View File

@ -0,0 +1,49 @@
from __future__ import annotations
import json
from typing import List
from pydantic import BaseModel
from langchain.output_parsers.base import BaseOutputParser
from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS
line_template = '\t"{name}": {type} // {description}'
class ResponseSchema(BaseModel):
name: str
description: str
def _get_sub_string(schema: ResponseSchema) -> str:
return line_template.format(
name=schema.name, description=schema.description, type="string"
)
class StructuredOutputParser(BaseOutputParser):
response_schemas: List[ResponseSchema]
@classmethod
def from_response_schemas(
cls, response_schemas: List[ResponseSchema]
) -> StructuredOutputParser:
return cls(response_schemas=response_schemas)
def get_format_instructions(self) -> str:
schema_str = "\n".join(
[_get_sub_string(schema) for schema in self.response_schemas]
)
return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str)
def parse(self, text: str) -> BaseModel:
json_string = text.split("```json")[1].strip().strip("```").strip()
json_obj = json.loads(json_string)
for schema in self.response_schemas:
if schema.name not in json_obj:
raise ValueError(
f"Got invalid return object. Expected key `{schema.name}` "
f"to be present, but got {json_obj}"
)
return json_obj

View File

@ -159,6 +159,7 @@ class ChatPromptTemplate(BasePromptTemplate, ABC):
return self.format_prompt(**kwargs).to_string()
def format_prompt(self, **kwargs: Any) -> PromptValue:
kwargs = self._merge_partial_and_user_variables(**kwargs)
result = []
for message_template in self.messages:
if isinstance(message_template, BaseMessage):