Extraction Chain - Custom Prompt (#9828)

# Description

This change allows you to customize the prompt used in
`create_extraction_chain` as well as `create_extraction_chain_pydantic`.

It also adds the `verbose` argument to
`create_extraction_chain_pydantic` - because `create_extraction_chain`
had it already and `create_extraction_chain_pydantic` did not.

# Issue
N/A

# Dependencies
N/A

# Twitter
https://twitter.com/CamAHutchison
This commit is contained in:
Cameron Hutchison 2023-09-03 16:01:55 -07:00 committed by GitHub
parent 33f43cc1b0
commit 7d8bb78e5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
from typing import Any, List
from typing import Any, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
@ -13,6 +13,7 @@ from langchain.output_parsers.openai_functions import (
)
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
@ -43,13 +44,17 @@ Passage:
def create_extraction_chain(
schema: dict, llm: BaseLanguageModel, verbose: bool = False
schema: dict,
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
verbose: bool = False,
) -> Chain:
"""Creates a chain that extracts information from a passage.
Args:
schema: The schema of the entities to extract.
llm: The language model to use.
prompt: The prompt to use for extraction.
verbose: Whether to run in verbose mode. In verbose mode, some intermediate
logs will be printed to the console. Defaults to `langchain.verbose` value.
@ -57,12 +62,12 @@ def create_extraction_chain(
Chain that can be used to extract information from a passage.
"""
function = _get_extraction_function(schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = JsonKeyOutputFunctionsParser(key_name="info")
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
llm=llm,
prompt=prompt,
prompt=extraction_prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
verbose=verbose,
@ -71,13 +76,19 @@ def create_extraction_chain(
def create_extraction_chain_pydantic(
pydantic_schema: Any, llm: BaseLanguageModel
pydantic_schema: Any,
llm: BaseLanguageModel,
prompt: Optional[BasePromptTemplate] = None,
verbose: bool = False,
) -> Chain:
"""Creates a chain that extracts information from a passage using pydantic schema.
Args:
pydantic_schema: The pydantic schema of the entities to extract.
llm: The language model to use.
prompt: The prompt to use for extraction.
verbose: Whether to run in verbose mode. In verbose mode, some intermediate
logs will be printed to the console. Defaults to `langchain.verbose` value.
Returns:
Chain that can be used to extract information from a passage.
@ -92,15 +103,16 @@ def create_extraction_chain_pydantic(
)
function = _get_extraction_function(openai_schema)
prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE)
output_parser = PydanticAttrOutputFunctionsParser(
pydantic_schema=PydanticSchema, attr_name="info"
)
llm_kwargs = get_llm_kwargs(function)
chain = LLMChain(
llm=llm,
prompt=prompt,
prompt=extraction_prompt,
llm_kwargs=llm_kwargs,
output_parser=output_parser,
verbose=verbose,
)
return chain