forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
from typing import Any
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.chains.base import Chain
|
|
from langchain.chains.llm import LLMChain
|
|
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
|
|
from langchain.output_parsers.openai_functions import (
|
|
JsonOutputFunctionsParser,
|
|
PydanticOutputFunctionsParser,
|
|
)
|
|
from langchain.prompts import ChatPromptTemplate
|
|
|
|
|
|
def _get_tagging_function(schema: dict) -> dict:
|
|
return {
|
|
"name": "information_extraction",
|
|
"description": "Extracts the relevant information from the passage.",
|
|
"parameters": _convert_schema(schema),
|
|
}
|
|
|
|
|
|
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
|
|
|
|
Passage:
|
|
{input}
|
|
"""
|
|
|
|
|
|
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
|
|
function = _get_tagging_function(schema)
|
|
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
|
output_parser = JsonOutputFunctionsParser()
|
|
llm_kwargs = get_llm_kwargs(function)
|
|
chain = LLMChain(
|
|
llm=llm,
|
|
prompt=prompt,
|
|
llm_kwargs=llm_kwargs,
|
|
output_parser=output_parser,
|
|
)
|
|
return chain
|
|
|
|
|
|
def create_tagging_chain_pydantic(
|
|
pydantic_schema: Any, llm: BaseLanguageModel
|
|
) -> Chain:
|
|
openai_schema = pydantic_schema.schema()
|
|
function = _get_tagging_function(openai_schema)
|
|
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
|
|
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
|
|
llm_kwargs = get_llm_kwargs(function)
|
|
chain = LLMChain(
|
|
llm=llm,
|
|
prompt=prompt,
|
|
llm_kwargs=llm_kwargs,
|
|
output_parser=output_parser,
|
|
)
|
|
return chain
|