Runnable single protocol (#7800)

Objects implementing Runnable: BasePromptTemplate, LLM, ChatModel,
Chain, Retriever, OutputParser

- [x] Implement Runnable in base Retriever
- [x] Raise TypeError in operator methods for unsupported things 
- [x] Implement dict which calls values in parallel and outputs dict
with results
- [x] Merge in `+` for prompts
- [x] Confirm precedence order for operators, ideal would be `+` `|`,
https://docs.python.org/3/reference/expressions.html#operator-precedence
- [x] Add support for openai functions, ie. Chat Models must return
messages
- [x] Implement BaseMessageChunk return type for BaseChatModel, a
subclass of BaseMessage which implements __add__ to return
BaseMessageChunk, concatenating all str args
- [x] Update implementation of stream/astream for llm and chat models to
use new `_stream`, `_astream` optional methods, with default
implementation in base class `raise NotImplementedError` use
https://stackoverflow.com/a/59762827 to see if it is implemented in base
class
- [x] Delete the IteratorCallbackHandler (leave the async one because
people using)
- [x] Make BaseLLMOutputParser implement Runnable, accepting either str
or BaseMessage
---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
pull/8275/head^2
Nuno Campos 1 year ago committed by GitHub
parent 04a4d3e312
commit a612800ef0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,10 +9,7 @@ from typing import (
Tuple,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
ChatGeneration,
@ -116,15 +113,6 @@ class ChatLlamaAPI(BaseChatModel):
generations.append(gen)
return ChatResult(generations=generations)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError
@property
def _client_params(self) -> Mapping[str, Any]:
"""Get the parameters used for the client."""

@ -1,13 +1,14 @@
"""Base callback handler that can be used to handle callbacks in langchain."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from uuid import UUID
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
if TYPE_CHECKING:
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import Document
from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
class RetrieverManagerMixin:
@ -543,3 +544,6 @@ class BaseCallbackManager(CallbackManagerMixin):
for key in keys:
self.metadata.pop(key)
self.inheritable_metadata.pop(key)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]

@ -4,6 +4,7 @@ import asyncio
import functools
import logging
import os
import uuid
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from typing import (
@ -20,12 +21,13 @@ from typing import (
Union,
cast,
)
from uuid import UUID, uuid4
from uuid import UUID
import langchain
from langchain.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
Callbacks,
ChainManagerMixin,
LLMManagerMixin,
RetrieverManagerMixin,
@ -50,7 +52,6 @@ if TYPE_CHECKING:
from langsmith import Client as LangSmithClient
logger = logging.getLogger(__name__)
Callbacks = Optional[Union[List[BaseCallbackHandler], BaseCallbackManager]]
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
@ -437,7 +438,7 @@ class BaseRunManager(RunManagerMixin):
BaseRunManager: The noop manager.
"""
return cls(
run_id=uuid4(),
run_id=uuid.uuid4(),
handlers=[],
inheritable_handlers=[],
tags=[],
@ -1024,7 +1025,7 @@ class CallbackManager(BaseCallbackManager):
"""
managers = []
for prompt in prompts:
run_id_ = uuid4()
run_id_ = uuid.uuid4()
_handle_event(
self.handlers,
"on_llm_start",
@ -1073,7 +1074,7 @@ class CallbackManager(BaseCallbackManager):
managers = []
for message_list in messages:
run_id_ = uuid4()
run_id_ = uuid.uuid4()
_handle_event(
self.handlers,
"on_chat_model_start",
@ -1120,7 +1121,7 @@ class CallbackManager(BaseCallbackManager):
CallbackManagerForChainRun: The callback manager for the chain run.
"""
if run_id is None:
run_id = uuid4()
run_id = uuid.uuid4()
_handle_event(
self.handlers,
@ -1166,7 +1167,7 @@ class CallbackManager(BaseCallbackManager):
CallbackManagerForToolRun: The callback manager for the tool run.
"""
if run_id is None:
run_id = uuid4()
run_id = uuid.uuid4()
_handle_event(
self.handlers,
@ -1202,7 +1203,7 @@ class CallbackManager(BaseCallbackManager):
) -> CallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
if run_id is None:
run_id = uuid4()
run_id = uuid.uuid4()
_handle_event(
self.handlers,
@ -1302,7 +1303,7 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = []
for prompt in prompts:
run_id_ = uuid4()
run_id_ = uuid.uuid4()
tasks.append(
_ahandle_event(
@ -1341,7 +1342,7 @@ class AsyncCallbackManager(BaseCallbackManager):
serialized: Dict[str, Any],
messages: List[List[BaseMessage]],
**kwargs: Any,
) -> Any:
) -> List[AsyncCallbackManagerForLLMRun]:
"""Run when LLM starts running.
Args:
@ -1358,7 +1359,7 @@ class AsyncCallbackManager(BaseCallbackManager):
managers = []
for message_list in messages:
run_id_ = uuid4()
run_id_ = uuid.uuid4()
tasks.append(
_ahandle_event(
@ -1410,7 +1411,7 @@ class AsyncCallbackManager(BaseCallbackManager):
for the chain run.
"""
if run_id is None:
run_id = uuid4()
run_id = uuid.uuid4()
await _ahandle_event(
self.handlers,
@ -1458,7 +1459,7 @@ class AsyncCallbackManager(BaseCallbackManager):
for the tool run.
"""
if run_id is None:
run_id = uuid4()
run_id = uuid.uuid4()
await _ahandle_event(
self.handlers,
@ -1494,7 +1495,7 @@ class AsyncCallbackManager(BaseCallbackManager):
) -> AsyncCallbackManagerForRetrieverRun:
"""Run when retriever starts running."""
if run_id is None:
run_id = uuid4()
run_id = uuid.uuid4()
await _ahandle_event(
self.handlers,

@ -4,7 +4,7 @@ import asyncio
from typing import Any, AsyncIterator, Dict, List, Literal, Union, cast
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema import LLMResult
from langchain.schema.output import LLMResult
# TODO If used by two LLM runs in parallel this won't work as expected

@ -22,6 +22,7 @@ from langchain.callbacks.manager import (
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
from langchain.schema.runnable import Runnable, RunnableConfig
logger = logging.getLogger(__name__)
@ -30,7 +31,7 @@ def _get_verbosity() -> bool:
return langchain.verbose
class Chain(Serializable, ABC):
class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
"""Abstract base class for creating structured sequences of calls to components.
Chains should be used to encode a sequence of calls to components like
@ -53,6 +54,20 @@ class Chain(Serializable, ABC):
chains and cannot return as rich of an output as `__call__`.
"""
def invoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
return self(input, **(config or {}))
async def ainvoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
if type(self)._acall == Chain._acall:
# If the chain does not implement async, fall back to default implementation
return await super().ainvoke(input, config)
return await self.acall(input, **(config or {}))
memory: Optional[BaseMemory] = None
"""Optional memory object. Defaults to None.
Memory is a class that gets called at the start

@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -12,11 +12,13 @@ from langchain.schema import (
)
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
@ -94,29 +96,64 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
text.rstrip()
) # trim off the trailing ' ' that might come from the "Assistant: "
def _generate(
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
) -> Iterator[ChatGenerationChunk]:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
stream_resp = self.client.completions.create(**params, stream=True)
for data in stream_resp:
delta = data.completion
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
stream_resp = await self.async_client.completions.create(**params, stream=True)
async for data in stream_resp:
delta = data.completion
yield ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
await run_manager.on_llm_new_token(delta)
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
completion = ""
stream_resp = self.client.completions.create(**params, stream=True)
for data in stream_resp:
delta = data.completion
completion += delta
if run_manager:
run_manager.on_llm_new_token(
delta,
)
for chunk in self._stream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = self.client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
@ -129,24 +166,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
if self.streaming:
completion = ""
stream_resp = await self.async_client.completions.create(
**params, stream=True
)
async for data in stream_resp:
delta = data.completion
completion += delta
if run_manager:
await run_manager.on_llm_new_token(
delta,
)
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
completion += chunk.text
else:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {
"prompt": prompt,
**self._default_params,
**kwargs,
}
if stop:
params["stop_sequences"] = stop
response = await self.async_client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)

@ -3,7 +3,16 @@ import inspect
import warnings
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Sequence
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
cast,
)
from pydantic import Field, root_validator
@ -17,6 +26,8 @@ from langchain.callbacks.manager import (
Callbacks,
)
from langchain.load.dump import dumpd, dumps
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import (
ChatGeneration,
ChatResult,
@ -24,17 +35,22 @@ from langchain.schema import (
PromptValue,
RunInfo,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import AIMessage, BaseMessage, HumanMessage
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
def _get_verbosity() -> bool:
return langchain.verbose
class BaseChatModel(BaseLanguageModel, ABC):
"""Base class for chat models."""
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)
@ -64,6 +80,154 @@ class BaseChatModel(BaseLanguageModel, ABC):
arbitrary_types_allowed = True
# --- Runnable methods ---
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, list):
return ChatPromptValue(messages=input)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> BaseMessageChunk:
return cast(
BaseMessageChunk,
cast(
ChatGeneration,
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
).generations[0][0],
).message,
)
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> BaseMessageChunk:
if type(self)._agenerate == BaseChatModel._agenerate:
# model doesn't implement async generation, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop)
)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
)
return cast(
BaseMessageChunk, cast(ChatGeneration, llm_result.generations[0][0]).message
)
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options
)
try:
message: Optional[BaseMessageChunk] = None
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if message is None:
message = chunk.message
else:
message += chunk.message
assert message is not None
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(
LLMResult(generations=[[ChatGeneration(message=message)]]),
)
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
if type(self)._astream == BaseChatModel._astream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_chat_model_start(
dumpd(self), [messages], invocation_params=params, options=options
)
try:
message: Optional[BaseMessageChunk] = None
async for chunk in self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if message is None:
message = chunk.message
else:
message += chunk.message
assert message is not None
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
else:
await run_manager.on_llm_end(
LLMResult(generations=[[ChatGeneration(message=message)]]),
)
# --- Custom methods ---
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
return {}
@ -334,7 +498,6 @@ class BaseChatModel(BaseLanguageModel, ABC):
) -> ChatResult:
"""Top Level call"""
@abstractmethod
async def _agenerate(
self,
messages: List[BaseMessage],
@ -343,6 +506,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
raise NotImplementedError()
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
raise NotImplementedError()
def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
raise NotImplementedError()
def __call__(
self,

@ -25,7 +25,10 @@ class FakeListChatModel(SimpleChatModel):
) -> str:
"""First try to lookup in queries, else return 'foo' or 'bar'."""
response = self.responses[self.i]
self.i += 1
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
return response
@property

@ -4,8 +4,10 @@ from __future__ import annotations
import logging
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
@ -36,6 +38,14 @@ from langchain.schema import (
HumanMessage,
SystemMessage,
)
from langchain.schema.messages import (
AIMessageChunk,
BaseMessageChunk,
ChatMessageChunk,
HumanMessageChunk,
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__)
@ -75,6 +85,24 @@ async def acompletion_with_retry(llm: JinaChat, **kwargs: Any) -> Any:
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
@ -258,6 +286,25 @@ class JinaChat(BaseChatModel):
overall_token_usage[k] = v
return {"token_usage": overall_token_usage}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
def _generate(
self,
messages: List[BaseMessage],
@ -265,27 +312,20 @@ class JinaChat(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
@ -309,6 +349,27 @@ class JinaChat(BaseChatModel):
llm_output = {"token_usage": response["usage"]}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params
):
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
async def _agenerate(
self,
messages: List[BaseMessage],
@ -316,32 +377,22 @@ class JinaChat(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
async for stream_resp in await acompletion_with_retry(
self, messages=message_dicts, **params
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token or ""
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
response = await acompletion_with_retry(
self, messages=message_dicts, **params
)
return self._create_chat_result(response)
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(self, messages=message_dicts, **params)
return self._create_chat_result(response)
@property
def _invocation_params(self) -> Mapping[str, Any]:

@ -6,8 +6,10 @@ import sys
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
@ -35,12 +37,19 @@ from langchain.schema import (
)
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessage,
ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
SystemMessageChunk,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
if TYPE_CHECKING:
@ -95,6 +104,30 @@ async def acompletion_with_retry(llm: ChatOpenAI, **kwargs: Any) -> Any:
return await _completion_with_retry(**kwargs)
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _dict.get("role")
content = _dict.get("content") or ""
if _dict.get("function_call"):
additional_kwargs = {"function_call": dict(_dict["function_call"])}
else:
additional_kwargs = {}
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
@ -313,6 +346,27 @@ class ChatOpenAI(BaseChatModel):
overall_token_usage[k] = v
return {"token_usage": overall_token_usage, "model_name": self.model_name}
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
for chunk in self.completion_with_retry(messages=message_dicts, **params):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
run_manager.on_llm_new_token(chunk.content)
def _generate(
self,
messages: List[BaseMessage],
@ -320,40 +374,20 @@ class ChatOpenAI(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
for stream_resp in self.completion_with_retry(
messages=message_dicts, **params
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content") or ""
inner_completion += token
_function_call = stream_resp["choices"][0]["delta"].get(
"function_call"
)
if _function_call:
if function_call is None:
function_call = _function_call
elif "arguments" in function_call:
function_call["arguments"] += _function_call["arguments"]
else:
function_call["arguments"] = _function_call["arguments"]
if run_manager:
run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = self.completion_with_retry(messages=message_dicts, **params)
return self._create_chat_result(response)
@ -381,6 +415,29 @@ class ChatOpenAI(BaseChatModel):
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return ChatResult(generations=generations, llm_output=llm_output)
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
default_chunk_class = AIMessageChunk
async for chunk in await acompletion_with_retry(
self, messages=message_dicts, **params
):
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
default_chunk_class = chunk.__class__
yield ChatGenerationChunk(message=chunk)
if run_manager:
await run_manager.on_llm_new_token(chunk.content)
async def _agenerate(
self,
messages: List[BaseMessage],
@ -388,45 +445,22 @@ class ChatOpenAI(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
params["stream"] = True
function_call: Optional[dict] = None
async for stream_resp in await acompletion_with_retry(
self, messages=message_dicts, **params
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if len(stream_resp["choices"]) > 0:
role = stream_resp["choices"][0]["delta"].get("role", role)
token = stream_resp["choices"][0]["delta"].get("content", "")
inner_completion += token or ""
_function_call = stream_resp["choices"][0]["delta"].get(
"function_call"
)
if _function_call:
if function_call is None:
function_call = _function_call
elif "arguments" in function_call:
function_call["arguments"] += _function_call["arguments"]
else:
function_call["arguments"] = _function_call["arguments"]
if run_manager:
await run_manager.on_llm_new_token(token)
message = _convert_dict_to_message(
{
"content": inner_completion,
"role": role,
"function_call": function_call,
}
)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
response = await acompletion_with_retry(
self, messages=message_dicts, **params
)
return self._create_chat_result(response)
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(generations=[generation])
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(self, messages=message_dicts, **params)
return self._create_chat_result(response)
@property
def _identifying_params(self) -> Dict[str, Any]:

@ -4,10 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import BaseChatModel
from langchain.llms.vertexai import _VertexAICommon, is_codey_model
from langchain.schema import (
@ -162,14 +159,3 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
response = chat.send_message(question.content)
text = self._enforce_stop_words(response.text, stop)
return ChatResult(generations=[ChatGeneration(message=AIMessage(content=text))])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError(
"""Vertex AI doesn't support async requests at the moment."""
)

@ -6,7 +6,8 @@ from pydantic import BaseModel
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions import create_tagging_chain
from langchain.prompts import ChatPromptTemplate
from langchain.schema import BaseDocumentTransformer, BaseLanguageModel, Document
from langchain.schema import BaseDocumentTransformer, Document
from langchain.schema.language_model import BaseLanguageModel
class OpenAIMetadataTagger(BaseDocumentTransformer, BaseModel):

@ -1,18 +1,20 @@
import re
import warnings
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
from pydantic import BaseModel, root_validator
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk
from langchain.utils import check_package_version, get_from_dict_or_env
class _AnthropicCommon(BaseModel):
class _AnthropicCommon(BaseLanguageModel):
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model: str = "claude-2"
@ -193,24 +195,16 @@ class Anthropic(LLM, _AnthropicCommon):
response = model(prompt)
"""
if self.streaming:
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming:
stream_resp = self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
)
current_completion = ""
for data in stream_resp:
delta = data.completion
current_completion += delta
if run_manager:
run_manager.on_llm_new_token(
delta,
)
return current_completion
response = self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
@ -226,22 +220,17 @@ class Anthropic(LLM, _AnthropicCommon):
**kwargs: Any,
) -> str:
"""Call out to Anthropic's completion endpoint asynchronously."""
if self.streaming:
completion = ""
async for chunk in self._astream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming:
stream_resp = await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
)
current_completion = ""
async for data in stream_resp:
delta = data.completion
current_completion += delta
if run_manager:
await run_manager.on_llm_new_token(delta)
return current_completion
response = await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
@ -249,23 +238,55 @@ class Anthropic(LLM, _AnthropicCommon):
)
return response.completion
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
r"""Call Anthropic completion_stream and return the resulting generator.
BETA: this is a beta feature while we figure out the right abstraction.
Once that happens, this interface could change.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Anthropic.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = anthropic.stream(prompt)
for token in generator:
yield token
"""
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
):
yield GenerationChunk(text=token.completion)
if run_manager:
run_manager.on_llm_new_token(token.completion)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
r"""Call Anthropic completion_stream and return the resulting generator.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Anthropic.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = anthropic.stream(prompt)
@ -273,12 +294,17 @@ class Anthropic(LLM, _AnthropicCommon):
yield token
"""
stop = self._get_anthropic_stop(stop)
return self.client.completions.create(
params = {**self._default_params, **kwargs}
async for token in await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**self._default_params,
)
**params,
):
yield GenerationChunk(text=token.completion)
if run_manager:
await run_manager.on_llm_new_token(token.completion)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""

@ -7,11 +7,14 @@ import json
import logging
import warnings
from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
@ -19,6 +22,7 @@ from typing import (
Tuple,
Type,
Union,
cast,
)
import yaml
@ -42,14 +46,18 @@ from langchain.callbacks.manager import (
Callbacks,
)
from langchain.load.dump import dumpd
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import (
Generation,
LLMResult,
PromptValue,
RunInfo,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.language_model import BaseLanguageModel, LanguageModelInput
from langchain.schema.messages import AIMessage, BaseMessage, get_buffer_string
from langchain.schema.output import GenerationChunk
from langchain.schema.runnable import RunnableConfig
logger = logging.getLogger(__name__)
@ -115,7 +123,7 @@ def update_cache(
return llm_output
class BaseLLM(BaseLanguageModel, ABC):
class BaseLLM(BaseLanguageModel[str], ABC):
"""Base LLM abstract interface.
It should take in a prompt and return a string."""
@ -157,6 +165,204 @@ class BaseLLM(BaseLanguageModel, ABC):
else:
return verbose
# --- Runnable methods ---
def _convert_input(self, input: LanguageModelInput) -> PromptValue:
if isinstance(input, PromptValue):
return input
elif isinstance(input, str):
return StringPromptValue(text=input)
elif isinstance(input, list):
return ChatPromptValue(messages=input)
else:
raise ValueError(
f"Invalid input type {type(input)}. "
"Must be a PromptValue, str, or list of BaseMessages."
)
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> str:
return (
self.generate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
)
.generations[0][0]
.text
)
async def ainvoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> str:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async invoke, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.invoke, input, config, stop=stop)
)
llm_result = await self.agenerate_prompt(
[self._convert_input(input)], stop=stop, **(config or {})
)
return llm_result.generations[0][0].text
def batch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
) -> List[str]:
config = self._get_config_list(config, len(inputs))
if max_concurrency is None:
llm_result = self.generate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
)
return [g[0].text for g in llm_result.generations]
else:
batches = [
inputs[i : i + max_concurrency]
for i in range(0, len(inputs), max_concurrency)
]
return [
output
for batch in batches
for output in self.batch(batch, config=config)
]
async def abatch(
self,
inputs: List[LanguageModelInput],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
max_concurrency: Optional[int] = None,
) -> List[str]:
if type(self)._agenerate == BaseLLM._agenerate:
# model doesn't implement async batch, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, self.batch, inputs, config, max_concurrency
)
config = self._get_config_list(config, len(inputs))
if max_concurrency is None:
llm_result = await self.agenerate_prompt(
[self._convert_input(input) for input in inputs],
callbacks=[c.get("callbacks") for c in config],
tags=[c.get("tags") for c in config],
metadata=[c.get("metadata") for c in config],
)
return [g[0].text for g in llm_result.generations]
else:
batches = [
inputs[i : i + max_concurrency]
for i in range(0, len(inputs), max_concurrency)
]
return [
output
for batch in batches
for output in await self.abatch(batch, config=config)
]
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> Iterator[str]:
if type(self)._stream == BaseLLM._stream:
# model doesn't implement streaming, so use default implementation
yield self.invoke(input, config=config, stop=stop)
else:
prompt = self._convert_input(input).to_string()
config = config or {}
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
)
try:
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompt, stop=stop, run_manager=run_manager):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(LLMResult(generations=[[generation]]))
async def astream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
) -> AsyncIterator[str]:
if type(self)._astream == BaseLLM._astream:
# model doesn't implement streaming, so use default implementation
yield await self.ainvoke(input, config=config, stop=stop)
else:
prompt = self._convert_input(input).to_string()
config = config or {}
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
)
try:
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(
prompt, stop=stop, run_manager=run_manager
):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
raise e
else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
# --- Custom methods ---
@abstractmethod
def _generate(
self,
@ -167,7 +373,6 @@ class BaseLLM(BaseLanguageModel, ABC):
) -> LLMResult:
"""Run the LLM on the given prompts."""
@abstractmethod
async def _agenerate(
self,
prompts: List[str],
@ -176,12 +381,31 @@ class BaseLLM(BaseLanguageModel, ABC):
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompts."""
raise NotImplementedError()
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
raise NotImplementedError()
def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
raise NotImplementedError()
def generate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
**kwargs: Any,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
@ -191,7 +415,7 @@ class BaseLLM(BaseLanguageModel, ABC):
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
**kwargs: Any,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
@ -236,10 +460,10 @@ class BaseLLM(BaseLanguageModel, ABC):
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
@ -248,6 +472,50 @@ class BaseLLM(BaseLanguageModel, ABC):
"Argument 'prompts' is expected to be of type List[str], received"
f" argument of type {type(prompts)}."
)
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))
or callbacks[0] is None
):
# We've received a list of callbacks args to apply to each input
assert len(callbacks) == len(prompts)
assert tags is None or (
isinstance(tags, list) and len(tags) == len(prompts)
)
assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
)
callback_managers = [
CallbackManager.configure(
callback,
self.callbacks,
self.verbose,
tag,
self.tags,
meta,
self.metadata,
)
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
]
else:
# We've received a single callbacks arg to apply to all inputs
callback_managers = [
CallbackManager.configure(
cast(Callbacks, callbacks),
self.callbacks,
self.verbose,
cast(List[str], tags),
self.tags,
cast(Dict[str, Any], metadata),
self.metadata,
)
] * len(prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
@ -258,15 +526,6 @@ class BaseLLM(BaseLanguageModel, ABC):
missing_prompts,
) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
callback_manager = CallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
@ -275,17 +534,26 @@ class BaseLLM(BaseLanguageModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_managers = callback_manager.on_llm_start(
dumpd(self), prompts, invocation_params=params, options=options
)
run_managers = [
callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
)[0]
for callback_manager, prompt in zip(callback_managers, prompts)
]
output = self._generate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_managers = callback_manager.on_llm_start(
dumpd(self), missing_prompts, invocation_params=params, options=options
)
run_managers = [
callback_managers[idx].on_llm_start(
dumpd(self),
[prompts[idx]],
invocation_params=params,
options=options,
)[0]
for idx in missing_prompt_idxs
]
new_results = self._generate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
@ -346,13 +614,57 @@ class BaseLLM(BaseLanguageModel, ABC):
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None,
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
tags: Optional[Union[List[str], List[List[str]]]] = None,
metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# Create callback managers
if isinstance(callbacks, list) and (
isinstance(callbacks[0], (list, BaseCallbackManager))
or callbacks[0] is None
):
# We've received a list of callbacks args to apply to each input
assert len(callbacks) == len(prompts)
assert tags is None or (
isinstance(tags, list) and len(tags) == len(prompts)
)
assert metadata is None or (
isinstance(metadata, list) and len(metadata) == len(prompts)
)
callbacks = cast(List[Callbacks], callbacks)
tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts)))
metadata_list = cast(
List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts))
)
callback_managers = [
AsyncCallbackManager.configure(
callback,
self.callbacks,
self.verbose,
tag,
self.tags,
meta,
self.metadata,
)
for callback, tag, meta in zip(callbacks, tags_list, metadata_list)
]
else:
# We've received a single callbacks arg to apply to all inputs
callback_managers = [
AsyncCallbackManager.configure(
cast(Callbacks, callbacks),
self.callbacks,
self.verbose,
cast(List[str], tags),
self.tags,
cast(Dict[str, Any], metadata),
self.metadata,
)
] * len(prompts)
params = self.dict()
params["stop"] = stop
options = {"stop": stop}
@ -363,15 +675,6 @@ class BaseLLM(BaseLanguageModel, ABC):
missing_prompts,
) = get_prompts(params, prompts)
disregard_cache = self.cache is not None and not self.cache
callback_manager = AsyncCallbackManager.configure(
callbacks,
self.callbacks,
self.verbose,
tags,
self.tags,
metadata,
self.metadata,
)
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
@ -380,17 +683,32 @@ class BaseLLM(BaseLanguageModel, ABC):
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
run_managers = await callback_manager.on_llm_start(
dumpd(self), prompts, invocation_params=params, options=options
run_managers = await asyncio.gather(
*[
callback_manager.on_llm_start(
dumpd(self), [prompt], invocation_params=params, options=options
)
for callback_manager, prompt in zip(callback_managers, prompts)
]
)
run_managers = [r[0] for r in run_managers]
output = await self._agenerate_helper(
prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
return output
if len(missing_prompts) > 0:
run_managers = await callback_manager.on_llm_start(
dumpd(self), missing_prompts, invocation_params=params, options=options
run_managers = await asyncio.gather(
*[
callback_managers[idx].on_llm_start(
dumpd(self),
[prompts[idx]],
invocation_params=params,
options=options,
)
for idx in missing_prompt_idxs
]
)
run_managers = [r[0] for r in run_managers]
new_results = await self._agenerate_helper(
missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs
)
@ -586,7 +904,7 @@ class LLM(BaseLLM):
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.")
raise NotImplementedError()
def _generate(
self,
@ -615,6 +933,12 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
if type(self)._acall == LLM._acall:
# model doesn't implement async call, so use default implementation
return await asyncio.get_running_loop().run_in_executor(
None, partial(self._generate, prompts, stop, run_manager, **kwargs)
)
"""Run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")

@ -27,7 +27,10 @@ class FakeListLLM(LLM):
) -> str:
"""Return next response"""
response = self.responses[self.i]
self.i += 1
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
return response
async def _acall(
@ -39,7 +42,10 @@ class FakeListLLM(LLM):
) -> str:
"""Return next response"""
response = self.responses[self.i]
self.i += 1
if self.i < len(self.responses) - 1:
self.i += 1
else:
self.i = 0
return response
@property

