forked from Archives/langchain
changes to llm chain (#6328)
- return raw and full output (but keep run shortcut method functional) - change output parser to take in generations (good for working with messages) - add output parser to base class, always run (default to same as current) --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>master
parent
d3c2eab0b3
commit
6a4a950a3c
@ -0,0 +1,181 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9b5c258f",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Question-Answering Citations\n",
|
||||||
|
"\n",
|
||||||
|
"This notebook shows how to use OpenAI functions ability to extract citations from text."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "eae4ca3e",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||||
|
" warnings.warn(\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from langchain.chains import create_citation_fuzzy_match_chain\n",
|
||||||
|
"from langchain.chat_models import ChatOpenAI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "2c6e62ee",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"question = \"What did the author do during college?\"\n",
|
||||||
|
"context = \"\"\"\n",
|
||||||
|
"My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.\n",
|
||||||
|
"I went to an arts highschool but in university I studied Computational Mathematics and physics. \n",
|
||||||
|
"As part of coop I worked at many companies including Stitchfix, Facebook.\n",
|
||||||
|
"I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.\n",
|
||||||
|
"\"\"\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "078e0300",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "02cad6d0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chain = create_citation_fuzzy_match_chain(llm)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "e3c6e7ba",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"result = chain.run(question=question, context=context)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "6f7615f2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(result)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "3be6f366",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def highlight(text, span):\n",
|
||||||
|
" return (\n",
|
||||||
|
" \"...\"\n",
|
||||||
|
" + text[span[0] - 20 : span[0]]\n",
|
||||||
|
" + \"*\"\n",
|
||||||
|
" + \"\\033[91m\"\n",
|
||||||
|
" + text[span[0] : span[1]]\n",
|
||||||
|
" + \"\\033[0m\"\n",
|
||||||
|
" + \"*\"\n",
|
||||||
|
" + text[span[1] : span[1] + 20]\n",
|
||||||
|
" + \"...\"\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "636c4528",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Statement: The author studied Computational Mathematics and physics in university.\n",
|
||||||
|
"Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n",
|
||||||
|
"As part of coop I...\n",
|
||||||
|
"\n",
|
||||||
|
"Statement: The author started the Data Science club at the University of Waterloo.\n",
|
||||||
|
"Citation: ...titchfix, Facebook.\n",
|
||||||
|
"*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
|
||||||
|
"\n",
|
||||||
|
"Statement: The author was the president of the Data Science club for 2 years.\n",
|
||||||
|
"Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n",
|
||||||
|
"...\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for fact in result.answer:\n",
|
||||||
|
" print(\"Statement:\", fact.fact)\n",
|
||||||
|
" for span in fact.get_spans(context):\n",
|
||||||
|
" print(\"Citation:\", highlight(context, span))\n",
|
||||||
|
" print()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "8409cab0",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3 (ipykernel)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
@ -1,233 +0,0 @@
|
|||||||
import json
|
|
||||||
from functools import partial
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from langchain.base_language import BaseLanguageModel
|
|
||||||
from langchain.callbacks.manager import (
|
|
||||||
AsyncCallbackManagerForChainRun,
|
|
||||||
CallbackManagerForChainRun,
|
|
||||||
)
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.sequential import SimpleSequentialChain
|
|
||||||
from langchain.chains.transform import TransformChain
|
|
||||||
from langchain.prompts.base import BasePromptTemplate
|
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
|
||||||
|
|
||||||
EXTRACTION_NAME = "information_extraction"
|
|
||||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
|
||||||
"""
|
|
||||||
Resolves the $ref keys in a JSON schema object using the provided definitions.
|
|
||||||
"""
|
|
||||||
if isinstance(schema, list):
|
|
||||||
for i, item in enumerate(schema):
|
|
||||||
schema[i] = _resolve_schema_references(item, definitions)
|
|
||||||
elif isinstance(schema, dict):
|
|
||||||
if "$ref" in schema:
|
|
||||||
ref_key = schema.pop("$ref").split("/")[-1]
|
|
||||||
ref = definitions.get(ref_key, {})
|
|
||||||
schema.update(ref)
|
|
||||||
else:
|
|
||||||
for key, value in schema.items():
|
|
||||||
schema[key] = _resolve_schema_references(value, definitions)
|
|
||||||
return schema
|
|
||||||
|
|
||||||
|
|
||||||
def _get_function_arguments(inputs: dict) -> str:
|
|
||||||
message = inputs["input"]
|
|
||||||
try:
|
|
||||||
func_call = message.additional_kwargs["function_call"]
|
|
||||||
except ValueError as exc:
|
|
||||||
raise ValueError(f"Could not parse function call: {exc}")
|
|
||||||
|
|
||||||
return func_call["arguments"]
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_tag(inputs: dict) -> dict:
|
|
||||||
args = _get_function_arguments(inputs)
|
|
||||||
return {"output": json.loads(args)}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_tag_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
|
|
||||||
args = _get_function_arguments(inputs)
|
|
||||||
args = pydantic_schema.parse_raw(args)
|
|
||||||
return {"output": args}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_entities(inputs: dict) -> dict:
|
|
||||||
args = _get_function_arguments(inputs)
|
|
||||||
return {"output": json.loads(args)["info"]}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_entities_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
|
|
||||||
args = _get_function_arguments(inputs)
|
|
||||||
pydantic_args = pydantic_schema.parse_raw(args)
|
|
||||||
return {"output": pydantic_args.info}
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIFunctionsChain(Chain):
|
|
||||||
prompt: BasePromptTemplate
|
|
||||||
llm: BaseLanguageModel
|
|
||||||
functions: List[Dict]
|
|
||||||
kwargs: Dict = Field(default_factory=dict)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
return self.prompt.input_variables
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
return ["output"]
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
|
|
||||||
prompt = self.prompt.format_prompt(**_inputs)
|
|
||||||
messages = prompt.to_messages()
|
|
||||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
predicted_message = self.llm.predict_messages(
|
|
||||||
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
|
||||||
)
|
|
||||||
return {"output": predicted_message}
|
|
||||||
|
|
||||||
async def _acall(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
|
|
||||||
prompt = self.prompt.format_prompt(**_inputs)
|
|
||||||
messages = prompt.to_messages()
|
|
||||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
||||||
callbacks = _run_manager.get_child()
|
|
||||||
predicted_message = await self.llm.apredict_messages(
|
|
||||||
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
|
||||||
)
|
|
||||||
return {"output": predicted_message}
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_schema(schema: dict) -> dict:
|
|
||||||
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
|
|
||||||
return {
|
|
||||||
"type": "object",
|
|
||||||
"properties": props,
|
|
||||||
"required": schema.get("required", []),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": EXTRACTION_NAME,
|
|
||||||
"description": "Extracts the relevant information from the passage.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"info": {"type": "array", "items": _convert_schema(entity_schema)}
|
|
||||||
},
|
|
||||||
"required": ["info"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tagging_functions(schema: dict) -> List[dict]:
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"name": EXTRACTION_NAME,
|
|
||||||
"description": "Extracts the relevant information from the passage.",
|
|
||||||
"parameters": _convert_schema(schema),
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
|
|
||||||
in the following passage together with their properties.
|
|
||||||
|
|
||||||
Passage:
|
|
||||||
{input}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
|
||||||
functions = _get_extraction_functions(schema)
|
|
||||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
|
||||||
chain = OpenAIFunctionsChain(
|
|
||||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
|
||||||
)
|
|
||||||
parsing_chain = TransformChain(
|
|
||||||
transform=_parse_entities,
|
|
||||||
input_variables=["input"],
|
|
||||||
output_variables=["output"],
|
|
||||||
)
|
|
||||||
return SimpleSequentialChain(chains=[chain, parsing_chain])
|
|
||||||
|
|
||||||
|
|
||||||
def create_extraction_chain_pydantic(
|
|
||||||
pydantic_schema: Any, llm: BaseLanguageModel
|
|
||||||
) -> Chain:
|
|
||||||
class PydanticSchema(BaseModel):
|
|
||||||
info: List[pydantic_schema] # type: ignore
|
|
||||||
|
|
||||||
openai_schema = PydanticSchema.schema()
|
|
||||||
openai_schema = _resolve_schema_references(
|
|
||||||
openai_schema, openai_schema["definitions"]
|
|
||||||
)
|
|
||||||
|
|
||||||
functions = _get_extraction_functions(openai_schema)
|
|
||||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
|
||||||
chain = OpenAIFunctionsChain(
|
|
||||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
|
||||||
)
|
|
||||||
pydantic_parsing_chain = TransformChain(
|
|
||||||
transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema),
|
|
||||||
input_variables=["input"],
|
|
||||||
output_variables=["output"],
|
|
||||||
)
|
|
||||||
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
|
|
||||||
|
|
||||||
|
|
||||||
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
|
||||||
|
|
||||||
Passage:
|
|
||||||
{input}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
|
||||||
functions = _get_tagging_functions(schema)
|
|
||||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
|
||||||
chain = OpenAIFunctionsChain(
|
|
||||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
|
||||||
)
|
|
||||||
parsing_chain = TransformChain(
|
|
||||||
transform=_parse_tag, input_variables=["input"], output_variables=["output"]
|
|
||||||
)
|
|
||||||
return SimpleSequentialChain(chains=[chain, parsing_chain])
|
|
||||||
|
|
||||||
|
|
||||||
def create_tagging_chain_pydantic(
|
|
||||||
pydantic_schema: Any, llm: BaseLanguageModel
|
|
||||||
) -> Chain:
|
|
||||||
openai_schema = pydantic_schema.schema()
|
|
||||||
|
|
||||||
functions = _get_tagging_functions(openai_schema)
|
|
||||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
|
||||||
chain = OpenAIFunctionsChain(
|
|
||||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
|
||||||
)
|
|
||||||
pydantic_parsing_chain = TransformChain(
|
|
||||||
transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema),
|
|
||||||
input_variables=["input"],
|
|
||||||
output_variables=["output"],
|
|
||||||
)
|
|
||||||
|
|
||||||
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
|
|
@ -0,0 +1,19 @@
|
|||||||
|
from langchain.chains.openai_functions.citation_fuzzy_match import (
|
||||||
|
create_citation_fuzzy_match_chain,
|
||||||
|
)
|
||||||
|
from langchain.chains.openai_functions.extraction import (
|
||||||
|
create_extraction_chain,
|
||||||
|
create_extraction_chain_pydantic,
|
||||||
|
)
|
||||||
|
from langchain.chains.openai_functions.tagging import (
|
||||||
|
create_tagging_chain,
|
||||||
|
create_tagging_chain_pydantic,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_tagging_chain",
|
||||||
|
"create_tagging_chain_pydantic",
|
||||||
|
"create_extraction_chain_pydantic",
|
||||||
|
"create_extraction_chain",
|
||||||
|
"create_citation_fuzzy_match_chain",
|
||||||
|
]
|
@ -0,0 +1,101 @@
|
|||||||
|
from typing import Iterator, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.output_parsers.openai_functions import (
|
||||||
|
PydanticOutputFunctionsParser,
|
||||||
|
)
|
||||||
|
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||||
|
from langchain.schema import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
|
||||||
|
class FactWithEvidence(BaseModel):
|
||||||
|
"""Class representing single statement.
|
||||||
|
|
||||||
|
Each fact has a body and a list of sources.
|
||||||
|
If there are multiple facts make sure to break them apart
|
||||||
|
such that each one only uses a set of sources that are relevant to it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||||
|
substring_quote: List[str] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Each source should be a direct quote from the context, "
|
||||||
|
"as a substring of the original content"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||||
|
import regex
|
||||||
|
|
||||||
|
minor = quote
|
||||||
|
major = context
|
||||||
|
|
||||||
|
errs_ = 0
|
||||||
|
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||||
|
while s is None and errs_ <= errs:
|
||||||
|
errs_ += 1
|
||||||
|
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||||
|
|
||||||
|
if s is not None:
|
||||||
|
yield from s.spans()
|
||||||
|
|
||||||
|
def get_spans(self, context: str) -> Iterator[str]:
|
||||||
|
for quote in self.substring_quote:
|
||||||
|
yield from self._get_span(quote, context)
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionAnswer(BaseModel):
|
||||||
|
"""A question and its answer as a list of facts each one should have a source.
|
||||||
|
each sentence contains a body and a list of sources."""
|
||||||
|
|
||||||
|
question: str = Field(..., description="Question that was asked")
|
||||||
|
answer: List[FactWithEvidence] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Body of the answer, each fact should be "
|
||||||
|
"its separate object with a body and a list of sources"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
||||||
|
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
|
||||||
|
schema = QuestionAnswer.schema()
|
||||||
|
functions = [
|
||||||
|
{
|
||||||
|
"name": schema["title"],
|
||||||
|
"description": schema["description"],
|
||||||
|
"parameters": schema,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
kwargs = {"function_call": {"name": schema["title"]}}
|
||||||
|
messages = [
|
||||||
|
SystemMessage(
|
||||||
|
content=(
|
||||||
|
"You are a world class algorithm to answer "
|
||||||
|
"questions with correct and exact citations."
|
||||||
|
)
|
||||||
|
),
|
||||||
|
HumanMessage(content="Answer question using the following context"),
|
||||||
|
HumanMessagePromptTemplate.from_template("{context}"),
|
||||||
|
HumanMessagePromptTemplate.from_template("Question: {question}"),
|
||||||
|
HumanMessage(
|
||||||
|
content=(
|
||||||
|
"Tips: Make sure to cite your sources, "
|
||||||
|
"and use the exact words from the context."
|
||||||
|
)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
prompt = ChatPromptTemplate(messages=messages)
|
||||||
|
|
||||||
|
chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
llm_kwargs={**{"functions": functions}, **kwargs},
|
||||||
|
output_parser=output_parser,
|
||||||
|
)
|
||||||
|
return chain
|
@ -0,0 +1,81 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.openai_functions.utils import (
|
||||||
|
_convert_schema,
|
||||||
|
_resolve_schema_references,
|
||||||
|
)
|
||||||
|
from langchain.output_parsers.openai_functions import (
|
||||||
|
JsonKeyOutputFunctionsParser,
|
||||||
|
PydanticAttrOutputFunctionsParser,
|
||||||
|
)
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
EXTRACTION_NAME = "information_extraction"
|
||||||
|
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": EXTRACTION_NAME,
|
||||||
|
"description": "Extracts the relevant information from the passage.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"info": {"type": "array", "items": _convert_schema(entity_schema)}
|
||||||
|
},
|
||||||
|
"required": ["info"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
|
||||||
|
in the following passage together with their properties.
|
||||||
|
|
||||||
|
Passage:
|
||||||
|
{input}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||||
|
functions = _get_extraction_functions(schema)
|
||||||
|
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||||
|
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||||
|
chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||||
|
output_parser=output_parser,
|
||||||
|
)
|
||||||
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
def create_extraction_chain_pydantic(
|
||||||
|
pydantic_schema: Any, llm: BaseLanguageModel
|
||||||
|
) -> Chain:
|
||||||
|
class PydanticSchema(BaseModel):
|
||||||
|
info: List[pydantic_schema] # type: ignore
|
||||||
|
|
||||||
|
openai_schema = PydanticSchema.schema()
|
||||||
|
openai_schema = _resolve_schema_references(
|
||||||
|
openai_schema, openai_schema["definitions"]
|
||||||
|
)
|
||||||
|
|
||||||
|
functions = _get_extraction_functions(openai_schema)
|
||||||
|
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||||
|
output_parser = PydanticAttrOutputFunctionsParser(
|
||||||
|
pydantic_schema=PydanticSchema, attr_name="info"
|
||||||
|
)
|
||||||
|
chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||||
|
output_parser=output_parser,
|
||||||
|
)
|
||||||
|
return chain
|
@ -0,0 +1,61 @@
|
|||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from langchain.base_language import BaseLanguageModel
|
||||||
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.chains.llm import LLMChain
|
||||||
|
from langchain.chains.openai_functions.utils import _convert_schema
|
||||||
|
from langchain.output_parsers.openai_functions import (
|
||||||
|
JsonOutputFunctionsParser,
|
||||||
|
PydanticOutputFunctionsParser,
|
||||||
|
)
|
||||||
|
from langchain.prompts import ChatPromptTemplate
|
||||||
|
|
||||||
|
EXTRACTION_NAME = "information_extraction"
|
||||||
|
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tagging_functions(schema: dict) -> List[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": EXTRACTION_NAME,
|
||||||
|
"description": "Extracts the relevant information from the passage.",
|
||||||
|
"parameters": _convert_schema(schema),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
||||||
|
|
||||||
|
Passage:
|
||||||
|
{input}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||||
|
functions = _get_tagging_functions(schema)
|
||||||
|
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||||
|
output_parser = JsonOutputFunctionsParser()
|
||||||
|
chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||||
|
output_parser=output_parser,
|
||||||
|
)
|
||||||
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
def create_tagging_chain_pydantic(
|
||||||
|
pydantic_schema: Any, llm: BaseLanguageModel
|
||||||
|
) -> Chain:
|
||||||
|
openai_schema = pydantic_schema.schema()
|
||||||
|
|
||||||
|
functions = _get_tagging_functions(openai_schema)
|
||||||
|
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||||
|
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||||
|
chain = LLMChain(
|
||||||
|
llm=llm,
|
||||||
|
prompt=prompt,
|
||||||
|
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||||
|
output_parser=output_parser,
|
||||||
|
)
|
||||||
|
return chain
|
@ -0,0 +1,28 @@
|
|||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
||||||
|
"""
|
||||||
|
Resolves the $ref keys in a JSON schema object using the provided definitions.
|
||||||
|
"""
|
||||||
|
if isinstance(schema, list):
|
||||||
|
for i, item in enumerate(schema):
|
||||||
|
schema[i] = _resolve_schema_references(item, definitions)
|
||||||
|
elif isinstance(schema, dict):
|
||||||
|
if "$ref" in schema:
|
||||||
|
ref_key = schema.pop("$ref").split("/")[-1]
|
||||||
|
ref = definitions.get(ref_key, {})
|
||||||
|
schema.update(ref)
|
||||||
|
else:
|
||||||
|
for key, value in schema.items():
|
||||||
|
schema[key] = _resolve_schema_references(value, definitions)
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_schema(schema: dict) -> dict:
|
||||||
|
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": props,
|
||||||
|
"required": schema.get("required", []),
|
||||||
|
}
|
@ -0,0 +1,51 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
from langchain.schema import BaseLLMOutputParser, ChatGeneration, Generation
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFunctionsParser(BaseLLMOutputParser[Any]):
|
||||||
|
def parse_result(self, result: List[Generation]) -> Any:
|
||||||
|
generation = result[0]
|
||||||
|
if not isinstance(generation, ChatGeneration):
|
||||||
|
raise ValueError(
|
||||||
|
"This output parser can only be used with a chat generation."
|
||||||
|
)
|
||||||
|
message = generation.message
|
||||||
|
try:
|
||||||
|
func_call = message.additional_kwargs["function_call"]
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"Could not parse function call: {exc}")
|
||||||
|
|
||||||
|
return func_call["arguments"]
|
||||||
|
|
||||||
|
|
||||||
|
class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||||
|
def parse_result(self, result: List[Generation]) -> Any:
|
||||||
|
_args = super().parse_result(result)
|
||||||
|
return json.loads(_args)
|
||||||
|
|
||||||
|
|
||||||
|
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||||
|
key_name: str
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation]) -> Any:
|
||||||
|
res = super().parse_result(result)
|
||||||
|
return res[self.key_name]
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||||
|
pydantic_schema: Any
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation]) -> Any:
|
||||||
|
_args = super().parse_result(result)
|
||||||
|
pydantic_args = self.pydantic_schema.parse_raw(_args)
|
||||||
|
return pydantic_args
|
||||||
|
|
||||||
|
|
||||||
|
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||||
|
attr_name: str
|
||||||
|
|
||||||
|
def parse_result(self, result: List[Generation]) -> Any:
|
||||||
|
result = super().parse_result(result)
|
||||||
|
return getattr(result, self.attr_name)
|
Loading…
Reference in New Issue