From 7d8bb78e5c54dcd80ae76acfe9cff11f96f3b8b0 Mon Sep 17 00:00:00 2001 From: Cameron Hutchison Date: Sun, 3 Sep 2023 16:01:55 -0700 Subject: [PATCH] Extraction Chain - Custom Prompt (#9828) # Description This change allows you to customize the prompt used in `create_extraction_chain` as well as `create_extraction_chain_pydantic`. It also adds the `verbose` argument to `create_extraction_chain_pydantic` - because `create_extraction_chain` had it already and `create_extraction_chain_pydantic` did not. # Issue N/A # Dependencies N/A # Twitter https://twitter.com/CamAHutchison --- .../chains/openai_functions/extraction.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/chains/openai_functions/extraction.py b/libs/langchain/langchain/chains/openai_functions/extraction.py index c881537866..284b97c416 100644 --- a/libs/langchain/langchain/chains/openai_functions/extraction.py +++ b/libs/langchain/langchain/chains/openai_functions/extraction.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Optional from langchain.chains.base import Chain from langchain.chains.llm import LLMChain @@ -13,6 +13,7 @@ from langchain.output_parsers.openai_functions import ( ) from langchain.prompts import ChatPromptTemplate from langchain.pydantic_v1 import BaseModel +from langchain.schema import BasePromptTemplate from langchain.schema.language_model import BaseLanguageModel @@ -43,13 +44,17 @@ Passage: def create_extraction_chain( - schema: dict, llm: BaseLanguageModel, verbose: bool = False + schema: dict, + llm: BaseLanguageModel, + prompt: Optional[BasePromptTemplate] = None, + verbose: bool = False, ) -> Chain: """Creates a chain that extracts information from a passage. Args: schema: The schema of the entities to extract. llm: The language model to use. + prompt: The prompt to use for extraction. verbose: Whether to run in verbose mode. In verbose mode, some intermediate logs will be printed to the console. Defaults to `langchain.verbose` value. @@ -57,12 +62,12 @@ def create_extraction_chain( Chain that can be used to extract information from a passage. """ function = _get_extraction_function(schema) - prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) + extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) output_parser = JsonKeyOutputFunctionsParser(key_name="info") llm_kwargs = get_llm_kwargs(function) chain = LLMChain( llm=llm, - prompt=prompt, + prompt=extraction_prompt, llm_kwargs=llm_kwargs, output_parser=output_parser, verbose=verbose, @@ -71,13 +76,19 @@ def create_extraction_chain( def create_extraction_chain_pydantic( - pydantic_schema: Any, llm: BaseLanguageModel + pydantic_schema: Any, + llm: BaseLanguageModel, + prompt: Optional[BasePromptTemplate] = None, + verbose: bool = False, ) -> Chain: """Creates a chain that extracts information from a passage using pydantic schema. Args: pydantic_schema: The pydantic schema of the entities to extract. llm: The language model to use. + prompt: The prompt to use for extraction. + verbose: Whether to run in verbose mode. In verbose mode, some intermediate + logs will be printed to the console. Defaults to `langchain.verbose` value. Returns: Chain that can be used to extract information from a passage. @@ -92,15 +103,16 @@ def create_extraction_chain_pydantic( ) function = _get_extraction_function(openai_schema) - prompt = ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) + extraction_prompt = prompt or ChatPromptTemplate.from_template(_EXTRACTION_TEMPLATE) output_parser = PydanticAttrOutputFunctionsParser( pydantic_schema=PydanticSchema, attr_name="info" ) llm_kwargs = get_llm_kwargs(function) chain = LLMChain( llm=llm, - prompt=prompt, + prompt=extraction_prompt, llm_kwargs=llm_kwargs, output_parser=output_parser, + verbose=verbose, ) return chain