@ -12,10 +12,7 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms import BaseLLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
@ -161,15 +158,6 @@ class GooglePalm(BaseLLM, BaseModel):
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
raise NotImplementedError()
@property
def _llm_type(self) -> str:
"""Return type of llm."""

@ -1,5 +1,4 @@
from functools import partial
from typing import Any, Dict, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from pydantic import Extra, Field, root_validator
@ -8,6 +7,7 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk
class HuggingFaceTextGenInference(LLM):
@ -69,7 +69,7 @@ class HuggingFaceTextGenInference(LLM):
temperature = 0.01,
repetition_penalty = 1.03,
callbacks = callbacks,
stream = True
streaming = True
)
print(llm("What is Deep Learning?"))
@ -87,7 +87,7 @@ class HuggingFaceTextGenInference(LLM):
inference_server_url: str = ""
timeout: int = 120
server_kwargs: Dict[str, Any] = Field(default_factory=dict)
stream: bool = False
streaming: bool = False
client: Any
async_client: Any
@ -154,37 +154,21 @@ class HuggingFaceTextGenInference(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
invocation_params = self._invocation_params(stop, **kwargs)
if not self.stream:
res = self.client.generate(prompt, **invocation_params)
# remove stop sequences from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq)
]
text = res.generated_text
else:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
text = ""
for res in self.client.generate_stream(prompt, **invocation_params):
token = res.token
is_stop = False
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in token.text:
is_stop = True
break
if is_stop:
break
if not token.special:
if text_callback:
text_callback(token.text)
text += token.text
return text
res = self.client.generate(prompt, **invocation_params)
# remove stop sequences from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq)
]
return res.generated_text
async def _acall(
self,
@ -193,39 +177,90 @@ class HuggingFaceTextGenInference(LLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
completion = ""
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
invocation_params = self._invocation_params(stop, **kwargs)
if not self.stream:
res = await self.async_client.generate(
prompt,
**invocation_params,
)
# remove stop sequences from the end of the generated text
res = await self.async_client.generate(prompt, **invocation_params)
# remove stop sequences from the end of the generated text
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq)
]
return res.generated_text
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
invocation_params = self._invocation_params(stop, **kwargs)
for res in self.client.generate_stream(prompt, **invocation_params):
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq)
]
text: str = res.generated_text
else:
text_callback = None
if run_manager:
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
text = ""
async for res in self.async_client.generate_stream(
prompt, **invocation_params
):
token = res.token
is_stop = False
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in token.text:
is_stop = True
break
if is_stop:
break
if not token.special:
if text_callback:
await text_callback(token.text)
text += token.text
return text
if stop_seq in res.token.text:
stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if res.token.special:
text = None
elif stop_seq_found:
text = res.token.text[: res.token.text.index(stop_seq_found)]
else:
text = res.token.text
# yield text, if any
if text:
chunk = GenerationChunk(text=text)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
# break if stop sequence found
if stop_seq_found:
break
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
invocation_params = self._invocation_params(stop, **kwargs)
async for res in self.async_client.generate_stream(prompt, **invocation_params):
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.token.text:
stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if res.token.special:
text = None
elif stop_seq_found:
text = res.token.text[: res.token.text.index(stop_seq_found)]
else:
text = res.token.text
# yield text, if any
if text:
chunk = GenerationChunk(text=text)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)
# break if stop sequence found
if stop_seq_found:
break

