forked from Archives/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
337 lines
10 KiB
Python
337 lines
10 KiB
Python
import asyncio
|
|
import inspect
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from functools import partial
|
|
from typing import Any, Dict, List, Mapping, Optional, Sequence
|
|
|
|
from pydantic import Extra, Field, root_validator
|
|
|
|
import langchain
|
|
from langchain.base_language import BaseLanguageModel
|
|
from langchain.callbacks.base import BaseCallbackManager
|
|
from langchain.callbacks.manager import (
|
|
AsyncCallbackManager,
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManager,
|
|
CallbackManagerForLLMRun,
|
|
Callbacks,
|
|
)
|
|
from langchain.load.dump import dumpd
|
|
from langchain.schema import (
|
|
AIMessage,
|
|
BaseMessage,
|
|
ChatGeneration,
|
|
ChatResult,
|
|
HumanMessage,
|
|
LLMResult,
|
|
PromptValue,
|
|
RunInfo,
|
|
)
|
|
|
|
|
|
def _get_verbosity() -> bool:
|
|
return langchain.verbose
|
|
|
|
|
|
class BaseChatModel(BaseLanguageModel, ABC):
|
|
verbose: bool = Field(default_factory=_get_verbosity)
|
|
"""Whether to print out response text."""
|
|
callbacks: Callbacks = Field(default=None, exclude=True)
|
|
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
|
tags: Optional[List[str]] = Field(default=None, exclude=True)
|
|
"""Tags to add to the run trace."""
|
|
|
|
@root_validator()
|
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
"""Raise deprecation warning if callback_manager is used."""
|
|
if values.get("callback_manager") is not None:
|
|
warnings.warn(
|
|
"callback_manager is deprecated. Please use callbacks instead.",
|
|
DeprecationWarning,
|
|
)
|
|
values["callbacks"] = values.pop("callback_manager", None)
|
|
return values
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
arbitrary_types_allowed = True
|
|
|
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
|
return {}
|
|
|
|
def generate(
|
|
self,
|
|
messages: List[List[BaseMessage]],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Top Level call"""
|
|
|
|
params = self.dict()
|
|
params["stop"] = stop
|
|
options = {"stop": stop}
|
|
|
|
callback_manager = CallbackManager.configure(
|
|
callbacks,
|
|
self.callbacks,
|
|
self.verbose,
|
|
tags,
|
|
self.tags,
|
|
)
|
|
run_manager = callback_manager.on_chat_model_start(
|
|
dumpd(self), messages, invocation_params=params, options=options
|
|
)
|
|
|
|
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
|
"run_manager"
|
|
)
|
|
try:
|
|
results = [
|
|
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
|
|
if new_arg_supported
|
|
else self._generate(m, stop=stop)
|
|
for m in messages
|
|
]
|
|
except (KeyboardInterrupt, Exception) as e:
|
|
run_manager.on_llm_error(e)
|
|
raise e
|
|
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
|
generations = [res.generations for res in results]
|
|
output = LLMResult(generations=generations, llm_output=llm_output)
|
|
run_manager.on_llm_end(output)
|
|
if run_manager:
|
|
output.run = RunInfo(run_id=run_manager.run_id)
|
|
return output
|
|
|
|
async def agenerate(
|
|
self,
|
|
messages: List[List[BaseMessage]],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Top Level call"""
|
|
params = self.dict()
|
|
params["stop"] = stop
|
|
options = {"stop": stop}
|
|
|
|
callback_manager = AsyncCallbackManager.configure(
|
|
callbacks,
|
|
self.callbacks,
|
|
self.verbose,
|
|
tags,
|
|
self.tags,
|
|
)
|
|
run_manager = await callback_manager.on_chat_model_start(
|
|
dumpd(self), messages, invocation_params=params, options=options
|
|
)
|
|
|
|
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
|
"run_manager"
|
|
)
|
|
try:
|
|
results = await asyncio.gather(
|
|
*[
|
|
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
|
|
if new_arg_supported
|
|
else self._agenerate(m, stop=stop)
|
|
for m in messages
|
|
]
|
|
)
|
|
except (KeyboardInterrupt, Exception) as e:
|
|
await run_manager.on_llm_error(e)
|
|
raise e
|
|
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
|
generations = [res.generations for res in results]
|
|
output = LLMResult(generations=generations, llm_output=llm_output)
|
|
await run_manager.on_llm_end(output)
|
|
if run_manager:
|
|
output.run = RunInfo(run_id=run_manager.run_id)
|
|
return output
|
|
|
|
def generate_prompt(
|
|
self,
|
|
prompts: List[PromptValue],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
prompt_messages = [p.to_messages() for p in prompts]
|
|
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
|
|
|
async def agenerate_prompt(
|
|
self,
|
|
prompts: List[PromptValue],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
prompt_messages = [p.to_messages() for p in prompts]
|
|
return await self.agenerate(
|
|
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
|
)
|
|
|
|
@abstractmethod
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Top Level call"""
|
|
|
|
@abstractmethod
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Top Level call"""
|
|
|
|
def __call__(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
generation = self.generate(
|
|
[messages], stop=stop, callbacks=callbacks, **kwargs
|
|
).generations[0][0]
|
|
if isinstance(generation, ChatGeneration):
|
|
return generation.message
|
|
else:
|
|
raise ValueError("Unexpected generation type")
|
|
|
|
async def _call_async(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
result = await self.agenerate(
|
|
[messages], stop=stop, callbacks=callbacks, **kwargs
|
|
)
|
|
generation = result.generations[0][0]
|
|
if isinstance(generation, ChatGeneration):
|
|
return generation.message
|
|
else:
|
|
raise ValueError("Unexpected generation type")
|
|
|
|
def call_as_llm(
|
|
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
|
) -> str:
|
|
return self.predict(message, stop=stop, **kwargs)
|
|
|
|
def predict(
|
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
|
) -> str:
|
|
if stop is None:
|
|
_stop = None
|
|
else:
|
|
_stop = list(stop)
|
|
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
|
return result.content
|
|
|
|
def predict_messages(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
*,
|
|
stop: Optional[Sequence[str]] = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
if stop is None:
|
|
_stop = None
|
|
else:
|
|
_stop = list(stop)
|
|
return self(messages, stop=_stop, **kwargs)
|
|
|
|
async def apredict(
|
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
|
) -> str:
|
|
if stop is None:
|
|
_stop = None
|
|
else:
|
|
_stop = list(stop)
|
|
result = await self._call_async(
|
|
[HumanMessage(content=text)], stop=_stop, **kwargs
|
|
)
|
|
return result.content
|
|
|
|
async def apredict_messages(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
*,
|
|
stop: Optional[Sequence[str]] = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
if stop is None:
|
|
_stop = None
|
|
else:
|
|
_stop = list(stop)
|
|
return await self._call_async(messages, stop=_stop, **kwargs)
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {}
|
|
|
|
@property
|
|
@abstractmethod
|
|
def _llm_type(self) -> str:
|
|
"""Return type of chat model."""
|
|
|
|
def dict(self, **kwargs: Any) -> Dict:
|
|
"""Return a dictionary of the LLM."""
|
|
starter_dict = dict(self._identifying_params)
|
|
starter_dict["_type"] = self._llm_type
|
|
return starter_dict
|
|
|
|
|
|
class SimpleChatModel(BaseChatModel):
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
|
message = AIMessage(content=output_str)
|
|
generation = ChatGeneration(message=message)
|
|
return ChatResult(generations=[generation])
|
|
|
|
@abstractmethod
|
|
def _call(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Simpler interface."""
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
func = partial(
|
|
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return await asyncio.get_event_loop().run_in_executor(None, func)
|