forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
234 lines
7.6 KiB
Python
234 lines
7.6 KiB
Python
import json
|
|
from functools import partial
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.callbacks.manager import (
|
|
AsyncCallbackManagerForChainRun,
|
|
CallbackManagerForChainRun,
|
|
)
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.sequential import SimpleSequentialChain
|
|
from langchain.chains.transform import TransformChain
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
from langchain.prompts.chat import ChatPromptTemplate
|
|
|
|
EXTRACTION_NAME = "information_extraction"
|
|
EXTRACTION_KWARGS = {"function_call": {"name": "information_extraction"}}
|
|
|
|
|
|
def _resolve_schema_references(schema: Any, definitions: Dict[str, Any]) -> Any:
|
|
"""
|
|
Resolves the $ref keys in a JSON schema object using the provided definitions.
|
|
"""
|
|
if isinstance(schema, list):
|
|
for i, item in enumerate(schema):
|
|
schema[i] = _resolve_schema_references(item, definitions)
|
|
elif isinstance(schema, dict):
|
|
if "$ref" in schema:
|
|
ref_key = schema.pop("$ref").split("/")[-1]
|
|
ref = definitions.get(ref_key, {})
|
|
schema.update(ref)
|
|
else:
|
|
for key, value in schema.items():
|
|
schema[key] = _resolve_schema_references(value, definitions)
|
|
return schema
|
|
|
|
|
|
def _get_function_arguments(inputs: dict) -> str:
|
|
message = inputs["input"]
|
|
try:
|
|
func_call = message.additional_kwargs["function_call"]
|
|
except ValueError as exc:
|
|
raise ValueError(f"Could not parse function call: {exc}")
|
|
|
|
return func_call["arguments"]
|
|
|
|
|
|
def _parse_tag(inputs: dict) -> dict:
|
|
args = _get_function_arguments(inputs)
|
|
return {"output": json.loads(args)}
|
|
|
|
|
|
def _parse_tag_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
|
|
args = _get_function_arguments(inputs)
|
|
args = pydantic_schema.parse_raw(args)
|
|
return {"output": args}
|
|
|
|
|
|
def _parse_entities(inputs: dict) -> dict:
|
|
args = _get_function_arguments(inputs)
|
|
return {"output": json.loads(args)["info"]}
|
|
|
|
|
|
def _parse_entities_pydantic(inputs: dict, pydantic_schema: Any) -> dict:
|
|
args = _get_function_arguments(inputs)
|
|
pydantic_args = pydantic_schema.parse_raw(args)
|
|
return {"output": pydantic_args.info}
|
|
|
|
|
|
class OpenAIFunctionsChain(Chain):
|
|
prompt: BasePromptTemplate
|
|
llm: BaseLanguageModel
|
|
functions: List[Dict]
|
|
kwargs: Dict = Field(default_factory=dict)
|
|
|
|
@property
|
|
def input_keys(self) -> List[str]:
|
|
return self.prompt.input_variables
|
|
|
|
@property
|
|
def output_keys(self) -> List[str]:
|
|
return ["output"]
|
|
|
|
def _call(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
|
|
prompt = self.prompt.format_prompt(**_inputs)
|
|
messages = prompt.to_messages()
|
|
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
|
callbacks = _run_manager.get_child()
|
|
predicted_message = self.llm.predict_messages(
|
|
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
|
)
|
|
return {"output": predicted_message}
|
|
|
|
async def _acall(
|
|
self,
|
|
inputs: Dict[str, Any],
|
|
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
|
) -> Dict[str, Any]:
|
|
_inputs = {k: v for k, v in inputs.items() if k in self.prompt.input_variables}
|
|
prompt = self.prompt.format_prompt(**_inputs)
|
|
messages = prompt.to_messages()
|
|
_run_manager = run_manager or AsyncCallbackManagerForChainRun.get_noop_manager()
|
|
callbacks = _run_manager.get_child()
|
|
predicted_message = await self.llm.apredict_messages(
|
|
messages, functions=self.functions, callbacks=callbacks, **self.kwargs
|
|
)
|
|
return {"output": predicted_message}
|
|
|
|
|
|
def _convert_schema(schema: dict) -> dict:
|
|
props = {k: {"title": k, **v} for k, v in schema["properties"].items()}
|
|
return {
|
|
"type": "object",
|
|
"properties": props,
|
|
"required": schema.get("required", []),
|
|
}
|
|
|
|
|
|
def _get_extraction_functions(entity_schema: dict) -> List[dict]:
|
|
return [
|
|
{
|
|
"name": EXTRACTION_NAME,
|
|
"description": "Extracts the relevant information from the passage.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"info": {"type": "array", "items": _convert_schema(entity_schema)}
|
|
},
|
|
"required": ["info"],
|
|
},
|
|
}
|
|
]
|
|
|
|
|
|
def _get_tagging_functions(schema: dict) -> List[dict]:
|
|
return [
|
|
{
|
|
"name": EXTRACTION_NAME,
|
|
"description": "Extracts the relevant information from the passage.",
|
|
"parameters": _convert_schema(schema),
|
|
}
|
|
]
|
|
|
|
|
|
_EXTRACTION_TEMPLATE = """Extract and save the relevant entities mentioned\
|
|
in the following passage together with their properties.
|
|
|
|
Passage:
|
|
{input}
|
|
"""
|
|
|
|
|
|
def create_extraction_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
|
functions = _get_extraction_functions(schema)
|
|
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
|
chain = OpenAIFunctionsChain(
|
|
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
|
)
|
|
parsing_chain = TransformChain(
|
|
transform=_parse_entities,
|
|
input_variables=["input"],
|
|
output_variables=["output"],
|
|
)
|
|
return SimpleSequentialChain(chains=[chain, parsing_chain])
|
|
|
|
|
|
def create_extraction_chain_pydantic(
|
|
pydantic_schema: Any, llm: BaseLanguageModel
|
|
) -> Chain:
|
|
class PydanticSchema(BaseModel):
|
|
info: List[pydantic_schema] # type: ignore
|
|
|
|
openai_schema = PydanticSchema.schema()
|
|
openai_schema = _resolve_schema_references(
|
|
openai_schema, openai_schema["definitions"]
|
|
)
|
|
|
|
functions = _get_extraction_functions(openai_schema)
|
|
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
|
|
chain = OpenAIFunctionsChain(
|
|
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
|
)
|
|
pydantic_parsing_chain = TransformChain(
|
|
transform=partial(_parse_entities_pydantic, pydantic_schema=PydanticSchema),
|
|
input_variables=["input"],
|
|
output_variables=["output"],
|
|
)
|
|
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
|
|
|
|
|
|
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
|
|
|
Passage:
|
|
{input}
|
|
"""
|
|
|
|
|
|
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
|
functions = _get_tagging_functions(schema)
|
|
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
|
chain = OpenAIFunctionsChain(
|
|
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
|
)
|
|
parsing_chain = TransformChain(
|
|
transform=_parse_tag, input_variables=["input"], output_variables=["output"]
|
|
)
|
|
return SimpleSequentialChain(chains=[chain, parsing_chain])
|
|
|
|
|
|
def create_tagging_chain_pydantic(
|
|
pydantic_schema: Any, llm: BaseLanguageModel
|
|
) -> Chain:
|
|
openai_schema = pydantic_schema.schema()
|
|
|
|
functions = _get_tagging_functions(openai_schema)
|
|
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
|
chain = OpenAIFunctionsChain(
|
|
llm=llm, prompt=prompt, functions=functions, kwargs=EXTRACTION_KWARGS
|
|
)
|
|
pydantic_parsing_chain = TransformChain(
|
|
transform=partial(_parse_tag_pydantic, pydantic_schema=pydantic_schema),
|
|
input_variables=["input"],
|
|
output_variables=["output"],
|
|
)
|
|
|
|
return SimpleSequentialChain(chains=[chain, pydantic_parsing_chain])
|