@ -1,10 +1,11 @@
import logging
from typing import Any, Dict, Generator, List, Optional
from typing import Any, Dict, Iterator, List, Optional
from pydantic import Field, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.schema.output import GenerationChunk
logger = logging.getLogger(__name__)
@ -226,8 +227,10 @@ class LlamaCpp(LLM):
# method that yields as they are generated
# and return the combined strings from the first choices's text:
combined_text_output = ""
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
combined_text_output += token["choices"][0]["text"]
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
combined_text_output += chunk.text
return combined_text_output
else:
params = self._get_parameters(stop)
@ -235,17 +238,15 @@ class LlamaCpp(LLM):
result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"]
def stream(
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> Generator[Dict, None, None]:
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Yields results objects as they are generated in real time.
BETA: this is a beta feature while we figure out the right abstraction.
Once that happens, this interface could change.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
@ -274,16 +275,19 @@ class LlamaCpp(LLM):
print(result["text"], end='', flush=True)
"""
params = self._get_parameters(stop)
params = {**self._get_parameters(stop), **kwargs}
result = self.client(prompt=prompt, stream=True, **params)
for chunk in result:
token = chunk["choices"][0]["text"]
log_probs = chunk["choices"][0].get("logprobs", None)
for part in result:
logprobs = part["choices"][0].get("logprobs", None)
chunk = GenerationChunk(
text=part["choices"][0]["text"],
generation_info={"logprobs": logprobs},
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
token=chunk.text, verbose=self.verbose, log_probs=logprobs
)
yield chunk
def get_num_tokens(self, text: str) -> int:
tokenized_text = self.client.tokenize(text.encode("utf-8"))

@ -6,10 +6,11 @@ import warnings
from typing import (
AbstractSet,
Any,
AsyncIterator,
Callable,
Collection,
Dict,
Generator,
Iterator,
List,
Literal,
Mapping,
@ -27,6 +28,7 @@ from langchain.callbacks.manager import (
)
from langchain.llms.base import BaseLLM, create_base_retry_decorator
from langchain.schema import Generation, LLMResult
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env, get_pydantic_field_names
logger = logging.getLogger(__name__)
@ -44,6 +46,19 @@ def update_token_usage(
token_usage[_key] += response["usage"][_key]
def _stream_response_to_generation_chunk(
stream_response: Dict[str, Any],
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
return GenerationChunk(
text=stream_response["choices"][0]["text"],
generation_info=dict(
finish_reason=stream_response["choices"][0].get("finish_reason", None),
logprobs=stream_response["choices"][0].get("logprobs", None),
),
)
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
"""Update response from the stream response."""
response["choices"][0]["text"] += stream_response["choices"][0]["text"]
@ -268,6 +283,50 @@ class BaseOpenAI(BaseLLM):
return {**normal_params, **self.model_kwargs}
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params
for stream_resp in completion_with_retry(self, prompt=prompt, **params):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = {**self._invocation_params, **kwargs, "stream": True}
self.get_sub_prompts(params, [prompt], stop) # this mutate params
async for stream_resp in await acompletion_with_retry(
self, prompt=prompt, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(
chunk.text,
verbose=self.verbose,
logprobs=chunk.generation_info["logprobs"]
if chunk.generation_info
else None,
)
def _generate(
self,
prompts: List[str],
@ -302,24 +361,28 @@ class BaseOpenAI(BaseLLM):
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
response = _streaming_response_template()
for stream_resp in completion_with_retry(
self, prompt=_prompts, **params
):
if run_manager:
run_manager.on_llm_new_token(
stream_resp["choices"][0]["text"],
verbose=self.verbose,
logprobs=stream_resp["choices"][0]["logprobs"],
)
_update_response(response, stream_resp)
choices.extend(response["choices"])
generation: Optional[GenerationChunk] = None
for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
choices.append(
{
"text": generation.text,
"finish_reason": generation.generation_info.get("finish_reason")
if generation.generation_info
else None,
"logprobs": generation.generation_info.get("logprobs")
if generation.generation_info
else None,
}
)
else:
response = completion_with_retry(self, prompt=_prompts, **params)
choices.extend(response["choices"])
if not self.streaming:
# Can't update token usage if streaming
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
@ -343,24 +406,30 @@ class BaseOpenAI(BaseLLM):
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
params["stream"] = True
response = _streaming_response_template()
async for stream_resp in await acompletion_with_retry(
self, prompt=_prompts, **params
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(
_prompts[0], stop, run_manager, **kwargs
):
if run_manager:
await run_manager.on_llm_new_token(
stream_resp["choices"][0]["text"],
verbose=self.verbose,
logprobs=stream_resp["choices"][0]["logprobs"],
)
_update_response(response, stream_resp)
choices.extend(response["choices"])
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
choices.append(
{
"text": generation.text,
"finish_reason": generation.generation_info.get("finish_reason")
if generation.generation_info
else None,
"logprobs": generation.generation_info.get("logprobs")
if generation.generation_info
else None,
}
)
else:
response = await acompletion_with_retry(self, prompt=_prompts, **params)
choices.extend(response["choices"])
if not self.streaming:
# Can't update token usage if streaming
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
@ -409,43 +478,6 @@ class BaseOpenAI(BaseLLM):
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
return LLMResult(generations=generations, llm_output=llm_output)
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
"""Call OpenAI with streaming flag and return the resulting generator.
BETA: this is a beta feature while we figure out the right abstraction.
Once that happens, this interface could change.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from OpenAI.
Example:
.. code-block:: python
generator = openai.stream("Tell me a joke.")
for token in generator:
yield token
"""
params = self.prep_streaming_params(stop)
generator = self.client.create(prompt=prompt, **params)
return generator
def prep_streaming_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""Prepare the params for streaming."""
params = self._invocation_params
if "best_of" in params and params["best_of"] != 1:
raise ValueError("OpenAI only supports best_of == 1 for streaming")
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
params["stream"] = True
return params
@property
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
@ -777,6 +809,38 @@ class OpenAIChat(BaseLLM):
del params["max_tokens"]
return messages, params
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry(self, messages=messages, **params):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
if run_manager:
run_manager.on_llm_new_token(token)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
messages, params = self._get_chat_params([prompt], stop)
params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry(
self, messages=messages, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
yield GenerationChunk(text=token)
if run_manager:
await run_manager.on_llm_new_token(token)
def _generate(
self,
prompts: List[str],
@ -784,33 +848,29 @@ class OpenAIChat(BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
if self.streaming:
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompts[0], stop, run_manager, **kwargs):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return LLMResult(generations=[[generation]])
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
if self.streaming:
response = ""
params["stream"] = True
for stream_resp in completion_with_retry(self, messages=messages, **params):
token = stream_resp["choices"][0]["delta"].get("content", "")
response += token
if run_manager:
run_manager.on_llm_new_token(
token,
)
return LLMResult(
generations=[[Generation(text=response)]],
)
else:
full_response = completion_with_retry(self, messages=messages, **params)
llm_output = {
"token_usage": full_response["usage"],
"model_name": self.model_name,
}
return LLMResult(
generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])]
],
llm_output=llm_output,
)
full_response = completion_with_retry(self, messages=messages, **params)
llm_output = {
"token_usage": full_response["usage"],
"model_name": self.model_name,
}
return LLMResult(
generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])]
],
llm_output=llm_output,
)
async def _agenerate(
self,
@ -819,37 +879,29 @@ class OpenAIChat(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
if self.streaming:
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompts[0], stop, run_manager, **kwargs):
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return LLMResult(generations=[[generation]])
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
if self.streaming:
response = ""
params["stream"] = True
async for stream_resp in await acompletion_with_retry(
self, messages=messages, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
response += token
if run_manager:
await run_manager.on_llm_new_token(
token,
)
return LLMResult(
generations=[[Generation(text=response)]],
)
else:
full_response = await acompletion_with_retry(
self, messages=messages, **params
)
llm_output = {
"token_usage": full_response["usage"],
"model_name": self.model_name,
}
return LLMResult(
generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])]
],
llm_output=llm_output,
)
full_response = await acompletion_with_retry(self, messages=messages, **params)
llm_output = {
"token_usage": full_response["usage"],
"model_name": self.model_name,
}
return LLMResult(
generations=[
[Generation(text=full_response["choices"][0]["message"]["content"])]
],
llm_output=llm_output,
)
@property
def _identifying_params(self) -> Mapping[str, Any]:

