Harrison/openai functions (#6261)

Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com>
searx_updates
Harrison Chase 1 year ago committed by GitHub
parent 6aafb46807
commit e67b26eee9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@ import json
from functools import partial from functools import partial
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import BaseModel from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -15,6 +15,9 @@ from langchain.chains.transform import TransformChain
from langchain.prompts.base import BasePromptTemplate from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any: def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
""" """
@ -70,6 +73,7 @@ class OpenAIFunctionsChain(Chain):
prompt: BasePromptTemplate prompt: BasePromptTemplate
llm: BaseLanguageModel llm: BaseLanguageModel
functions: List[Dict] functions: List[Dict]
kwargs: Dict = Field(default_factory=dict)
@property @property
def input_keys(self) -> List[str]: def input_keys(self) -> List[str]:
@ -90,7 +94,7 @@ class OpenAIFunctionsChain(Chain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
predicted_message = self.llm.predict_messages( predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks messages, functions=self.functions, callbacks=callbacks, **self.kwargs
) )
return {"output": predicted_message} return {"output": predicted_message}
@ -105,7 +109,7 @@ class OpenAIFunctionsChain(Chain):
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager() _run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child() callbacks = _run_manager.get_child()
predicted_message = await self.llm.apredict_messages( predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks messages, functions=self.functions, callbacks=callbacks, **self.kwargs
) )
return {"output": predicted_message} return {"output": predicted_message}
@ -122,7 +126,7 @@ def _convert_schema(schema: dict) -> dict:
def _get_extraction_functions(entity_schema: dict) -> List[dict]: def _get_extraction_functions(entity_schema: dict) -> List[dict]:
return [ return [
{ {
"name": "information_extraction", "name": EXTRACTION_NAME,
"description": "Extracts the relevant information from the passage.", "description": "Extracts the relevant information from the passage.",
"parameters": { "parameters": {
"type": "object", "type": "object",
@ -138,7 +142,7 @@ def _get_extraction_functions(entity_schema: dict) -> List[dict]:
def _get_tagging_functions(schema: dict) -> List[dict]: def _get_tagging_functions(schema: dict) -> List[dict]:
return [ return [
{ {
"name": "information_extraction", "name": EXTRACTION_NAME,
"description": "Extracts the relevant information from the passage.", "description": "Extracts the relevant information from the passage.",
"parameters": _convert_schema(schema), "parameters": _convert_schema(schema),
} }
@ -156,7 +160,9 @@ Passage:
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain: def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_extraction_functions(schema) functions = _get_extraction_functions(schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
parsing_chain = TransformChain( parsing_chain = TransformChain(
transform=_parse_entities, transform=_parse_entities,
input_variables=["input"], input_variables=["input"],
@ -178,7 +184,9 @@ def create_extraction_chain_pydantic(
functions = _get_extraction_functions(openai_schema) functions = _get_extraction_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
pydantic_parsing_chain = TransformChain( pydantic_parsing_chain = TransformChain(
transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema), transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema),
input_variables=["input"], input_variables=["input"],
@ -197,7 +205,9 @@ Passage:
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain: def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_tagging_functions(schema) functions = _get_tagging_functions(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
parsing_chain = TransformChain( parsing_chain = TransformChain(
transform=_parse_tag, input_variables=["input"], output_variables=["output"] transform=_parse_tag, input_variables=["input"], output_variables=["output"]
) )
@ -211,7 +221,9 @@ def create_tagging_chain_pydantic(
functions = _get_tagging_functions(openai_schema) functions = _get_tagging_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE) prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions) chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
pydantic_parsing_chain = TransformChain( pydantic_parsing_chain = TransformChain(
transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema), transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema),
input_variables=["input"], input_variables=["input"],

Loading…
Cancel
Save