From e67b26eee9902e38d0890e5c115ae345524eebf3 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Thu, 15 Jun 2023 21:54:39 -0700 Subject: [PATCH] Harrison/openai functions (#6261) Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com> --- langchain/chains/openai_functions.py | 30 +++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/langchain/chains/openai_functions.py b/langchain/chains/openai_functions.py index e558289c..5b0a7f2d 100644 --- a/langchain/chains/openai_functions.py +++ b/langchain/chains/openai_functions.py @@ -2,7 +2,7 @@ import json from functools import partial from typing import Any, Dict, List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import ( @@ -15,6 +15,9 @@ from langchain.chains.transform import TransformChain from langchain.prompts.base import BasePromptTemplate from langchain.prompts.chat import ChatPromptTemplate +EXTRACTION_NAME = "information_extraction" +EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}} + def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any: """ @@ -70,6 +73,7 @@ class OpenAIFunctionsChain(Chain): prompt: BasePromptTemplate llm: BaseLanguageModel functions: List[Dict] + kwargs: Dict = Field(default_factory=dict) @property def input_keys(self) -> List[str]: @@ -90,7 +94,7 @@ class OpenAIFunctionsChain(Chain): _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() predicted_message = self.llm.predict_messages( - messages, functions=self.functions, callbacks=callbacks + messages, functions=self.functions, callbacks=callbacks, **self.kwargs ) return {"output": predicted_message} @@ -105,7 +109,7 @@ class OpenAIFunctionsChain(Chain): _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() callbacks = _run_manager.get_child() predicted_message = await self.llm.apredict_messages( - messages, functions=self.functions, callbacks=callbacks + messages, functions=self.functions, callbacks=callbacks, **self.kwargs ) return {"output": predicted_message} @@ -122,7 +126,7 @@ def _convert_schema(schema: dict) -> dict: def _get_extraction_functions(entity_schema: dict) -> List[dict]: return [ { - "name": "information_extraction", + "name": EXTRACTION_NAME, "description": "Extracts the relevant information from the passage.", "parameters": { "type": "object", @@ -138,7 +142,7 @@ def _get_extraction_functions(entity_schema: dict) -> List[dict]: def _get_tagging_functions(schema: dict) -> List[dict]: return [ { - "name": "information_extraction", + "name": EXTRACTION_NAME, "description": "Extracts the relevant information from the passage.", "parameters": _convert_schema(schema), } @@ -156,7 +160,9 @@ Passage: def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain: functions = _get_extraction_functions(schema) prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) - chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) + chain = OpenAIFunctionsChain( + llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS + ) parsing_chain = TransformChain( transform=_parse_entities, input_variables=["input"], @@ -178,7 +184,9 @@ def create_extraction_chain_pydantic( functions = _get_extraction_functions(openai_schema) prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) - chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) + chain = OpenAIFunctionsChain( + llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS + ) pydantic_parsing_chain = TransformChain( transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema), input_variables=["input"], @@ -197,7 +205,9 @@ Passage: def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: functions = _get_tagging_functions(schema) prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) - chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) + chain = OpenAIFunctionsChain( + llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS + ) parsing_chain = TransformChain( transform=_parse_tag, input_variables=["input"], output_variables=["output"] ) @@ -211,7 +221,9 @@ def create_tagging_chain_pydantic( functions = _get_tagging_functions(openai_schema) prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) - chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) + chain = OpenAIFunctionsChain( + llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS + ) pydantic_parsing_chain = TransformChain( transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema), input_variables=["input"],