diff --git a/docs/modules/prompts/examples/output_parsers.ipynb b/docs/modules/prompts/examples/output_parsers.ipynb new file mode 100644 index 00000000..76b71df7 --- /dev/null +++ b/docs/modules/prompts/examples/output_parsers.ipynb @@ -0,0 +1,334 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "084ee2f0", + "metadata": {}, + "source": [ + "# Output Parsers\n", + "\n", + "Language models output text. But many times you may want to get more structured information than just text back. This is where output parsers come in.\n", + "\n", + "Output parsers are classes that help structure language model responses. There are two main methods an output parser must implement:\n", + "\n", + "- `get_format_instructions() -> str`: A method which returns a string containing instructions for how the output of a language model should be formatted.\n", + "- `parse(str) -> Any`: A method which takes in a string (assumed to be the response from a language model) and parses it into some structure.\n", + "\n", + "Below we go over some examples of output parsers." + ] + }, + { + "cell_type": "markdown", + "id": "91871002", + "metadata": {}, + "source": [ + "## Structured Output Parser\n", + "\n", + "This output parser can be used when you want to return multiple fields." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "b492997a", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.output_parsers import StructuredOutputParser, ResponseSchema" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ffb7fc57", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate\n", + "from langchain.llms import OpenAI\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "markdown", + "id": "09473dce", + "metadata": {}, + "source": [ + "Here we define the response schema we want to receive." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "432ac44a", + "metadata": {}, + "outputs": [], + "source": [ + "response_schemas = [\n", + " ResponseSchema(name=\"answer\", description=\"answer to the user's question\"),\n", + " ResponseSchema(name=\"source\", description=\"source used to answer the user's question, should be a website.\")\n", + "]\n", + "output_parser = StructuredOutputParser.from_response_schemas(response_schemas)" + ] + }, + { + "cell_type": "markdown", + "id": "7b92ce96", + "metadata": {}, + "source": [ + "We now get a string that contains instructions for how the response should be formatted, and we then insert that into our prompt." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "593cfc25", + "metadata": {}, + "outputs": [], + "source": [ + "format_instructions = output_parser.get_format_instructions()\n", + "prompt = PromptTemplate(\n", + " template=\"answer the users question as best as possible.\\n{format_instructions}\\n{question}\",\n", + " input_variables=[\"question\"],\n", + " partial_variables={\"format_instructions\": format_instructions}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0943e783", + "metadata": {}, + "source": [ + "We can now use this to format a prompt to send to the language model, and then parse the returned result." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "106f1ba6", + "metadata": {}, + "outputs": [], + "source": [ + "model = OpenAI(temperature=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "86d9d24f", + "metadata": {}, + "outputs": [], + "source": [ + "_input = prompt.format_prompt(question=\"what's the capital of france\")\n", + "output = model(_input.to_string())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "956bdc99", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_parser.parse(output)" + ] + }, + { + "cell_type": "markdown", + "id": "da639285", + "metadata": {}, + "source": [ + "And here's an example of using this in a chat model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8f483d7d", + "metadata": {}, + "outputs": [], + "source": [ + "chat_model = ChatOpenAI(temperature=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f761cbf1", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = ChatPromptTemplate(\n", + " messages=[\n", + " HumanMessagePromptTemplate.from_template(\"answer the users question as best as possible.\\n{format_instructions}\\n{question}\") \n", + " ],\n", + " input_variables=[\"question\"],\n", + " partial_variables={\"format_instructions\": format_instructions}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "edd73ae3", + "metadata": {}, + "outputs": [], + "source": [ + "_input = prompt.format_prompt(question=\"what's the capital of france\")\n", + "output = chat_model(_input.to_messages())" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a3c8b91e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'answer': 'Paris', 'source': 'https://en.wikipedia.org/wiki/Paris'}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_parser.parse(output.content)" + ] + }, + { + "cell_type": "markdown", + "id": "9936fa27", + "metadata": {}, + "source": [ + "## CommaSeparatedListOutputParser\n", + "\n", + "This output parser can be used to get a list of items as output." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "872246d7", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.output_parsers import CommaSeparatedListOutputParser" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c3f9aee6", + "metadata": {}, + "outputs": [], + "source": [ + "output_parser = CommaSeparatedListOutputParser()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e77871b7", + "metadata": {}, + "outputs": [], + "source": [ + "format_instructions = output_parser.get_format_instructions()\n", + "prompt = PromptTemplate(\n", + " template=\"List five {subject}.\\n{format_instructions}\",\n", + " input_variables=[\"subject\"],\n", + " partial_variables={\"format_instructions\": format_instructions}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a71cb5d3", + "metadata": {}, + "outputs": [], + "source": [ + "model = OpenAI(temperature=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "783d7d98", + "metadata": {}, + "outputs": [], + "source": [ + "_input = prompt.format(subject=\"ice cream flavors\")\n", + "output = model(_input)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "fcb81344", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['Vanilla',\n", + " 'Chocolate',\n", + " 'Strawberry',\n", + " 'Mint Chocolate Chip',\n", + " 'Cookies and Cream']" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_parser.parse(output)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cba6d8e3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/output_parsers/__init__.py b/langchain/output_parsers/__init__.py index 8509b6f2..cac915e0 100644 --- a/langchain/output_parsers/__init__.py +++ b/langchain/output_parsers/__init__.py @@ -4,10 +4,13 @@ from langchain.output_parsers.list import ( ListOutputParser, ) from langchain.output_parsers.regex import RegexParser +from langchain.output_parsers.structured import ResponseSchema, StructuredOutputParser __all__ = [ "RegexParser", "ListOutputParser", "CommaSeparatedListOutputParser", "BaseOutputParser", + "StructuredOutputParser", + "ResponseSchema", ] diff --git a/langchain/output_parsers/base.py b/langchain/output_parsers/base.py index 3ea30f9c..f3598416 100644 --- a/langchain/output_parsers/base.py +++ b/langchain/output_parsers/base.py @@ -13,6 +13,9 @@ class BaseOutputParser(BaseModel, ABC): def parse(self, text: str) -> Any: """Parse the output of an LLM call.""" + def get_format_instructions(self) -> str: + raise NotImplementedError + @property def _type(self) -> str: """Return the type key.""" diff --git a/langchain/output_parsers/format_instructions.py b/langchain/output_parsers/format_instructions.py new file mode 100644 index 00000000..3653d477 --- /dev/null +++ b/langchain/output_parsers/format_instructions.py @@ -0,0 +1,9 @@ +# flake8: noqa + +STRUCTURED_FORMAT_INSTRUCTIONS = """The output should be a markdown code snippet formatted in the following schema: + +```json +{{ +{format} +}} +```""" diff --git a/langchain/output_parsers/list.py b/langchain/output_parsers/list.py index 028685f6..aebbf409 100644 --- a/langchain/output_parsers/list.py +++ b/langchain/output_parsers/list.py @@ -17,6 +17,12 @@ class ListOutputParser(BaseOutputParser): class CommaSeparatedListOutputParser(ListOutputParser): """Parse out comma separated lists.""" + def get_format_instructions(self) -> str: + return ( + "Your response should be a list of comma separated values, " + "eg: `foo, bar, baz`" + ) + def parse(self, text: str) -> List[str]: """Parse the output of an LLM call.""" return text.strip().split(", ") diff --git a/langchain/output_parsers/structured.py b/langchain/output_parsers/structured.py new file mode 100644 index 00000000..4a32d23f --- /dev/null +++ b/langchain/output_parsers/structured.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import json +from typing import List + +from pydantic import BaseModel + +from langchain.output_parsers.base import BaseOutputParser +from langchain.output_parsers.format_instructions import STRUCTURED_FORMAT_INSTRUCTIONS + +line_template = '\t"{name}": {type} // {description}' + + +class ResponseSchema(BaseModel): + name: str + description: str + + +def _get_sub_string(schema: ResponseSchema) -> str: + return line_template.format( + name=schema.name, description=schema.description, type="string" + ) + + +class StructuredOutputParser(BaseOutputParser): + response_schemas: List[ResponseSchema] + + @classmethod + def from_response_schemas( + cls, response_schemas: List[ResponseSchema] + ) -> StructuredOutputParser: + return cls(response_schemas=response_schemas) + + def get_format_instructions(self) -> str: + schema_str = "\n".join( + [_get_sub_string(schema) for schema in self.response_schemas] + ) + return STRUCTURED_FORMAT_INSTRUCTIONS.format(format=schema_str) + + def parse(self, text: str) -> BaseModel: + json_string = text.split("```json")[1].strip().strip("```").strip() + json_obj = json.loads(json_string) + for schema in self.response_schemas: + if schema.name not in json_obj: + raise ValueError( + f"Got invalid return object. Expected key `{schema.name}` " + f"to be present, but got {json_obj}" + ) + return json_obj diff --git a/langchain/prompts/chat.py b/langchain/prompts/chat.py index c52df499..87f67a6d 100644 --- a/langchain/prompts/chat.py +++ b/langchain/prompts/chat.py @@ -159,6 +159,7 @@ class ChatPromptTemplate(BasePromptTemplate, ABC): return self.format_prompt(**kwargs).to_string() def format_prompt(self, **kwargs: Any) -> PromptValue: + kwargs = self._merge_partial_and_user_variables(**kwargs) result = [] for message_template in self.messages: if isinstance(message_template, BaseMessage):