anthropic[minor]: package move (#17974)

This commit is contained in:
Erick Friis 2024-02-25 21:57:26 -08:00 committed by GitHub
parent a2d5fa7649
commit 3b5bdbfee8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 886 additions and 57 deletions

View File

@ -1,3 +1,4 @@
from langchain_anthropic.chat_models import ChatAnthropicMessages
from langchain_anthropic.chat_models import ChatAnthropic, ChatAnthropicMessages
from langchain_anthropic.llms import Anthropic, AnthropicLLM
__all__ = ["ChatAnthropicMessages"]
__all__ = ["ChatAnthropicMessages", "ChatAnthropic", "Anthropic", "AnthropicLLM"]

View File

@ -2,6 +2,7 @@ import os
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple
import anthropic
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -14,7 +15,11 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils import (
build_extra_kwargs,
convert_to_secret_str,
get_pydantic_field_names,
)
_message_type_lookups = {"human": "user", "ai": "assistant"}
@ -50,7 +55,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
return system, formatted_messages
class ChatAnthropicMessages(BaseChatModel):
class ChatAnthropic(BaseChatModel):
"""ChatAnthropicMessages chat model.
Example:
@ -61,13 +66,18 @@ class ChatAnthropicMessages(BaseChatModel):
model = ChatAnthropicMessages()
"""
_client: anthropic.Client = Field(default_factory=anthropic.Client)
_async_client: anthropic.AsyncClient = Field(default_factory=anthropic.AsyncClient)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
_client: anthropic.Client = Field(default=None)
_async_client: anthropic.AsyncClient = Field(default=None)
model: str = Field(alias="model_name")
"""Model name to use."""
max_tokens: int = Field(default=256)
max_tokens: int = Field(default=256, alias="max_tokens_to_sample")
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
@ -88,16 +98,20 @@ class ChatAnthropicMessages(BaseChatModel):
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-anthropic-messages"
@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
anthropic_api_key = convert_to_secret_str(
@ -130,6 +144,7 @@ class ChatAnthropicMessages(BaseChatModel):
"top_p": self.top_p,
"stop_sequences": stop,
"system": system,
**self.model_kwargs,
}
rtn = {k: v for k, v in rtn.items() if v is not None}
@ -145,7 +160,10 @@ class ChatAnthropicMessages(BaseChatModel):
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
async def _astream(
self,
@ -157,7 +175,10 @@ class ChatAnthropicMessages(BaseChatModel):
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk
def _generate(
self,
@ -190,3 +211,8 @@ class ChatAnthropicMessages(BaseChatModel):
],
llm_output=data,
)
@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):
pass

View File