@ -13,10 +13,7 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
@ -250,12 +247,3 @@ class Tongyi(LLM):
]
)
return LLMResult(generations=generations)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
raise NotImplementedError()

@ -6,7 +6,7 @@ from typing import List
from langchain.schema import BaseOutputParser
class ListOutputParser(BaseOutputParser):
class ListOutputParser(BaseOutputParser[List[str]]):
"""Parse the output of an LLM call to a list."""
@property

@ -4,14 +4,14 @@ from typing import Any, Dict, List, Type, Union
from pydantic import BaseModel, root_validator
from langchain.schema import (
BaseLLMOutputParser,
ChatGeneration,
Generation,
OutputParserException,
)
from langchain.schema.output_parser import BaseGenerationOutputParser
class OutputFunctionsParser(BaseLLMOutputParser[Any]):
class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Parse an output that is one of sets of values."""
args_only: bool = True

@ -5,10 +5,10 @@ import warnings
from abc import ABC
from typing import Any, Callable, Dict, List, Set
from langchain.schema import BasePromptTemplate
from langchain.formatting import formatter
from langchain.schema.messages import BaseMessage, HumanMessage
from langchain.schema.prompt import PromptValue
from langchain.utils import formatter
from langchain.schema.prompt_template import BasePromptTemplate
def jinja2_formatter(template: str, **kwargs: Any) -> str:

@ -446,7 +446,7 @@ class ChatPromptTemplate(BaseChatPromptTemplate, ABC):
for message in messages:
if isinstance(message, BaseMessagePromptTemplate):
input_vars.update(message.input_variables)
return cls(input_variables=list(input_vars), messages=messages)
return cls(input_variables=sorted(input_vars), messages=messages)
def format(self, **kwargs: Any) -> str:
"""Format the chat template into a string.

