mirror of https://github.com/hwchase17/langchain
changes to llm chain (#6328)
- return raw and full output (but keep run shortcut method functional) - change output parser to take in generations (good for working with messages) - add output parser to base class, always run (default to same as current) --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>pull/6408/head
parent
d3c2eab0b3
commit
6a4a950a3c
@ -0,0 +1,181 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9b5c258f",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Question-Answering Citations\n",
|
||||
"\n",
|
||||
"This notebook shows how to use OpenAI functions ability to extract citations from text."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "eae4ca3e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/Users/harrisonchase/.pyenv/versions/3.9.1/envs/langchain/lib/python3.9/site-packages/deeplake/util/check_latest_version.py:32: UserWarning: A newer version of deeplake (3.6.4) is available. It's recommended that you update to the latest version using `pip install -U deeplake`.\n",
|
||||
" warnings.warn(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain.chains import create_citation_fuzzy_match_chain\n",
|
||||
"from langchain.chat_models import ChatOpenAI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "2c6e62ee",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"What did the author do during college?\"\n",
|
||||
"context = \"\"\"\n",
|
||||
"My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.\n",
|
||||
"I went to an arts highschool but in university I studied Computational Mathematics and physics. \n",
|
||||
"As part of coop I worked at many companies including Stitchfix, Facebook.\n",
|
||||
"I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.\n",
|
||||
"\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "078e0300",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo-0613\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "02cad6d0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"chain = create_citation_fuzzy_match_chain(llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "e3c6e7ba",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"result = chain.run(question=question, context=context)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "6f7615f2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "3be6f366",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def highlight(text, span):\n",
|
||||
" return (\n",
|
||||
" \"...\"\n",
|
||||
" + text[span[0] - 20 : span[0]]\n",
|
||||
" + \"*\"\n",
|
||||
" + \"\\033[91m\"\n",
|
||||
" + text[span[0] : span[1]]\n",
|
||||
" + \"\\033[0m\"\n",
|
||||
" + \"*\"\n",
|
||||
" + text[span[1] : span[1] + 20]\n",
|
||||
" + \"...\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "636c4528",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Statement: The author studied Computational Mathematics and physics in university.\n",
|
||||
"Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n",
|
||||
"As part of coop I...\n",
|
||||
"\n",
|
||||
"Statement: The author started the Data Science club at the University of Waterloo.\n",
|
||||
"Citation: ...titchfix, Facebook.\n",
|
||||
"*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
|
||||
"\n",
|
||||
"Statement: The author was the president of the Data Science club for 2 years.\n",
|
||||
"Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n",
|
||||
"...\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for fact in result.answer:\n",
|
||||
" print(\"Statement:\", fact.fact)\n",
|
||||
" for span in fact.get_spans(context):\n",
|
||||
" print(\"Citation:\", highlight(context, span))\n",
|
||||
" print()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "8409cab0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -1,233 +0,0 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SimpleSequentialChain
|
||||
from langchain.chains.transform import TransformChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Resolves the $ref keys in a JSON schema object using the provided definitions.
|
||||
"""
|
||||
if isinstance(schema, list):
|
||||
for i, item in enumerate(schema):
|
||||
schema[i] = _resolve_schema_references(item, definitions)
|
||||
elif isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_key = schema.pop("$ref").split("/")[-1]
|
||||
ref = definitions.get(ref_key, {})
|
||||
schema.update(ref)
|
||||
else:
|
||||
for key, value in schema.items():
|
||||
schema[key] = _resolve_schema_references(value, definitions)
|
||||
return schema
|
||||
|
||||
|
||||
def _get_function_arguments(inputs: dict) -> str:
|
||||
message = inputs["input"]
|
||||
try:
|
||||
func_call = message.additional_kwargs["function_call"]
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Could not parse function call: {exc}")
|
||||
|
||||
return func_call["arguments"]
|
||||
|
||||
|
||||
def _parse_tag(inputs: dict) -> dict:
|
||||
args = _get_function_arguments(inputs)
|
||||
return {"output": json.loads(args)}
|
||||
|
||||
|
||||
def _parse_tag_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
|
||||
args = _get_function_arguments(inputs)
|
||||
args = pydantic_schema.parse_raw(args)
|
||||
return {"output": args}
|
||||
|
||||
|
||||
def _parse_entities(inputs: dict) -> dict:
|
||||
args = _get_function_arguments(inputs)
|
||||
return {"output": json.loads(args)["info"]}
|
||||
|
||||
|
||||
def _parse_entities_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
|
||||
args = _get_function_arguments(inputs)
|
||||
pydantic_args = pydantic_schema.parse_raw(args)
|
||||
return {"output": pydantic_args.info}
|
||||
|
||||
|
||||
class OpenAIFunctionsChain(Chain):
|
||||
prompt: BasePromptTemplate
|
||||
llm: BaseLanguageModel
|
||||
functions: List[Dict]
|
||||
kwargs: Dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
return ["output"]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format_prompt(**_inputs)
|
||||
messages = prompt.to_messages()
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
||||
)
|
||||
return {"output": predicted_message}
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, Any]:
|
||||
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format_prompt(**_inputs)
|
||||
messages = prompt.to_messages()
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
predicted_message = await self.llm.apredict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
||||
)
|
||||
return {"output": predicted_message}
|
||||
|
||||
|
||||
def _convert_schema(schema: dict) -> dict:
|
||||
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": schema.get("required", []),
|
||||
}
|
||||
|
||||
|
||||
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"info": {"type": "array", "items": _convert_schema(entity_schema)}
|
||||
},
|
||||
"required": ["info"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _get_tagging_functions(schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": _convert_schema(schema),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
|
||||
in the following passage together with their properties.
|
||||
|
||||
Passage:
|
||||
{input}
|
||||
"""
|
||||
|
||||
|
||||
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_extraction_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
parsing_chain = TransformChain(
|
||||
transform=_parse_entities,
|
||||
input_variables=["input"],
|
||||
output_variables=["output"],
|
||||
)
|
||||
return SimpleSequentialChain(chains=[chain, parsing_chain])
|
||||
|
||||
|
||||
def create_extraction_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
class PydanticSchema(BaseModel):
|
||||
info: List[pydantic_schema] # type: ignore
|
||||
|
||||
openai_schema = PydanticSchema.schema()
|
||||
openai_schema = _resolve_schema_references(
|
||||
openai_schema, openai_schema["definitions"]
|
||||
)
|
||||
|
||||
functions = _get_extraction_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
pydantic_parsing_chain = TransformChain(
|
||||
transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema),
|
||||
input_variables=["input"],
|
||||
output_variables=["output"],
|
||||
)
|
||||
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
|
||||
|
||||
|
||||
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
||||
|
||||
Passage:
|
||||
{input}
|
||||
"""
|
||||
|
||||
|
||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_tagging_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
parsing_chain = TransformChain(
|
||||
transform=_parse_tag, input_variables=["input"], output_variables=["output"]
|
||||
)
|
||||
return SimpleSequentialChain(chains=[chain, parsing_chain])
|
||||
|
||||
|
||||
def create_tagging_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
openai_schema = pydantic_schema.schema()
|
||||
|
||||
functions = _get_tagging_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
pydantic_parsing_chain = TransformChain(
|
||||
transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema),
|
||||
input_variables=["input"],
|
||||
output_variables=["output"],
|
||||
)
|
||||
|
||||
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
|
@ -0,0 +1,19 @@
|
||||
from langchain.chains.openai_functions.citation_fuzzy_match import (
|
||||
create_citation_fuzzy_match_chain,
|
||||
)
|
||||
from langchain.chains.openai_functions.extraction import (
|
||||
create_extraction_chain,
|
||||
create_extraction_chain_pydantic,
|
||||
)
|
||||
from langchain.chains.openai_functions.tagging import (
|
||||
create_tagging_chain,
|
||||
create_tagging_chain_pydantic,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_tagging_chain",
|
||||
"create_tagging_chain_pydantic",
|
||||
"create_extraction_chain_pydantic",
|
||||
"create_extraction_chain",
|
||||
"create_citation_fuzzy_match_chain",
|
||||
]
|
@ -0,0 +1,101 @@
|
||||
from typing import Iterator, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
|
||||
class FactWithEvidence(BaseModel):
|
||||
"""Class representing single statement.
|
||||
|
||||
Each fact has a body and a list of sources.
|
||||
If there are multiple facts make sure to break them apart
|
||||
such that each one only uses a set of sources that are relevant to it.
|
||||
"""
|
||||
|
||||
fact: str = Field(..., description="Body of the sentence, as part of a response")
|
||||
substring_quote: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Each source should be a direct quote from the context, "
|
||||
"as a substring of the original content"
|
||||
),
|
||||
)
|
||||
|
||||
def _get_span(self, quote: str, context: str, errs: int = 100) -> Iterator[str]:
|
||||
import regex
|
||||
|
||||
minor = quote
|
||||
major = context
|
||||
|
||||
errs_ = 0
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
while s is None and errs_ <= errs:
|
||||
errs_ += 1
|
||||
s = regex.search(f"({minor}){{e<={errs_}}}", major)
|
||||
|
||||
if s is not None:
|
||||
yield from s.spans()
|
||||
|
||||
def get_spans(self, context: str) -> Iterator[str]:
|
||||
for quote in self.substring_quote:
|
||||
yield from self._get_span(quote, context)
|
||||
|
||||
|
||||
class QuestionAnswer(BaseModel):
|
||||
"""A question and its answer as a list of facts each one should have a source.
|
||||
each sentence contains a body and a list of sources."""
|
||||
|
||||
question: str = Field(..., description="Question that was asked")
|
||||
answer: List[FactWithEvidence] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Body of the answer, each fact should be "
|
||||
"its separate object with a body and a list of sources"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
|
||||
schema = QuestionAnswer.schema()
|
||||
functions = [
|
||||
{
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
||||
]
|
||||
kwargs = {"function_call": {"name": schema["title"]}}
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content=(
|
||||
"You are a world class algorithm to answer "
|
||||
"questions with correct and exact citations."
|
||||
)
|
||||
),
|
||||
HumanMessage(content="Answer question using the following context"),
|
||||
HumanMessagePromptTemplate.from_template("{context}"),
|
||||
HumanMessagePromptTemplate.from_template("Question: {question}"),
|
||||
HumanMessage(
|
||||
content=(
|
||||
"Tips: Make sure to cite your sources, "
|
||||
"and use the exact words from the context."
|
||||
)
|
||||
),
|
||||
]
|
||||
prompt = ChatPromptTemplate(messages=messages)
|
||||
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **kwargs},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
@ -0,0 +1,81 @@
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import (
|
||||
_convert_schema,
|
||||
_resolve_schema_references,
|
||||
)
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
JsonKeyOutputFunctionsParser,
|
||||
PydanticAttrOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"info": {"type": "array", "items": _convert_schema(entity_schema)}
|
||||
},
|
||||
"required": ["info"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
|
||||
in the following passage together with their properties.
|
||||
|
||||
Passage:
|
||||
{input}
|
||||
"""
|
||||
|
||||
|
||||
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_extraction_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
def create_extraction_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
class PydanticSchema(BaseModel):
|
||||
info: List[pydantic_schema] # type: ignore
|
||||
|
||||
openai_schema = PydanticSchema.schema()
|
||||
openai_schema = _resolve_schema_references(
|
||||
openai_schema, openai_schema["definitions"]
|
||||
)
|
||||
|
||||
functions = _get_extraction_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
output_parser = PydanticAttrOutputFunctionsParser(
|
||||
pydantic_schema=PydanticSchema, attr_name="info"
|
||||
)
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
@ -0,0 +1,61 @@
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import _convert_schema
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _get_tagging_functions(schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": _convert_schema(schema),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
||||
|
||||
Passage:
|
||||
{input}
|
||||
"""
|
||||
|
||||
|
||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_tagging_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = JsonOutputFunctionsParser()
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
|
||||
|
||||
def create_tagging_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
openai_schema = pydantic_schema.schema()
|
||||
|
||||
functions = _get_tagging_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
@ -0,0 +1,28 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Resolves the $ref keys in a JSON schema object using the provided definitions.
|
||||
"""
|
||||
if isinstance(schema, list):
|
||||
for i, item in enumerate(schema):
|
||||
schema[i] = _resolve_schema_references(item, definitions)
|
||||
elif isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_key = schema.pop("$ref").split("/")[-1]
|
||||
ref = definitions.get(ref_key, {})
|
||||
schema.update(ref)
|
||||
else:
|
||||
for key, value in schema.items():
|
||||
schema[key] = _resolve_schema_references(value, definitions)
|
||||
return schema
|
||||
|
||||
|
||||
def _convert_schema(schema: dict) -> dict:
|
||||
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": props,
|
||||
"required": schema.get("required", []),
|
||||
}
|
@ -0,0 +1,51 @@
|
||||
import json
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.schema import BaseLLMOutputParser, ChatGeneration, Generation
|
||||
|
||||
|
||||
class OutputFunctionsParser(BaseLLMOutputParser[Any]):
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
raise ValueError(
|
||||
"This output parser can only be used with a chat generation."
|
||||
)
|
||||
message = generation.message
|
||||
try:
|
||||
func_call = message.additional_kwargs["function_call"]
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Could not parse function call: {exc}")
|
||||
|
||||
return func_call["arguments"]
|
||||
|
||||
|
||||
class JsonOutputFunctionsParser(OutputFunctionsParser):
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
_args = super().parse_result(result)
|
||||
return json.loads(_args)
|
||||
|
||||
|
||||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
|
||||
key_name: str
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
res = super().parse_result(result)
|
||||
return res[self.key_name]
|
||||
|
||||
|
||||
class PydanticOutputFunctionsParser(OutputFunctionsParser):
|
||||
pydantic_schema: Any
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
_args = super().parse_result(result)
|
||||
pydantic_args = self.pydantic_schema.parse_raw(_args)
|
||||
return pydantic_args
|
||||
|
||||
|
||||
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
|
||||
attr_name: str
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> Any:
|
||||
result = super().parse_result(result)
|
||||
return getattr(result, self.attr_name)
|
Loading…
Reference in New Issue