Add Anthropic ChatModel to langchain (#2293)

* Adds an Anthropic ChatModel
* Factors out common code in our LLMModel and ChatModel
* Supports streaming llm-tokens to the callbacks on a delta basis (until
a future V2 API does that for us)
* Some fixes
fix_agent_callbacks
Mike Lambert 1 year ago committed by GitHub
parent 66bef1d7ed
commit 392f1b3218
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,145 @@
from typing import List, Optional
from pydantic import Extra
from langchain.chat_models.base import BaseChatModel
from langchain.llms.anthropic import _AnthropicCommon
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
ChatMessage,
ChatResult,
HumanMessage,
SystemMessage,
)
class ChatAnthropic(BaseChatModel, _AnthropicCommon):
r"""Wrapper around Anthropic's large language model.
To use, you should have the ``anthropic`` python package installed, and the
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
import anthropic
from langchain.llms import Anthropic
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
# and AI_PROMPT.
response = model("What are the biggest risks facing humanity?")
# Or if you want to use the chat mode, build a few-shot-prompt, or
# put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT:
raw_prompt = "What are the biggest risks facing humanity?"
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
response = model(prompt)
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "anthropic-chat"
def _convert_one_message_to_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
message_text = f"{self.HUMAN_PROMPT} {message.content}"
elif isinstance(message, AIMessage):
message_text = f"{self.AI_PROMPT} {message.content}"
elif isinstance(message, SystemMessage):
message_text = f"{self.HUMAN_PROMPT} <admin>{message.content}</admin>"
else:
raise ValueError(f"Got unknown type {message}")
return message_text
def _convert_messages_to_text(self, messages: List[BaseMessage]) -> str:
"""Format a list of strings into a single string with necessary newlines.
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary newlines.
"""
return "".join(
self._convert_one_message_to_text(message) for message in messages
)
def _convert_messages_to_prompt(self, messages: List[BaseMessage]) -> str:
"""Format a list of messages into a full prompt for the Anthropic model
Args:
messages (List[BaseMessage]): List of BaseMessage to combine.
Returns:
str: Combined string with necessary HUMAN_PROMPT and AI_PROMPT tags.
"""
if not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
if not isinstance(messages[-1], AIMessage):
messages.append(AIMessage(content=""))
text = self._convert_messages_to_text(messages)
return (
text.rstrip()
) # trim off the trailing ' ' that might come from the "Assistant: "
def _generate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params = {"prompt": prompt, "stop_sequences": stop, **self._default_params}
if self.streaming:
completion = ""
stream_resp = self.client.completion_stream(**params)
for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
response = self.client.completion(**params)
completion = response["completion"]
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params = {"prompt": prompt, "stop_sequences": stop, **self._default_params}
if self.streaming:
completion = ""
stream_resp = await self.client.acompletion_stream(**params)
async for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
self.callback_manager.on_llm_new_token(
delta,
verbose=self.verbose,
)
else:
response = await self.client.acompletion(**params)
completion = response["completion"]
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])

@ -1,51 +1,28 @@
"""Wrapper around Anthropic APIs.""" """Wrapper around Anthropic APIs."""
import re import re
from typing import Any, Dict, Generator, List, Mapping, Optional from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
from pydantic import Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class Anthropic(LLM): class _AnthropicCommon(BaseModel):
r"""Wrapper around Anthropic large language models. client: Any = None #: :meta private:
model: str = "claude-latest"
To use, you should have the ``anthropic`` python package installed, and the
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
import anthropic
from langchain.llms import Anthropic
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
# and AI_PROMPT.
response = model("What are the biggest risks facing humanity?")
# Or if you want to use the chat mode, build a few-shot-prompt, or
# put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT:
raw_prompt = "What are the biggest risks facing humanity?"
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
response = model(prompt)
"""
client: Any #: :meta private:
model: str = "claude-v1"
"""Model name to use.""" """Model name to use."""
max_tokens_to_sample: int = 256 max_tokens_to_sample: int = 256
"""Denotes the number of tokens to predict per generation.""" """Denotes the number of tokens to predict per generation."""
temperature: float = 1.0 temperature: Optional[float] = None
"""A non-negative float that tunes the degree of randomness in generation.""" """A non-negative float that tunes the degree of randomness in generation."""
top_k: int = 0 top_k: Optional[int] = None
"""Number of most likely tokens to consider at each step.""" """Number of most likely tokens to consider at each step."""
top_p: float = 1 top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step.""" """Total probability mass of tokens to consider at each step."""
streaming: bool = False streaming: bool = False
@ -55,11 +32,7 @@ class Anthropic(LLM):
HUMAN_PROMPT: Optional[str] = None HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
@ -73,32 +46,86 @@ class Anthropic(LLM):
values["client"] = anthropic.Client(anthropic_api_key) values["client"] = anthropic.Client(anthropic_api_key)
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = anthropic.count_tokens
except ImportError: except ImportError:
raise ValueError( raise ValueError(
"Could not import anthropic python package. " "Could not import anthropic python package. "
"Please install it with `pip install anthropic`." "Please it install it with `pip install anthropic`."
) )
return values return values
@property @property
def _default_params(self) -> Mapping[str, Any]: def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling Anthropic API.""" """Get the default parameters for calling Anthropic API."""
return { d = {
"max_tokens_to_sample": self.max_tokens_to_sample, "max_tokens_to_sample": self.max_tokens_to_sample,
"temperature": self.temperature, "model": self.model,
"top_k": self.top_k,
"top_p": self.top_p,
} }
if self.temperature is not None:
d["temperature"] = self.temperature
if self.top_k is not None:
d["top_k"] = self.top_k
if self.top_p is not None:
d["top_p"] = self.top_p
return d
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters.""" """Get the identifying parameters."""
return {**{"model": self.model}, **self._default_params} return {**{}, **self._default_params}
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
if stop is None:
stop = []
# Never want model to invent new turns of Human / Assistant dialog.
stop.extend([self.HUMAN_PROMPT])
return stop
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)
class Anthropic(LLM, _AnthropicCommon):
r"""Wrapper around Anthropic's large language models.
To use, you should have the ``anthropic`` python package installed, and the
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
import anthropic
from langchain.llms import Anthropic
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
# and AI_PROMPT.
response = model("What are the biggest risks facing humanity?")
# Or if you want to use the chat mode, build a few-shot-prompt, or
# put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT:
raw_prompt = "What are the biggest risks facing humanity?"
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
response = model(prompt)
"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""
return "anthropic" return "anthropic-llm"
def _wrap_prompt(self, prompt: str) -> str: def _wrap_prompt(self, prompt: str) -> str:
if not self.HUMAN_PROMPT or not self.AI_PROMPT: if not self.HUMAN_PROMPT or not self.AI_PROMPT:
@ -115,18 +142,6 @@ class Anthropic(LLM):
# As a last resort, wrap the prompt ourselves to emulate instruct-style. # As a last resort, wrap the prompt ourselves to emulate instruct-style.
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n" return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
if stop is None:
stop = []
# Never want model to invent new turns of Human / Assistant dialog.
stop.extend([self.HUMAN_PROMPT, self.AI_PROMPT])
return stop
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
r"""Call out to Anthropic's completion endpoint. r"""Call out to Anthropic's completion endpoint.
@ -148,10 +163,8 @@ class Anthropic(LLM):
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
if self.streaming: if self.streaming:
stream_resp = self.client.completion_stream( stream_resp = self.client.completion_stream(
model=self.model,
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
stream=True,
**self._default_params, **self._default_params,
) )
current_completion = "" current_completion = ""
@ -163,7 +176,6 @@ class Anthropic(LLM):
) )
return current_completion return current_completion
response = self.client.completion( response = self.client.completion(
model=self.model,
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **self._default_params,
@ -175,10 +187,8 @@ class Anthropic(LLM):
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
if self.streaming: if self.streaming:
stream_resp = await self.client.acompletion_stream( stream_resp = await self.client.acompletion_stream(
model=self.model,
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
stream=True,
**self._default_params, **self._default_params,
) )
current_completion = "" current_completion = ""
@ -195,7 +205,6 @@ class Anthropic(LLM):
) )
return current_completion return current_completion
response = await self.client.acompletion( response = await self.client.acompletion(
model=self.model,
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **self._default_params,
@ -227,7 +236,6 @@ class Anthropic(LLM):
""" """
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
return self.client.completion_stream( return self.client.completion_stream(
model=self.model,
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **self._default_params,

@ -0,0 +1,81 @@
"""Test Anthropic API wrapper."""
from typing import List
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.chat_models.anthropic import ChatAnthropic
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
chat = ChatAnthropic(model="bare-nano-0")
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
chat = ChatAnthropic(model="bare-nano-0", streaming=True)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_anthropic_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Write me a sentence with 100 words.")
chat([message])
assert callback_handler.llm_streams > 1
@pytest.mark.asyncio
async def test_anthropic_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
chat_messages: List[BaseMessage] = [
HumanMessage(content="How many toes do dogs have?")
]
result: LLMResult = await chat.agenerate([chat_messages])
assert callback_handler.llm_streams > 1
assert isinstance(result, LLMResult)
for response in result.generations[0]:
assert isinstance(response, ChatGeneration)
assert isinstance(response.text, str)
assert response.text == response.message.content
def test_formatting() -> None:
chat = ChatAnthropic()
chat_messages: List[BaseMessage] = [HumanMessage(content="Hello")]
result = chat._convert_messages_to_prompt(chat_messages)
assert result == "\n\nHuman: Hello\n\nAssistant:"
chat_messages = [HumanMessage(content="Hello"), AIMessage(content="Answer:")]
result = chat._convert_messages_to_prompt(chat_messages)
assert result == "\n\nHuman: Hello\n\nAssistant: Answer:"

@ -32,7 +32,6 @@ def test_anthropic_streaming_callback() -> None:
callback_handler = FakeCallbackHandler() callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler]) callback_manager = CallbackManager([callback_handler])
llm = Anthropic( llm = Anthropic(
model="claude-v1",
streaming=True, streaming=True,
callback_manager=callback_manager, callback_manager=callback_manager,
verbose=True, verbose=True,
@ -55,7 +54,6 @@ async def test_anthropic_async_streaming_callback() -> None:
callback_handler = FakeCallbackHandler() callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler]) callback_manager = CallbackManager([callback_handler])
llm = Anthropic( llm = Anthropic(
model="claude-v1",
streaming=True, streaming=True,
callback_manager=callback_manager, callback_manager=callback_manager,
verbose=True, verbose=True,

Loading…
Cancel
Save