forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Callable, List, Tuple
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.chat_models.base import BaseChatModel
|
|
from langchain.llms.base import BaseLLM
|
|
from langchain.prompts.base import BasePromptTemplate
|
|
|
|
|
|
class BasePromptSelector(BaseModel, ABC):
|
|
@abstractmethod
|
|
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
|
"""Get default prompt for a language model."""
|
|
|
|
|
|
class ConditionalPromptSelector(BasePromptSelector):
|
|
"""Prompt collection that goes through conditionals."""
|
|
|
|
default_prompt: BasePromptTemplate
|
|
conditionals: List[
|
|
Tuple[Callable[[BaseLanguageModel], bool], BasePromptTemplate]
|
|
] = Field(default_factory=list)
|
|
|
|
def get_prompt(self, llm: BaseLanguageModel) -> BasePromptTemplate:
|
|
for condition, prompt in self.conditionals:
|
|
if condition(llm):
|
|
return prompt
|
|
return self.default_prompt
|
|
|
|
|
|
def is_llm(llm: BaseLanguageModel) -> bool:
|
|
return isinstance(llm, BaseLLM)
|
|
|
|
|
|
def is_chat_model(llm: BaseLanguageModel) -> bool:
|
|
return isinstance(llm, BaseChatModel)
|