move base prompt to schema (#6995)

Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/7316/head
Harrison Chase 1 year ago committed by GitHub
parent 200be43da6
commit 60b05511d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -38,11 +38,11 @@ from langchain.llms import (
)
from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import (
BasePromptTemplate,
FewShotPromptTemplate,
Prompt,
PromptTemplate,
)
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.sql_database import SQLDatabase
from langchain.utilities.arxiv import ArxivAPIWrapper
from langchain.utilities.google_search import GoogleSearchAPIWrapper

@ -26,13 +26,13 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
AgentAction,
AgentFinish,
BaseOutputParser,
BasePromptTemplate,
OutputParserException,
)
from langchain.schema.messages import BaseMessage

@ -34,8 +34,8 @@ from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.schema import BasePromptTemplate
from langchain.tools.base import BaseTool
from langchain.tools.requests.tool import BaseRequestsTool

@ -19,7 +19,7 @@ from langchain.agents.types import AgentType
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema.messages import SystemMessage
from langchain.tools.python.tool import PythonAstREPLTool

@ -14,13 +14,12 @@ from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction
from langchain.schema import AgentAction, BasePromptTemplate
from langchain.tools.base import BaseTool

@ -16,17 +16,13 @@ from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
)
from langchain.schema import (
AgentAction,
BaseOutputParser,
)
from langchain.schema import AgentAction, BaseOutputParser, BasePromptTemplate
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.tools.base import BaseTool

