diff --git a/libs/langchain/langchain/chains/__init__.py b/libs/langchain/langchain/chains/__init__.py index 3bd6b4e4cb..fdde1cce6d 100644 --- a/libs/langchain/langchain/chains/__init__.py +++ b/libs/langchain/langchain/chains/__init__.py @@ -26,6 +26,7 @@ from langchain.chains.conversational_retrieval.base import ( ChatVectorDBChain, ConversationalRetrievalChain, ) +from langchain.chains.example_generator import generate_example from langchain.chains.flare.base import FlareChain from langchain.chains.graph_qa.base import GraphQAChain from langchain.chains.graph_qa.cypher import GraphCypherQAChain @@ -84,9 +85,9 @@ __all__ = [ "GraphCypherQAChain", "GraphQAChain", "GraphSparqlQAChain", + "HugeGraphQAChain", "HypotheticalDocumentEmbedder", "KuzuQAChain", - "HugeGraphQAChain", "LLMBashChain", "LLMChain", "LLMCheckerChain", @@ -95,6 +96,8 @@ __all__ = [ "LLMRouterChain", "LLMSummarizationCheckerChain", "MapReduceChain", + "MapReduceDocumentsChain", + "MapRerankDocumentsChain", "MultiPromptChain", "MultiRetrievalQAChain", "MultiRouteChain", @@ -105,6 +108,8 @@ __all__ = [ "PALChain", "QAGenerationChain", "QAWithSourcesChain", + "ReduceDocumentsChain", + "RefineDocumentsChain", "RetrievalQA", "RetrievalQAWithSourcesChain", "RouterChain", @@ -112,20 +117,17 @@ __all__ = [ "SQLDatabaseSequentialChain", "SequentialChain", "SimpleSequentialChain", + "StuffDocumentsChain", "TransformChain", "VectorDBQA", "VectorDBQAWithSourcesChain", + "create_citation_fuzzy_match_chain", "create_extraction_chain", "create_extraction_chain_pydantic", + "create_qa_with_sources_chain", + "create_qa_with_structure_chain", "create_tagging_chain", "create_tagging_chain_pydantic", + "generate_example", "load_chain", - "create_citation_fuzzy_match_chain", - "create_qa_with_structure_chain", - "create_qa_with_sources_chain", - "StuffDocumentsChain", - "MapRerankDocumentsChain", - "MapReduceDocumentsChain", - "RefineDocumentsChain", - "ReduceDocumentsChain", ] diff --git a/libs/langchain/langchain/chains/example_generator.py b/libs/langchain/langchain/chains/example_generator.py new file mode 100644 index 0000000000..c01cba4667 --- /dev/null +++ b/libs/langchain/langchain/chains/example_generator.py @@ -0,0 +1,22 @@ +from typing import List + +from langchain.chains.llm import LLMChain +from langchain.prompts.few_shot import FewShotPromptTemplate +from langchain.prompts.prompt import PromptTemplate +from langchain.schema.language_model import BaseLanguageModel + +TEST_GEN_TEMPLATE_SUFFIX = "Add another example." + + +def generate_example( + examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate +) -> str: + """Return another example given a list of examples for a prompt.""" + prompt = FewShotPromptTemplate( + examples=examples, + suffix=TEST_GEN_TEMPLATE_SUFFIX, + input_variables=[], + example_prompt=prompt_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + return chain.predict() diff --git a/libs/langchain/langchain/example_generator.py b/libs/langchain/langchain/example_generator.py index e825b9eb98..2d11be8bf1 100644 --- a/libs/langchain/langchain/example_generator.py +++ b/libs/langchain/langchain/example_generator.py @@ -1,23 +1,4 @@ -"""Utility functions for working with prompts.""" -from typing import List +"""Keep here for backwards compatibility.""" +from langchain.chains.example_generator import generate_example -from langchain.chains.llm import LLMChain -from langchain.prompts.few_shot import FewShotPromptTemplate -from langchain.prompts.prompt import PromptTemplate -from langchain.schema.language_model import BaseLanguageModel - -TEST_GEN_TEMPLATE_SUFFIX = "Add another example." - - -def generate_example( - examples: List[dict], llm: BaseLanguageModel, prompt_template: PromptTemplate -) -> str: - """Return another example given a list of examples for a prompt.""" - prompt = FewShotPromptTemplate( - examples=examples, - suffix=TEST_GEN_TEMPLATE_SUFFIX, - input_variables=[], - example_prompt=prompt_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - return chain.predict() +__all__ = ["generate_example"]