diff --git a/docs/extras/modules/chains/additional/extraction.ipynb b/docs/extras/modules/chains/additional/extraction.ipynb index 9ede44dd..7e2a0258 100644 --- a/docs/extras/modules/chains/additional/extraction.ipynb +++ b/docs/extras/modules/chains/additional/extraction.ipynb @@ -218,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "id": "f771df58", "metadata": {}, "outputs": [ @@ -229,7 +229,7 @@ " Properties(person_name='Claudia', person_height=6, person_hair_color='brunette', dog_breed=None, dog_name=None)]" ] }, - "execution_count": 10, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/extras/modules/chains/additional/qa_citations.ipynb b/docs/extras/modules/chains/additional/qa_citations.ipynb index b53c3405..5eaf9e5d 100644 --- a/docs/extras/modules/chains/additional/qa_citations.ipynb +++ b/docs/extras/modules/chains/additional/qa_citations.ipynb @@ -86,7 +86,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo.', substring_quote=['I also started the Data Science club at the University of Waterloo']), FactWithEvidence(fact='The author was the president of the Data Science club for 2 years.', substring_quote=['I was the president of the club for 2 years'])]\n" + "question='What did the author do during college?' answer=[FactWithEvidence(fact='The author studied Computational Mathematics and physics in university.', substring_quote=['in university I studied Computational Mathematics and physics']), FactWithEvidence(fact='The author started the Data Science club at the University of Waterloo and was the president of the club for 2 years.', substring_quote=['started the Data Science club at the University of Waterloo', 'president of the club for 2 years'])]\n" ] } ], @@ -129,12 +129,10 @@ "Citation: ...arts highschool but *\u001b[91min university I studied Computational Mathematics and physics\u001b[0m*. \n", "As part of coop I...\n", "\n", - "Statement: The author started the Data Science club at the University of Waterloo.\n", - "Citation: ...titchfix, Facebook.\n", - "*\u001b[91mI also started the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n", - "\n", - "Statement: The author was the president of the Data Science club for 2 years.\n", - "Citation: ...ity of Waterloo and *\u001b[91mI was the president of the club for 2 years\u001b[0m*.\n", + "Statement: The author started the Data Science club at the University of Waterloo and was the president of the club for 2 years.\n", + "Citation: ...x, Facebook.\n", + "I also *\u001b[91mstarted the Data Science club at the University of Waterloo\u001b[0m* and I was the presi...\n", + "Citation: ...erloo and I was the *\u001b[91mpresident of the club for 2 years\u001b[0m*.\n", "...\n", "\n" ] diff --git a/docs/extras/modules/chains/additional/tagging.ipynb b/docs/extras/modules/chains/additional/tagging.ipynb index 7c6487fe..b51e3f6d 100644 --- a/docs/extras/modules/chains/additional/tagging.ipynb +++ b/docs/extras/modules/chains/additional/tagging.ipynb @@ -126,7 +126,7 @@ { "data": { "text/plain": [ - "{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'es'}" + "{'sentiment': 'enojado', 'aggressiveness': 1, 'language': 'Spanish'}" ] }, "execution_count": 6, diff --git a/langchain/chains/openai_functions/citation_fuzzy_match.py b/langchain/chains/openai_functions/citation_fuzzy_match.py index 9f4ba6f9..43a8c197 100644 --- a/langchain/chains/openai_functions/citation_fuzzy_match.py +++ b/langchain/chains/openai_functions/citation_fuzzy_match.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from langchain.base_language import BaseLanguageModel from langchain.chains.llm import LLMChain +from langchain.chains.openai_functions.utils import get_llm_kwargs from langchain.output_parsers.openai_functions import ( PydanticOutputFunctionsParser, ) @@ -65,14 +66,12 @@ class QuestionAnswer(BaseModel): def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain: output_parser = PydanticOutputFunctionsParser(pydantic_schema=QuestionAnswer) schema = QuestionAnswer.schema() - functions = [ - { - "name": schema["title"], - "description": schema["description"], - "parameters": schema, - } - ] - kwargs = {"function_call": {"name": schema["title"]}} + function = { + "name": schema["title"], + "description": schema["description"], + "parameters": schema, + } + llm_kwargs = get_llm_kwargs(function) messages = [ SystemMessage( content=( @@ -95,7 +94,7 @@ def create_citation_fuzzy_match_chain(llm: BaseLanguageModel) -> LLMChain: chain = LLMChain( llm=llm, prompt=prompt, - llm_kwargs={**{"functions": functions}, **kwargs}, + llm_kwargs=llm_kwargs, output_parser=output_parser, ) return chain diff --git a/langchain/chains/openai_functions/extraction.py b/langchain/chains/openai_functions/extraction.py index 0db5ddd6..fd195175 100644 --- a/langchain/chains/openai_functions/extraction.py +++ b/langchain/chains/openai_functions/extraction.py @@ -8,6 +8,7 @@ from langchain.chains.llm import LLMChain from langchain.chains.openai_functions.utils import ( _convert_schema, _resolve_schema_references, + get_llm_kwargs, ) from langchain.output_parsers.openai_functions import ( JsonKeyOutputFunctionsParser, @@ -15,24 +16,19 @@ from langchain.output_parsers.openai_functions import ( ) from langchain.prompts import ChatPromptTemplate -EXTRACTION_NAME = "information_extraction" -EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}} - -def _get_extraction_functions(entity_schema: dict) -> List[dict]: - return [ - { - "name": EXTRACTION_NAME, - "description": "Extracts the relevant information from the passage.", - "parameters": { - "type": "object", - "properties": { - "info": {"type": "array", "items": _convert_schema(entity_schema)} - }, - "required": ["info"], +def _get_extraction_function(entity_schema: dict) -> dict: + return { + "name": "information_extraction", + "description": "Extracts the relevant information from the passage.", + "parameters": { + "type": "object", + "properties": { + "info": {"type": "array", "items": _convert_schema(entity_schema)} }, - } - ] + "required": ["info"], + }, + } _EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\ @@ -44,13 +40,14 @@ Passage: def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain: - functions = _get_extraction_functions(schema) + function = _get_extraction_function(schema) prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) output_parser = JsonKeyOutputFunctionsParser(key_name="info") + llm_kwargs = get_llm_kwargs(function) chain = LLMChain( llm=llm, prompt=prompt, - llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, + llm_kwargs=llm_kwargs, output_parser=output_parser, ) return chain @@ -67,15 +64,16 @@ def create_extraction_chain_pydantic( openai_schema, openai_schema["definitions"] ) - functions = _get_extraction_functions(openai_schema) + function = _get_extraction_function(openai_schema) prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) output_parser = PydanticAttrOutputFunctionsParser( pydantic_schema=PydanticSchema, attr_name="info" ) + llm_kwargs = get_llm_kwargs(function) chain = LLMChain( llm=llm, prompt=prompt, - llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, + llm_kwargs=llm_kwargs, output_parser=output_parser, ) return chain diff --git a/langchain/chains/openai_functions/tagging.py b/langchain/chains/openai_functions/tagging.py index 61c2fbfa..e568eb68 100644 --- a/langchain/chains/openai_functions/tagging.py +++ b/langchain/chains/openai_functions/tagging.py @@ -1,27 +1,22 @@ -from typing import Any, List +from typing import Any from langchain.base_language import BaseLanguageModel from langchain.chains.base import Chain from langchain.chains.llm import LLMChain -from langchain.chains.openai_functions.utils import _convert_schema +from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs from langchain.output_parsers.openai_functions import ( JsonOutputFunctionsParser, PydanticOutputFunctionsParser, ) from langchain.prompts import ChatPromptTemplate -EXTRACTION_NAME = "information_extraction" -EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}} - -def _get_tagging_functions(schema: dict) -> List[dict]: - return [ - { - "name": EXTRACTION_NAME, - "description": "Extracts the relevant information from the passage.", - "parameters": _convert_schema(schema), - } - ] +def _get_tagging_function(schema: dict) -> dict: + return { + "name": "information_extraction", + "description": "Extracts the relevant information from the passage.", + "parameters": _convert_schema(schema), + } _TAGGING_TEMPLATE = """Extract the desired information from the following passage. @@ -32,13 +27,14 @@ Passage: def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: - functions = _get_tagging_functions(schema) + function = _get_tagging_function(schema) prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) output_parser = JsonOutputFunctionsParser() + llm_kwargs = get_llm_kwargs(function) chain = LLMChain( llm=llm, prompt=prompt, - llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, + llm_kwargs=llm_kwargs, output_parser=output_parser, ) return chain @@ -48,14 +44,14 @@ def create_tagging_chain_pydantic( pydantic_schema: Any, llm: BaseLanguageModel ) -> Chain: openai_schema = pydantic_schema.schema() - - functions = _get_tagging_functions(openai_schema) + function = _get_tagging_function(openai_schema) prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema) + llm_kwargs = get_llm_kwargs(function) chain = LLMChain( llm=llm, prompt=prompt, - llm_kwargs={**{"functions": functions}, **EXTRACTION_KWARGS}, + llm_kwargs=llm_kwargs, output_parser=output_parser, ) return chain diff --git a/langchain/chains/openai_functions/utils.py b/langchain/chains/openai_functions/utils.py index 4ad2e3e6..9f5a0591 100644 --- a/langchain/chains/openai_functions/utils.py +++ b/langchain/chains/openai_functions/utils.py @@ -26,3 +26,7 @@ def _convert_schema(schema: dict) -> dict: "properties": props, "required": schema.get("required", []), } + + +def get_llm_kwargs(function: dict) -> dict: + return {"functions": [function], "function_call": {"name": function["name"]}}