@ -0,0 +1,352 @@
import re
import warnings
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
)
import anthropic
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str
class _AnthropicCommon(BaseLanguageModel):
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model: str = Field(default="claude-2", alias="model_name")
"""Model name to use."""
max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
"""Denotes the number of tokens to predict per generation."""
temperature: Optional[float] = None
"""A non-negative float that tunes the degree of randomness in generation."""
top_k: Optional[int] = None
"""Number of most likely tokens to consider at each step."""
top_p: Optional[float] = None
"""Total probability mass of tokens to consider at each step."""
streaming: bool = False
"""Whether to stream the results."""
default_request_timeout: Optional[float] = None
"""Timeout for requests to Anthropic Completion API. Default is 600 seconds."""
max_retries: int = 2
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
anthropic_api_url: Optional[str] = None
anthropic_api_key: Optional[SecretStr] = None
HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anthropic_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
)
# Get custom api url from environment.
values["anthropic_api_url"] = get_from_dict_or_env(
values,
"anthropic_api_url",
"ANTHROPIC_API_URL",
default="https://api.anthropic.com",
)
values["client"] = anthropic.Anthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"].get_secret_value(),
timeout=values["default_request_timeout"],
max_retries=values["max_retries"],
)
values["async_client"] = anthropic.AsyncAnthropic(
base_url=values["anthropic_api_url"],
api_key=values["anthropic_api_key"].get_secret_value(),
timeout=values["default_request_timeout"],
max_retries=values["max_retries"],
)
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens
return values
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling Anthropic API."""
d = {
"max_tokens_to_sample": self.max_tokens_to_sample,
"model": self.model,
}
if self.temperature is not None:
d["temperature"] = self.temperature
if self.top_k is not None:
d["top_k"] = self.top_k
if self.top_p is not None:
d["top_p"] = self.top_p
return {**d, **self.model_kwargs}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{}, **self._default_params}
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
if stop is None:
stop = []
# Never want model to invent new turns of Human / Assistant dialog.
stop.extend([self.HUMAN_PROMPT])
return stop
class AnthropicLLM(LLM, _AnthropicCommon):
"""Anthropic large language models.
To use, you should have the ``anthropic`` python package installed, and the
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
import anthropic
from langchain_community.llms import Anthropic
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
# and AI_PROMPT.
response = model("What are the biggest risks facing humanity?")
# Or if you want to use the chat mode, build a few-shot-prompt, or
# put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT:
raw_prompt = "What are the biggest risks facing humanity?"
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
response = model(prompt)
"""
class Config:
"""Configuration for this pydantic object."""
allow_population_by_field_name = True
arbitrary_types_allowed = True
@root_validator()
def raise_warning(cls, values: Dict) -> Dict:
"""Raise warning that this class is deprecated."""
warnings.warn(
"This Anthropic LLM is deprecated. "
"Please use `from langchain_community.chat_models import ChatAnthropic` "
"instead"
)
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "anthropic-llm"
def _wrap_prompt(self, prompt: str) -> str:
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
raise NameError("Please ensure the anthropic package is loaded")
if prompt.startswith(self.HUMAN_PROMPT):
return prompt # Already wrapped.
# Guard against common errors in specifying wrong number of newlines.
corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt)
if n_subs == 1:
return corrected_prompt
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to Anthropic's completion endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
prompt = "What are the biggest risks facing humanity?"
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
response = model(prompt)
"""
if self.streaming:
completion = ""
for chunk in self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
response = self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**params,
)
return response.completion
def convert_prompt(self, prompt: PromptValue) -> str:
return self._wrap_prompt(prompt.to_string())
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Anthropic's completion endpoint asynchronously."""
if self.streaming:
completion = ""
async for chunk in self._astream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
):
completion += chunk.text
return completion
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
response = await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**params,
)
return response.completion
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
r"""Call Anthropic completion_stream and return the resulting generator.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Anthropic.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = anthropic.stream(prompt)
for token in generator:
yield token
"""
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
for token in self.client.completions.create(
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
):
chunk = GenerationChunk(text=token.completion)
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
r"""Call Anthropic completion_stream and return the resulting generator.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens from Anthropic.
Example:
.. code-block:: python
prompt = "Write a poem about a stream."
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
generator = anthropic.stream(prompt)
for token in generator:
yield token
"""
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
async for token in await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
):
chunk = GenerationChunk(text=token.completion)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)
@deprecated(since="0.1.0", removal="0.2.0", alternative="AnthropicLLM")
class Anthropic(AnthropicLLM):
pass

View File

