From 6a4a950a3c7cc6fdd54e8ae1713ec91075f8a656 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sun, 18 Jun 2023 22:49:47 -0700 Subject: [PATCH] changes to llm chain (#6328) - return raw and full output (but keep run shortcut method functional) - change output parser to take in generations (good for working with messages) - add output parser to base class, always run (default to same as current) --------- Co-authored-by: Eugene Yurtsev --- .../chains/additional/extraction.ipynb | 43 +++- .../chains/additional/qa_citations.ipynb | 181 ++++++++++++++ .../modules/chains/additional/tagging.ipynb | 63 +++-- langchain/chains/__init__.py | 2 + langchain/chains/base.py | 20 +- langchain/chains/llm.py | 81 ++++-- langchain/chains/loading.py | 7 +- langchain/chains/openai_functions.py | 233 ------------------ langchain/chains/openai_functions/__init__.py | 19 ++ .../openai_functions/citation_fuzzy_match.py | 101 ++++++++ .../chains/openai_functions/extraction.py | 81 ++++++ langchain/chains/openai_functions/tagging.py | 61 +++++ langchain/chains/openai_functions/utils.py | 28 +++ langchain/output_parsers/openai_functions.py | 51 ++++ langchain/prompts/loading.py | 5 +- langchain/schema.py | 26 +- 16 files changed, 704 insertions(+), 298 deletions(-) create mode 100644 docs/extras/modules/chains/additional/qa_citations.ipynb delete mode 100644 langchain/chains/openai_functions.py create mode 100644 langchain/chains/openai_functions/__init__.py create mode 100644 langchain/chains/openai_functions/citation_fuzzy_match.py create mode 100644 langchain/chains/openai_functions/extraction.py create mode 100644 langchain/chains/openai_functions/tagging.py create mode 100644 langchain/chains/openai_functions/utils.py create mode 100644 langchain/output_parsers/openai_functions.py diff --git a/docs/extras/modules/chains/additional/extraction.ipynb b/docs/extras/modules/chains/additional/extraction.ipynb index 3f7b4f51..9ede44dd 100644 --- a/docs/extras/modules/chains/additional/extraction.ipynb +++ b/docs/extras/modules/chains/additional/extraction.ipynb @@ -17,7 +17,16 @@ "execution_count": 1, "id": "34f04daf", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from langchain.chat_models import ChatOpenAI\n", "from langchain.chains import create_extraction_chain, create_extraction_chain_pydantic\n", @@ -71,7 +80,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 4, "id": "640bd005", "metadata": {}, "outputs": [], @@ -84,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "64313214", "metadata": {}, "outputs": [], @@ -102,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "id": "cc5436ed", "metadata": {}, "outputs": [ @@ -119,7 +128,7 @@ " 'person_hair_color': 'brunette'}]" ] }, - "execution_count": 8, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -150,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "id": "6792866b", "metadata": {}, "outputs": [], @@ -161,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "id": "36a63761", "metadata": {}, "outputs": [], @@ -176,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "8ffd1e57", "metadata": {}, "outputs": [], @@ -186,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "id": "24baa954", "metadata": { "scrolled": false @@ -220,7 +229,7 @@ " Properties(person_name='Claudia', person_height=6, person_hair_color='brunette', dog_breed=None, dog_name=None)]" ] }, - "execution_count": 13, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -228,13 +237,21 @@ "source": [ "chain.run(inp)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0df61283", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "general", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "general" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -246,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/docs/extras/modules/chains/additional/qa_citations.ipynb b/docs/extras/modules/chains/additional/qa_citations.ipynb new file mode 100644 index 00000000..b53c3405 --- /dev/null +++ b/docs/extras/modules/chains/additional/qa_citations.ipynb @@ -0,0 +1,181 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9b5c258f", + "metadata": {}, + "source": [ + "# Question-Answering Citations\n", + "\n", + "This notebook shows how to use OpenAI functions ability to extract citations from text." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "eae4ca3e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from langchain.chains import create_citation_fuzzy_match_chain\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2c6e62ee", + "metadata": {}, + "outputs": [], + "source": [ + "question = \"What did the author do during college?\"\n", + "context = \"\"\"\n", + "My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.\n", + "I went to an arts highschool but in university I studied Computational Mathematics and physics. \n", + "As part of coop I worked at many companies including Stitchfix, Facebook.\n", + "I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "078e0300", + "metadata": {}, + "outputs": [], + "source": [ + "llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "02cad6d0", + "metadata": {}, + "outputs": [], + "source": [ + "chain = create_citation_fuzzy_match_chain(llm)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e3c6e7ba", + "metadata": {}, + "outputs": [], + "source": [ + "result = chain.run(question=question, context=context)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6f7615f2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n" + ] + } + ], + "source": [ + "print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3be6f366", + "metadata": {}, + "outputs": [], + "source": [ + "def highlight(text, span):\n", + " return (\n", + " \"...\"\n", + " + text[span[0] - 20 : span[0]]\n", + " + \"*\"\n", + " + \"\\033[91m\"\n", + " + text[span[0] : span[1]]\n", + " + \"\\033[0m\"\n", + " + \"*\"\n", + " + text[span[1] : span[1] + 20]\n", + " + \"...\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "636c4528", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Statement: The author studied Computational Mathematics and physics in university.\n", + "Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n", + "As part of coop I...\n", + "\n", + "Statement: The author started the Data Science club at the University of Waterloo.\n", + "Citation: ...titchfix, Facebook.\n", + "*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n", + "\n", + "Statement: The author was the president of the Data Science club for 2 years.\n", + "Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n", + "...\n", + "\n" + ] + } + ], + "source": [ + "for fact in result.answer:\n", + " print(\"Statement:\", fact.fact)\n", + " for span in fact.get_spans(context):\n", + " print(\"Citation:\", highlight(context, span))\n", + " print()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8409cab0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/extras/modules/chains/additional/tagging.ipynb b/docs/extras/modules/chains/additional/tagging.ipynb index 49e85abd..7c6487fe 100644 --- a/docs/extras/modules/chains/additional/tagging.ipynb +++ b/docs/extras/modules/chains/additional/tagging.ipynb @@ -17,7 +17,16 @@ "execution_count": 1, "id": "bafb496a", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n", + " warnings.warn(\n" + ] + } + ], "source": [ "from langchain.chat_models import ChatOpenAI\n", "from langchain.chains import create_tagging_chain, create_tagging_chain_pydantic\n", @@ -52,7 +61,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "8329f943", "metadata": {}, "outputs": [], @@ -68,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "6146ae70", "metadata": {}, "outputs": [], @@ -88,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 5, "id": "5509b6a6", "metadata": {}, "outputs": [ @@ -98,7 +107,7 @@ "{'sentiment': 'positive', 'language': 'Spanish'}" ] }, - "execution_count": 59, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -110,17 +119,17 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 6, "id": "9154474c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'Spanish'}" + "{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'es'}" ] }, - "execution_count": 60, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -132,7 +141,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 7, "id": "aae85b27", "metadata": {}, "outputs": [ @@ -142,7 +151,7 @@ "{'sentiment': 'positive', 'aggressiveness': 0, 'language': 'English'}" ] }, - "execution_count": 61, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -176,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "id": "6a5f7961", "metadata": {}, "outputs": [], @@ -200,7 +209,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 9, "id": "e5a5881f", "metadata": {}, "outputs": [], @@ -218,7 +227,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "id": "d9b9d53d", "metadata": {}, "outputs": [ @@ -228,7 +237,7 @@ "{'sentiment': 'happy', 'aggressiveness': 0, 'language': 'spanish'}" ] }, - "execution_count": 13, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -240,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "id": "1c12fa00", "metadata": {}, "outputs": [ @@ -250,7 +259,7 @@ "{'sentiment': 'sad', 'aggressiveness': 10, 'language': 'spanish'}" ] }, - "execution_count": 14, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -262,7 +271,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "id": "0bdfcb05", "metadata": {}, "outputs": [ @@ -272,7 +281,7 @@ "{'sentiment': 'neutral', 'aggressiveness': 0, 'language': 'english'}" ] }, - "execution_count": 15, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -304,7 +313,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "id": "bf1f367e", "metadata": {}, "outputs": [], @@ -315,7 +324,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 14, "id": "83a2e826", "metadata": {}, "outputs": [], @@ -334,7 +343,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 15, "id": "6e404892", "metadata": {}, "outputs": [], @@ -344,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 16, "id": "b5fc43c4", "metadata": {}, "outputs": [], @@ -355,7 +364,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 17, "id": "5074bcc3", "metadata": {}, "outputs": [ @@ -365,7 +374,7 @@ "Tags(sentiment='sad', aggressiveness=10, language='spanish')" ] }, - "execution_count": 26, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -377,9 +386,9 @@ ], "metadata": { "kernelspec": { - "display_name": "general", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "general" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -391,7 +400,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.1" } }, "nbformat": 4, diff --git a/langchain/chains/__init__.py b/langchain/chains/__init__.py index 747973c9..ccc8a2a9 100644 --- a/langchain/chains/__init__.py +++ b/langchain/chains/__init__.py @@ -24,6 +24,7 @@ from langchain.chains.mapreduce import MapReduceChain from langchain.chains.moderation import OpenAIModerationChain from langchain.chains.natbot.base import NatBotChain from langchain.chains.openai_functions import ( + create_citation_fuzzy_match_chain, create_extraction_chain, create_extraction_chain_pydantic, create_tagging_chain, @@ -93,4 +94,5 @@ __all__ = [ "create_tagging_chain", "create_tagging_chain_pydantic", "load_chain", + "create_citation_fuzzy_match_chain", ] diff --git a/langchain/chains/base.py b/langchain/chains/base.py index dfa1b23a..2e54402a 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -247,6 +247,15 @@ class Chain(Serializable, ABC): """Call the chain on all inputs in the list.""" return [self(inputs, callbacks=callbacks) for inputs in input_list] + @property + def _run_output_key(self) -> str: + if len(self.output_keys) != 1: + raise ValueError( + f"`run` not supported when there is not exactly " + f"one output key. Got {self.output_keys}." + ) + return self.output_keys[0] + def run( self, *args: Any, @@ -255,19 +264,16 @@ class Chain(Serializable, ABC): **kwargs: Any, ) -> str: """Run the chain as text in, text out or multiple variables, text out.""" - if len(self.output_keys) != 1: - raise ValueError( - f"`run` not supported when there is not exactly " - f"one output key. Got {self.output_keys}." - ) + # Run at start to make sure this is possible/defined + _output_key = self._run_output_key if args and not kwargs: if len(args) != 1: raise ValueError("`run` supports only one positional argument.") - return self(args[0], callbacks=callbacks, tags=tags)[self.output_keys[0]] + return self(args[0], callbacks=callbacks, tags=tags)[_output_key] if kwargs and not args: - return self(kwargs, callbacks=callbacks, tags=tags)[self.output_keys[0]] + return self(kwargs, callbacks=callbacks, tags=tags)[_output_key] if not kwargs and not args: raise ValueError( diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index f8154c8c..a8c7f155 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -1,9 +1,10 @@ """Chain that just formats a prompt and calls an LLM.""" from __future__ import annotations +import warnings from typing import Any, Dict, List, Optional, Sequence, Tuple, Union -from pydantic import Extra +from pydantic import Extra, Field from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( @@ -18,7 +19,12 @@ from langchain.input import get_colored_text from langchain.load.dump import dumpd from langchain.prompts.base import BasePromptTemplate from langchain.prompts.prompt import PromptTemplate -from langchain.schema import LLMResult, PromptValue +from langchain.schema import ( + BaseLLMOutputParser, + LLMResult, + NoOpOutputParser, + PromptValue, +) class LLMChain(Chain): @@ -42,7 +48,16 @@ class LLMChain(Chain): prompt: BasePromptTemplate """Prompt object to use.""" llm: BaseLanguageModel + """Language model to call.""" output_key: str = "text" #: :meta private: + output_parser: BaseLLMOutputParser = Field(default_factory=NoOpOutputParser) + """Output parser to use. + Defaults to one that takes the most likely string but does not change it + otherwise.""" + return_final_only: bool = True + """Whether to return only the final parsed result. Defaults to True. + If false, will return a bunch of extra information about the generation.""" + llm_kwargs: dict = Field(default_factory=dict) class Config: """Configuration for this pydantic object.""" @@ -64,7 +79,10 @@ class LLMChain(Chain): :meta private: """ - return [self.output_key] + if self.return_final_only: + return [self.output_key] + else: + return [self.output_key, "full_generation"] def _call( self, @@ -82,7 +100,10 @@ class LLMChain(Chain): """Generate LLM result from inputs.""" prompts, stop = self.prep_prompts(input_list, run_manager=run_manager) return self.llm.generate_prompt( - prompts, stop, callbacks=run_manager.get_child() if run_manager else None + prompts, + stop, + callbacks=run_manager.get_child() if run_manager else None, + **self.llm_kwargs, ) async def agenerate( @@ -93,7 +114,10 @@ class LLMChain(Chain): """Generate LLM result from inputs.""" prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager) return await self.llm.agenerate_prompt( - prompts, stop, callbacks=run_manager.get_child() if run_manager else None + prompts, + stop, + callbacks=run_manager.get_child() if run_manager else None, + **self.llm_kwargs, ) def prep_prompts( @@ -184,13 +208,23 @@ class LLMChain(Chain): await run_manager.on_chain_end({"outputs": outputs}) return outputs - def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: + @property + def _run_output_key(self) -> str: + return self.output_key + + def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]: """Create outputs from response.""" - return [ + result = [ # Get the text of the top generated string. - {self.output_key: generation[0].text} - for generation in response.generations + { + self.output_key: self.output_parser.parse_result(generation), + "full_generation": generation, + } + for generation in llm_result.generations ] + if self.return_final_only: + result = [{self.output_key: r[self.output_key]} for r in result] + return result async def _acall( self, @@ -238,6 +272,10 @@ class LLMChain(Chain): self, callbacks: Callbacks = None, **kwargs: Any ) -> Union[str, List[str], Dict[str, Any]]: """Call predict and then parse the results.""" + warnings.warn( + "The predict_and_parse method is deprecated, " + "instead pass an output parser directly to LLMChain." + ) result = self.predict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: return self.prompt.output_parser.parse(result) @@ -248,6 +286,10 @@ class LLMChain(Chain): self, callbacks: Callbacks = None, **kwargs: Any ) -> Union[str, List[str], Dict[str, str]]: """Call apredict and then parse the results.""" + warnings.warn( + "The apredict_and_parse method is deprecated, " + "instead pass an output parser directly to LLMChain." + ) result = await self.apredict(callbacks=callbacks, **kwargs) if self.prompt.output_parser is not None: return self.prompt.output_parser.parse(result) @@ -258,25 +300,34 @@ class LLMChain(Chain): self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" + warnings.warn( + "The apply_and_parse method is deprecated, " + "instead pass an output parser directly to LLMChain." + ) result = self.apply(input_list, callbacks=callbacks) - return self._parse_result(result) + return self._parse_generation(result) - def _parse_result( - self, result: List[Dict[str, str]] + def _parse_generation( + self, generation: List[Dict[str, str]] ) -> Sequence[Union[str, List[str], Dict[str, str]]]: if self.prompt.output_parser is not None: return [ - self.prompt.output_parser.parse(res[self.output_key]) for res in result + self.prompt.output_parser.parse(res[self.output_key]) + for res in generation ] else: - return result + return generation async def aapply_and_parse( self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None ) -> Sequence[Union[str, List[str], Dict[str, str]]]: """Call apply and then parse the results.""" + warnings.warn( + "The aapply_and_parse method is deprecated, " + "instead pass an output parser directly to LLMChain." + ) result = await self.aapply(input_list, callbacks=callbacks) - return self._parse_result(result) + return self._parse_generation(result) @property def _chain_type(self) -> str: diff --git a/langchain/chains/loading.py b/langchain/chains/loading.py index f115b905..a01872fb 100644 --- a/langchain/chains/loading.py +++ b/langchain/chains/loading.py @@ -24,7 +24,11 @@ from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChai from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA from langchain.chains.sql_database.base import SQLDatabaseChain from langchain.llms.loading import load_llm, load_llm_from_config -from langchain.prompts.loading import load_prompt, load_prompt_from_config +from langchain.prompts.loading import ( + _load_output_parser, + load_prompt, + load_prompt_from_config, +) from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/" @@ -47,6 +51,7 @@ def _load_llm_chain(config: dict, **kwargs: Any) -> LLMChain: prompt = load_prompt(config.pop("prompt_path")) else: raise ValueError("One of `prompt` or `prompt_path` must be present.") + _load_output_parser(config) return LLMChain(llm=llm, prompt=prompt, **config) diff --git a/langchain/chains/openai_functions.py b/langchain/chains/openai_functions.py deleted file mode 100644 index 5b0a7f2d..00000000 --- a/langchain/chains/openai_functions.py +++ /dev/null @@ -1,233 +0,0 @@ -import json -from functools import partial -from typing import Any, Dict, List, Optional - -from pydantic import BaseModel, Field - -from langchain.base_language import BaseLanguageModel -from langchain.callbacks.manager import ( - AsyncCallbackManagerForChainRun, - CallbackManagerForChainRun, -) -from langchain.chains.base import Chain -from langchain.chains.sequential import SimpleSequentialChain -from langchain.chains.transform import TransformChain -from langchain.prompts.base import BasePromptTemplate -from langchain.prompts.chat import ChatPromptTemplate - -EXTRACTION_NAME = "information_extraction" -EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}} - - -def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any: - """ - Resolves the $ref keys in a JSON schema object using the provided definitions. - """ - if isinstance(schema, list): - for i, item in enumerate(schema): - schema[i] = _resolve_schema_references(item, definitions) - elif isinstance(schema, dict): - if "$ref" in schema: - ref_key = schema.pop("$ref").split("/")[-1] - ref = definitions.get(ref_key, {}) - schema.update(ref) - else: - for key, value in schema.items(): - schema[key] = _resolve_schema_references(value, definitions) - return schema - - -def _get_function_arguments(inputs: dict) -> str: - message = inputs["input"] - try: - func_call = message.additional_kwargs["function_call"] - except ValueError as exc: - raise ValueError(f"Could not parse function call: {exc}") - - return func_call["arguments"] - - -def _parse_tag(inputs: dict) -> dict: - args = _get_function_arguments(inputs) - return {"output": json.loads(args)} - - -def _parse_tag_pydantic(inputs: dict, pydantic_schema: Any) -> dict: - args = _get_function_arguments(inputs) - args = pydantic_schema.parse_raw(args) - return {"output": args} - - -def _parse_entities(inputs: dict) -> dict: - args = _get_function_arguments(inputs) - return {"output": json.loads(args)["info"]} - - -def _parse_entities_pydantic(inputs: dict, pydantic_schema: Any) -> dict: - args = _get_function_arguments(inputs) - pydantic_args = pydantic_schema.parse_raw(args) - return {"output": pydantic_args.info} - - -class OpenAIFunctionsChain(Chain): - prompt: BasePromptTemplate - llm: BaseLanguageModel - functions: List[Dict] - kwargs: Dict = Field(default_factory=dict) - - @property - def input_keys(self) -> List[str]: - return self.prompt.input_variables - - @property - def output_keys(self) -> List[str]: - return ["output"] - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - _inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables} - prompt = self.prompt.format_prompt(**_inputs) - messages = prompt.to_messages() - _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - predicted_message = self.llm.predict_messages( - messages, functions=self.functions, callbacks=callbacks, **self.kwargs - ) - return {"output": predicted_message} - - async def _acall( - self, - inputs: Dict[str, Any], - run_manager: Optional[AsyncCallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - _inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables} - prompt = self.prompt.format_prompt(**_inputs) - messages = prompt.to_messages() - _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() - callbacks = _run_manager.get_child() - predicted_message = await self.llm.apredict_messages( - messages, functions=self.functions, callbacks=callbacks, **self.kwargs - ) - return {"output": predicted_message} - - -def _convert_schema(schema: dict) -> dict: - props = {k: {"title": k, **v} for k, v in schema["properties"].items()} - return { - "type": "object", - "properties": props, - "required": schema.get("required", []), - } - - -def _get_extraction_functions(entity_schema: dict) -> List[dict]: - return [ - { - "name": EXTRACTION_NAME, - "description": "Extracts the relevant information from the passage.", - "parameters": { - "type": "object", - "properties": { - "info": {"type": "array", "items": _convert_schema(entity_schema)} - }, - "required": ["info"], - }, - } - ] - - -def _get_tagging_functions(schema: dict) -> List[dict]: - return [ - { - "name": EXTRACTION_NAME, - "description": "Extracts the relevant information from the passage.", - "parameters": _convert_schema(schema), - } - ] - - -_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\ - in the following passage together with their properties. - -Passage: -{input} -""" - - -def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain: - functions = _get_extraction_functions(schema) - prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) - chain = OpenAIFunctionsChain( - llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS - ) - parsing_chain = TransformChain( - transform=_parse_entities, - input_variables=["input"], - output_variables=["output"], - ) - return SimpleSequentialChain(chains=[chain, parsing_chain]) - - -def create_extraction_chain_pydantic( - pydantic_schema: Any, llm: BaseLanguageModel -) -> Chain: - class PydanticSchema(BaseModel): - info: List[pydantic_schema] # type: ignore - - openai_schema = PydanticSchema.schema() - openai_schema = _resolve_schema_references( - openai_schema, openai_schema["definitions"] - ) - - functions = _get_extraction_functions(openai_schema) - prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) - chain = OpenAIFunctionsChain( - llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS - ) - pydantic_parsing_chain = TransformChain( - transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema), - input_variables=["input"], - output_variables=["output"], - ) - return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain]) - - -_TAGGING_TEMPLATE = """Extract the desired information from the following passage. - -Passage: -{input} -""" - - -def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: - functions = _get_tagging_functions(schema) - prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) - chain = OpenAIFunctionsChain( - llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS - ) - parsing_chain = TransformChain( - transform=_parse_tag, input_variables=["input"], output_variables=["output"] - ) - return SimpleSequentialChain(chains=[chain, parsing_chain]) - - -def create_tagging_chain_pydantic( - pydantic_schema: Any, llm: BaseLanguageModel -) -> Chain: - openai_schema = pydantic_schema.schema() - - functions = _get_tagging_functions(openai_schema) - prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) - chain = OpenAIFunctionsChain( - llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS - ) - pydantic_parsing_chain = TransformChain( - transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema), - input_variables=["input"], - output_variables=["output"], - ) - - return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain]) diff --git a/langchain/chains/openai_functions/__init__.py b/langchain/chains/openai_functions/__init__.py new file mode 100644 index 00000000..079dbd3b --- /dev/null +++ b/langchain/chains/openai_functions/__init__.py @@ -0,0 +1,19 @@ +from langchain.chains.openai_functions.citation_fuzzy_match import ( + create_citation_fuzzy_match_chain, +) +from langchain.chains.openai_functions.extraction import ( + create_extraction_chain, + create_extraction_chain_pydantic, +) +from langchain.chains.openai_functions.tagging import ( + create_tagging_chain, + create_tagging_chain_pydantic, +) + +__all__ = [ + "create_tagging_chain", + "create_tagging_chain_pydantic", + "create_extraction_chain_pydantic", + "create_extraction_chain", + "create_citation_fuzzy_match_chain", +] diff --git a/langchain/chains/openai_functions/citation_fuzzy_match.py b/langchain/chains/openai_functions/citation_fuzzy_match.py new file mode 100644 index 00000000..9f4ba6f9 --- /dev/null +++ b/langchain/chains/openai_functions/citation_fuzzy_match.py @@ -0,0 +1,101 @@ +from typing import Iterator, List + +from pydantic import BaseModel, Field + +from langchain.base_language import BaseLanguageModel +from langchain.chains.llm import LLMChain +from langchain.output_parsers.openai_functions import ( + PydanticOutputFunctionsParser, +) +from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain.schema import HumanMessage, SystemMessage + + +class FactWithEvidence(BaseModel): + """Class representing single statement. + + Each fact has a body and a list of sources. + If there are multiple facts make sure to break them apart + such that each one only uses a set of sources that are relevant to it. + """ + + fact: str = Field(..., description="Body of the sentence, as part of a response") + substring_quote: List[str] = Field( + ..., + description=( + "Each source should be a direct quote from the context, " + "as a substring of the original content" + ), + ) + + def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]: + import regex + + minor = quote + major = context + + errs_ = 0 + s = regex.search(f"({minor}){{e<={errs_}}}", major) + while s is None and errs_ <= errs: + errs_ += 1 + s = regex.search(f"({minor}){{e<={errs_}}}", major) + + if s is not None: + yield from s.spans() + + def get_spans(self, context: str) -> Iterator[str]: + for quote in self.substring_quote: + yield from self._get_span(quote, context) + + +class QuestionAnswer(BaseModel): + """A question and its answer as a list of facts each one should have a source. + each sentence contains a body and a list of sources.""" + + question: str = Field(..., description="Question that was asked") + answer: List[FactWithEvidence] = Field( + ..., + description=( + "Body of the answer, each fact should be " + "its separate object with a body and a list of sources" + ), + ) + + +def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain: + output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer) + schema = QuestionAnswer.schema() + functions = [ + { + "name": schema["title"], + "description": schema["description"], + "parameters": schema, + } + ] + kwargs = {"function_call": {"name": schema["title"]}} + messages = [ + SystemMessage( + content=( + "You are a world class algorithm to answer " + "questions with correct and exact citations." + ) + ), + HumanMessage(content="Answer question using the following context"), + HumanMessagePromptTemplate.from_template("{context}"), + HumanMessagePromptTemplate.from_template("Question: {question}"), + HumanMessage( + content=( + "Tips: Make sure to cite your sources, " + "and use the exact words from the context." + ) + ), + ] + prompt = ChatPromptTemplate(messages=messages) + + chain = LLMChain( + llm=llm, + prompt=prompt, + llm_kwargs={**{"functions": functions}, **kwargs}, + output_parser=output_parser, + ) + return chain diff --git a/langchain/chains/openai_functions/extraction.py b/langchain/chains/openai_functions/extraction.py new file mode 100644 index 00000000..0db5ddd6 --- /dev/null +++ b/langchain/chains/openai_functions/extraction.py @@ -0,0 +1,81 @@ +from typing import Any, List + +from pydantic import BaseModel + +from langchain.base_language import BaseLanguageModel +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.openai_functions.utils import ( + _convert_schema, + _resolve_schema_references, +) +from langchain.output_parsers.openai_functions import ( + JsonKeyOutputFunctionsParser, + PydanticAttrOutputFunctionsParser, +) +from langchain.prompts import ChatPromptTemplate + +EXTRACTION_NAME = "information_extraction" +EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}} + + +def _get_extraction_functions(entity_schema: dict) -> List[dict]: + return [ + { + "name": EXTRACTION_NAME, + "description": "Extracts the relevant information from the passage.", + "parameters": { + "type": "object", + "properties": { + "info": {"type": "array", "items": _convert_schema(entity_schema)} + }, + "required": ["info"], + }, + } + ] + + +_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\ + in the following passage together with their properties. + +Passage: +{input} +""" + + +def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain: + functions = _get_extraction_functions(schema) + prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) + output_parser = JsonKeyOutputFunctionsParser(key_name="info") + chain = LLMChain( + llm=llm, + prompt=prompt, + llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, + output_parser=output_parser, + ) + return chain + + +def create_extraction_chain_pydantic( + pydantic_schema: Any, llm: BaseLanguageModel +) -> Chain: + class PydanticSchema(BaseModel): + info: List[pydantic_schema] # type: ignore + + openai_schema = PydanticSchema.schema() + openai_schema = _resolve_schema_references( + openai_schema, openai_schema["definitions"] + ) + + functions = _get_extraction_functions(openai_schema) + prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) + output_parser = PydanticAttrOutputFunctionsParser( + pydantic_schema=PydanticSchema, attr_name="info" + ) + chain = LLMChain( + llm=llm, + prompt=prompt, + llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, + output_parser=output_parser, + ) + return chain diff --git a/langchain/chains/openai_functions/tagging.py b/langchain/chains/openai_functions/tagging.py new file mode 100644 index 00000000..61c2fbfa --- /dev/null +++ b/langchain/chains/openai_functions/tagging.py @@ -0,0 +1,61 @@ +from typing import Any, List + +from langchain.base_language import BaseLanguageModel +from langchain.chains.base import Chain +from langchain.chains.llm import LLMChain +from langchain.chains.openai_functions.utils import _convert_schema +from langchain.output_parsers.openai_functions import ( + JsonOutputFunctionsParser, + PydanticOutputFunctionsParser, +) +from langchain.prompts import ChatPromptTemplate + +EXTRACTION_NAME = "information_extraction" +EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}} + + +def _get_tagging_functions(schema: dict) -> List[dict]: + return [ + { + "name": EXTRACTION_NAME, + "description": "Extracts the relevant information from the passage.", + "parameters": _convert_schema(schema), + } + ] + + +_TAGGING_TEMPLATE = """Extract the desired information from the following passage. + +Passage: +{input} +""" + + +def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: + functions = _get_tagging_functions(schema) + prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) + output_parser = JsonOutputFunctionsParser() + chain = LLMChain( + llm=llm, + prompt=prompt, + llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, + output_parser=output_parser, + ) + return chain + + +def create_tagging_chain_pydantic( + pydantic_schema: Any, llm: BaseLanguageModel +) -> Chain: + openai_schema = pydantic_schema.schema() + + functions = _get_tagging_functions(openai_schema) + prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) + output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) + chain = LLMChain( + llm=llm, + prompt=prompt, + llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, + output_parser=output_parser, + ) + return chain diff --git a/langchain/chains/openai_functions/utils.py b/langchain/chains/openai_functions/utils.py new file mode 100644 index 00000000..4ad2e3e6 --- /dev/null +++ b/langchain/chains/openai_functions/utils.py @@ -0,0 +1,28 @@ +from typing import Any, Dict + + +def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any: + """ + Resolves the $ref keys in a JSON schema object using the provided definitions. + """ + if isinstance(schema, list): + for i, item in enumerate(schema): + schema[i] = _resolve_schema_references(item, definitions) + elif isinstance(schema, dict): + if "$ref" in schema: + ref_key = schema.pop("$ref").split("/")[-1] + ref = definitions.get(ref_key, {}) + schema.update(ref) + else: + for key, value in schema.items(): + schema[key] = _resolve_schema_references(value, definitions) + return schema + + +def _convert_schema(schema: dict) -> dict: + props = {k: {"title": k, **v} for k, v in schema["properties"].items()} + return { + "type": "object", + "properties": props, + "required": schema.get("required", []), + } diff --git a/langchain/output_parsers/openai_functions.py b/langchain/output_parsers/openai_functions.py new file mode 100644 index 00000000..9f516ac4 --- /dev/null +++ b/langchain/output_parsers/openai_functions.py @@ -0,0 +1,51 @@ +import json +from typing import Any, List + +from langchain.schema import BaseLLMOutputParser, ChatGeneration, Generation + + +class OutputFunctionsParser(BaseLLMOutputParser[Any]): + def parse_result(self, result: List[Generation]) -> Any: + generation = result[0] + if not isinstance(generation, ChatGeneration): + raise ValueError( + "This output parser can only be used with a chat generation." + ) + message = generation.message + try: + func_call = message.additional_kwargs["function_call"] + except ValueError as exc: + raise ValueError(f"Could not parse function call: {exc}") + + return func_call["arguments"] + + +class JsonOutputFunctionsParser(OutputFunctionsParser): + def parse_result(self, result: List[Generation]) -> Any: + _args = super().parse_result(result) + return json.loads(_args) + + +class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): + key_name: str + + def parse_result(self, result: List[Generation]) -> Any: + res = super().parse_result(result) + return res[self.key_name] + + +class PydanticOutputFunctionsParser(OutputFunctionsParser): + pydantic_schema: Any + + def parse_result(self, result: List[Generation]) -> Any: + _args = super().parse_result(result) + pydantic_args = self.pydantic_schema.parse_raw(_args) + return pydantic_args + + +class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): + attr_name: str + + def parse_result(self, result: List[Generation]) -> Any: + result = super().parse_result(result) + return getattr(result, self.attr_name) diff --git a/langchain/prompts/loading.py b/langchain/prompts/loading.py index 4e73f569..20c8f8d7 100644 --- a/langchain/prompts/loading.py +++ b/langchain/prompts/loading.py @@ -11,6 +11,7 @@ from langchain.output_parsers.regex import RegexParser from langchain.prompts.base import BasePromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.prompt import PromptTemplate +from langchain.schema import BaseLLMOutputParser, NoOpOutputParser from langchain.utilities.loading import try_load_from_hub URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/" @@ -78,7 +79,9 @@ def _load_output_parser(config: dict) -> dict: _config = config.pop("output_parser") output_parser_type = _config.pop("_type") if output_parser_type == "regex_parser": - output_parser = RegexParser(**_config) + output_parser: BaseLLMOutputParser = RegexParser(**_config) + elif output_parser_type == "default": + output_parser = NoOpOutputParser(**_config) else: raise ValueError(f"Unsupported output parser {output_parser_type}") config["output_parser"] = output_parser diff --git a/langchain/schema.py b/langchain/schema.py index 77fa6989..b8d75d2b 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -339,12 +339,21 @@ Memory = BaseMemory T = TypeVar("T") -class BaseOutputParser(Serializable, ABC, Generic[T]): +class BaseLLMOutputParser(Serializable, ABC, Generic[T]): + @abstractmethod + def parse_result(self, result: List[Generation]) -> T: + """Parse LLM Result.""" + + +class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]): """Class to parse the output of an LLM call. Output parsers help structure language model responses. """ + def parse_result(self, result: List[Generation]) -> T: + return self.parse(result[0].text) + @abstractmethod def parse(self, text: str) -> T: """Parse the output of an LLM call. @@ -394,6 +403,21 @@ class BaseOutputParser(Serializable, ABC, Generic[T]): return output_parser_dict +class NoOpOutputParser(BaseOutputParser[str]): + """Output parser that just returns the text as is.""" + + @property + def lc_serializable(self) -> bool: + return True + + @property + def _type(self) -> str: + return "default" + + def parse(self, text: str) -> str: + return text + + class OutputParserException(ValueError): """Exception that output parsers should raise to signify a parsing error.