forked from Archives/langchain
Harrison/openai functions (#6261)
Co-authored-by: Francisco Ingham <24279597+fpingham@users.noreply.github.com>
This commit is contained in:
parent
6aafb46807
commit
e67b26eee9
@ -2,7 +2,7 @@ import json
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
@ -15,6 +15,9 @@ from langchain.chains.transform import TransformChain
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
|
||||
EXTRACTION_NAME = "information_extraction"
|
||||
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
||||
|
||||
|
||||
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
@ -70,6 +73,7 @@ class OpenAIFunctionsChain(Chain):
|
||||
prompt: BasePromptTemplate
|
||||
llm: BaseLanguageModel
|
||||
functions: List[Dict]
|
||||
kwargs: Dict = Field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
@ -90,7 +94,7 @@ class OpenAIFunctionsChain(Chain):
|
||||
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
predicted_message = self.llm.predict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
||||
)
|
||||
return {"output": predicted_message}
|
||||
|
||||
@ -105,7 +109,7 @@ class OpenAIFunctionsChain(Chain):
|
||||
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
||||
callbacks = _run_manager.get_child()
|
||||
predicted_message = await self.llm.apredict_messages(
|
||||
messages, functions=self.functions, callbacks=callbacks
|
||||
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
||||
)
|
||||
return {"output": predicted_message}
|
||||
|
||||
@ -122,7 +126,7 @@ def _convert_schema(schema: dict) -> dict:
|
||||
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": "information_extraction",
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
@ -138,7 +142,7 @@ def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
||||
def _get_tagging_functions(schema: dict) -> List[dict]:
|
||||
return [
|
||||
{
|
||||
"name": "information_extraction",
|
||||
"name": EXTRACTION_NAME,
|
||||
"description": "Extracts the relevant information from the passage.",
|
||||
"parameters": _convert_schema(schema),
|
||||
}
|
||||
@ -156,7 +160,9 @@ Passage:
|
||||
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_extraction_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
parsing_chain = TransformChain(
|
||||
transform=_parse_entities,
|
||||
input_variables=["input"],
|
||||
@ -178,7 +184,9 @@ def create_extraction_chain_pydantic(
|
||||
|
||||
functions = _get_extraction_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
pydantic_parsing_chain = TransformChain(
|
||||
transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema),
|
||||
input_variables=["input"],
|
||||
@ -197,7 +205,9 @@ Passage:
|
||||
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
||||
functions = _get_tagging_functions(schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
parsing_chain = TransformChain(
|
||||
transform=_parse_tag, input_variables=["input"], output_variables=["output"]
|
||||
)
|
||||
@ -211,7 +221,9 @@ def create_tagging_chain_pydantic(
|
||||
|
||||
functions = _get_tagging_functions(openai_schema)
|
||||
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
||||
chain = OpenAIFunctionsChain(llm=llm, prompt=prompt, functions=functions)
|
||||
chain = OpenAIFunctionsChain(
|
||||
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
||||
)
|
||||
pydantic_parsing_chain = TransformChain(
|
||||
transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema),
|
||||
input_variables=["input"],
|
||||
|
Loading…
Reference in New Issue
Block a user