Harrison/openai functions (#6261)

Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-06-15 21:54:39 -07:00 committed by GitHub
parent 6aafb46807
commit e67b26eee9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@ import json
from functools import partial
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
@ -15,6 +15,9 @@ from langchain.chains.transform import TransformChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
EXTRACTION_NAME = "information_extraction"
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
"""
@ -70,6 +73,7 @@ class OpenAIFunctionsChain(Chain):
prompt: BasePromptTemplate
llm: BaseLanguageModel
functions: List[Dict]
kwargs: Dict = Field(default_factory=dict)
@property
def input_keys(self) -> List[str]:
@ -90,7 +94,7 @@ class OpenAIFunctionsChain(Chain):
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
)
return {"output": predicted_message}
@ -105,7 +109,7 @@ class OpenAIFunctionsChain(Chain):
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
callbacks = _run_manager.get_child()
predicted_message = await self.llm.apredict_messages(
messages, functions=self.functions, callbacks=callbacks
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
)
return {"output": predicted_message}
@ -122,7 +126,7 @@ def _convert_schema(schema: dict) -> dict:
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
return [
{
"name": "information_extraction",
"name": EXTRACTION_NAME,
"description": "Extracts the relevant information from the passage.",
"parameters": {
"type": "object",
@ -138,7 +142,7 @@ def _get_extraction_functions(entity_schema: dict) -> List[dict]:
def _get_tagging_functions(schema: dict) -> List[dict]:
return [
{
"name": "information_extraction",
"name": EXTRACTION_NAME,
"description": "Extracts the relevant information from the passage.",
"parameters": _convert_schema(schema),
}
@ -156,7 +160,9 @@ Passage:
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_extraction_functions(schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions)
chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
parsing_chain = TransformChain(
transform=_parse_entities,
input_variables=["input"],
@ -178,7 +184,9 @@ def create_extraction_chain_pydantic(
functions = _get_extraction_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions)
chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
pydantic_parsing_chain = TransformChain(
transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema),
input_variables=["input"],
@ -197,7 +205,9 @@ Passage:
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
functions = _get_tagging_functions(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions)
chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
parsing_chain = TransformChain(
transform=_parse_tag, input_variables=["input"], output_variables=["output"]
)
@ -211,7 +221,9 @@ def create_tagging_chain_pydantic(
functions = _get_tagging_functions(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions)
chain = OpenAIFunctionsChain(
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
)
pydantic_parsing_chain = TransformChain(
transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema),
input_variables=["input"],