@ -11,7 +11,6 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate,
@ -21,6 +20,7 @@ from langchain.prompts.chat import (
from langchain.schema import (
AgentAction,
AgentFinish,
BasePromptTemplate,
OutputParserException,
)
from langchain.schema.messages import (

@ -11,7 +11,6 @@ from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import ChatOpenAI
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
BaseMessagePromptTemplate,
ChatPromptTemplate,
@ -21,6 +20,7 @@ from langchain.prompts.chat import (
from langchain.schema import (
AgentAction,
AgentFinish,
BasePromptTemplate,
OutputParserException,
)
from langchain.schema.messages import (

@ -13,7 +13,7 @@ from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.tools.base import BaseTool

@ -10,7 +10,7 @@ from langchain.agents.self_ask_with_search.prompt import PROMPT
from langchain.agents.tools import Tool
from langchain.agents.utils import validate_tools_single_input
from langchain.base_language import BaseLanguageModel
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.tools.base import BaseTool
from langchain.utilities.google_serper import GoogleSerperAPIWrapper
from langchain.utilities.serpapi import SerpAPIWrapper

@ -11,13 +11,12 @@ from langchain.agents.structured_chat.prompt import FORMAT_INSTRUCTIONS, PREFIX,
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema import AgentAction
from langchain.schema import AgentAction, BasePromptTemplate
from langchain.tools import BaseTool
HUMAN_MESSAGE_TEMPLATE = "{input}\n\n{agent_scratchpad}"

@ -1,5 +1,4 @@
"""A tracer that runs evaluators over completed runs."""
import logging
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Optional, Sequence, Set, Union
from uuid import UUID
@ -9,8 +8,6 @@ from langchainplus_sdk import LangChainPlusClient, RunEvaluator
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
logger = logging.getLogger(__name__)
class EvaluatorCallbackHandler(BaseTracer):
"""A tracer that runs a run evaluator whenever a run is persisted.
@ -50,7 +47,7 @@ class EvaluatorCallbackHandler(BaseTracer):
max_workers: Optional[int] = None,
client: Optional[LangChainPlusClient] = None,
example_id: Optional[Union[UUID, str]] = None,
**kwargs: Any,
**kwargs: Any
) -> None:
super().__init__(**kwargs)
self.example_id = (
@ -63,17 +60,6 @@ class EvaluatorCallbackHandler(BaseTracer):
)
self.futures: Set[Future] = set()
def _evaluate_run(self, run: Run, evaluator: RunEvaluator) -> None:
try:
self.client.evaluate_run(run, evaluator)
except Exception as e:
logger.error(
f"Error evaluating run {run.id} with "
f"{evaluator.__class__.__name__}: {e}",
exc_info=True,
)
raise e
def _persist_run(self, run: Run) -> None:
"""Run the evaluator on the run.
@ -86,7 +72,9 @@ class EvaluatorCallbackHandler(BaseTracer):
run_ = run.copy()
run_.reference_example_id = self.example_id
for evaluator in self.evaluators:
self.futures.add(self.executor.submit(self._evaluate_run, run_, evaluator))
self.futures.add(
self.executor.submit(self.client.evaluate_run, run_, evaluator)
)
def wait_for_futures(self) -> None:
"""Wait for all futures to complete."""

@ -13,8 +13,8 @@ from langchain.callbacks.manager import (
from langchain.chains.api.prompt import API_RESPONSE_PROMPT, API_URL_PROMPT
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.requests import TextRequestsWrapper
from langchain.schema import BasePromptTemplate
class APIChain(Chain):

@ -11,7 +11,7 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter

@ -13,8 +13,8 @@ from langchain.chains.combine_documents.base import (
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
def _get_default_document_prompt() -> PromptTemplate:

@ -11,8 +11,8 @@ from langchain.chains.combine_documents.base import (
)
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
def _get_default_document_prompt() -> PromptTemplate:

@ -8,7 +8,7 @@ from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.constitutional_ai.principles import PRINCIPLES
from langchain.chains.constitutional_ai.prompts import CRITIQUE_PROMPT, REVISION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
class ConstitutionalChain(Chain):

@ -6,8 +6,7 @@ from pydantic import Extra, Field, root_validator
from langchain.chains.conversation.prompt import PROMPT
from langchain.chains.llm import LLMChain
from langchain.memory.buffer import ConversationBufferMemory
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseMemory
from langchain.schema import BaseMemory, BasePromptTemplate
class ConversationChain(LLMChain):

@ -21,8 +21,7 @@ from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.chains.question_answering import load_qa_chain
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseRetriever, Document
from langchain.schema import BasePromptTemplate, BaseRetriever, Document
from langchain.schema.messages import BaseMessage
from langchain.vectorstores.base import VectorStore

@ -19,8 +19,7 @@ from langchain.chains.flare.prompts import (
)
from langchain.chains.llm import LLMChain
from langchain.llms import OpenAI
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseRetriever, Generation
from langchain.schema import BasePromptTemplate, BaseRetriever, Generation
class _ResponseChain(LLMChain):

@ -11,7 +11,7 @@ from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import ENTITY_EXTRACTION_PROMPT, PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.networkx_graph import NetworkxEntityGraph, get_entities
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
class GraphQAChain(Chain):

@ -12,7 +12,7 @@ from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_GENERATION_PROMPT, CYPHER_QA_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.neo4j_graph import Neo4jGraph
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
INTERMEDIATE_STEPS_KEY = "intermediate_steps"

@ -11,7 +11,7 @@ from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, KUZU_GENERATION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.kuzu_graph import KuzuGraph
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
class KuzuQAChain(Chain):

@ -11,7 +11,7 @@ from langchain.chains.base import Chain
from langchain.chains.graph_qa.prompts import CYPHER_QA_PROMPT, NGQL_GENERATION_PROMPT
from langchain.chains.llm import LLMChain
from langchain.graphs.nebula_graph import NebulaGraph
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
class NebulaGraphQAChain(Chain):

@ -17,10 +17,10 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseLLMOutputParser,
BasePromptTemplate,
LLMResult,
NoOpOutputParser,
PromptValue,

@ -12,8 +12,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_bash.prompt import PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import OutputParserException
from langchain.schema import BasePromptTemplate, OutputParserException
from langchain.utilities.bash import BashProcess
logger = logging.getLogger(__name__)

@ -17,7 +17,7 @@ from langchain.callbacks.manager import (
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
class LLMMathChain(Chain):

@ -17,7 +17,7 @@ from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChai
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.text_splitter import TextSplitter

@ -7,7 +7,7 @@ import requests
from openapi_schema_pydantic import Parameter
from requests import Response
from langchain import BasePromptTemplate, LLMChain
from langchain import LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
@ -16,6 +16,7 @@ from langchain.chat_models import ChatOpenAI
from langchain.input import get_colored_text
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.tools import APIOperation
from langchain.utilities.openapi import OpenAPISpec

@ -15,7 +15,7 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.utilities import PythonREPL

@ -6,7 +6,7 @@ from pydantic import BaseModel, Field
from langchain.base_language import BaseLanguageModel
from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
class BasePromptSelector(BaseModel, ABC):

@ -10,7 +10,7 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter

@ -26,7 +26,7 @@ from langchain.chains.qa_with_sources.map_reduce_prompt import (
QUESTION_PROMPT,
)
from langchain.docstore.document import Document
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
class BaseQAWithSourcesChain(Chain, ABC):

@ -18,7 +18,7 @@ from langchain.chains.qa_with_sources import (
from langchain.chains.question_answering.map_rerank_prompt import (
PROMPT as MAP_RERANK_PROMPT,
)
from langchain.prompts.base import BasePromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate
class LoadingCallable(Protocol):

@ -4,7 +4,7 @@ from __future__ import annotations
import json
from typing import Any, Callable, List, Optional, Sequence
from langchain import BasePromptTemplate, FewShotPromptTemplate, LLMChain
from langchain import FewShotPromptTemplate, LLMChain
from langchain.base_language import BaseLanguageModel
from langchain.chains.query_constructor.ir import (
Comparator,
@ -23,7 +23,7 @@ from langchain.chains.query_constructor.prompt import (
)
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
class StructuredQueryOutputParser(BaseOutputParser[StructuredQuery]):

@ -18,7 +18,7 @@ from langchain.chains.question_answering import (
from langchain.chains.question_answering.map_rerank_prompt import (
PROMPT as MAP_RERANK_PROMPT,
)
from langchain.prompts.base import BasePromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate
class LoadingCallable(Protocol):

@ -13,8 +13,7 @@ from langchain.callbacks.manager import (
from langchain.chains import LLMChain
from langchain.chains.router.base import RouterChain
from langchain.output_parsers.json import parse_and_check_json_markdown
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
class LLMRouterChain(RouterChain):

@ -11,8 +11,8 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.sql_database import SQLDatabase
from langchain.tools.sql_database.prompt import QUERY_CHECKER

@ -8,7 +8,7 @@ from langchain.chains.combine_documents.refine import RefineDocumentsChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.llm import LLMChain
from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
class LoadingCallable(Protocol):

@ -8,8 +8,7 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
from langchain.chains.llm import LLMChain
from langchain.evaluation.criteria.prompt import PROMPT, PROMPT_WITH_REFERENCES
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseOutputParser
from langchain.schema import BaseOutputParser, BasePromptTemplate
_SUPPORTED_CRITERIA = {
"conciseness": "Is the submission concise and to the point?",

@ -19,20 +19,14 @@ def _parse_string_eval_output(text: str) -> dict:
Returns:
Any: The parsed output.
"""
splits = text.strip().rsplit("\n", maxsplit=1)
if len(splits) == 1:
verdict = splits[0]
reasoning = None
else:
reasoning, verdict = splits
reasoning = reasoning.strip()
reasoning, verdict = text.strip().rsplit("\n", maxsplit=1)
score = (
1
if verdict.upper() == "CORRECT"
else (0 if verdict.upper() == "INCORRECT" else None)
)
return {
"reasoning": reasoning,
"reasoning": reasoning.strip(),
"value": verdict,
"score": score,
}

@ -23,9 +23,8 @@ from langchain.evaluation.run_evaluators.base import (
RunEvaluatorInputMapper,
RunEvaluatorOutputParser,
)
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import OutputParserException
from langchain.schema import BasePromptTemplate, OutputParserException
from langchain.tools.base import BaseTool
_QA_PROMPTS = {

@ -13,7 +13,7 @@ from langchain.memory.prompt import (
ENTITY_SUMMARIZATION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema.messages import BaseMessage, get_buffer_string
logger = logging.getLogger(__name__)

@ -12,7 +12,7 @@ from langchain.memory.prompt import (
KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string

@ -8,9 +8,9 @@ from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import (
BaseChatMessageHistory,
BasePromptTemplate,
)
from langchain.schema.messages import BaseMessage, SystemMessage, get_buffer_string

@ -5,8 +5,7 @@ from typing import TypeVar
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.output_parsers.prompts import NAIVE_FIX_PROMPT
from langchain.prompts.base import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException
from langchain.schema import BaseOutputParser, BasePromptTemplate, OutputParserException
T = TypeVar("T")

@ -4,10 +4,10 @@ from typing import TypeVar
from langchain.base_language import BaseLanguageModel
from langchain.chains.llm import LLMChain
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BaseOutputParser,
BasePromptTemplate,
OutputParserException,
PromptValue,
)

@ -1,5 +1,5 @@
"""Prompt template classes."""
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
from langchain.prompts.base import StringPromptTemplate
from langchain.prompts.chat import (
AIMessagePromptTemplate,
BaseChatPromptTemplate,
@ -20,6 +20,7 @@ from langchain.prompts.few_shot_with_templates import FewShotPromptWithTemplates
from langchain.prompts.loading import load_prompt
from langchain.prompts.pipeline import PipelinePromptTemplate
from langchain.prompts.prompt import Prompt, PromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate
__all__ = [
"AIMessagePromptTemplate",

@ -1,18 +1,13 @@
"""BasePrompt schema definition."""
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Set, Union
import yaml
from pydantic import Field, root_validator
from abc import ABC
from typing import Any, Callable, Dict, List, Set
from langchain.formatting import formatter
from langchain.load.serializable import Serializable
from langchain.schema import BaseOutputParser, PromptValue
from langchain.schema import BasePromptTemplate
from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.schema.prompt import PromptValue
def jinja2_formatter(template: str, **kwargs: Any) -> str:
@ -110,133 +105,6 @@ class StringPromptValue(PromptValue):
return [HumanMessage(content=self.text)]
class BasePromptTemplate(Serializable, ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
@property
def lc_serializable(self) -> bool:
return True
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
prompt_dict["_type"] = self._prompt_type
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt should expose the format method, returning a prompt."""

@ -8,9 +8,10 @@ from typing import Any, Callable, List, Sequence, Tuple, Type, TypeVar, Union
from pydantic import Field, root_validator
from langchain.load.serializable import Serializable
from langchain.prompts.base import BasePromptTemplate, StringPromptTemplate
from langchain.prompts.base import StringPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
BasePromptTemplate,
PromptValue,
)
from langchain.schema.messages import (

@ -8,10 +8,9 @@ from typing import Union
import yaml
from langchain.output_parsers.regex import RegexParser
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import BaseLLMOutputParser, NoOpOutputParser
from langchain.schema import BaseLLMOutputParser, BasePromptTemplate, NoOpOutputParser
from langchain.utilities.loading import try_load_from_hub
URL_BASE = "https://raw.githubusercontent.com/hwchase17/langchain-hub/master/prompts/"

@ -2,9 +2,8 @@ from typing import Any, Dict, List, Tuple
from pydantic import root_validator
from langchain.prompts.base import BasePromptTemplate
from langchain.prompts.chat import BaseChatPromptTemplate
from langchain.schema import PromptValue
from langchain.schema import BasePromptTemplate, PromptValue
def _get_inputs(inputs: dict, input_variables: List[str]) -> dict:

@ -1,7 +1,7 @@
"""Filter that uses an LLM to drop documents that aren't relevant to the query."""
from typing import Any, Callable, Dict, Optional, Sequence
from langchain import BasePromptTemplate, LLMChain, PromptTemplate
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import Callbacks
from langchain.output_parsers.boolean import BooleanOutputParser
@ -9,7 +9,7 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompresso
from langchain.retrievers.document_compressors.chain_filter_prompt import (
prompt_template,
)
from langchain.schema import Document
from langchain.schema import BasePromptTemplate, Document
def _get_default_chain_prompt() -> PromptTemplate:

@ -28,6 +28,7 @@ from langchain.schema.output_parser import (
OutputParserException,
)
from langchain.schema.prompt import PromptValue
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.schema.retriever import BaseRetriever
RUN_KEY = "__run"
@ -64,4 +65,5 @@ __all__ = [
"NoOpOutputParser",
"BaseOutputParser",
"BaseLLMOutputParser",
"BasePromptTemplate",
]

@ -0,0 +1,139 @@
from __future__ import annotations
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
import yaml
from pydantic import Field, root_validator
from langchain.load.serializable import Serializable
from langchain.schema import BaseOutputParser, PromptValue
class BasePromptTemplate(Serializable, ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
"""A list of the names of the variables the prompt template expects."""
output_parser: Optional[BaseOutputParser] = None
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Union[str, Callable[[], str]]] = Field(
default_factory=dict
)
@property
def lc_serializable(self) -> bool:
return True
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""
@root_validator()
def validate_variable_names(cls, values: Dict) -> Dict:
"""Validate variable names do not include restricted names."""
if "stop" in values["input_variables"]:
raise ValueError(
"Cannot have an input variable named 'stop', as it is used internally,"
" please rename."
)
if "stop" in values["partial_variables"]:
raise ValueError(
"Cannot have an partial variable named 'stop', as it is used "
"internally, please rename."
)
overall = set(values["input_variables"]).intersection(
values["partial_variables"]
)
if overall:
raise ValueError(
f"Found overlapping input and partial variables: {overall}"
)
return values
def partial(self, **kwargs: Union[str, Callable[[], str]]) -> BasePromptTemplate:
"""Return a partial of the prompt template."""
prompt_dict = self.__dict__.copy()
prompt_dict["input_variables"] = list(
set(self.input_variables).difference(kwargs)
)
prompt_dict["partial_variables"] = {**self.partial_variables, **kwargs}
return type(self)(**prompt_dict)
def _merge_partial_and_user_variables(self, **kwargs: Any) -> Dict[str, Any]:
# Get partial params:
partial_kwargs = {
k: v if isinstance(v, str) else v()
for k, v in self.partial_variables.items()
}
return {**partial_kwargs, **kwargs}
@abstractmethod
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
@property
def _prompt_type(self) -> str:
"""Return the prompt type key."""
raise NotImplementedError
def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of prompt."""
prompt_dict = super().dict(**kwargs)
prompt_dict["_type"] = self._prompt_type
return prompt_dict
def save(self, file_path: Union[Path, str]) -> None:
"""Save the prompt.
Args:
file_path: Path to directory to save prompt to.
Example:
.. code-block:: python
prompt.save(file_path="path/prompt.yaml")
"""
if self.partial_variables:
raise ValueError("Cannot save prompt with partial variables.")
# Convert file to Path object.
if isinstance(file_path, str):
save_path = Path(file_path)
else:
save_path = file_path
directory_path = save_path.parent
directory_path.mkdir(parents=True, exist_ok=True)
# Fetch dictionary to save
prompt_dict = self.dict()
if save_path.suffix == ".json":
with open(file_path, "w") as f:
json.dump(prompt_dict, f, indent=4)
elif save_path.suffix == ".yaml":
with open(file_path, "w") as f:
yaml.dump(prompt_dict, f, default_flow_style=False)
else:
raise ValueError(f"{save_path} must be json or yaml")
Loading…
Cancel
Save