You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chains/openai_functions/tagging.py

58 lines
1.7 KiB
Python

from typing import Any
from langchain.base_language import BaseLanguageModel
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.prompts import ChatPromptTemplate
def _get_tagging_function(schema: dict) -> dict:
return {
"name": "information_extraction",
"description": "Extracts the relevant information from the passage.",
"parameters": _convert_schema(schema),
}
_TAGGING_TEMPLATE = """Extract the desired information from the following passage.
Passage:
{input}
"""
def create_tagging_chain(schema: dict, llm: BaseLanguageModel) -> Chain:
function = _get_tagging_function(schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = JsonOutputFunctionsParser()
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
)
return chain
def create_tagging_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
) -> Chain:
openai_schema = pydantic_schema.schema()
function = _get_tagging_function(openai_schema)
prompt = ChatPromptTemplate.from_template(_TAGGING_TEMPLATE)
output_parser = PydanticOutputFunctionsParser(pydantic_schema=pydantic_schema)
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
llm=llm,
prompt=prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
)
return chain