@ -1,9 +1,6 @@
from typing import List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
from langchain.utilities.arxiv import ArxivAPIWrapper
@ -20,8 +17,3 @@ class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load(query=query)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -7,10 +7,7 @@ from __future__ import annotations
from typing import Any, Callable, Dict, Iterable, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
@ -108,8 +105,3 @@ class BM25Retriever(BaseRetriever):
processed_query = self.preprocess_func(query)
return_docs = self.vectorizer.get_top_n(processed_query, self.docs, n=self.k)
return return_docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional, Union
import numpy as np
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.utils import maximal_marginal_relevance
@ -208,11 +205,3 @@ class DocArrayRetriever(BaseRetriever):
lc_doc.metadata[name] = value
return lc_doc
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError

@ -5,10 +5,7 @@ from __future__ import annotations
import uuid
from typing import Any, Iterable, List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
@ -138,8 +135,3 @@ class ElasticSearchBM25Retriever(BaseRetriever):
for r in res["hits"]["hits"]:
docs.append(Document(page_content=r["_source"]["content"]))
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -5,10 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
from pydantic import Extra, Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
from langchain.utils import get_from_dict_or_env
@ -184,8 +181,3 @@ class GoogleCloudEnterpriseSearchRetriever(BaseRetriever):
documents = self._convert_search_response(response.results)
return documents
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -4,10 +4,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
@ -411,11 +408,3 @@ class AmazonKendraRetriever(BaseRetriever):
"""
docs = self._kendra_query(query, self.top_k, self.attribute_filter)
return docs
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError("Async version is not implemented for Kendra yet.")

@ -9,10 +9,7 @@ from typing import Any, List, Optional
import numpy as np
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
@ -82,8 +79,3 @@ class KNNRetriever(BaseRetriever):
)
]
return top_k_results
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -2,10 +2,7 @@ from typing import Any, Dict, List, cast
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
@ -42,11 +39,6 @@ class LlamaIndexRetriever(BaseRetriever):
)
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError("LlamaIndexRetriever does not support async")
class LlamaIndexGraphRetriever(BaseRetriever):
"""Retriever for question-answering with sources over an LlamaIndex
@ -88,8 +80,3 @@ class LlamaIndexGraphRetriever(BaseRetriever):
Document(page_content=source_node.source_text, metadata=metadata)
)
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError("LlamaIndexGraphRetriever does not support async")

@ -2,10 +2,7 @@ from typing import Any, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
@ -43,8 +40,3 @@ class MetalRetriever(BaseRetriever):
metadata = {k: v for k, v in r.items() if k != "text"}
final_results.append(Document(page_content=r["text"], metadata=metadata))
return final_results
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.milvus import Milvus
@ -63,15 +60,6 @@ class MilvusRetriever(BaseRetriever):
query, run_manager=run_manager.get_child(), **kwargs
)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def MilvusRetreiver(*args: Any, **kwargs: Any) -> MilvusRetriever:
"""Deprecated MilvusRetreiver. Please use MilvusRetriever ('i' before 'e') instead.

@ -3,10 +3,7 @@ from typing import List
from pydantic import BaseModel, Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.output_parsers.pydantic import PydanticOutputParser
@ -101,14 +98,6 @@ class MultiQueryRetriever(BaseRetriever):
unique_documents = self.unique_union(documents)
return unique_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
raise NotImplementedError
def generate_queries(
self, question: str, run_manager: CallbackManagerForRetrieverRun
) -> List[str]:

@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
@ -175,8 +172,3 @@ class PineconeHybridSearchRetriever(BaseRetriever):
)
# return search results as json
return final_result
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -1,9 +1,6 @@
from typing import List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
from langchain.utilities.pupmed import PubMedAPIWrapper
@ -19,8 +16,3 @@ class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load_docs(query=query)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -5,10 +5,7 @@ from typing import Any, Dict, List, Optional, Type, cast
from pydantic import BaseModel, Field, root_validator
from langchain import LLMChain
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.chains.query_constructor.base import load_query_constructor_chain
from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
from langchain.chains.query_constructor.schema import AttributeInfo
@ -119,11 +116,6 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
docs = self.vectorstore.search(new_query, self.search_type, **search_kwargs)
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError
@classmethod
def from_llm(
cls,

@ -5,10 +5,7 @@ from typing import Any, Iterable, List, Optional
import numpy as np
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
@ -113,8 +110,3 @@ class SVMRetriever(BaseRetriever):
):
top_k_results.append(Document(page_content=self.texts[row - 1]))
return top_k_results
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -2,10 +2,7 @@ from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
@ -79,8 +76,3 @@ class TFIDFRetriever(BaseRetriever):
) # Op -- (n_docs,1) -- Cosine Sim with each doc
return_docs = [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
return return_docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -4,10 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.base import VectorStore
@ -109,12 +106,6 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
result.append(buffered_doc)
return result
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Return documents that are relevant to the query."""
raise NotImplementedError
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
current_time = kwargs.get("current_time")

@ -3,10 +3,7 @@ from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
if TYPE_CHECKING:
@ -57,11 +54,6 @@ class VespaRetriever(BaseRetriever):
body["query"] = query
return self._query(body)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError
def get_relevant_documents_with_filter(
self, query: str, *, _filter: Optional[str] = None
) -> List[Document]:

@ -5,10 +5,7 @@ from uuid import uuid4
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.schema import BaseRetriever
@ -118,8 +115,3 @@ class WeaviateHybridSearchRetriever(BaseRetriever):
text = res.pop(self.text_key)
docs.append(Document(page_content=text, metadata=res))
return docs
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -1,9 +1,6 @@
from typing import List
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.schema import BaseRetriever, Document
from langchain.utilities.wikipedia import WikipediaAPIWrapper
@ -19,8 +16,3 @@ class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
return self.load(query=query)
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