@ -40,13 +40,13 @@ vertex = ["google-auth (>=2,<3)"]
[[package]]
name = "anyio"
version = "4.2.0"
version = "4.3.0"
description = "High level compatibility layer for multiple asynchronous event loop implementations"
optional = false
python-versions = ">=3.8"
files = [
{file = "anyio-4.2.0-py3-none-any.whl", hash = "sha256:745843b39e829e108e518c489b31dc757de7d2131d53fac32bd8df268227bfee"},
{file = "anyio-4.2.0.tar.gz", hash = "sha256:e1875bb4b4e2de1669f4bc7869b6d3f54231cdced71605e6e64c9be77e3be50f"},
{file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"},
{file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"},
]
[package.dependencies]
@ -301,13 +301,13 @@ files = [
[[package]]
name = "httpcore"
version = "1.0.3"
version = "1.0.4"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpcore-1.0.3-py3-none-any.whl", hash = "sha256:9a6a501c3099307d9fd76ac244e08503427679b1e81ceb1d922485e2f2462ad2"},
{file = "httpcore-1.0.3.tar.gz", hash = "sha256:5c0f9546ad17dac4d0772b0808856eb616eb8b48ce94f49ed819fd6982a8a544"},
{file = "httpcore-1.0.4-py3-none-any.whl", hash = "sha256:ac418c1db41bade2ad53ae2f3834a3a0f5ae76b56cf5aa497d2d033384fc7d73"},
{file = "httpcore-1.0.4.tar.gz", hash = "sha256:cb2839ccfcba0d2d3c1131d3c3e26dfc327326fbe7a5dc0dbfe9f6c9151bb022"},
]
[package.dependencies]
@ -318,17 +318,17 @@ h11 = ">=0.13,<0.15"
asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.24.0)"]
trio = ["trio (>=0.22.0,<0.25.0)"]
[[package]]
name = "httpx"
version = "0.26.0"
version = "0.27.0"
description = "The next generation HTTP client."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpx-0.26.0-py3-none-any.whl", hash = "sha256:8915f5a3627c4d47b73e8202457cb28f1266982d1159bd5779d86a80c0eab1cd"},
{file = "httpx-0.26.0.tar.gz", hash = "sha256:451b55c30d5185ea6b23c2c793abf9bb237d2a7dfb901ced6ff69ad37ec1dfaf"},
{file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"},
{file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"},
]
[package.dependencies]
@ -425,7 +425,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.23"
version = "0.1.25"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -435,7 +435,7 @@ develop = true
[package.dependencies]
anyio = ">=3,<5"
jsonpatch = "^1.33"
langsmith = "^0.0.87"
langsmith = "^0.1.0"
packaging = "^23.2"
pydantic = ">=1,<3"
PyYAML = ">=5.3"
@ -451,13 +451,13 @@ url = "../../core"
[[package]]
name = "langsmith"
version = "0.0.87"
version = "0.1.5"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langsmith-0.0.87-py3-none-any.whl", hash = "sha256:8903d3811b9fc89eb18f5961c8e6935fbd2d0f119884fbf30dc70b8f8f4121fc"},
{file = "langsmith-0.0.87.tar.gz", hash = "sha256:36c4cc47e5b54be57d038036a30fb19ce6e4c73048cd7a464b8f25b459694d34"},
{file = "langsmith-0.1.5-py3-none-any.whl", hash = "sha256:a1811821a923d90e53bcbacdd0988c3c366aff8f4c120d8777e7af8ecda06268"},
{file = "langsmith-0.1.5.tar.gz", hash = "sha256:aa7a2861aa3d9ae563a077c622953533800466c4e2e539b0d567b84d5fd5b157"},
]
[package.dependencies]
@ -830,28 +830,28 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "ruff"
version = "0.1.15"
version = "0.2.2"
description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false
python-versions = ">=3.7"
files = [
{file = "ruff-0.1.15-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5fe8d54df166ecc24106db7dd6a68d44852d14eb0729ea4672bb4d96c320b7df"},
{file = "ruff-0.1.15-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6f0bfbb53c4b4de117ac4d6ddfd33aa5fc31beeaa21d23c45c6dd249faf9126f"},
{file = "ruff-0.1.15-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0d432aec35bfc0d800d4f70eba26e23a352386be3a6cf157083d18f6f5881c8"},
{file = "ruff-0.1.15-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9405fa9ac0e97f35aaddf185a1be194a589424b8713e3b97b762336ec79ff807"},
{file = "ruff-0.1.15-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c66ec24fe36841636e814b8f90f572a8c0cb0e54d8b5c2d0e300d28a0d7bffec"},
{file = "ruff-0.1.15-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6f8ad828f01e8dd32cc58bc28375150171d198491fc901f6f98d2a39ba8e3ff5"},
{file = "ruff-0.1.15-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86811954eec63e9ea162af0ffa9f8d09088bab51b7438e8b6488b9401863c25e"},
{file = "ruff-0.1.15-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd4025ac5e87d9b80e1f300207eb2fd099ff8200fa2320d7dc066a3f4622dc6b"},
{file = "ruff-0.1.15-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b17b93c02cdb6aeb696effecea1095ac93f3884a49a554a9afa76bb125c114c1"},
{file = "ruff-0.1.15-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ddb87643be40f034e97e97f5bc2ef7ce39de20e34608f3f829db727a93fb82c5"},
{file = "ruff-0.1.15-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:abf4822129ed3a5ce54383d5f0e964e7fef74a41e48eb1dfad404151efc130a2"},
{file = "ruff-0.1.15-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6c629cf64bacfd136c07c78ac10a54578ec9d1bd2a9d395efbee0935868bf852"},
{file = "ruff-0.1.15-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1bab866aafb53da39c2cadfb8e1c4550ac5340bb40300083eb8967ba25481447"},
{file = "ruff-0.1.15-py3-none-win32.whl", hash = "sha256:2417e1cb6e2068389b07e6fa74c306b2810fe3ee3476d5b8a96616633f40d14f"},
{file = "ruff-0.1.15-py3-none-win_amd64.whl", hash = "sha256:3837ac73d869efc4182d9036b1405ef4c73d9b1f88da2413875e34e0d6919587"},
{file = "ruff-0.1.15-py3-none-win_arm64.whl", hash = "sha256:9a933dfb1c14ec7a33cceb1e49ec4a16b51ce3c20fd42663198746efc0427360"},
{file = "ruff-0.1.15.tar.gz", hash = "sha256:f6dfa8c1b21c913c326919056c390966648b680966febcb796cc9d1aaab8564e"},
{file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:0a9efb032855ffb3c21f6405751d5e147b0c6b631e3ca3f6b20f917572b97eb6"},
{file = "ruff-0.2.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d450b7fbff85913f866a5384d8912710936e2b96da74541c82c1b458472ddb39"},
{file = "ruff-0.2.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ecd46e3106850a5c26aee114e562c329f9a1fbe9e4821b008c4404f64ff9ce73"},
{file = "ruff-0.2.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5e22676a5b875bd72acd3d11d5fa9075d3a5f53b877fe7b4793e4673499318ba"},
{file = "ruff-0.2.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1695700d1e25a99d28f7a1636d85bafcc5030bba9d0578c0781ba1790dbcf51c"},
{file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:b0c232af3d0bd8f521806223723456ffebf8e323bd1e4e82b0befb20ba18388e"},
{file = "ruff-0.2.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f63d96494eeec2fc70d909393bcd76c69f35334cdbd9e20d089fb3f0640216ca"},
{file = "ruff-0.2.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a61ea0ff048e06de273b2e45bd72629f470f5da8f71daf09fe481278b175001"},
{file = "ruff-0.2.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e1439c8f407e4f356470e54cdecdca1bd5439a0673792dbe34a2b0a551a2fe3"},
{file = "ruff-0.2.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:940de32dc8853eba0f67f7198b3e79bc6ba95c2edbfdfac2144c8235114d6726"},
{file = "ruff-0.2.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0c126da55c38dd917621552ab430213bdb3273bb10ddb67bc4b761989210eb6e"},
{file = "ruff-0.2.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3b65494f7e4bed2e74110dac1f0d17dc8e1f42faaa784e7c58a98e335ec83d7e"},
{file = "ruff-0.2.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:1ec49be4fe6ddac0503833f3ed8930528e26d1e60ad35c2446da372d16651ce9"},
{file = "ruff-0.2.2-py3-none-win32.whl", hash = "sha256:d920499b576f6c68295bc04e7b17b6544d9d05f196bb3aac4358792ef6f34325"},
{file = "ruff-0.2.2-py3-none-win_amd64.whl", hash = "sha256:cc9a91ae137d687f43a44c900e5d95e9617cb37d4c989e462980ba27039d239d"},
{file = "ruff-0.2.2-py3-none-win_arm64.whl", hash = "sha256:c9d15fc41e6054bfc7200478720570078f0b41c9ae4f010bcc16bd6f4d1aacdd"},
{file = "ruff-0.2.2.tar.gz", hash = "sha256:e62ed7f36b3068a30ba39193a14274cd706bc486fad521276458022f7bccb31d"},
]
[[package]]
@ -1075,13 +1075,13 @@ files = [
[[package]]
name = "urllib3"
version = "2.2.0"
version = "2.2.1"
description = "HTTP library with thread-safe connection pooling, file post, and more."
optional = false
python-versions = ">=3.8"
files = [
{file = "urllib3-2.2.0-py3-none-any.whl", hash = "sha256:ce3711610ddce217e6d113a2732fafad960a03fd0318c91faa79481e35c11224"},
{file = "urllib3-2.2.0.tar.gz", hash = "sha256:051d961ad0c62a94e50ecf1af379c3aba230c66c710493493560c0c223c49f20"},
{file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"},
{file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"},
]
[package.extras]
@ -1134,4 +1134,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "c4d03a1586b121b905ea4c0f86d04427cbb3e155e60d67c4b3351186de0d540a"
content-hash = "e88b90ae60758bdab0fe844948c5b9a45bb7f3a96e9ab31ee6d56a7ebd24bfde"

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-anthropic"
version = "0.0.2"
version = "0.1.0"
description = "An integration package connecting AnthropicMessages and LangChain"
authors = []
readme = "README.md"
@ -37,7 +37,8 @@ codespell = "^2.2.0"
optional = true
[tool.poetry.group.lint.dependencies]
ruff = "^0.1.5"
ruff = ">=0.2.2,<1"
mypy = "^0.991"
[tool.poetry.group.typing.dependencies]
mypy = "^0.991"

View File

@ -1,7 +1,14 @@
"""Test ChatAnthropicMessages chat model."""
"""Test ChatAnthropic chat model."""
from typing import List
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_anthropic.chat_models import ChatAnthropicMessages
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
from tests.unit_tests._utils import FakeCallbackHandler
def test_stream() -> None:
@ -84,3 +91,72 @@ def test_system_invoke() -> None:
result = chain.invoke({})
assert isinstance(result.content, str)
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
chat = ChatAnthropic(model="test")
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
def test_anthropic_generate() -> None:
"""Test generate method of anthropic."""
chat = ChatAnthropic(model="test")
chat_messages: List[List[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")]
]
messages_copy = [messages.copy() for messages in chat_messages]
result: LLMResult = chat.generate(chat_messages)
assert isinstance(result, LLMResult)
for response in result.generations[0]:
assert isinstance(response, ChatGeneration)
assert isinstance(response.text, str)
assert response.text == response.message.content
assert chat_messages == messages_copy
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
chat = ChatAnthropic(model="test")
message = HumanMessage(content="Hello")
response = chat.stream([message])
for token in response:
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
def test_anthropic_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
model="test",
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Write me a sentence with 10 words.")
for token in chat.stream([message]):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
assert callback_handler.llm_streams > 1
async def test_anthropic_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
model="test",
callback_manager=callback_manager,
verbose=True,
)
chat_messages: List[BaseMessage] = [
HumanMessage(content="How many toes do dogs have?")
]
async for token in chat.astream(chat_messages):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
assert callback_handler.llm_streams > 1

View File

@ -0,0 +1,74 @@
"""Test Anthropic API wrapper."""
from typing import Generator
import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.outputs import LLMResult
from langchain_anthropic import Anthropic
from tests.unit_tests._utils import FakeCallbackHandler
@pytest.mark.requires("anthropic")
def test_anthropic_model_name_param() -> None:
llm = Anthropic(model_name="foo")
assert llm.model == "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_param() -> None:
llm = Anthropic(model="foo")
assert llm.model == "foo"
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
llm = Anthropic(model="claude-instant-1")
output = llm("Say foo:")
assert isinstance(output, str)
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
llm = Anthropic(model="claude-instant-1")
generator = llm.stream("I'm Pickle Rick")
assert isinstance(generator, Generator)
for token in generator:
assert isinstance(token, str)
def test_anthropic_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Anthropic(
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
llm("Write me a sentence with 100 words.")
assert callback_handler.llm_streams > 1
async def test_anthropic_async_generate() -> None:
"""Test async generate."""
llm = Anthropic()
output = await llm.agenerate(["How many toes do dogs have?"])
assert isinstance(output, LLMResult)
async def test_anthropic_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Anthropic(
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
result = await llm.agenerate(["How many toes do dogs have?"])
assert callback_handler.llm_streams > 1
assert isinstance(result, LLMResult)

View File

@ -0,0 +1,255 @@
"""A fake callback handler for testing purposes."""
from typing import Any, Union
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.pydantic_v1 import BaseModel
class BaseFakeCallbackHandler(BaseModel):
"""Base fake callback handler for testing."""
starts: int = 0
ends: int = 0
errors: int = 0
text: int = 0
ignore_llm_: bool = False
ignore_chain_: bool = False
ignore_agent_: bool = False
ignore_retriever_: bool = False
ignore_chat_model_: bool = False
# to allow for similar callback handlers that are not technicall equal
fake_id: Union[str, None] = None
# add finer-grained counters for easier debugging of failing tests
chain_starts: int = 0
chain_ends: int = 0
llm_starts: int = 0
llm_ends: int = 0
llm_streams: int = 0
tool_starts: int = 0
tool_ends: int = 0
agent_actions: int = 0
agent_ends: int = 0
chat_model_starts: int = 0
retriever_starts: int = 0
retriever_ends: int = 0
retriever_errors: int = 0
retries: int = 0
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
"""Base fake callback handler mixin for testing."""
def on_llm_start_common(self) -> None:
self.llm_starts += 1
self.starts += 1
def on_llm_end_common(self) -> None:
self.llm_ends += 1
self.ends += 1
def on_llm_error_common(self) -> None:
self.errors += 1
def on_llm_new_token_common(self) -> None:
self.llm_streams += 1
def on_retry_common(self) -> None:
self.retries += 1
def on_chain_start_common(self) -> None:
self.chain_starts += 1
self.starts += 1
def on_chain_end_common(self) -> None:
self.chain_ends += 1
self.ends += 1
def on_chain_error_common(self) -> None:
self.errors += 1
def on_tool_start_common(self) -> None:
self.tool_starts += 1
self.starts += 1
def on_tool_end_common(self) -> None:
self.tool_ends += 1
self.ends += 1
def on_tool_error_common(self) -> None:
self.errors += 1
def on_agent_action_common(self) -> None:
self.agent_actions += 1
self.starts += 1
def on_agent_finish_common(self) -> None:
self.agent_ends += 1
self.ends += 1
def on_chat_model_start_common(self) -> None:
self.chat_model_starts += 1
self.starts += 1
def on_text_common(self) -> None:
self.text += 1
def on_retriever_start_common(self) -> None:
self.starts += 1
self.retriever_starts += 1
def on_retriever_end_common(self) -> None:
self.ends += 1
self.retriever_ends += 1
def on_retriever_error_common(self) -> None:
self.errors += 1
self.retriever_errors += 1
class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Fake callback handler for testing."""
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return self.ignore_llm_
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return self.ignore_chain_
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@property
def ignore_retriever(self) -> bool:
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def on_llm_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_start_common()
def on_llm_new_token(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_new_token_common()
def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_end_common()
def on_llm_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_error_common()
def on_retry(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retry_common()
def on_chain_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_start_common()
def on_chain_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_end_common()
def on_chain_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_error_common()
def on_tool_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_start_common()
def on_tool_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_end_common()
def on_tool_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_error_common()
def on_agent_action(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_agent_action_common()
def on_agent_finish(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_agent_finish_common()
def on_text(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_text_common()
def on_retriever_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_start_common()
def on_retriever_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_end_common()
def on_retriever_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
return self

View File

@ -1,10 +1,54 @@
"""Test chat model integration."""
import os
from langchain_anthropic.chat_models import ChatAnthropicMessages
import pytest
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
os.environ["ANTHROPIC_API_KEY"] = "foo"
def test_initialization() -> None:
"""Test chat model initialization."""
ChatAnthropicMessages(model_name="claude-instant-1.2", anthropic_api_key="xyz")
ChatAnthropicMessages(model="claude-instant-1.2", anthropic_api_key="xyz")
@pytest.mark.requires("anthropic")
def test_anthropic_model_name_param() -> None:
llm = ChatAnthropic(model_name="foo")
assert llm.model == "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_param() -> None:
llm = ChatAnthropic(model="foo")
assert llm.model == "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_kwargs() -> None:
llm = ChatAnthropic(model_name="foo", model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("anthropic")
def test_anthropic_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
ChatAnthropic(model="foo", model_kwargs={"max_tokens_to_sample": 5})
@pytest.mark.requires("anthropic")
def test_anthropic_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = ChatAnthropic(model="foo", foo="bar")
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("anthropic")
def test_anthropic_initialization() -> None:
"""Test anthropic initialization."""
# Verify that chat anthropic can be initialized using a secret key provided
# as a parameter rather than an environment variable.
ChatAnthropic(model="test", anthropic_api_key="test")

View File

@ -1,6 +1,6 @@
from langchain_anthropic import __all__
EXPECTED_ALL = ["ChatAnthropicMessages"]
EXPECTED_ALL = ["ChatAnthropicMessages", "ChatAnthropic", "Anthropic", "AnthropicLLM"]
def test_all_imports() -> None: