mirror of https://github.com/hwchase17/langchain
anthropic[minor]: package move (#17974)
parent
a2d5fa7649
commit
3b5bdbfee8
@ -1,3 +1,4 @@
|
|||||||
from langchain_anthropic.chat_models import ChatAnthropicMessages
|
from langchain_anthropic.chat_models import ChatAnthropic, ChatAnthropicMessages
|
||||||
|
from langchain_anthropic.llms import Anthropic, AnthropicLLM
|
||||||
|
|
||||||
__all__ = ["ChatAnthropicMessages"]
|
__all__ = ["ChatAnthropicMessages", "ChatAnthropic", "Anthropic", "AnthropicLLM"]
|
||||||
|
@ -0,0 +1,352 @@
|
|||||||
|
import re
|
||||||
|
import warnings
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
from langchain_core._api.deprecation import deprecated
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
|
from langchain_core.language_models.llms import LLM
|
||||||
|
from langchain_core.outputs import GenerationChunk
|
||||||
|
from langchain_core.prompt_values import PromptValue
|
||||||
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||||
|
from langchain_core.utils import (
|
||||||
|
get_from_dict_or_env,
|
||||||
|
get_pydantic_field_names,
|
||||||
|
)
|
||||||
|
from langchain_core.utils.utils import build_extra_kwargs, convert_to_secret_str
|
||||||
|
|
||||||
|
|
||||||
|
class _AnthropicCommon(BaseLanguageModel):
|
||||||
|
client: Any = None #: :meta private:
|
||||||
|
async_client: Any = None #: :meta private:
|
||||||
|
model: str = Field(default="claude-2", alias="model_name")
|
||||||
|
"""Model name to use."""
|
||||||
|
|
||||||
|
max_tokens_to_sample: int = Field(default=256, alias="max_tokens")
|
||||||
|
"""Denotes the number of tokens to predict per generation."""
|
||||||
|
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||||
|
|
||||||
|
top_k: Optional[int] = None
|
||||||
|
"""Number of most likely tokens to consider at each step."""
|
||||||
|
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
"""Total probability mass of tokens to consider at each step."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
|
default_request_timeout: Optional[float] = None
|
||||||
|
"""Timeout for requests to Anthropic Completion API. Default is 600 seconds."""
|
||||||
|
|
||||||
|
max_retries: int = 2
|
||||||
|
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
||||||
|
|
||||||
|
anthropic_api_url: Optional[str] = None
|
||||||
|
|
||||||
|
anthropic_api_key: Optional[SecretStr] = None
|
||||||
|
|
||||||
|
HUMAN_PROMPT: Optional[str] = None
|
||||||
|
AI_PROMPT: Optional[str] = None
|
||||||
|
count_tokens: Optional[Callable[[str], int]] = None
|
||||||
|
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def build_extra(cls, values: Dict) -> Dict:
|
||||||
|
extra = values.get("model_kwargs", {})
|
||||||
|
all_required_field_names = get_pydantic_field_names(cls)
|
||||||
|
values["model_kwargs"] = build_extra_kwargs(
|
||||||
|
extra, values, all_required_field_names
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key and python package exists in environment."""
|
||||||
|
values["anthropic_api_key"] = convert_to_secret_str(
|
||||||
|
get_from_dict_or_env(values, "anthropic_api_key", "ANTHROPIC_API_KEY")
|
||||||
|
)
|
||||||
|
# Get custom api url from environment.
|
||||||
|
values["anthropic_api_url"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"anthropic_api_url",
|
||||||
|
"ANTHROPIC_API_URL",
|
||||||
|
default="https://api.anthropic.com",
|
||||||
|
)
|
||||||
|
|
||||||
|
values["client"] = anthropic.Anthropic(
|
||||||
|
base_url=values["anthropic_api_url"],
|
||||||
|
api_key=values["anthropic_api_key"].get_secret_value(),
|
||||||
|
timeout=values["default_request_timeout"],
|
||||||
|
max_retries=values["max_retries"],
|
||||||
|
)
|
||||||
|
values["async_client"] = anthropic.AsyncAnthropic(
|
||||||
|
base_url=values["anthropic_api_url"],
|
||||||
|
api_key=values["anthropic_api_key"].get_secret_value(),
|
||||||
|
timeout=values["default_request_timeout"],
|
||||||
|
max_retries=values["max_retries"],
|
||||||
|
)
|
||||||
|
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
|
||||||
|
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
||||||
|
values["count_tokens"] = values["client"].count_tokens
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _default_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the default parameters for calling Anthropic API."""
|
||||||
|
d = {
|
||||||
|
"max_tokens_to_sample": self.max_tokens_to_sample,
|
||||||
|
"model": self.model,
|
||||||
|
}
|
||||||
|
if self.temperature is not None:
|
||||||
|
d["temperature"] = self.temperature
|
||||||
|
if self.top_k is not None:
|
||||||
|
d["top_k"] = self.top_k
|
||||||
|
if self.top_p is not None:
|
||||||
|
d["top_p"] = self.top_p
|
||||||
|
return {**d, **self.model_kwargs}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
"""Get the identifying parameters."""
|
||||||
|
return {**{}, **self._default_params}
|
||||||
|
|
||||||
|
def _get_anthropic_stop(self, stop: Optional[List[str]] = None) -> List[str]:
|
||||||
|
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
|
||||||
|
raise NameError("Please ensure the anthropic package is loaded")
|
||||||
|
|
||||||
|
if stop is None:
|
||||||
|
stop = []
|
||||||
|
|
||||||
|
# Never want model to invent new turns of Human / Assistant dialog.
|
||||||
|
stop.extend([self.HUMAN_PROMPT])
|
||||||
|
|
||||||
|
return stop
|
||||||
|
|
||||||
|
|
||||||
|
class AnthropicLLM(LLM, _AnthropicCommon):
|
||||||
|
"""Anthropic large language models.
|
||||||
|
|
||||||
|
To use, you should have the ``anthropic`` python package installed, and the
|
||||||
|
environment variable ``ANTHROPIC_API_KEY`` set with your API key, or pass
|
||||||
|
it as a named parameter to the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import anthropic
|
||||||
|
from langchain_community.llms import Anthropic
|
||||||
|
|
||||||
|
model = Anthropic(model="<model_name>", anthropic_api_key="my-api-key")
|
||||||
|
|
||||||
|
# Simplest invocation, automatically wrapped with HUMAN_PROMPT
|
||||||
|
# and AI_PROMPT.
|
||||||
|
response = model("What are the biggest risks facing humanity?")
|
||||||
|
|
||||||
|
# Or if you want to use the chat mode, build a few-shot-prompt, or
|
||||||
|
# put words in the Assistant's mouth, use HUMAN_PROMPT and AI_PROMPT:
|
||||||
|
raw_prompt = "What are the biggest risks facing humanity?"
|
||||||
|
prompt = f"{anthropic.HUMAN_PROMPT} {prompt}{anthropic.AI_PROMPT}"
|
||||||
|
response = model(prompt)
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def raise_warning(cls, values: Dict) -> Dict:
|
||||||
|
"""Raise warning that this class is deprecated."""
|
||||||
|
warnings.warn(
|
||||||
|
"This Anthropic LLM is deprecated. "
|
||||||
|
"Please use `from langchain_community.chat_models import ChatAnthropic` "
|
||||||
|
"instead"
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of llm."""
|
||||||
|
return "anthropic-llm"
|
||||||
|
|
||||||
|
def _wrap_prompt(self, prompt: str) -> str:
|
||||||
|
if not self.HUMAN_PROMPT or not self.AI_PROMPT:
|
||||||
|
raise NameError("Please ensure the anthropic package is loaded")
|
||||||
|
|
||||||
|
if prompt.startswith(self.HUMAN_PROMPT):
|
||||||
|
return prompt # Already wrapped.
|
||||||
|
|
||||||
|
# Guard against common errors in specifying wrong number of newlines.
|
||||||
|
corrected_prompt, n_subs = re.subn(r"^\n*Human:", self.HUMAN_PROMPT, prompt)
|
||||||
|
if n_subs == 1:
|
||||||
|
return corrected_prompt
|
||||||
|
|
||||||
|
# As a last resort, wrap the prompt ourselves to emulate instruct-style.
|
||||||
|
return f"{self.HUMAN_PROMPT} {prompt}{self.AI_PROMPT} Sure, here you go:\n"
|
||||||
|
|
||||||
|
def _call(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
r"""Call out to Anthropic's completion endpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The string generated by the model.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt = "What are the biggest risks facing humanity?"
|
||||||
|
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||||
|
response = model(prompt)
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self.streaming:
|
||||||
|
completion = ""
|
||||||
|
for chunk in self._stream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
completion += chunk.text
|
||||||
|
return completion
|
||||||
|
|
||||||
|
stop = self._get_anthropic_stop(stop)
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
|
response = self.client.completions.create(
|
||||||
|
prompt=self._wrap_prompt(prompt),
|
||||||
|
stop_sequences=stop,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
return response.completion
|
||||||
|
|
||||||
|
def convert_prompt(self, prompt: PromptValue) -> str:
|
||||||
|
return self._wrap_prompt(prompt.to_string())
|
||||||
|
|
||||||
|
async def _acall(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Call out to Anthropic's completion endpoint asynchronously."""
|
||||||
|
if self.streaming:
|
||||||
|
completion = ""
|
||||||
|
async for chunk in self._astream(
|
||||||
|
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
):
|
||||||
|
completion += chunk.text
|
||||||
|
return completion
|
||||||
|
|
||||||
|
stop = self._get_anthropic_stop(stop)
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
|
|
||||||
|
response = await self.async_client.completions.create(
|
||||||
|
prompt=self._wrap_prompt(prompt),
|
||||||
|
stop_sequences=stop,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
return response.completion
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[GenerationChunk]:
|
||||||
|
r"""Call Anthropic completion_stream and return the resulting generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
Returns:
|
||||||
|
A generator representing the stream of tokens from Anthropic.
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
prompt = "Write a poem about a stream."
|
||||||
|
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||||
|
generator = anthropic.stream(prompt)
|
||||||
|
for token in generator:
|
||||||
|
yield token
|
||||||
|
"""
|
||||||
|
stop = self._get_anthropic_stop(stop)
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
|
|
||||||
|
for token in self.client.completions.create(
|
||||||
|
prompt=self._wrap_prompt(prompt), stop_sequences=stop, stream=True, **params
|
||||||
|
):
|
||||||
|
chunk = GenerationChunk(text=token.completion)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[GenerationChunk]:
|
||||||
|
r"""Call Anthropic completion_stream and return the resulting generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The prompt to pass into the model.
|
||||||
|
stop: Optional list of stop words to use when generating.
|
||||||
|
Returns:
|
||||||
|
A generator representing the stream of tokens from Anthropic.
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
prompt = "Write a poem about a stream."
|
||||||
|
prompt = f"\n\nHuman: {prompt}\n\nAssistant:"
|
||||||
|
generator = anthropic.stream(prompt)
|
||||||
|
for token in generator:
|
||||||
|
yield token
|
||||||
|
"""
|
||||||
|
stop = self._get_anthropic_stop(stop)
|
||||||
|
params = {**self._default_params, **kwargs}
|
||||||
|
|
||||||
|
async for token in await self.async_client.completions.create(
|
||||||
|
prompt=self._wrap_prompt(prompt),
|
||||||
|
stop_sequences=stop,
|
||||||
|
stream=True,
|
||||||
|
**params,
|
||||||
|
):
|
||||||
|
chunk = GenerationChunk(text=token.completion)
|
||||||
|
yield chunk
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||||
|
|
||||||
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
"""Calculate number of tokens."""
|
||||||
|
if not self.count_tokens:
|
||||||
|
raise NameError("Please ensure the anthropic package is loaded")
|
||||||
|
return self.count_tokens(text)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated(since="0.1.0", removal="0.2.0", alternative="AnthropicLLM")
|
||||||
|
class Anthropic(AnthropicLLM):
|
||||||
|
pass
|
@ -0,0 +1,74 @@
|
|||||||
|
"""Test Anthropic API wrapper."""
|
||||||
|
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.callbacks import CallbackManager
|
||||||
|
from langchain_core.outputs import LLMResult
|
||||||
|
|
||||||
|
from langchain_anthropic import Anthropic
|
||||||
|
from tests.unit_tests._utils import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_name_param() -> None:
|
||||||
|
llm = Anthropic(model_name="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_param() -> None:
|
||||||
|
llm = Anthropic(model="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_call() -> None:
|
||||||
|
"""Test valid call to anthropic."""
|
||||||
|
llm = Anthropic(model="claude-instant-1")
|
||||||
|
output = llm("Say foo:")
|
||||||
|
assert isinstance(output, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_streaming() -> None:
|
||||||
|
"""Test streaming tokens from anthropic."""
|
||||||
|
llm = Anthropic(model="claude-instant-1")
|
||||||
|
generator = llm.stream("I'm Pickle Rick")
|
||||||
|
|
||||||
|
assert isinstance(generator, Generator)
|
||||||
|
|
||||||
|
for token in generator:
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_anthropic_streaming_callback() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
llm = Anthropic(
|
||||||
|
streaming=True,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
llm("Write me a sentence with 100 words.")
|
||||||
|
assert callback_handler.llm_streams > 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_anthropic_async_generate() -> None:
|
||||||
|
"""Test async generate."""
|
||||||
|
llm = Anthropic()
|
||||||
|
output = await llm.agenerate(["How many toes do dogs have?"])
|
||||||
|
assert isinstance(output, LLMResult)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_anthropic_async_streaming_callback() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
llm = Anthropic(
|
||||||
|
streaming=True,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
result = await llm.agenerate(["How many toes do dogs have?"])
|
||||||
|
assert callback_handler.llm_streams > 1
|
||||||
|
assert isinstance(result, LLMResult)
|
@ -0,0 +1,255 @@
|
|||||||
|
"""A fake callback handler for testing purposes."""
|
||||||
|
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFakeCallbackHandler(BaseModel):
|
||||||
|
"""Base fake callback handler for testing."""
|
||||||
|
|
||||||
|
starts: int = 0
|
||||||
|
ends: int = 0
|
||||||
|
errors: int = 0
|
||||||
|
text: int = 0
|
||||||
|
ignore_llm_: bool = False
|
||||||
|
ignore_chain_: bool = False
|
||||||
|
ignore_agent_: bool = False
|
||||||
|
ignore_retriever_: bool = False
|
||||||
|
ignore_chat_model_: bool = False
|
||||||
|
|
||||||
|
# to allow for similar callback handlers that are not technicall equal
|
||||||
|
fake_id: Union[str, None] = None
|
||||||
|
|
||||||
|
# add finer-grained counters for easier debugging of failing tests
|
||||||
|
chain_starts: int = 0
|
||||||
|
chain_ends: int = 0
|
||||||
|
llm_starts: int = 0
|
||||||
|
llm_ends: int = 0
|
||||||
|
llm_streams: int = 0
|
||||||
|
tool_starts: int = 0
|
||||||
|
tool_ends: int = 0
|
||||||
|
agent_actions: int = 0
|
||||||
|
agent_ends: int = 0
|
||||||
|
chat_model_starts: int = 0
|
||||||
|
retriever_starts: int = 0
|
||||||
|
retriever_ends: int = 0
|
||||||
|
retriever_errors: int = 0
|
||||||
|
retries: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFakeCallbackHandlerMixin(BaseFakeCallbackHandler):
|
||||||
|
"""Base fake callback handler mixin for testing."""
|
||||||
|
|
||||||
|
def on_llm_start_common(self) -> None:
|
||||||
|
self.llm_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_llm_end_common(self) -> None:
|
||||||
|
self.llm_ends += 1
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_llm_error_common(self) -> None:
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_llm_new_token_common(self) -> None:
|
||||||
|
self.llm_streams += 1
|
||||||
|
|
||||||
|
def on_retry_common(self) -> None:
|
||||||
|
self.retries += 1
|
||||||
|
|
||||||
|
def on_chain_start_common(self) -> None:
|
||||||
|
self.chain_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_chain_end_common(self) -> None:
|
||||||
|
self.chain_ends += 1
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_chain_error_common(self) -> None:
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_tool_start_common(self) -> None:
|
||||||
|
self.tool_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_tool_end_common(self) -> None:
|
||||||
|
self.tool_ends += 1
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_tool_error_common(self) -> None:
|
||||||
|
self.errors += 1
|
||||||
|
|
||||||
|
def on_agent_action_common(self) -> None:
|
||||||
|
self.agent_actions += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_agent_finish_common(self) -> None:
|
||||||
|
self.agent_ends += 1
|
||||||
|
self.ends += 1
|
||||||
|
|
||||||
|
def on_chat_model_start_common(self) -> None:
|
||||||
|
self.chat_model_starts += 1
|
||||||
|
self.starts += 1
|
||||||
|
|
||||||
|
def on_text_common(self) -> None:
|
||||||
|
self.text += 1
|
||||||
|
|
||||||
|
def on_retriever_start_common(self) -> None:
|
||||||
|
self.starts += 1
|
||||||
|
self.retriever_starts += 1
|
||||||
|
|
||||||
|
def on_retriever_end_common(self) -> None:
|
||||||
|
self.ends += 1
|
||||||
|
self.retriever_ends += 1
|
||||||
|
|
||||||
|
def on_retriever_error_common(self) -> None:
|
||||||
|
self.errors += 1
|
||||||
|
self.retriever_errors += 1
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||||
|
"""Fake callback handler for testing."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_llm(self) -> bool:
|
||||||
|
"""Whether to ignore LLM callbacks."""
|
||||||
|
return self.ignore_llm_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_chain(self) -> bool:
|
||||||
|
"""Whether to ignore chain callbacks."""
|
||||||
|
return self.ignore_chain_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_agent(self) -> bool:
|
||||||
|
"""Whether to ignore agent callbacks."""
|
||||||
|
return self.ignore_agent_
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ignore_retriever(self) -> bool:
|
||||||
|
"""Whether to ignore retriever callbacks."""
|
||||||
|
return self.ignore_retriever_
|
||||||
|
|
||||||
|
def on_llm_start(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_llm_start_common()
|
||||||
|
|
||||||
|
def on_llm_new_token(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_llm_new_token_common()
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_llm_end_common()
|
||||||
|
|
||||||
|
def on_llm_error(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_llm_error_common()
|
||||||
|
|
||||||
|
def on_retry(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_retry_common()
|
||||||
|
|
||||||
|
def on_chain_start(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_chain_start_common()
|
||||||
|
|
||||||
|
def on_chain_end(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_chain_end_common()
|
||||||
|
|
||||||
|
def on_chain_error(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_chain_error_common()
|
||||||
|
|
||||||
|
def on_tool_start(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_tool_start_common()
|
||||||
|
|
||||||
|
def on_tool_end(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_tool_end_common()
|
||||||
|
|
||||||
|
def on_tool_error(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_tool_error_common()
|
||||||
|
|
||||||
|
def on_agent_action(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_agent_action_common()
|
||||||
|
|
||||||
|
def on_agent_finish(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_agent_finish_common()
|
||||||
|
|
||||||
|
def on_text(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_text_common()
|
||||||
|
|
||||||
|
def on_retriever_start(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_retriever_start_common()
|
||||||
|
|
||||||
|
def on_retriever_end(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_retriever_end_common()
|
||||||
|
|
||||||
|
def on_retriever_error(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
self.on_retriever_error_common()
|
||||||
|
|
||||||
|
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
|
||||||
|
return self
|
@ -1,10 +1,54 @@
|
|||||||
"""Test chat model integration."""
|
"""Test chat model integration."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from langchain_anthropic.chat_models import ChatAnthropicMessages
|
import pytest
|
||||||
|
|
||||||
|
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
|
||||||
|
|
||||||
|
os.environ["ANTHROPIC_API_KEY"] = "foo"
|
||||||
|
|
||||||
|
|
||||||
def test_initialization() -> None:
|
def test_initialization() -> None:
|
||||||
"""Test chat model initialization."""
|
"""Test chat model initialization."""
|
||||||
ChatAnthropicMessages(model_name="claude-instant-1.2", anthropic_api_key="xyz")
|
ChatAnthropicMessages(model_name="claude-instant-1.2", anthropic_api_key="xyz")
|
||||||
ChatAnthropicMessages(model="claude-instant-1.2", anthropic_api_key="xyz")
|
ChatAnthropicMessages(model="claude-instant-1.2", anthropic_api_key="xyz")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_name_param() -> None:
|
||||||
|
llm = ChatAnthropic(model_name="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_param() -> None:
|
||||||
|
llm = ChatAnthropic(model="foo")
|
||||||
|
assert llm.model == "foo"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_model_kwargs() -> None:
|
||||||
|
llm = ChatAnthropic(model_name="foo", model_kwargs={"foo": "bar"})
|
||||||
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_invalid_model_kwargs() -> None:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
ChatAnthropic(model="foo", model_kwargs={"max_tokens_to_sample": 5})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_incorrect_field() -> None:
|
||||||
|
with pytest.warns(match="not default parameter"):
|
||||||
|
llm = ChatAnthropic(model="foo", foo="bar")
|
||||||
|
assert llm.model_kwargs == {"foo": "bar"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("anthropic")
|
||||||
|
def test_anthropic_initialization() -> None:
|
||||||
|
"""Test anthropic initialization."""
|
||||||
|
# Verify that chat anthropic can be initialized using a secret key provided
|
||||||
|
# as a parameter rather than an environment variable.
|
||||||
|
ChatAnthropic(model="test", anthropic_api_key="test")
|
||||||
|
Loading…
Reference in New Issue