@ -3,10 +3,7 @@ from typing import Any, Dict, List, Optional
from pydantic import root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever, Document
from langchain.vectorstores.zilliz import Zilliz
@ -67,15 +64,6 @@ class ZillizRetriever(BaseRetriever):
query, run_manager=run_manager.get_child(), **kwargs
)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
**kwargs: Any,
) -> List[Document]:
raise NotImplementedError
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
"""Deprecated ZillizRetreiver.

@ -1,6 +1,5 @@
from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.document import BaseDocumentTransformer, Document
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.memory import BaseChatMessageHistory, BaseMemory
from langchain.schema.messages import (
AIMessage,
@ -67,6 +66,5 @@ __all__ = [
"BaseOutputParser",
"BaseLLMOutputParser",
"BasePromptTemplate",
"BaseLanguageModel",
"format_document",
]

@ -1,12 +1,22 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Set
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
Sequence,
Set,
TypeVar,
Union,
)
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage, get_buffer_string
from langchain.schema.output import LLMResult
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable
from langchain.utils import get_pydantic_field_names
if TYPE_CHECKING:
@ -32,7 +42,13 @@ def _get_token_ids_default_method(text: str) -> List[int]:
return tokenizer.encode(text)
class BaseLanguageModel(Serializable, ABC):
LanguageModelInput = Union[PromptValue, str, List[BaseMessage]]
LanguageModelOutput = TypeVar("LanguageModelOutput")
class BaseLanguageModel(
Serializable, Runnable[LanguageModelInput, LanguageModelOutput], ABC
):
"""Abstract base class for interfacing with language models.
All language model wrappers inherit from BaseLanguageModel.

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import abstractmethod
from typing import List, Sequence
from typing import Any, Dict, List, Sequence
from pydantic import Field
@ -78,6 +78,49 @@ class BaseMessage(Serializable):
return True
class BaseMessageChunk(BaseMessage):
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one."""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif type(merged[k]) != type(v):
raise ValueError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
else:
raise ValueError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk:
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
# concat into a single BaseMessageChunk
return self.__class__(
content=self.content + other.content,
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'
f"{self.__class__.__name__}"
f'" and "{other.__class__.__name__}"'
)
class HumanMessage(BaseMessage):
"""A Message from a human."""
@ -92,6 +135,10 @@ class HumanMessage(BaseMessage):
return "human"
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
pass
class AIMessage(BaseMessage):
"""A Message from an AI."""
@ -106,6 +153,10 @@ class AIMessage(BaseMessage):
return "ai"
class AIMessageChunk(AIMessage, BaseMessageChunk):
pass
class SystemMessage(BaseMessage):
"""A Message for priming AI behavior, usually passed in as the first of a sequence
of input messages.
@ -117,6 +168,10 @@ class SystemMessage(BaseMessage):
return "system"
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
pass
class FunctionMessage(BaseMessage):
"""A Message for passing the result of executing a function back to a model."""
@ -129,6 +184,10 @@ class FunctionMessage(BaseMessage):
return "function"
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
pass
class ChatMessage(BaseMessage):
"""A Message that can be assigned an arbitrary speaker (i.e. role)."""
@ -141,6 +200,10 @@ class ChatMessage(BaseMessage):
return "chat"
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
pass
def _message_to_dict(message: BaseMessage) -> dict:
return {"type": message.type, "data": message.dict()}

@ -7,7 +7,7 @@ from uuid import UUID
from pydantic import BaseModel, root_validator
from langchain.load.serializable import Serializable
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import BaseMessage, BaseMessageChunk
class Generation(Serializable):
@ -28,6 +28,24 @@ class Generation(Serializable):
return True
class GenerationChunk(Generation):
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return GenerationChunk(
text=self.text + other.text,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
class ChatGeneration(Generation):
"""A single chat generation output."""
@ -43,6 +61,26 @@ class ChatGeneration(Generation):
return values
class ChatGenerationChunk(ChatGeneration):
message: BaseMessageChunk
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = (
{**(self.generation_info or {}), **(other.generation_info or {})}
if self.generation_info is not None or other.generation_info is not None
else None
)
return ChatGenerationChunk(
message=self.message + other.message,
generation_info=generation_info,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"
)
class RunInfo(BaseModel):
"""Class that contains metadata for a single execution of a Chain or model."""

@ -1,16 +1,18 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from langchain.load.serializable import Serializable
from langchain.schema.output import Generation
from langchain.schema.messages import BaseMessage
from langchain.schema.output import ChatGeneration, Generation
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
T = TypeVar("T")
class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
class BaseLLMOutputParser(Serializable, Generic[T], ABC):
"""Abstract base class for parsing the outputs of a model."""
@abstractmethod
@ -26,7 +28,19 @@ class BaseLLMOutputParser(Serializable, ABC, Generic[T]):
"""
class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
):
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
) -> T:
if isinstance(input, BaseMessage):
return self.parse_result([ChatGeneration(message=input)])
else:
return self.parse_result([Generation(text=input)])
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
"""Base class to parse the output of an LLM call.
Output parsers help structure language model responses.
@ -53,6 +67,14 @@ class BaseOutputParser(BaseLLMOutputParser, ABC, Generic[T]):
return "boolean_output_parser"
""" # noqa: E501
def invoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
) -> T:
if isinstance(input, BaseMessage):
return self.parse_result([ChatGeneration(message=input)])
else:
return self.parse_result([Generation(text=input)])
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.

@ -12,9 +12,10 @@ from langchain.load.serializable import Serializable
from langchain.schema.document import Document
from langchain.schema.output_parser import BaseOutputParser
from langchain.schema.prompt import PromptValue
from langchain.schema.runnable import Runnable, RunnableConfig
class BasePromptTemplate(Serializable, ABC):
class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
"""Base class for all prompt templates, returning a prompt."""
input_variables: List[str]
@ -34,6 +35,11 @@ class BasePromptTemplate(Serializable, ABC):
arbitrary_types_allowed = True
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
return self._call_with_config(
lambda inner_input: self.format_prompt(**inner_input), input, config
)
@abstractmethod
def format_prompt(self, **kwargs: Any) -> PromptValue:
"""Create Chat Messages."""

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
from langchain.schema.document import Document
from langchain.schema.runnable import Runnable, RunnableConfig
if TYPE_CHECKING:
from langchain.callbacks.manager import (
@ -17,7 +18,7 @@ if TYPE_CHECKING:
)
class BaseRetriever(Serializable, ABC):
class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
"""Abstract base class for a Document retrieval system.
A retrieval system is defined as something that can take string queries and return
@ -43,9 +44,6 @@ class BaseRetriever(Serializable, ABC):
# Op -- (n_docs,1) -- Cosine Sim with each doc
results = cosine_similarity(self.tfidf_array, query_vec).reshape((-1,))
return [self.docs[i] for i in results.argsort()[-self.k :][::-1]]
async def aget_relevant_documents(self, query: str) -> List[Document]:
raise NotImplementedError
""" # noqa: E501
class Config:
@ -106,6 +104,20 @@ class BaseRetriever(Serializable, ABC):
len(set(parameters.keys()) - {"self", "query", "run_manager"}) > 0
)
def invoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
return self.get_relevant_documents(input, **(config or {}))
async def ainvoke(
self, input: str, config: Optional[RunnableConfig] = None
) -> List[Document]:
if type(self).aget_relevant_documents == BaseRetriever.aget_relevant_documents:
# If the retriever doesn't implement async, use default implementation
return await super().ainvoke(input, config)
return await self.aget_relevant_documents(input, **(config or {}))
@abstractmethod
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
@ -118,7 +130,6 @@ class BaseRetriever(Serializable, ABC):
List of relevant documents
"""
@abstractmethod
async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
@ -129,6 +140,7 @@ class BaseRetriever(Serializable, ABC):
Returns:
List of relevant documents
"""
raise NotImplementedError()
def get_relevant_documents(
self,

@ -0,0 +1,705 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
Callable,
Coroutine,
Dict,
Generic,
Iterator,
List,
Optional,
TypedDict,
TypeVar,
Union,
cast,
)
from pydantic import Field
from langchain.callbacks.base import BaseCallbackManager, Callbacks
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
async def _gated_coro(semaphore: asyncio.Semaphore, coro: Coroutine) -> Any:
async with semaphore:
return await coro
async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> list:
if n is None:
return await asyncio.gather(*coros)
semaphore = asyncio.Semaphore(n)
return await asyncio.gather(*(_gated_coro(semaphore, c) for c in coros))
class RunnableConfig(TypedDict, total=False):
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
You can use these to filter calls.
"""
metadata: Dict[str, Any]
"""
Metadata for this call and any sub-calls (eg. a Chain calling an LLM).
Keys should be strings, values should be JSON-serializable.
"""
callbacks: Callbacks
"""
Callbacks for this call and any sub-calls (eg. a Chain calling an LLM).
Tags are passed to all callbacks, metadata is passed to handle*Start callbacks.
"""
Input = TypeVar("Input")
# Output type should implement __concat__, as eg str, list, dict do
Output = TypeVar("Output")
Other = TypeVar("Other")
class Runnable(Generic[Input, Output], ABC):
def __or__(
self,
other: Union[
Runnable[Any, Other],
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
return RunnableSequence(first=self, last=_coerce_to_runnable(other))
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
return RunnableSequence(first=_coerce_to_runnable(other), last=self)
@abstractmethod
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
...
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
return await asyncio.get_running_loop().run_in_executor(
None, self.invoke, input, config
)
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
with ThreadPoolExecutor(max_workers=max_concurrency) as executor:
return list(executor.map(self.invoke, inputs, configs))
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
configs = self._get_config_list(config, len(inputs))
coros = map(self.ainvoke, inputs, configs)
return await _gather_with_concurrency(max_concurrency, *coros)
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
yield self.invoke(input, config)
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
yield await self.ainvoke(input, config)
def _get_config_list(
self, config: Optional[Union[RunnableConfig, List[RunnableConfig]]], length: int
) -> List[RunnableConfig]:
if isinstance(config, list) and len(config) != length:
raise ValueError(
f"config must be a list of the same length as inputs, "
f"but got {len(config)} configs for {length} inputs"
)
return (
config
if isinstance(config, list)
else [config.copy() if config is not None else {} for _ in range(length)]
)
def _call_with_config(
self,
func: Callable[[Input], Output],
input: Input,
config: Optional[RunnableConfig],
) -> Output:
from langchain.callbacks.manager import CallbackManager
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
try:
output = func(input)
except Exception as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
class RunnableSequence(Serializable, Runnable[Input, Output]):
first: Runnable[Input, Any]
middle: List[Runnable[Any, Any]] = Field(default_factory=list)
last: Runnable[Any, Output]
@property
def steps(self) -> List[Runnable[Any, Any]]:
return [self.first] + self.middle + [self.last]
@property
def lc_serializable(self) -> bool:
return True
class Config:
arbitrary_types_allowed = True
def __or__(
self,
other: Union[
Runnable[Any, Other],
Dict[str, Union[Runnable[Any, Other], Callable[[Any], Other]]],
],
) -> RunnableSequence[Input, Other]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last] + other.middle,
last=other.last,
)
else:
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last],
last=_coerce_to_runnable(other),
)
def __ror__(
self,
other: Union[
Runnable[Other, Any],
Dict[str, Union[Runnable[Other, Any], Callable[[Other], Any]]],
],
) -> RunnableSequence[Other, Output]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=other.first,
middle=other.middle + [other.last] + self.middle,
last=self.last,
)
else:
return RunnableSequence(
first=_coerce_to_runnable(other),
middle=[self.first] + self.middle,
last=self.last,
)
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke all steps in sequence
try:
for step in self.steps:
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
return cast(Output, input)
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke all steps in sequence
try:
for step in self.steps:
input = await step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
return cast(Output, input)
def batch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers = [
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
]
# invoke
try:
for step in self.steps:
inputs = step.batch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
for rm in run_managers:
rm.on_chain_error(e)
raise
else:
for rm, input in zip(run_managers, inputs):
rm.on_chain_end(input if isinstance(input, dict) else {"output": input})
return cast(List[Output], inputs)
async def abatch(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
max_concurrency: Optional[int] = None,
) -> List[Output]:
from langchain.callbacks.manager import (
AsyncCallbackManager,
AsyncCallbackManagerForChainRun,
)
# setup callbacks
configs = self._get_config_list(config, len(inputs))
callback_managers = [
AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
for config in configs
]
# start the root runs, one per input
run_managers: List[AsyncCallbackManagerForChainRun] = await asyncio.gather(
*(
cm.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
for cm, input in zip(callback_managers, inputs)
)
)
# invoke .batch() on each step
# this uses batching optimizations in Runnable subclasses, like LLM
try:
for step in self.steps:
inputs = await step.abatch(
inputs,
[
# each step a child run of the corresponding root run
_patch_config(config, rm.get_child())
for rm, config in zip(run_managers, configs)
],
max_concurrency=max_concurrency,
)
# finish the root runs
except (KeyboardInterrupt, Exception) as e:
await asyncio.gather(*(rm.on_chain_error(e) for rm in run_managers))
raise
else:
await asyncio.gather(
*(
rm.on_chain_end(
input if isinstance(input, dict) else {"output": input}
)
for rm, input in zip(run_managers, inputs)
)
)
return cast(List[Output], inputs)
def stream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Iterator[Output]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke the first steps
try:
for step in [self.first] + self.middle:
input = step.invoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
# stream the last step
final: Union[Output, None] = None
final_supported = True
try:
for output in self.last.stream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
if final is None:
final = output
else:
try:
final += output # type: ignore[operator]
except TypeError:
final = None
final_supported = False
pass
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(
final if isinstance(final, dict) else {"output": final}
)
async def astream(
self, input: Input, config: Optional[RunnableConfig] = None
) -> AsyncIterator[Output]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input}
)
# invoke the first steps
try:
for step in [self.first] + self.middle:
input = await step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
# stream the last step
final: Union[Output, None] = None
final_supported = True
try:
async for output in self.last.astream(
input,
# mark the last step as a child run
_patch_config(config, run_manager.get_child()),
):
yield output
# Accumulate output if possible, otherwise disable accumulation
if final_supported:
if final is None:
final = output
else:
try:
final += output # type: ignore[operator]
except TypeError:
final = None
final_supported = False
pass
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
final if isinstance(final, dict) else {"output": final}
)
class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
steps: Dict[str, Runnable[Input, Any]]
@property
def lc_serializable(self) -> bool:
return True
class Config:
arbitrary_types_allowed = True
def invoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
from langchain.callbacks.manager import CallbackManager
# setup callbacks
config = config or {}
callback_manager = CallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = callback_manager.on_chain_start(dumpd(self), {"input": input})
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = self.steps.copy()
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(
step.invoke,
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
for step in steps.values()
]
output = {key: future.result() for key, future in zip(steps, futures)}
# finish the root run
except (KeyboardInterrupt, Exception) as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
return output
async def ainvoke(
self, input: Input, config: Optional[RunnableConfig] = None
) -> Dict[str, Any]:
from langchain.callbacks.manager import AsyncCallbackManager
# setup callbacks
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
local_callbacks=None,
verbose=False,
inheritable_tags=config.get("tags"),
local_tags=None,
inheritable_metadata=config.get("metadata"),
local_metadata=None,
)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), {"input": input}
)
# gather results from all steps
try:
# copy to avoid issues from the caller mutating the steps during invoke()
steps = self.steps.copy()
results = await asyncio.gather(
*(
step.ainvoke(
input,
# mark each step as a child run
_patch_config(config, run_manager.get_child()),
)
for step in steps.values()
)
)
output = {key: value for key, value in zip(steps, results)}
# finish the root run
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
return output
class RunnableLambda(Runnable[Input, Output]):
def __init__(self, func: Callable[[Input], Output]) -> None:
if callable(func):
self.func = func
else:
raise TypeError(
"Expected a callable type for `func`."
f"Instead got an unsupported type: {type(func)}"
)
def __eq__(self, other: Any) -> bool:
if isinstance(other, RunnableLambda):
return self.func == other.func
else:
return False
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
return self._call_with_config(self.func, input, config)
class RunnablePassthrough(Serializable, Runnable[Input, Input]):
@property
def lc_serializable(self) -> bool:
return True
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Input:
return self._call_with_config(lambda x: x, input, config)
def _patch_config(
config: RunnableConfig, callback_manager: BaseCallbackManager
) -> RunnableConfig:
config = config.copy()
config["callbacks"] = callback_manager
return config
def _coerce_to_runnable(
thing: Union[
Runnable[Input, Output],
Callable[[Input], Output],
Dict[str, Union[Runnable[Input, Output], Callable[[Input], Output]]],
]
) -> Runnable[Input, Output]:
if isinstance(thing, Runnable):
return thing
elif callable(thing):
return RunnableLambda(thing)
elif isinstance(thing, dict):
runnables = {key: _coerce_to_runnable(r) for key, r in thing.items()}
return cast(Runnable[Input, Output], RunnableMap(steps=runnables))
else:
raise TypeError(
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)

