mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
960 lines
33 KiB
Python
960 lines
33 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import inspect
|
|
import uuid
|
|
import warnings
|
|
from abc import ABC, abstractmethod
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
AsyncIterator,
|
|
Callable,
|
|
Dict,
|
|
Iterator,
|
|
List,
|
|
Optional,
|
|
Sequence,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from langchain_core._api import deprecated
|
|
from langchain_core.caches import BaseCache
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManager,
|
|
AsyncCallbackManagerForLLMRun,
|
|
BaseCallbackManager,
|
|
CallbackManager,
|
|
CallbackManagerForLLMRun,
|
|
Callbacks,
|
|
)
|
|
from langchain_core.globals import get_llm_cache
|
|
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
|
|
from langchain_core.load import dumpd, dumps
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AnyMessage,
|
|
BaseMessage,
|
|
BaseMessageChunk,
|
|
HumanMessage,
|
|
convert_to_messages,
|
|
message_chunk_to_message,
|
|
)
|
|
from langchain_core.outputs import (
|
|
ChatGeneration,
|
|
ChatGenerationChunk,
|
|
ChatResult,
|
|
LLMResult,
|
|
RunInfo,
|
|
)
|
|
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
|
from langchain_core.pydantic_v1 import Field, root_validator
|
|
from langchain_core.runnables.config import ensure_config, run_in_executor
|
|
from langchain_core.tracers.log_stream import LogStreamCallbackHandler
|
|
|
|
if TYPE_CHECKING:
|
|
from langchain_core.pydantic_v1 import BaseModel
|
|
from langchain_core.runnables import Runnable, RunnableConfig
|
|
from langchain_core.tools import BaseTool
|
|
|
|
|
|
def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
|
|
"""Generate from a stream."""
|
|
|
|
generation: Optional[ChatGenerationChunk] = None
|
|
for chunk in stream:
|
|
if generation is None:
|
|
generation = chunk
|
|
else:
|
|
generation += chunk
|
|
assert generation is not None
|
|
return ChatResult(
|
|
generations=[
|
|
ChatGeneration(
|
|
message=message_chunk_to_message(generation.message),
|
|
generation_info=generation.generation_info,
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
async def agenerate_from_stream(
|
|
stream: AsyncIterator[ChatGenerationChunk],
|
|
) -> ChatResult:
|
|
"""Async generate from a stream."""
|
|
|
|
generation: Optional[ChatGenerationChunk] = None
|
|
async for chunk in stream:
|
|
if generation is None:
|
|
generation = chunk
|
|
else:
|
|
generation += chunk
|
|
assert generation is not None
|
|
return ChatResult(
|
|
generations=[
|
|
ChatGeneration(
|
|
message=message_chunk_to_message(generation.message),
|
|
generation_info=generation.generation_info,
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|
"""Base class for Chat models."""
|
|
|
|
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
|
|
"""[DEPRECATED] Callback manager to add to the run trace."""
|
|
|
|
@root_validator()
|
|
def raise_deprecation(cls, values: Dict) -> Dict:
|
|
"""Raise deprecation warning if callback_manager is used."""
|
|
if values.get("callback_manager") is not None:
|
|
warnings.warn(
|
|
"callback_manager is deprecated. Please use callbacks instead.",
|
|
DeprecationWarning,
|
|
)
|
|
values["callbacks"] = values.pop("callback_manager", None)
|
|
return values
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
# --- Runnable methods ---
|
|
|
|
@property
|
|
def OutputType(self) -> Any:
|
|
"""Get the output type for this runnable."""
|
|
return AnyMessage
|
|
|
|
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
|
|
if isinstance(input, PromptValue):
|
|
return input
|
|
elif isinstance(input, str):
|
|
return StringPromptValue(text=input)
|
|
elif isinstance(input, Sequence):
|
|
return ChatPromptValue(messages=convert_to_messages(input))
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid input type {type(input)}. "
|
|
"Must be a PromptValue, str, or list of BaseMessages."
|
|
)
|
|
|
|
def invoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
config = ensure_config(config)
|
|
return cast(
|
|
ChatGeneration,
|
|
self.generate_prompt(
|
|
[self._convert_input(input)],
|
|
stop=stop,
|
|
callbacks=config.get("callbacks"),
|
|
tags=config.get("tags"),
|
|
metadata=config.get("metadata"),
|
|
run_name=config.get("run_name"),
|
|
run_id=config.pop("run_id", None),
|
|
**kwargs,
|
|
).generations[0][0],
|
|
).message
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
config = ensure_config(config)
|
|
llm_result = await self.agenerate_prompt(
|
|
[self._convert_input(input)],
|
|
stop=stop,
|
|
callbacks=config.get("callbacks"),
|
|
tags=config.get("tags"),
|
|
metadata=config.get("metadata"),
|
|
run_name=config.get("run_name"),
|
|
run_id=config.pop("run_id", None),
|
|
**kwargs,
|
|
)
|
|
return cast(ChatGeneration, llm_result.generations[0][0]).message
|
|
|
|
def stream(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[BaseMessageChunk]:
|
|
if type(self)._stream == BaseChatModel._stream:
|
|
# model doesn't implement streaming, so use default implementation
|
|
yield cast(
|
|
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
|
|
)
|
|
else:
|
|
config = ensure_config(config)
|
|
messages = self._convert_input(input).to_messages()
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
options = {"stop": stop, **kwargs}
|
|
callback_manager = CallbackManager.configure(
|
|
config.get("callbacks"),
|
|
self.callbacks,
|
|
self.verbose,
|
|
config.get("tags"),
|
|
self.tags,
|
|
config.get("metadata"),
|
|
self.metadata,
|
|
)
|
|
(run_manager,) = callback_manager.on_chat_model_start(
|
|
dumpd(self),
|
|
[messages],
|
|
invocation_params=params,
|
|
options=options,
|
|
name=config.get("run_name"),
|
|
run_id=config.pop("run_id", None),
|
|
)
|
|
generation: Optional[ChatGenerationChunk] = None
|
|
try:
|
|
for chunk in self._stream(messages, stop=stop, **kwargs):
|
|
if chunk.message.id is None:
|
|
chunk.message.id = f"run-{run_manager.run_id}"
|
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
|
run_manager.on_llm_new_token(
|
|
cast(str, chunk.message.content), chunk=chunk
|
|
)
|
|
yield chunk.message
|
|
if generation is None:
|
|
generation = chunk
|
|
else:
|
|
generation += chunk
|
|
assert generation is not None
|
|
except BaseException as e:
|
|
run_manager.on_llm_error(
|
|
e,
|
|
response=LLMResult(
|
|
generations=[[generation]] if generation else []
|
|
),
|
|
)
|
|
raise e
|
|
else:
|
|
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
|
|
|
|
async def astream(
|
|
self,
|
|
input: LanguageModelInput,
|
|
config: Optional[RunnableConfig] = None,
|
|
*,
|
|
stop: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[BaseMessageChunk]:
|
|
if (
|
|
type(self)._astream is BaseChatModel._astream
|
|
and type(self)._stream is BaseChatModel._stream
|
|
):
|
|
# No async or sync stream is implemented, so fall back to ainvoke
|
|
yield cast(
|
|
BaseMessageChunk,
|
|
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
|
)
|
|
return
|
|
|
|
config = ensure_config(config)
|
|
messages = self._convert_input(input).to_messages()
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
options = {"stop": stop, **kwargs}
|
|
callback_manager = AsyncCallbackManager.configure(
|
|
config.get("callbacks"),
|
|
self.callbacks,
|
|
self.verbose,
|
|
config.get("tags"),
|
|
self.tags,
|
|
config.get("metadata"),
|
|
self.metadata,
|
|
)
|
|
(run_manager,) = await callback_manager.on_chat_model_start(
|
|
dumpd(self),
|
|
[messages],
|
|
invocation_params=params,
|
|
options=options,
|
|
name=config.get("run_name"),
|
|
run_id=config.pop("run_id", None),
|
|
)
|
|
|
|
generation: Optional[ChatGenerationChunk] = None
|
|
try:
|
|
async for chunk in self._astream(
|
|
messages,
|
|
stop=stop,
|
|
**kwargs,
|
|
):
|
|
if chunk.message.id is None:
|
|
chunk.message.id = f"run-{run_manager.run_id}"
|
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
|
await run_manager.on_llm_new_token(
|
|
cast(str, chunk.message.content), chunk=chunk
|
|
)
|
|
yield chunk.message
|
|
if generation is None:
|
|
generation = chunk
|
|
else:
|
|
generation += chunk
|
|
assert generation is not None
|
|
except BaseException as e:
|
|
await run_manager.on_llm_error(
|
|
e,
|
|
response=LLMResult(generations=[[generation]] if generation else []),
|
|
)
|
|
raise e
|
|
else:
|
|
await run_manager.on_llm_end(
|
|
LLMResult(generations=[[generation]]),
|
|
)
|
|
|
|
# --- Custom methods ---
|
|
|
|
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
|
return {}
|
|
|
|
def _get_invocation_params(
|
|
self,
|
|
stop: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> dict:
|
|
params = self.dict()
|
|
params["stop"] = stop
|
|
return {**params, **kwargs}
|
|
|
|
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
|
if self.is_lc_serializable():
|
|
params = {**kwargs, **{"stop": stop}}
|
|
param_string = str(sorted([(k, v) for k, v in params.items()]))
|
|
llm_string = dumps(self)
|
|
return llm_string + "---" + param_string
|
|
else:
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
params = {**params, **kwargs}
|
|
return str(sorted([(k, v) for k, v in params.items()]))
|
|
|
|
def generate(
|
|
self,
|
|
messages: List[List[BaseMessage]],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
run_name: Optional[str] = None,
|
|
run_id: Optional[uuid.UUID] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Pass a sequence of prompts to the model and return model generations.
|
|
|
|
This method should make use of batched calls for models that expose a batched
|
|
API.
|
|
|
|
Use this method when you want to:
|
|
1. take advantage of batched calls,
|
|
2. need more output from the model than just the top generated value,
|
|
3. are building chains that are agnostic to the underlying language model
|
|
type (e.g., pure text completion models vs chat models).
|
|
|
|
Args:
|
|
messages: List of list of messages.
|
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
first occurrence of any of these substrings.
|
|
callbacks: Callbacks to pass through. Used for executing additional
|
|
functionality, such as logging or streaming, throughout generation.
|
|
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
|
to the model provider API call.
|
|
|
|
Returns:
|
|
An LLMResult, which contains a list of candidate Generations for each input
|
|
prompt and additional model provider-specific output.
|
|
"""
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
options = {"stop": stop}
|
|
|
|
callback_manager = CallbackManager.configure(
|
|
callbacks,
|
|
self.callbacks,
|
|
self.verbose,
|
|
tags,
|
|
self.tags,
|
|
metadata,
|
|
self.metadata,
|
|
)
|
|
run_managers = callback_manager.on_chat_model_start(
|
|
dumpd(self),
|
|
messages,
|
|
invocation_params=params,
|
|
options=options,
|
|
name=run_name,
|
|
run_id=run_id,
|
|
)
|
|
results = []
|
|
for i, m in enumerate(messages):
|
|
try:
|
|
results.append(
|
|
self._generate_with_cache(
|
|
m,
|
|
stop=stop,
|
|
run_manager=run_managers[i] if run_managers else None,
|
|
**kwargs,
|
|
)
|
|
)
|
|
except BaseException as e:
|
|
if run_managers:
|
|
run_managers[i].on_llm_error(e, response=LLMResult(generations=[]))
|
|
raise e
|
|
flattened_outputs = [
|
|
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item]
|
|
for res in results
|
|
]
|
|
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
|
generations = [res.generations for res in results]
|
|
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type]
|
|
if run_managers:
|
|
run_infos = []
|
|
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
|
manager.on_llm_end(flattened_output)
|
|
run_infos.append(RunInfo(run_id=manager.run_id))
|
|
output.run = run_infos
|
|
return output
|
|
|
|
async def agenerate(
|
|
self,
|
|
messages: List[List[BaseMessage]],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
*,
|
|
tags: Optional[List[str]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
run_name: Optional[str] = None,
|
|
run_id: Optional[uuid.UUID] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Asynchronously pass a sequence of prompts to a model and return generations.
|
|
|
|
This method should make use of batched calls for models that expose a batched
|
|
API.
|
|
|
|
Use this method when you want to:
|
|
1. take advantage of batched calls,
|
|
2. need more output from the model than just the top generated value,
|
|
3. are building chains that are agnostic to the underlying language model
|
|
type (e.g., pure text completion models vs chat models).
|
|
|
|
Args:
|
|
messages: List of list of messages.
|
|
stop: Stop words to use when generating. Model output is cut off at the
|
|
first occurrence of any of these substrings.
|
|
callbacks: Callbacks to pass through. Used for executing additional
|
|
functionality, such as logging or streaming, throughout generation.
|
|
**kwargs: Arbitrary additional keyword arguments. These are usually passed
|
|
to the model provider API call.
|
|
|
|
Returns:
|
|
An LLMResult, which contains a list of candidate Generations for each input
|
|
prompt and additional model provider-specific output.
|
|
"""
|
|
params = self._get_invocation_params(stop=stop, **kwargs)
|
|
options = {"stop": stop}
|
|
|
|
callback_manager = AsyncCallbackManager.configure(
|
|
callbacks,
|
|
self.callbacks,
|
|
self.verbose,
|
|
tags,
|
|
self.tags,
|
|
metadata,
|
|
self.metadata,
|
|
)
|
|
|
|
run_managers = await callback_manager.on_chat_model_start(
|
|
dumpd(self),
|
|
messages,
|
|
invocation_params=params,
|
|
options=options,
|
|
name=run_name,
|
|
run_id=run_id,
|
|
)
|
|
|
|
results = await asyncio.gather(
|
|
*[
|
|
self._agenerate_with_cache(
|
|
m,
|
|
stop=stop,
|
|
run_manager=run_managers[i] if run_managers else None,
|
|
**kwargs,
|
|
)
|
|
for i, m in enumerate(messages)
|
|
],
|
|
return_exceptions=True,
|
|
)
|
|
exceptions = []
|
|
for i, res in enumerate(results):
|
|
if isinstance(res, BaseException):
|
|
if run_managers:
|
|
await run_managers[i].on_llm_error(
|
|
res, response=LLMResult(generations=[])
|
|
)
|
|
exceptions.append(res)
|
|
if exceptions:
|
|
if run_managers:
|
|
await asyncio.gather(
|
|
*[
|
|
run_manager.on_llm_end(
|
|
LLMResult(
|
|
generations=[res.generations], # type: ignore[list-item, union-attr]
|
|
llm_output=res.llm_output, # type: ignore[list-item, union-attr]
|
|
)
|
|
)
|
|
for run_manager, res in zip(run_managers, results)
|
|
if not isinstance(res, Exception)
|
|
]
|
|
)
|
|
raise exceptions[0]
|
|
flattened_outputs = [
|
|
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item, union-attr]
|
|
for res in results
|
|
]
|
|
llm_output = self._combine_llm_outputs([res.llm_output for res in results]) # type: ignore[union-attr]
|
|
generations = [res.generations for res in results] # type: ignore[union-attr]
|
|
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type]
|
|
await asyncio.gather(
|
|
*[
|
|
run_manager.on_llm_end(flattened_output)
|
|
for run_manager, flattened_output in zip(
|
|
run_managers, flattened_outputs
|
|
)
|
|
]
|
|
)
|
|
if run_managers:
|
|
output.run = [
|
|
RunInfo(run_id=run_manager.run_id) for run_manager in run_managers
|
|
]
|
|
return output
|
|
|
|
def generate_prompt(
|
|
self,
|
|
prompts: List[PromptValue],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
prompt_messages = [p.to_messages() for p in prompts]
|
|
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
|
|
|
|
async def agenerate_prompt(
|
|
self,
|
|
prompts: List[PromptValue],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
prompt_messages = [p.to_messages() for p in prompts]
|
|
return await self.agenerate(
|
|
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
|
|
)
|
|
|
|
def _generate_with_cache(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
if isinstance(self.cache, BaseCache):
|
|
llm_cache = self.cache
|
|
else:
|
|
llm_cache = get_llm_cache()
|
|
# We should check the cache unless it's explicitly set to False
|
|
# A None cache means we should use the default global cache
|
|
# if it's configured.
|
|
check_cache = self.cache or self.cache is None
|
|
if check_cache:
|
|
if llm_cache:
|
|
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
|
prompt = dumps(messages)
|
|
cache_val = llm_cache.lookup(prompt, llm_string)
|
|
if isinstance(cache_val, list):
|
|
return ChatResult(generations=cache_val)
|
|
elif self.cache is None:
|
|
pass
|
|
else:
|
|
raise ValueError(
|
|
"Asked to cache, but no cache found at `langchain.cache`."
|
|
)
|
|
# If stream is not explicitly set, check if implicitly requested by
|
|
# astream_events() or astream_log(). Bail out if _stream not implemented
|
|
if type(self)._stream != BaseChatModel._stream and kwargs.pop(
|
|
"stream",
|
|
(
|
|
next(
|
|
(
|
|
True
|
|
for h in run_manager.handlers
|
|
if isinstance(h, LogStreamCallbackHandler)
|
|
),
|
|
False,
|
|
)
|
|
if run_manager
|
|
else False
|
|
),
|
|
):
|
|
chunks: List[ChatGenerationChunk] = []
|
|
for chunk in self._stream(messages, stop=stop, **kwargs):
|
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
|
if run_manager:
|
|
if chunk.message.id is None:
|
|
chunk.message.id = f"run-{run_manager.run_id}"
|
|
run_manager.on_llm_new_token(
|
|
cast(str, chunk.message.content), chunk=chunk
|
|
)
|
|
chunks.append(chunk)
|
|
result = generate_from_stream(iter(chunks))
|
|
else:
|
|
if inspect.signature(self._generate).parameters.get("run_manager"):
|
|
result = self._generate(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
else:
|
|
result = self._generate(messages, stop=stop, **kwargs)
|
|
|
|
# Add response metadata to each generation
|
|
for idx, generation in enumerate(result.generations):
|
|
if run_manager and generation.message.id is None:
|
|
generation.message.id = f"run-{run_manager.run_id}-{idx}"
|
|
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
|
generation
|
|
)
|
|
if len(result.generations) == 1 and result.llm_output is not None:
|
|
result.generations[0].message.response_metadata = {
|
|
**result.llm_output,
|
|
**result.generations[0].message.response_metadata,
|
|
}
|
|
if check_cache and llm_cache:
|
|
llm_cache.update(prompt, llm_string, result.generations)
|
|
return result
|
|
|
|
async def _agenerate_with_cache(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
if isinstance(self.cache, BaseCache):
|
|
llm_cache = self.cache
|
|
else:
|
|
llm_cache = get_llm_cache()
|
|
# We should check the cache unless it's explicitly set to False
|
|
# A None cache means we should use the default global cache
|
|
# if it's configured.
|
|
check_cache = self.cache or self.cache is None
|
|
if check_cache:
|
|
if llm_cache:
|
|
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
|
prompt = dumps(messages)
|
|
cache_val = await llm_cache.alookup(prompt, llm_string)
|
|
if isinstance(cache_val, list):
|
|
return ChatResult(generations=cache_val)
|
|
elif self.cache is None:
|
|
pass
|
|
else:
|
|
raise ValueError(
|
|
"Asked to cache, but no cache found at `langchain.cache`."
|
|
)
|
|
# If stream is not explicitly set, check if implicitly requested by
|
|
# astream_events() or astream_log(). Bail out if _astream not implemented
|
|
if (
|
|
type(self)._astream != BaseChatModel._astream
|
|
or type(self)._stream != BaseChatModel._stream
|
|
) and kwargs.pop(
|
|
"stream",
|
|
(
|
|
next(
|
|
(
|
|
True
|
|
for h in run_manager.handlers
|
|
if isinstance(h, LogStreamCallbackHandler)
|
|
),
|
|
False,
|
|
)
|
|
if run_manager
|
|
else False
|
|
),
|
|
):
|
|
chunks: List[ChatGenerationChunk] = []
|
|
async for chunk in self._astream(messages, stop=stop, **kwargs):
|
|
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
|
if run_manager:
|
|
if chunk.message.id is None:
|
|
chunk.message.id = f"run-{run_manager.run_id}"
|
|
await run_manager.on_llm_new_token(
|
|
cast(str, chunk.message.content), chunk=chunk
|
|
)
|
|
chunks.append(chunk)
|
|
result = generate_from_stream(iter(chunks))
|
|
else:
|
|
if inspect.signature(self._agenerate).parameters.get("run_manager"):
|
|
result = await self._agenerate(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
else:
|
|
result = await self._agenerate(messages, stop=stop, **kwargs)
|
|
|
|
# Add response metadata to each generation
|
|
for idx, generation in enumerate(result.generations):
|
|
if run_manager and generation.message.id is None:
|
|
generation.message.id = f"run-{run_manager.run_id}-{idx}"
|
|
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
|
generation
|
|
)
|
|
if len(result.generations) == 1 and result.llm_output is not None:
|
|
result.generations[0].message.response_metadata = {
|
|
**result.llm_output,
|
|
**result.generations[0].message.response_metadata,
|
|
}
|
|
if check_cache and llm_cache:
|
|
await llm_cache.aupdate(prompt, llm_string, result.generations)
|
|
return result
|
|
|
|
@abstractmethod
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Top Level call"""
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Top Level call"""
|
|
return await run_in_executor(
|
|
None,
|
|
self._generate,
|
|
messages,
|
|
stop,
|
|
run_manager.get_sync() if run_manager else None,
|
|
**kwargs,
|
|
)
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
raise NotImplementedError()
|
|
|
|
async def _astream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
iterator = await run_in_executor(
|
|
None,
|
|
self._stream,
|
|
messages,
|
|
stop,
|
|
run_manager.get_sync() if run_manager else None,
|
|
**kwargs,
|
|
)
|
|
done = object()
|
|
while True:
|
|
item = await run_in_executor(
|
|
None,
|
|
next,
|
|
iterator,
|
|
done, # type: ignore[call-arg, arg-type]
|
|
)
|
|
if item is done:
|
|
break
|
|
yield item # type: ignore[misc]
|
|
|
|
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
|
|
def __call__(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
generation = self.generate(
|
|
[messages], stop=stop, callbacks=callbacks, **kwargs
|
|
).generations[0][0]
|
|
if isinstance(generation, ChatGeneration):
|
|
return generation.message
|
|
else:
|
|
raise ValueError("Unexpected generation type")
|
|
|
|
async def _call_async(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
callbacks: Callbacks = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
result = await self.agenerate(
|
|
[messages], stop=stop, callbacks=callbacks, **kwargs
|
|
)
|
|
generation = result.generations[0][0]
|
|
if isinstance(generation, ChatGeneration):
|
|
return generation.message
|
|
else:
|
|
raise ValueError("Unexpected generation type")
|
|
|
|
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
|
|
def call_as_llm(
|
|
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
|
|
) -> str:
|
|
return self.predict(message, stop=stop, **kwargs)
|
|
|
|
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
|
|
def predict(
|
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
|
) -> str:
|
|
if stop is None:
|
|
_stop = None
|
|
else:
|
|
_stop = list(stop)
|
|
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
|
|
if isinstance(result.content, str):
|
|
return result.content
|
|
else:
|
|
raise ValueError("Cannot use predict when output is not a string.")
|
|
|
|
@deprecated("0.1.7", alternative="invoke", removal="0.2.0")
|
|
def predict_messages(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
*,
|
|
stop: Optional[Sequence[str]] = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
if stop is None:
|
|
_stop = None
|
|
else:
|
|
_stop = list(stop)
|
|
return self(messages, stop=_stop, **kwargs)
|
|
|
|
@deprecated("0.1.7", alternative="ainvoke", removal="0.2.0")
|
|
async def apredict(
|
|
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
|
|
) -> str:
|
|
if stop is None:
|
|
_stop = None
|
|
else:
|
|
_stop = list(stop)
|
|
result = await self._call_async(
|
|
[HumanMessage(content=text)], stop=_stop, **kwargs
|
|
)
|
|
if isinstance(result.content, str):
|
|
return result.content
|
|
else:
|
|
raise ValueError("Cannot use predict when output is not a string.")
|
|
|
|
@deprecated("0.1.7", alternative="ainvoke", removal="0.2.0")
|
|
async def apredict_messages(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
*,
|
|
stop: Optional[Sequence[str]] = None,
|
|
**kwargs: Any,
|
|
) -> BaseMessage:
|
|
if stop is None:
|
|
_stop = None
|
|
else:
|
|
_stop = list(stop)
|
|
return await self._call_async(messages, stop=_stop, **kwargs)
|
|
|
|
@property
|
|
@abstractmethod
|
|
def _llm_type(self) -> str:
|
|
"""Return type of chat model."""
|
|
|
|
def dict(self, **kwargs: Any) -> Dict:
|
|
"""Return a dictionary of the LLM."""
|
|
starter_dict = dict(self._identifying_params)
|
|
starter_dict["_type"] = self._llm_type
|
|
return starter_dict
|
|
|
|
def bind_tools(
|
|
self,
|
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
|
**kwargs: Any,
|
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class SimpleChatModel(BaseChatModel):
|
|
"""A simplified implementation for a chat model to inherit from."""
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
|
message = AIMessage(content=output_str)
|
|
generation = ChatGeneration(message=message)
|
|
return ChatResult(generations=[generation])
|
|
|
|
@abstractmethod
|
|
def _call(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Simpler interface."""
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
return await run_in_executor(
|
|
None,
|
|
self._generate,
|
|
messages,
|
|
stop=stop,
|
|
run_manager=run_manager.get_sync() if run_manager else None,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def _gen_info_and_msg_metadata(
|
|
generation: Union[ChatGeneration, ChatGenerationChunk],
|
|
) -> dict:
|
|
return {
|
|
**(generation.generation_info or {}),
|
|
**generation.message.response_metadata,
|
|
}
|