|
|
|
@ -8,7 +8,11 @@ from langchain_core.callbacks import (
|
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
|
|
|
from langchain_core.language_models.chat_models import (
|
|
|
|
|
BaseChatModel,
|
|
|
|
|
agenerate_from_stream,
|
|
|
|
|
generate_from_stream,
|
|
|
|
|
)
|
|
|
|
|
from langchain_core.messages import (
|
|
|
|
|
AIMessage,
|
|
|
|
|
AIMessageChunk,
|
|
|
|
@ -174,6 +178,9 @@ class ChatAnthropic(BaseChatModel):
|
|
|
|
|
|
|
|
|
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
|
|
|
|
streaming: bool = False
|
|
|
|
|
"""Whether to use streaming or not."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
|
"""Return type of chat model."""
|
|
|
|
@ -271,6 +278,11 @@ class ChatAnthropic(BaseChatModel):
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
if self.streaming:
|
|
|
|
|
stream_iter = self._stream(
|
|
|
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
|
|
|
)
|
|
|
|
|
return generate_from_stream(stream_iter)
|
|
|
|
|
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
|
|
|
|
data = self._client.messages.create(**params)
|
|
|
|
|
return self._format_output(data)
|
|
|
|
@ -282,6 +294,11 @@ class ChatAnthropic(BaseChatModel):
|
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
if self.streaming:
|
|
|
|
|
stream_iter = self._astream(
|
|
|
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
|
|
|
)
|
|
|
|
|
return await agenerate_from_stream(stream_iter)
|
|
|
|
|
params = self._format_params(messages=messages, stop=stop, **kwargs)
|
|
|
|
|
data = await self._async_client.messages.create(**params)
|
|
|
|
|
return self._format_output(data)
|
|
|
|
|