core[patch]: move some attr/methods to BaseLanguageModel (#18936)

Cleans up some shared code between `BaseLLM` and `BaseChatModel`. One
functional difference to make it more consistent (see comment)
This commit is contained in:
Erick Friis 2024-03-11 14:59:45 -07:00 committed by GitHub
parent 4ff6aa5c78
commit 0d888a65cb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 37 additions and 56 deletions

View File

@ -7,6 +7,7 @@ from typing import (
Any,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
@ -25,7 +26,7 @@ from langchain_core.messages import (
get_buffer_string,
)
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_core.runnables import Runnable, RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
@ -63,6 +64,12 @@ LanguageModelLike = Runnable[LanguageModelInput, LanguageModelOutput]
LanguageModelOutputVar = TypeVar("LanguageModelOutputVar", BaseMessage, str)
def _get_verbosity() -> bool:
from langchain_core.globals import get_verbose
return get_verbose()
class BaseLanguageModel(
RunnableSerializable[LanguageModelInput, LanguageModelOutputVar], ABC
):
@ -71,6 +78,28 @@ class BaseLanguageModel(
All language model wrappers inherit from BaseLanguageModel.
"""
cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
@ -257,6 +286,11 @@ class BaseLanguageModel(
Top model prediction as a message.
"""
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {}
def get_token_ids(self, text: str) -> List[int]:
"""Return the ordered ids of the tokens in a text.

View File

@ -54,12 +54,6 @@ if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
def _get_verbosity() -> bool:
from langchain_core.globals import get_verbose
return get_verbose()
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
"""Generate from a stream."""
@ -125,18 +119,8 @@ def _as_async_iterator(sync_iterator: Callable) -> Callable:
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
"""Base class for Chat models."""
cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED] Callback manager to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
@ -816,11 +800,6 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
_stop = list(stop)
return await self._call_async(messages, stop=_stop, **kwargs)
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {}
@property
@abstractmethod
def _llm_type(self) -> str:

View File

@ -1,4 +1,5 @@
"""Base interface for large language models to expose."""
from __future__ import annotations
import asyncio
@ -16,7 +17,6 @@ from typing import (
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
@ -56,19 +56,13 @@ from langchain_core.messages import (
)
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator, validator
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
from langchain_core.runnables.config import run_in_executor
logger = logging.getLogger(__name__)
def _get_verbosity() -> bool:
from langchain_core.globals import get_verbose
return get_verbose()
@functools.lru_cache
def _log_error_once(msg: str) -> None:
"""Log an error once."""
@ -200,16 +194,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
It should take in a prompt and return a string."""
cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
"""Metadata to add to the run trace."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""[DEPRECATED]"""
@ -229,17 +213,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
values["callbacks"] = values.pop("callback_manager", None)
return values
@validator("verbose", pre=True, always=True)
def set_verbose(cls, verbose: Optional[bool]) -> bool:
"""If verbose is None, set it.
This allows users to pass in None as verbose to access the global setting.
"""
if verbose is None:
return _get_verbosity()
else:
return verbose
# --- Runnable methods ---
@property
@ -1081,11 +1054,6 @@ class BaseLLM(BaseLanguageModel[str], ABC):
content = await self._call_async(text, stop=_stop, **kwargs)
return AIMessage(content=content)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {}
def __str__(self) -> str:
"""Get a string representation of the object for printing."""
cls_name = f"\033[1m{self.__class__.__name__}\033[0m"