changes to llm chain (#6328)

- return raw and full output (but keep run shortcut method functional)
- change output parser to take in generations (good for working with
messages)
- add output parser to base class, always run (default to same as
current)

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
master
Harrison Chase 11 months ago committed by GitHub
parent d3c2eab0b3
commit 6a4a950a3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -17,7 +17,16 @@
"execution_count": 1,
"id": "34f04daf",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains import create_extraction_chain, create_extraction_chain_pydantic\n",
@ -71,7 +80,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"id": "640bd005",
"metadata": {},
"outputs": [],
@ -84,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "64313214",
"metadata": {},
"outputs": [],
@ -102,7 +111,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"id": "cc5436ed",
"metadata": {},
"outputs": [
@ -119,7 +128,7 @@
" 'person_hair_color': 'brunette'}]"
]
},
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -150,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"id": "6792866b",
"metadata": {},
"outputs": [],
@ -161,7 +170,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"id": "36a63761",
"metadata": {},
"outputs": [],
@ -176,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"id": "8ffd1e57",
"metadata": {},
"outputs": [],
@ -186,7 +195,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"id": "24baa954",
"metadata": {
"scrolled": false
@ -220,7 +229,7 @@
" Properties(person_name='Claudia', person_height=6, person_hair_color='brunette', dog_breed=None, dog_name=None)]"
]
},
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@ -228,13 +237,21 @@
"source": [
"chain.run(inp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0df61283",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "general",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "general"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -246,7 +263,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.1"
}
},
"nbformat": 4,

@ -0,0 +1,181 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9b5c258f",
"metadata": {},
"source": [
"# Question-Answering Citations\n",
"\n",
"This notebook shows how to use OpenAI functions ability to extract citations from text."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "eae4ca3e",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.chains import create_citation_fuzzy_match_chain\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2c6e62ee",
"metadata": {},
"outputs": [],
"source": [
"question = \"What did the author do during college?\"\n",
"context = \"\"\"\n",
"My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.\n",
"I went to an arts highschool but in university I studied Computational Mathematics and physics. \n",
"As part of coop I worked at many companies including Stitchfix, Facebook.\n",
"I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.\n",
"\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "078e0300",
"metadata": {},
"outputs": [],
"source": [
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "02cad6d0",
"metadata": {},
"outputs": [],
"source": [
"chain = create_citation_fuzzy_match_chain(llm)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "e3c6e7ba",
"metadata": {},
"outputs": [],
"source": [
"result = chain.run(question=question, context=context)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6f7615f2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n"
]
}
],
"source": [
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3be6f366",
"metadata": {},
"outputs": [],
"source": [
"def highlight(text, span):\n",
" return (\n",
" \"...\"\n",
" + text[span[0] - 20 : span[0]]\n",
" + \"*\"\n",
" + \"\\033[91m\"\n",
" + text[span[0] : span[1]]\n",
" + \"\\033[0m\"\n",
" + \"*\"\n",
" + text[span[1] : span[1] + 20]\n",
" + \"...\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "636c4528",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Statement: The author studied Computational Mathematics and physics in university.\n",
"Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n",
"As part of coop I...\n",
"\n",
"Statement: The author started the Data Science club at the University of Waterloo.\n",
"Citation: ...titchfix, Facebook.\n",
"*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
"\n",
"Statement: The author was the president of the Data Science club for 2 years.\n",
"Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n",
"...\n",
"\n"
]
}
],
"source": [
"for fact in result.answer:\n",
" print(\"Statement:\", fact.fact)\n",
" for span in fact.get_spans(context):\n",
" print(\"Citation:\", highlight(context, span))\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8409cab0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -17,7 +17,16 @@
"execution_count": 1,
"id": "bafb496a",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.chains import create_tagging_chain, create_tagging_chain_pydantic\n",
@ -52,7 +61,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"id": "8329f943",
"metadata": {},
"outputs": [],
@ -68,7 +77,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"id": "6146ae70",
"metadata": {},
"outputs": [],
@ -88,7 +97,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 5,
"id": "5509b6a6",
"metadata": {},
"outputs": [
@ -98,7 +107,7 @@
"{'sentiment': 'positive', 'language': 'Spanish'}"
]
},
"execution_count": 59,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@ -110,17 +119,17 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 6,
"id": "9154474c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'Spanish'}"
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'es'}"
]
},
"execution_count": 60,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -132,7 +141,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 7,
"id": "aae85b27",
"metadata": {},
"outputs": [
@ -142,7 +151,7 @@
"{'sentiment': 'positive', 'aggressiveness': 0, 'language': 'English'}"
]
},
"execution_count": 61,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -176,7 +185,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"id": "6a5f7961",
"metadata": {},
"outputs": [],
@ -200,7 +209,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 9,
"id": "e5a5881f",
"metadata": {},
"outputs": [],
@ -218,7 +227,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 10,
"id": "d9b9d53d",
"metadata": {},
"outputs": [
@ -228,7 +237,7 @@
"{'sentiment': 'happy', 'aggressiveness': 0, 'language': 'spanish'}"
]
},
"execution_count": 13,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@ -240,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 11,
"id": "1c12fa00",
"metadata": {},
"outputs": [
@ -250,7 +259,7 @@
"{'sentiment': 'sad', 'aggressiveness': 10, 'language': 'spanish'}"
]
},
"execution_count": 14,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
@ -262,7 +271,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 12,
"id": "0bdfcb05",
"metadata": {},
"outputs": [
@ -272,7 +281,7 @@
"{'sentiment': 'neutral', 'aggressiveness': 0, 'language': 'english'}"
]
},
"execution_count": 15,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@ -304,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 13,
"id": "bf1f367e",
"metadata": {},
"outputs": [],
@ -315,7 +324,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 14,
"id": "83a2e826",
"metadata": {},
"outputs": [],
@ -334,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 15,
"id": "6e404892",
"metadata": {},
"outputs": [],
@ -344,7 +353,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 16,
"id": "b5fc43c4",
"metadata": {},
"outputs": [],
@ -355,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 17,
"id": "5074bcc3",
"metadata": {},
"outputs": [
@ -365,7 +374,7 @@
"Tags(sentiment='sad', aggressiveness=10, language='spanish')"
]
},
"execution_count": 26,
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@ -377,9 +386,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "general",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "general"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -391,7 +400,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.1"
}
},
"nbformat": 4,

