You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/langchain/chat_models/base.py

242 lines
7.6 KiB
Python

import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Sequence
from pydantic import Extra, Field, root_validator
import langchain
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForLLMRun,
CallbackManager,
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatResult,
HumanMessage,
LLMResult,
PromptValue,
)
def _get_verbosity() -> bool:
return langchain.verbose
class BaseChatModel(BaseLanguageModel, ABC):
verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True)
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
@root_validator()
def raise_deprecation(cls, values: Dict) -> Dict:
"""Raise deprecation warning if callback_manager is used."""
if values.get("callback_manager") is not None:
warnings.warn(
"callback_manager is deprecated. Please use callbacks instead.",
DeprecationWarning,
)
values["callbacks"] = values.pop("callback_manager", None)
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
def generate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
callback_manager = CallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._generate(m, stop=stop)
for m in messages
]
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
return output
async def agenerate(
self,
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
params["stop"] = stop
callback_manager = AsyncCallbackManager.configure(
callbacks, self.callbacks, self.verbose
)
run_manager = await callback_manager.on_chat_model_start(
{"name": self.__class__.__name__}, messages, invocation_params=params
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
if new_arg_supported
else self._agenerate(m, stop=stop)
for m in messages
]
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
return output
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
@abstractmethod
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Top Level call"""
@abstractmethod
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
) -> ChatResult:
"""Top Level call"""
def __call__(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
) -> BaseMessage:
generation = self.generate(
[messages], stop=stop, callbacks=callbacks
).generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
return self.predict(message, stop=stop)
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = self([HumanMessage(content=text)], stop=_stop)
return result.content
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(messages, stop=_stop)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {}
@property
@abstractmethod
def _llm_type(self) -> str:
"""Return type of chat model."""
def dict(self, **kwargs: Any) -> Dict:
"""Return a dictionary of the LLM."""
starter_dict = dict(self._identifying_params)
starter_dict["_type"] = self._llm_type
return starter_dict
class SimpleChatModel(BaseChatModel):
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@abstractmethod
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
"""Simpler interface."""