Harrison/refactor functions (#6408)

This commit is contained in:
Harrison Chase 2023-06-18 23:13:42 -07:00 committed by GitHub
parent 6a4a950a3c
commit e9c2b280db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 52 additions and 57 deletions

View File

@ -218,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 11,
"id": "f771df58",
"metadata": {},
"outputs": [
@ -229,7 +229,7 @@
" Properties(person_name='Claudia', person_height=6, person_hair_color='brunette', dog_breed=None, dog_name=None)]"
]
},
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}

View File

@ -86,7 +86,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n"
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo and was the president of the club for 2 years.', substring_quote=['started the Data Science club at the University of Waterloo', 'president of the club for 2 years'])]\n"
]
}
],
@ -129,12 +129,10 @@
"Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n",
"As part of coop I...\n",
"\n",
"Statement: The author started the Data Science club at the University of Waterloo.\n",
"Citation: ...titchfix, Facebook.\n",
"*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
"\n",
"Statement: The author was the president of the Data Science club for 2 years.\n",
"Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n",
"Statement: The author started the Data Science club at the University of Waterloo and was the president of the club for 2 years.\n",
"Citation: ...x, Facebook.\n",
"I also *\u001b[91mstarted the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
"Citation: ...erloo and I was the *\u001b[91mpresident of the club for 2 years\u001b[0m*.\n",
"...\n",
"\n"
]

View File

@ -126,7 +126,7 @@
{
"data": {
"text/plain": [
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'es'}"
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'Spanish'}"
]
},
"execution_count": 6,

View File

@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs
from langchain.output_parsers.openai_functions import (
PydanticOutputFunctionsParser,
)
@ -65,14 +66,12 @@ class QuestionAnswer(BaseModel):
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
schema = QuestionAnswer.schema()
functions = [
{
function = {
"name": schema["title"],
"description": schema["description"],
"parameters": schema,
}
]
kwargs = {"function_call": {"name": schema["title"]}}
llm_kwargs = get_llm_kwargs(function)
messages = [
SystemMessage(
content=(
@ -95,7 +94,7 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **kwargs},
llm_kwargs=llm_kwargs,
output_parser=output_parser,
)
return chain

View File

@ -8,6 +8,7 @@ from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import (
_convert_schema,
_resolve_schema_references,
get_llm_kwargs,
)
from langchain.output_parsers.openai_functions import (
JsonKeyOutputFunctionsParser,
@ -15,14 +16,10 @@ from langchain.output_parsers.openai_functions import (
)
from langchain.prompts import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
return [
{
"name": EXTRACTION_NAME,
def _get_extraction_function(entity_schema: dict) -> dict:
return {
"name": "information_extraction",
"description": "Extracts the relevant information from the passage.",
"parameters": {
"type": "object",
@ -32,7 +29,6 @@ def _get_extraction_functions(entity_schema: dict) -> List[dict]:
"required": ["info"],
},
}
]
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
@ -44,13 +40,14 @@ Passage:
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_extraction_functions(schema)
function = _get_extraction_function(schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
llm_kwargs=llm_kwargs,
output_parser=output_parser,
)
return chain
@ -67,15 +64,16 @@ def create_extraction_chain_pydantic(
openai_schema, openai_schema["definitions"]
)
functions = _get_extraction_functions(openai_schema)
function = _get_extraction_function(openai_schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = PydanticAttrOutputFunctionsParser(
pydantic_schema=PydanticSchema, attr_name="info"
)
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
llm_kwargs=llm_kwargs,
output_parser=output_parser,
)
return chain

View File

@ -1,27 +1,22 @@
from typing import Any, List
from typing import Any
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import _convert_schema
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.prompts import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _get_tagging_functions(schema: dict) -> List[dict]:
return [
{
"name": EXTRACTION_NAME,
def _get_tagging_function(schema: dict) -> dict:
return {
"name": "information_extraction",
"description": "Extracts the relevant information from the passage.",
"parameters": _convert_schema(schema),
}
]
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
@ -32,13 +27,14 @@ Passage:
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_tagging_functions(schema)
function = _get_tagging_function(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = JsonOutputFunctionsParser()
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
llm_kwargs=llm_kwargs,
output_parser=output_parser,
)
return chain
@ -48,14 +44,14 @@ def create_tagging_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain:
openai_schema = pydantic_schema.schema()
functions = _get_tagging_functions(openai_schema)
function = _get_tagging_function(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
llm_kwargs=llm_kwargs,
output_parser=output_parser,
)
return chain

View File

@ -26,3 +26,7 @@ def _convert_schema(schema: dict) -> dict:
"properties": props,
"required": schema.get("required", []),
}
def get_llm_kwargs(function: dict) -> dict:
return {"functions": [function], "function_call": {"name": function["name"]}}