@ -177,3 +177,68 @@ def test_chat_openai_extra_kwargs() -> None:
# Test that "model" cannot be specified in kwargs
with pytest.raises(ValueError):
ChatOpenAI(model_kwargs={"model": "text-davinci-003"})
def test_openai_streaming() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(max_tokens=10)
for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_openai_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = ChatOpenAI(max_tokens=10)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_openai_abatch() -> None:
"""Test streaming tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_openai_abatch_tags() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)
def test_openai_batch() -> None:
"""Test batch tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)
@pytest.mark.asyncio
async def test_openai_ainvoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)
def test_openai_invoke() -> None:
"""Test invoke tokens from ChatOpenAI."""
llm = ChatOpenAI(max_tokens=10)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)

@ -93,7 +93,64 @@ def test_openai_streaming() -> None:
assert isinstance(generator, Generator)
for token in generator:
assert isinstance(token["choices"][0]["text"], str)
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_openai_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_openai_abatch() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_openai_abatch_tags() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)
def test_openai_batch() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)
@pytest.mark.asyncio
async def test_openai_ainvoke() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)
def test_openai_invoke() -> None:
"""Test streaming tokens from OpenAI."""
llm = OpenAI(max_tokens=10)
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result, str)
def test_openai_multiple_prompts() -> None:
@ -105,13 +162,6 @@ def test_openai_multiple_prompts() -> None:
assert len(output.generations) == 2
def test_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = OpenAI(best_of=2)
with pytest.raises(ValueError):
llm.stream("I'm Pickle Rick")
def test_openai_streaming_best_of_error() -> None:
"""Test validation for streaming fails if best_of is not 1."""
with pytest.raises(ValueError):

@ -67,10 +67,3 @@ def test_promptlayer_openai_streaming() -> None:
for token in generator:
assert isinstance(token["choices"][0]["text"], str)
def test_promptlayer_openai_streaming_error() -> None:
"""Test error handling in stream."""
llm = PromptLayerOpenAI(best_of=2)
with pytest.raises(ValueError):
llm.stream("I'm Pickle Rick")

File diff suppressed because one or more lines are too long

@ -0,0 +1,38 @@
from langchain.schema.messages import AIMessageChunk, HumanMessageChunk
def test_message_chunks() -> None:
assert AIMessageChunk(content="I am") + AIMessageChunk(
content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "MessageChunk + MessageChunk should be a MessageChunk"
assert AIMessageChunk(content="I am") + HumanMessageChunk(
content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
assert AIMessageChunk(
content="", additional_kwargs={"foo": "bar"}
) + AIMessageChunk(content="", additional_kwargs={"baz": "foo"}) == AIMessageChunk(
content="", additional_kwargs={"foo": "bar", "baz": "foo"}
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
assert AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "web_search"}}
) + AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
) + AIMessageChunk(
content="",
additional_kwargs={"function_call": {"arguments": ' "query": "turtles"\n}'}},
) == AIMessageChunk(
content="",
additional_kwargs={
"function_call": {
"name": "web_search",
"arguments": '{\n "query": "turtles"\n}',
}
},
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501

@ -0,0 +1,52 @@
from langchain.schema.messages import HumanMessageChunk
from langchain.schema.output import ChatGenerationChunk, GenerationChunk
def test_generation_chunk() -> None:
assert GenerationChunk(text="Hello, ") + GenerationChunk(
text="world!"
) == GenerationChunk(
text="Hello, world!"
), "GenerationChunk + GenerationChunk should be a GenerationChunk"
assert GenerationChunk(text="Hello, ") + GenerationChunk(
text="world!", generation_info={"foo": "bar"}
) == GenerationChunk(
text="Hello, world!", generation_info={"foo": "bar"}
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
assert GenerationChunk(text="Hello, ") + GenerationChunk(
text="world!", generation_info={"foo": "bar"}
) + GenerationChunk(text="!", generation_info={"baz": "foo"}) == GenerationChunk(
text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"}
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
def test_chat_generation_chunk() -> None:
assert ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, ")
) + ChatGenerationChunk(
message=HumanMessageChunk(content="world!")
) == ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!")
), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk"
assert ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, ")
) + ChatGenerationChunk(
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
) == ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!"),
generation_info={"foo": "bar"},
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
assert ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, ")
) + ChatGenerationChunk(
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
) + ChatGenerationChunk(
message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"}
) == ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!!"),
generation_info={"foo": "bar", "baz": "foo"},
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501

