mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Harrison/refactor functions (#6408)
This commit is contained in:
parent
6a4a950a3c
commit
e9c2b280db
@ -218,7 +218,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 11,
|
||||
"id": "f771df58",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -229,7 +229,7 @@
|
||||
" Properties(person_name='Claudia', person_height=6, person_hair_color='brunette', dog_breed=None, dog_name=None)]"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -86,7 +86,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n"
|
||||
"question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo and was the president of the club for 2 years.', substring_quote=['started the Data Science club at the University of Waterloo', 'president of the club for 2 years'])]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -129,12 +129,10 @@
|
||||
"Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n",
|
||||
"As part of coop I...\n",
|
||||
"\n",
|
||||
"Statement: The author started the Data Science club at the University of Waterloo.\n",
|
||||
"Citation: ...titchfix, Facebook.\n",
|
||||
"*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
|
||||
"\n",
|
||||
"Statement: The author was the president of the Data Science club for 2 years.\n",
|
||||
"Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n",
|
||||
"Statement: The author started the Data Science club at the University of Waterloo and was the president of the club for 2 years.\n",
|
||||
"Citation: ...x, Facebook.\n",
|
||||
"I also *\u001b[91mstarted the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n",
|
||||
"Citation: ...erloo and I was the *\u001b[91mpresident of the club for 2 years\u001b[0m*.\n",
|
||||
"...\n",
|
||||
"\n"
|
||||
]
|
||||
|
@ -126,7 +126,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'es'}"
|
||||
"{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'Spanish'}"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
|
@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import get_llm_kwargs
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
@ -65,14 +66,12 @@ class QuestionAnswer(BaseModel):
|
||||
def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer)
|
||||
schema = QuestionAnswer.schema()
|
||||
functions = [
|
||||
{
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": schema,
|
||||
}
|
||||
]
|
||||
kwargs = {"function_call": {"name": schema["title"]}}
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content=(
|
||||
@ -95,7 +94,7 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain:
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **kwargs},
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
|
@ -8,6 +8,7 @@ from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import (
|
||||
_convert_schema,
|
||||
_resolve_schema_references,
|
||||
get_llm_kwargs,
|
||||
)
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
JsonKeyOutputFunctionsParser,
|
||||
@ -15,14 +16,10 @@ from langchain.output_parsers.openai_functions import (
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
def _get_extraction_function(entity_schema: dict) -> dict:
|
||||
return {
|
||||
"name": "information_extraction",
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
@ -32,7 +29,6 @@ def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||
"required": ["info"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
|
||||
@ -44,13 +40,14 @@ Passage:
|
||||
|
||||
|
||||
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_extraction_functions(schema)
|
||||
function = _get_extraction_function(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
@ -67,15 +64,16 @@ def create_extraction_chain_pydantic(
|
||||
openai_schema, openai_schema["definitions"]
|
||||
)
|
||||
|
||||
functions = _get_extraction_functions(openai_schema)
|
||||
function = _get_extraction_function(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
output_parser = PydanticAttrOutputFunctionsParser(
|
||||
pydantic_schema=PydanticSchema, attr_name="info"
|
||||
)
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
|
@ -1,27 +1,22 @@
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.openai_functions.utils import _convert_schema
|
||||
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
|
||||
from langchain.output_parsers.openai_functions import (
|
||||
JsonOutputFunctionsParser,
|
||||
PydanticOutputFunctionsParser,
|
||||
)
|
||||
from langchain.prompts import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _get_tagging_functions(schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": EXTRACTION_NAME,
|
||||
def _get_tagging_function(schema: dict) -> dict:
|
||||
return {
|
||||
"name": "information_extraction",
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": _convert_schema(schema),
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
||||
@ -32,13 +27,14 @@ Passage:
|
||||
|
||||
|
||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_tagging_functions(schema)
|
||||
function = _get_tagging_function(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = JsonOutputFunctionsParser()
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
@ -48,14 +44,14 @@ def create_tagging_chain_pydantic(
|
||||
pydantic_schema: Any, llm: BaseLanguageModel
|
||||
) -> Chain:
|
||||
openai_schema = pydantic_schema.schema()
|
||||
|
||||
functions = _get_tagging_functions(openai_schema)
|
||||
function = _get_tagging_function(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
||||
llm_kwargs = get_llm_kwargs(function)
|
||||
chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS},
|
||||
llm_kwargs=llm_kwargs,
|
||||
output_parser=output_parser,
|
||||
)
|
||||
return chain
|
||||
|
@ -26,3 +26,7 @@ def _convert_schema(schema: dict) -> dict:
|
||||
"properties": props,
|
||||
"required": schema.get("required", []),
|
||||
}
|
||||
|
||||
|
||||
def get_llm_kwargs(function: dict) -> dict:
|
||||
return {"functions": [function], "function_call": {"name": function["name"]}}
|
||||
|
Loading…
Reference in New Issue
Block a user