You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
langchain/libs/partners/anthropic/langchain_anthropic/chat_models.py

188 lines
6.3 KiB
Python

import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple
import anthropic
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
_message_type_lookups = {"human": "user", "assistant": "ai"}
def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for anthropic."""
"""
[
{
"role": _message_type_lookups[m.type],
"content": [_AnthropicMessageContent(text=m.content).dict()],
}
for m in messages
]
"""
system = None
formatted_messages = []
for i, message in enumerate(messages):
if not isinstance(message.content, str):
raise ValueError("Anthropic Messages API only supports text generation.")
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
system = message.content
else:
formatted_messages.append(
{
"role": _message_type_lookups[message.type],
"content": message.content,
}
)
return system, formatted_messages
class ChatAnthropicMessages(BaseChatModel):
"""Beta ChatAnthropicMessages chat model.
Example:
.. code-block:: python
from langchain_anthropic import ChatAnthropicMessages
model = ChatAnthropicMessages()
"""
_client: anthropic.Client = Field(default_factory=anthropic.Client)
_async_client: anthropic.AsyncClient = Field(default_factory=anthropic.AsyncClient)
model: str = Field(alias="model_name")
"""Model name to use."""
max_tokens: int = Field(default=256)
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
"""A non-negative float that tunes the degree of randomness in generation."""
top_k: Optional[int] = None
"""Number of most likely tokens to consider at each step."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
default_request_timeout: Optional[float] = None
"""Timeout for requests to Anthropic Completion API. Default is 600 seconds."""
anthropic_api_url: str = "https://api.anthropic.com"
anthropic_api_key: Optional[SecretStr] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-anthropic-messages"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
anthropic_api_key = convert_to_secret_str(
values.get("anthropic_api_key") or os.environ.get("ANTHROPIC_API_KEY") or ""
)
values["anthropic_api_key"] = anthropic_api_key
values["_client"] = anthropic.Client(
api_key=anthropic_api_key.get_secret_value()
)
values["_async_client"] = anthropic.AsyncClient(
api_key=anthropic_api_key.get_secret_value()
)
return values
def _format_params(
self,
*,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Dict,
) -> Dict:
# get system prompt if any
system, formatted_messages = _format_messages(messages)
rtn = {
"model": self.model,
"max_tokens": self.max_tokens,
"messages": formatted_messages,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"stop_sequences": stop,
"system": system,
}
rtn = {k: v for k, v in rtn.items() if v is not None}
return rtn
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self._client.beta.messages.stream(**params) as stream:
for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self._async_client.beta.messages.stream(**params) as stream:
async for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
data = self._client.beta.messages.create(**params)
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=data.content[0].text))
],
llm_output=data,
)
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._format_params(messages=messages, stop=stop, **kwargs)
data = await self._async_client.beta.messages.create(**params)
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=data.content[0].text))
],
llm_output=data,
)