@ -24,6 +24,7 @@ from langchain.chains.mapreduce import MapReduceChain
from langchain.chains.moderation import OpenAIModerationChain
from langchain.chains.natbot.base import NatBotChain
from langchain.chains.openai_functions import (
create_citation_fuzzy_match_chain,
create_extraction_chain,
create_extraction_chain_pydantic,
create_tagging_chain,
@ -93,4 +94,5 @@ __all__ = [
"create_tagging_chain",
"create_tagging_chain_pydantic",
"load_chain",
"create_citation_fuzzy_match_chain",
]

@ -247,6 +247,15 @@ class Chain(Serializable, ABC):
"""Call the chain on all inputs in the list."""
return [self(inputs, callbacks=callbacks) for inputs in input_list]
@property
def _run_output_key(self) -> str:
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
return self.output_keys[0]
def run(
self,
*args: Any,
@ -255,19 +264,16 @@ class Chain(Serializable, ABC):
**kwargs: Any,
) -> str:
"""Run the chain as text in, text out or multiple variables, text out."""
if len(self.output_keys) != 1:
raise ValueError(
f"`run` not supported when there is not exactly "
f"one output key. Got {self.output_keys}."
)
# Run at start to make sure this is possible/defined
_output_key = self._run_output_key
if args and not kwargs:
if len(args) != 1:
raise ValueError("`run` supports only one positional argument.")
return self(args[0], callbacks=callbacks, tags=tags)[self.output_keys[0]]
return self(args[0], callbacks=callbacks, tags=tags)[_output_key]
if kwargs and not args:
return self(kwargs, callbacks=callbacks, tags=tags)[self.output_keys[0]]
return self(kwargs, callbacks=callbacks, tags=tags)[_output_key]
if not kwargs and not args:
raise ValueError(

@ -1,9 +1,10 @@
"""Chain that just formats a prompt and calls an LLM."""
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from pydantic import Extra
from pydantic import Extra, Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
@ -18,7 +19,12 @@ from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import LLMResult, PromptValue
from langchain.schema import (
BaseLLMOutputParser,
LLMResult,
NoOpOutputParser,
PromptValue,
)
class LLMChain(Chain):
@ -42,7 +48,16 @@ class LLMChain(Chain):
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
"""Language model to call."""
output_key: str = "text" #: :meta private:
output_parser: BaseLLMOutputParser = Field(default_factory=NoOpOutputParser)
"""Output parser to use.
Defaults to one that takes the most likely string but does not change it
otherwise."""
return_final_only: bool = True
"""Whether to return only the final parsed result. Defaults to True.
If false, will return a bunch of extra information about the generation."""
llm_kwargs: dict = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
@ -64,7 +79,10 @@ class LLMChain(Chain):
:meta private:
"""
return [self.output_key]
if self.return_final_only:
return [self.output_key]
else:
return [self.output_key, "full_generation"]
def _call(
self,
@ -82,7 +100,10 @@ class LLMChain(Chain):
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
async def agenerate(
@ -93,7 +114,10 @@ class LLMChain(Chain):
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
return await self.llm.agenerate_prompt(
prompts, stop, callbacks=run_manager.get_child() if run_manager else None
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
def prep_prompts(
@ -184,13 +208,23 @@ class LLMChain(Chain):
await run_manager.on_chain_end({"outputs": outputs})
return outputs
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
@property
def _run_output_key(self) -> str:
return self.output_key
def create_outputs(self, llm_result: LLMResult) -> List[Dict[str, Any]]:
"""Create outputs from response."""
return [
result = [
# Get the text of the top generated string.
{self.output_key: generation[0].text}
for generation in response.generations
{
self.output_key: self.output_parser.parse_result(generation),
"full_generation": generation,
}
for generation in llm_result.generations
]
if self.return_final_only:
result = [{self.output_key: r[self.output_key]} for r in result]
return result
async def _acall(
self,
@ -238,6 +272,10 @@ class LLMChain(Chain):
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, Any]]:
"""Call predict and then parse the results."""
warnings.warn(
"The predict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = self.predict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
@ -248,6 +286,10 @@ class LLMChain(Chain):
self, callbacks: Callbacks = None, **kwargs: Any
) -> Union[str, List[str], Dict[str, str]]:
"""Call apredict and then parse the results."""
warnings.warn(
"The apredict_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = await self.apredict(callbacks=callbacks, **kwargs)
if self.prompt.output_parser is not None:
return self.prompt.output_parser.parse(result)
@ -258,25 +300,34 @@ class LLMChain(Chain):
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The apply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = self.apply(input_list, callbacks=callbacks)
return self._parse_result(result)
return self._parse_generation(result)
def _parse_result(
self, result: List[Dict[str, str]]
def _parse_generation(
self, generation: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
return [
self.prompt.output_parser.parse(res[self.output_key]) for res in result
self.prompt.output_parser.parse(res[self.output_key])
for res in generation
]
else:
return result
return generation
async def aapply_and_parse(
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
"""Call apply and then parse the results."""
warnings.warn(
"The aapply_and_parse method is deprecated, "
"instead pass an output parser directly to LLMChain."
)
result = await self.aapply(input_list, callbacks=callbacks)
return self._parse_result(result)
return self._parse_generation(result)
@property
def _chain_type(self) -> str:

@ -24,7 +24,11 @@ from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChai
from langchain.chains.retrieval_qa.base import RetrievalQA, VectorDBQA
from langchain.chains.sql_database.base import SQLDatabaseChain
from langchain.llms.loading import load_llm, load_llm_from_config
from langchain.prompts.loading import load_prompt, load_prompt_from_config
from langchain.prompts.loading import (
_load_output_parser,
load_prompt,
load_prompt_from_config,
)
from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/chains/"
@ -47,6 +51,7 @@ def _load_llm_chain(config: dict, **kwargs: Any) -> LLMChain:
prompt = load_prompt(config.pop("prompt_path"))
else:
raise ValueError("One of `prompt` or `prompt_path` must be present.")
_load_output_parser(config)
return LLMChain(llm=llm, prompt=prompt, **config)

@ -1,233 +0,0 @@
import json
from functools import partial
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.chains.sequential import SimpleSequentialChain
from langchain.chains.transform import TransformChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
"""
Resolves the $ref keys in a JSON schema object using the provided definitions.
"""
if isinstance(schema, list):
for i, item in enumerate(schema):
schema[i] = _resolve_schema_references(item, definitions)
elif isinstance(schema, dict):
if "$ref" in schema:
ref_key = schema.pop("$ref").split("/")[-1]
ref = definitions.get(ref_key, {})
schema.update(ref)
else:
for key, value in schema.items():
schema[key] = _resolve_schema_references(value, definitions)
return schema
def _get_function_arguments(inputs: dict) -> str:
message = inputs["input"]
try:
func_call = message.additional_kwargs["function_call"]
except ValueError as exc:
raise ValueError(f"Could not parse function call: {exc}")
return func_call["arguments"]
def _parse_tag(inputs: dict) -> dict:
args = _get_function_arguments(inputs)
return {"output": json.loads(args)}
def _parse_tag_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
args = _get_function_arguments(inputs)
args = pydantic_schema.parse_raw(args)
return {"output": args}
def _parse_entities(inputs: dict) -> dict:
args = _get_function_arguments(inputs)
return {"output": json.loads(args)["info"]}
def _parse_entities_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
args = _get_function_arguments(inputs)
pydantic_args = pydantic_schema.parse_raw(args)
return {"output": pydantic_args.info}
class OpenAIFunctionsChain(Chain):
prompt: BasePromptTemplate
llm: BaseLanguageModel
functions: List[Dict]
kwargs: Dict = Field(default_factory=dict)
@property
def input_keys(self) -> List[str]:
return self.prompt.input_variables
@property
def output_keys(self) -> List[str]:
return ["output"]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**_inputs)
messages = prompt.to_messages()
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
)
return {"output": predicted_message}
async def _acall(
self,
inputs: Dict[str, Any],
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
prompt = self.prompt.format_prompt(**_inputs)
messages = prompt.to_messages()
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
)
return {"output": predicted_message}
def _convert_schema(schema: dict) -> dict:
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
return {
"type": "object",
"properties": props,
"required": schema.get("required", []),
}
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
return [
{
"name": EXTRACTION_NAME,
"description": "Extracts the relevant information from the passage.",
"parameters": {
"type": "object",
"properties": {
"info": {"type": "array", "items": _convert_schema(entity_schema)}
},
"required": ["info"],
},
}
]
def _get_tagging_functions(schema: dict) -> List[dict]:
return [
{
"name": EXTRACTION_NAME,
"description": "Extracts the relevant information from the passage.",
"parameters": _convert_schema(schema),
}
]
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
in the following passage together with their properties.
Passage:
{input}
"""
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_extraction_functions(schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
parsing_chain = TransformChain(
transform=_parse_entities,
input_variables=["input"],
output_variables=["output"],
)
return SimpleSequentialChain(chains=[chain, parsing_chain])
def create_extraction_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain:
class PydanticSchema(BaseModel):
info: List[pydantic_schema] # type: ignore
openai_schema = PydanticSchema.schema()
openai_schema = _resolve_schema_references(
openai_schema, openai_schema["definitions"]
)
functions = _get_extraction_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
pydantic_parsing_chain = TransformChain(
transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema),
input_variables=["input"],
output_variables=["output"],
)
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
Passage:
{input}
"""
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_tagging_functions(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
parsing_chain = TransformChain(
transform=_parse_tag, input_variables=["input"], output_variables=["output"]
)
return SimpleSequentialChain(chains=[chain, parsing_chain])
def create_tagging_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain:
openai_schema = pydantic_schema.schema()
functions = _get_tagging_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
pydantic_parsing_chain = TransformChain(
transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema),
input_variables=["input"],
output_variables=["output"],
)
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])

@ -0,0 +1,19 @@
from langchain.chains.openai_functions.citation_fuzzy_match import (
create_citation_fuzzy_match_chain,
)
from langchain.chains.openai_functions.extraction import (
create_extraction_chain,
create_extraction_chain_pydantic,
)
from langchain.chains.openai_functions.tagging import (
create_tagging_chain,
create_tagging_chain_pydantic,
)
__all__ = [
"create_tagging_chain",
"create_tagging_chain_pydantic",
"create_extraction_chain_pydantic",
"create_extraction_chain",
"create_citation_fuzzy_match_chain",
]

@ -0,0 +1,101 @@
from typing import Iterator, List
from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.output_parsers.openai_functions import (
PydanticOutputFunctionsParser,
)
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.schema import HumanMessage, SystemMessage
class FactWithEvidence(BaseModel):
"""Class representing single statement.
Each fact has a body and a list of sources.
If there are multiple facts make sure to break them apart
such that each one only uses a set of sources that are relevant to it.
"""
fact: str = Field(..., description="Body of the sentence, as part of a response")
substring_quote: List[str] = Field(
...,
description=(
"Each source should be a direct quote from the context, "
"as a substring of the original content"
),
)
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
import regex
minor = quote
major = context
errs_ = 0
s = regex.search(f"({minor}){{e<={errs_}}}", major)
while s is None and errs_ <= errs:
errs_ += 1
s = regex.search(f"({minor}){{e<={errs_}}}", major)
if s is not None:
yield from s.spans()
def get_spans(self, context: str) -> Iterator[str]:
for quote in self.substring_quote:
yield from self._get_span(quote, context)
class QuestionAnswer(BaseModel):
"""A question and its answer as a list of facts each one should have a source.
each sentence contains a body and a list of sources."""
question: str = Field(..., description="Question that was asked")
answer: List[FactWithEvidence] = Field(
...,
description=(
"Body of the answer, each fact should be "
"its separate object with a body and a list of sources"
),
)
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
schema = QuestionAnswer.schema()
functions = [
{
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
]
kwargs = {"function_call": {"name": schema["title"]}}
messages = [
SystemMessage(
content=(
"You are a world class algorithm to answer "
"questions with correct and exact citations."
)
),
HumanMessage(content="Answer question using the following context"),
HumanMessagePromptTemplate.from_template("{context}"),
HumanMessagePromptTemplate.from_template("Question: {question}"),
HumanMessage(
content=(
"Tips: Make sure to cite your sources, "
"and use the exact words from the context."
)
),
]
prompt = ChatPromptTemplate(messages=messages)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **kwargs},
output_parser=output_parser,
)
return chain

@ -0,0 +1,81 @@
from typing import Any, List
from pydantic import BaseModel
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import (
_convert_schema,
_resolve_schema_references,
)
from langchain.output_parsers.openai_functions import (
JsonKeyOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
)
from langchain.prompts import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
return [
{
"name": EXTRACTION_NAME,
"description": "Extracts the relevant information from the passage.",
"parameters": {
"type": "object",
"properties": {
"info": {"type": "array", "items": _convert_schema(entity_schema)}
},
"required": ["info"],
},
}
]
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
in the following passage together with their properties.
Passage:
{input}
"""
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_extraction_functions(schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
output_parser=output_parser,
)
return chain
def create_extraction_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain:
class PydanticSchema(BaseModel):
info: List[pydantic_schema] # type: ignore
openai_schema = PydanticSchema.schema()
openai_schema = _resolve_schema_references(
openai_schema, openai_schema["definitions"]
)
functions = _get_extraction_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = PydanticAttrOutputFunctionsParser(
pydantic_schema=PydanticSchema, attr_name="info"
)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
output_parser=output_parser,
)
return chain

@ -0,0 +1,61 @@
from typing import Any, List
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import _convert_schema
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.prompts import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _get_tagging_functions(schema: dict) -> List[dict]:
return [
{
"name": EXTRACTION_NAME,
"description": "Extracts the relevant information from the passage.",
"parameters": _convert_schema(schema),
}
]
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
Passage:
{input}
"""
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_tagging_functions(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = JsonOutputFunctionsParser()
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
output_parser=output_parser,
)
return chain
def create_tagging_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain:
openai_schema = pydantic_schema.schema()
functions = _get_tagging_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
output_parser=output_parser,
)
return chain

@ -0,0 +1,28 @@
from typing import Any, Dict
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
"""
Resolves the $ref keys in a JSON schema object using the provided definitions.
"""
if isinstance(schema, list):
for i, item in enumerate(schema):
schema[i] = _resolve_schema_references(item, definitions)
elif isinstance(schema, dict):
if "$ref" in schema:
ref_key = schema.pop("$ref").split("/")[-1]
ref = definitions.get(ref_key, {})
schema.update(ref)
else:
for key, value in schema.items():
schema[key] = _resolve_schema_references(value, definitions)
return schema
def _convert_schema(schema: dict) -> dict:
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
return {
"type": "object",
"properties": props,
"required": schema.get("required", []),
}

@ -0,0 +1,51 @@
import json
from typing import Any, List
from langchain.schema import BaseLLMOutputParser, ChatGeneration, Generation
class OutputFunctionsParser(BaseLLMOutputParser[Any]):
def parse_result(self, result: List[Generation]) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise ValueError(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
func_call = message.additional_kwargs["function_call"]
except ValueError as exc:
raise ValueError(f"Could not parse function call: {exc}")
return func_call["arguments"]
class JsonOutputFunctionsParser(OutputFunctionsParser):
def parse_result(self, result: List[Generation]) -> Any:
_args = super().parse_result(result)
return json.loads(_args)
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
key_name: str
def parse_result(self, result: List[Generation]) -> Any:
res = super().parse_result(result)
return res[self.key_name]
class PydanticOutputFunctionsParser(OutputFunctionsParser):
pydantic_schema: Any
def parse_result(self, result: List[Generation]) -> Any:
_args = super().parse_result(result)
pydantic_args = self.pydantic_schema.parse_raw(_args)
return pydantic_args
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
attr_name: str
def parse_result(self, result: List[Generation]) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)

@ -11,6 +11,7 @@ from langchain.output_parsers.regex import RegexParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLLMOutputParser, NoOpOutputParser
from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"
@ -78,7 +79,9 @@ def _load_output_parser(config: dict) -> dict:
_config = config.pop("output_parser")
output_parser_type = _config.pop("_type")
if output_parser_type == "regex_parser":
output_parser = RegexParser(**_config)
output_parser: BaseLLMOutputParser = RegexParser(**_config)
elif output_parser_type == "default":
output_parser = NoOpOutputParser(**_config)
else:
raise ValueError(f"Unsupported output parser {output_parser_type}")
config["output_parser"] = output_parser

@ -339,12 +339,21 @@ Memory = BaseMemory
T = TypeVar("T")
class BaseOutputParser(Serializable, ABC, Generic[T]):
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
@abstractmethod
def parse_result(self, result: List[Generation]) -> T:
"""Parse LLM Result."""
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
"""Class to parse the output of an LLM call.
Output parsers help structure language model responses.
"""
def parse_result(self, result: List[Generation]) -> T:
return self.parse(result[0].text)
@abstractmethod
def parse(self, text: str) -> T:
"""Parse the output of an LLM call.
@ -394,6 +403,21 @@ class BaseOutputParser(Serializable, ABC, Generic[T]):
return output_parser_dict
class NoOpOutputParser(BaseOutputParser[str]):
"""Output parser that just returns the text as is."""
@property
def lc_serializable(self) -> bool:
return True
@property
def _type(self) -> str:
return "default"
def parse(self, text: str) -> str:
return text
class OutputParserException(ValueError):
"""Exception that output parsers should raise to signify a parsing error.

Loading…
Cancel
Save