Harrison/refactor functions (#6408)

This commit is contained in:
Harrison Chase 2023-06-18 23:13:42 -07:00 committed by GitHub
parent 6a4a950a3c
commit e9c2b280db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 52 additions and 57 deletions

View File

@ -218,7 +218,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 11,
"id": "f771df58", "id": "f771df58",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -229,7 +229,7 @@
" Properties(person_name='Claudia', person_height=6, person_hair_color='brunette', dog_breed=None, dog_name=None)]" " Properties(person_name='Claudia', person_height=6, person_hair_color='brunette', dog_breed=None, dog_name=None)]"
] ]
}, },
"execution_count": 10, "execution_count": 11,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }

View File

@ -86,7 +86,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n" "question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo and was the president of the club for 2 years.', substring_quote=['started the Data Science club at the University of Waterloo', 'president of the club for 2 years'])]\n"
] ]
} }
], ],
@ -129,12 +129,10 @@
"Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n", "Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n",
"As part of coop I...\n", "As part of coop I...\n",
"\n", "\n",
"Statement: The author started the Data Science club at the University of Waterloo.\n", "Statement: The author started the Data Science club at the University of Waterloo and was the president of the club for 2 years.\n",
"Citation: ...titchfix, Facebook.\n", "Citation: ...x, Facebook.\n",
"*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n", "I also *\u001b[91mstarted the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
"\n", "Citation: ...erloo and I was the *\u001b[91mpresident of the club for 2 years\u001b[0m*.\n",
"Statement: The author was the president of the Data Science club for 2 years.\n",
"Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n",
"...\n", "...\n",
"\n" "\n"
] ]

View File

@ -126,7 +126,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'es'}" "{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'Spanish'}"
] ]
}, },
"execution_count": 6, "execution_count": 6,

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs
from langchain.output_parsers.openai_functions import ( from langchain.output_parsers.openai_functions import (
PydanticOutputFunctionsParser, PydanticOutputFunctionsParser,
) )
@ -65,14 +66,12 @@ class QuestionAnswer(BaseModel):
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain: def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer) output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
schema = QuestionAnswer.schema() schema = QuestionAnswer.schema()
functions = [ function = {
{ "name": schema["title"],
"name": schema["title"], "description": schema["description"],
"description": schema["description"], "parameters": schema,
"parameters": schema, }
} llm_kwargs = get_llm_kwargs(function)
]
kwargs = {"function_call": {"name": schema["title"]}}
messages = [ messages = [
SystemMessage( SystemMessage(
content=( content=(
@ -95,7 +94,7 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
chain = LLMChain( chain = LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs={**{"functions": functions}, **kwargs}, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
) )
return chain return chain

View File

@ -8,6 +8,7 @@ from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import ( from langchain.chains.openai_functions.utils import (
_convert_schema, _convert_schema,
_resolve_schema_references, _resolve_schema_references,
get_llm_kwargs,
) )
from langchain.output_parsers.openai_functions import ( from langchain.output_parsers.openai_functions import (
JsonKeyOutputFunctionsParser, JsonKeyOutputFunctionsParser,
@ -15,24 +16,19 @@ from langchain.output_parsers.openai_functions import (
) )
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _get_extraction_function(entity_schema: dict) -> dict:
def _get_extraction_functions(entity_schema: dict) -> List[dict]: return {
return [ "name": "information_extraction",
{ "description": "Extracts the relevant information from the passage.",
"name": EXTRACTION_NAME, "parameters": {
"description": "Extracts the relevant information from the passage.", "type": "object",
"parameters": { "properties": {
"type": "object", "info": {"type": "array", "items": _convert_schema(entity_schema)}
"properties": {
"info": {"type": "array", "items": _convert_schema(entity_schema)}
},
"required": ["info"],
}, },
} "required": ["info"],
] },
}
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\ _EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
@ -44,13 +40,14 @@ Passage:
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain: def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_extraction_functions(schema) function = _get_extraction_function(schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = JsonKeyOutputFunctionsParser(key_name="info") output_parser = JsonKeyOutputFunctionsParser(key_name="info")
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain( chain = LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
) )
return chain return chain
@ -67,15 +64,16 @@ def create_extraction_chain_pydantic(
openai_schema, openai_schema["definitions"] openai_schema, openai_schema["definitions"]
) )
functions = _get_extraction_functions(openai_schema) function = _get_extraction_function(openai_schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = PydanticAttrOutputFunctionsParser( output_parser = PydanticAttrOutputFunctionsParser(
pydantic_schema=PydanticSchema, attr_name="info" pydantic_schema=PydanticSchema, attr_name="info"
) )
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain( chain = LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
) )
return chain return chain

View File

@ -1,27 +1,22 @@
from typing import Any, List from typing import Any
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import _convert_schema from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
from langchain.output_parsers.openai_functions import ( from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser, JsonOutputFunctionsParser,
PydanticOutputFunctionsParser, PydanticOutputFunctionsParser,
) )
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _get_tagging_function(schema: dict) -> dict:
def _get_tagging_functions(schema: dict) -> List[dict]: return {
return [ "name": "information_extraction",
{ "description": "Extracts the relevant information from the passage.",
"name": EXTRACTION_NAME, "parameters": _convert_schema(schema),
"description": "Extracts the relevant information from the passage.", }
"parameters": _convert_schema(schema),
}
]
_TAGGING_TEMPLATE = """Extract the desired information from the following passage. _TAGGING_TEMPLATE = """Extract the desired information from the following passage.
@ -32,13 +27,14 @@ Passage:
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_tagging_functions(schema) function = _get_tagging_function(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = JsonOutputFunctionsParser() output_parser = JsonOutputFunctionsParser()
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain( chain = LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
) )
return chain return chain
@ -48,14 +44,14 @@ def create_tagging_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain: ) -> Chain:
openai_schema = pydantic_schema.schema() openai_schema = pydantic_schema.schema()
function = _get_tagging_function(openai_schema)
functions = _get_tagging_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain( chain = LLMChain(
llm=llm, llm=llm,
prompt=prompt, prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, llm_kwargs=llm_kwargs,
output_parser=output_parser, output_parser=output_parser,
) )
return chain return chain

View File

@ -26,3 +26,7 @@ def _convert_schema(schema: dict) -> dict:
"properties": props, "properties": props,
"required": schema.get("required", []), "required": schema.get("required", []),
} }
def get_llm_kwargs(function: dict) -> dict:
return {"functions": [function], "function_call": {"name": function["name"]}}