@ -0,0 +1,547 @@
from typing import Any, Dict, List, Optional
from uuid import UUID
import pytest
from freezegun import freeze_time
from pytest_mock import MockerFixture
from syrupy import SnapshotAssertion
from langchain.callbacks.manager import Callbacks
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.chat_models.fake import FakeListChatModel
from langchain.llms.fake import FakeListLLM
from langchain.load.dump import dumps
from langchain.output_parsers.list import CommaSeparatedListOutputParser
from langchain.prompts.chat import (
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
)
from langchain.schema.document import Document
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from langchain.schema.retriever import BaseRetriever
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableMap,
RunnablePassthrough,
RunnableSequence,
)
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Run] = []
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
self.runs.append(run)
class FakeRunnable(Runnable[str, int]):
def invoke(
self,
input: str,
config: Optional[RunnableConfig] = None,
) -> int:
return len(input)
class FakeRetriever(BaseRetriever):
def _get_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
return [Document(page_content="foo"), Document(page_content="bar")]
async def _aget_relevant_documents(
self,
query: str,
*,
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
return [Document(page_content="foo"), Document(page_content="bar")]
@pytest.fixture()
def fixed_uuids(mocker: MockerFixture) -> MockerFixture._Patcher:
"""Note this mock only works with `import uuid; uuid.uuid4()`,
it does not work with `from uuid import uuid4; uuid4()`."""
# Disable tracing to avoid fixed UUIDs causing tracing errors.
mocker.patch.dict("os.environ", {"LANGCHAIN_TRACING_V2": "false"})
side_effect = (
UUID(f"00000000-0000-4000-8000-{i:012}", version=4) for i in range(10000)
)
return mocker.patch("uuid.uuid4", side_effect=side_effect)
@pytest.mark.asyncio
async def test_default_method_implementations(mocker: MockerFixture) -> None:
fake = FakeRunnable()
spy = mocker.spy(fake, "invoke")
assert fake.invoke("hello", dict(tags=["a-tag"])) == 5
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
]
spy.reset_mock()
assert [*fake.stream("hello", dict(metadata={"key": "value"}))] == [5]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"key": "value"})),
]
spy.reset_mock()
assert fake.batch(
["hello", "wooorld"], [dict(tags=["a-tag"]), dict(metadata={"key": "value"})]
) == [5, 7]
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
mocker.call("wooorld", dict(metadata={"key": "value"})),
]
spy.reset_mock()
assert fake.batch(["hello", "wooorld"], dict(tags=["a-tag"])) == [5, 7]
assert spy.call_args_list == [
mocker.call("hello", dict(tags=["a-tag"])),
mocker.call("wooorld", dict(tags=["a-tag"])),
]
spy.reset_mock()
assert await fake.ainvoke("hello", config={"callbacks": []}) == 5
assert spy.call_args_list == [
mocker.call("hello", dict(callbacks=[])),
]
spy.reset_mock()
assert [part async for part in fake.astream("hello")] == [5]
assert spy.call_args_list == [
mocker.call("hello", None),
]
spy.reset_mock()
assert await fake.abatch(["hello", "wooorld"], dict(metadata={"key": "value"})) == [
5,
7,
]
assert spy.call_args_list == [
mocker.call("hello", dict(metadata={"key": "value"})),
mocker.call("wooorld", dict(metadata={"key": "value"})),
]
@pytest.mark.asyncio
async def test_prompt() -> None:
prompt = ChatPromptTemplate.from_messages(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessagePromptTemplate.from_template("{question}"),
]
)
expected = ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert prompt.invoke({"question": "What is your name?"}) == expected
assert prompt.batch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
) == [
expected,
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert [*prompt.stream({"question": "What is your name?"})] == [expected]
assert await prompt.ainvoke({"question": "What is your name?"}) == expected
assert await prompt.abatch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
) == [
expected,
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert [
part async for part in prompt.astream({"question": "What is your name?"})
] == [expected]
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_chat_model(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo", "bar"])
chain = prompt | chat
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == []
assert chain.last == chat
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == AIMessage(content="foo")
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(chat_spy)
# Test batch
prompt_spy = mocker.spy(prompt.__class__, "batch")
chat_spy = mocker.spy(chat.__class__, "batch")
tracer = FakeTracer()
assert chain.batch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
) == [
AIMessage(content="bar"),
AIMessage(content="foo"),
]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
assert chat_spy.call_args.args[1] == [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(chat_spy)
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "stream")
tracer = FakeTracer()
assert [
*chain.stream({"question": "What is your name?"}, dict(callbacks=[tracer]))
] == [AIMessage(content="bar")]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
@pytest.mark.asyncio
@freeze_time("2023-01-01")
async def test_prompt_with_llm(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
llm = FakeListLLM(responses=["foo", "bar"])
chain = prompt | llm
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == []
assert chain.last == llm
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "ainvoke")
tracer = FakeTracer()
assert (
await chain.ainvoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
== "foo"
)
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
# Test batch
prompt_spy = mocker.spy(prompt.__class__, "abatch")
llm_spy = mocker.spy(llm.__class__, "abatch")
tracer = FakeTracer()
assert await chain.abatch(
[
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
],
dict(callbacks=[tracer]),
) == ["bar", "foo"]
assert prompt_spy.call_args.args[1] == [
{"question": "What is your name?"},
{"question": "What is your favorite color?"},
]
assert llm_spy.call_args.args[1] == [
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
),
ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your favorite color?"),
]
),
]
assert tracer.runs == snapshot
mocker.stop(prompt_spy)
mocker.stop(llm_spy)
# Test stream
prompt_spy = mocker.spy(prompt.__class__, "ainvoke")
llm_spy = mocker.spy(llm.__class__, "astream")
tracer = FakeTracer()
assert [
token
async for token in chain.astream(
{"question": "What is your name?"}, dict(callbacks=[tracer])
)
] == ["bar"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
@freeze_time("2023-01-01")
def test_prompt_with_chat_model_and_parser(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo, bar"])
parser = CommaSeparatedListOutputParser()
chain = prompt | chat | parser
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [chat]
assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
parser_spy = mocker.spy(parser.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == ["foo", "bar"]
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
def test_seq_dict_prompt_llm(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x)
retriever = FakeRetriever()
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ """Context:
{documents}
Question:
{question}"""
)
chat = FakeListChatModel(responses=["foo, bar"])
parser = CommaSeparatedListOutputParser()
chain = (
{
"question": RunnablePassthrough[str]() | passthrough,
"documents": passthrough | retriever,
"just_to_test_lambda": passthrough,
}
| prompt
| chat
| parser
)
assert isinstance(chain, RunnableSequence)
assert isinstance(chain.first, RunnableMap)
assert chain.middle == [prompt, chat]
assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
parser_spy = mocker.spy(parser.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke("What is your name?", dict(callbacks=[tracer])) == [
"foo",
"bar",
]
assert prompt_spy.call_args.args[1] == {
"documents": [Document(page_content="foo"), Document(page_content="bar")],
"question": "What is your name?",
"just_to_test_lambda": "What is your name?",
}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(
content="""Context:
[Document(page_content='foo', metadata={}), Document(page_content='bar', metadata={})]
Question:
What is your name?"""
),
]
)
assert parser_spy.call_args.args[1] == AIMessage(content="foo, bar")
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
def test_seq_prompt_dict(
mocker: MockerFixture, snapshot: SnapshotAssertion, fixed_uuids: None
) -> None:
passthrough = mocker.Mock(side_effect=lambda x: x)
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["i'm a chatbot"])
llm = FakeListLLM(responses=["i'm a textbot"])
chain = (
prompt
| passthrough
| { # type: ignore
"chat": chat,
"llm": llm,
}
)
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [RunnableLambda(passthrough)]
assert isinstance(chain.last, RunnableMap)
assert dumps(chain, pretty=True) == snapshot
# Test invoke
prompt_spy = mocker.spy(prompt.__class__, "invoke")
chat_spy = mocker.spy(chat.__class__, "invoke")
llm_spy = mocker.spy(llm.__class__, "invoke")
tracer = FakeTracer()
assert chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == {
"chat": AIMessage(content="i'm a chatbot"),
"llm": "i'm a textbot",
}
assert prompt_spy.call_args.args[1] == {"question": "What is your name?"}
assert chat_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert llm_spy.call_args.args[1] == ChatPromptValue(
messages=[
SystemMessage(content="You are a nice assistant."),
HumanMessage(content="What is your name?"),
]
)
assert tracer.runs == snapshot
Loading…
Cancel
Save