|
|
|
@ -1,8 +1,10 @@
|
|
|
|
|
"""Wrapper around Anthropic APIs."""
|
|
|
|
|
import re
|
|
|
|
|
import warnings
|
|
|
|
|
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
|
|
|
|
from importlib.metadata import version
|
|
|
|
|
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
|
|
|
|
|
|
|
|
|
|
import packaging
|
|
|
|
|
from pydantic import BaseModel, root_validator
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
@ -15,6 +17,7 @@ from langchain.utils import get_from_dict_or_env
|
|
|
|
|
|
|
|
|
|
class _AnthropicCommon(BaseModel):
|
|
|
|
|
client: Any = None #: :meta private:
|
|
|
|
|
async_client: Any = None #: :meta private:
|
|
|
|
|
model: str = "claude-v1"
|
|
|
|
|
"""Model name to use."""
|
|
|
|
|
|
|
|
|
@ -33,7 +36,7 @@ class _AnthropicCommon(BaseModel):
|
|
|
|
|
streaming: bool = False
|
|
|
|
|
"""Whether to stream the results."""
|
|
|
|
|
|
|
|
|
|
default_request_timeout: Optional[Union[float, Tuple[float, float]]] = None
|
|
|
|
|
default_request_timeout: Optional[float] = None
|
|
|
|
|
"""Timeout for requests to Anthropic Completion API. Default is 600 seconds."""
|
|
|
|
|
|
|
|
|
|
anthropic_api_url: Optional[str] = None
|
|
|
|
@ -50,7 +53,7 @@ class _AnthropicCommon(BaseModel):
|
|
|
|
|
anthropic_api_key = get_from_dict_or_env(
|
|
|
|
|
values, "anthropic_api_key", "ANTHROPIC_API_KEY"
|
|
|
|
|
)
|
|
|
|
|
"""Get custom api url from environment."""
|
|
|
|
|
# Get custom api url from environment.
|
|
|
|
|
anthropic_api_url = get_from_dict_or_env(
|
|
|
|
|
values,
|
|
|
|
|
"anthropic_api_url",
|
|
|
|
@ -61,14 +64,26 @@ class _AnthropicCommon(BaseModel):
|
|
|
|
|
try:
|
|
|
|
|
import anthropic
|
|
|
|
|
|
|
|
|
|
values["client"] = anthropic.Client(
|
|
|
|
|
api_url=anthropic_api_url,
|
|
|
|
|
anthropic_version = packaging.version.parse(version("anthropic"))
|
|
|
|
|
if anthropic_version < packaging.version.parse("0.3"):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Anthropic client version must be > 0.3, got {anthropic_version}. "
|
|
|
|
|
f"To update the client, please run "
|
|
|
|
|
f"`pip install -U anthropic`"
|
|
|
|
|
)
|
|
|
|
|
values["client"] = anthropic.Anthropic(
|
|
|
|
|
base_url=anthropic_api_url,
|
|
|
|
|
api_key=anthropic_api_key,
|
|
|
|
|
default_request_timeout=values["default_request_timeout"],
|
|
|
|
|
timeout=values["default_request_timeout"],
|
|
|
|
|
)
|
|
|
|
|
values["async_client"] = anthropic.AsyncAnthropic(
|
|
|
|
|
base_url=anthropic_api_url,
|
|
|
|
|
api_key=anthropic_api_key,
|
|
|
|
|
timeout=values["default_request_timeout"],
|
|
|
|
|
)
|
|
|
|
|
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
|
|
|
|
|
values["AI_PROMPT"] = anthropic.AI_PROMPT
|
|
|
|
|
values["count_tokens"] = anthropic.count_tokens
|
|
|
|
|
values["count_tokens"] = values["client"].count_tokens
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import anthropic python package. "
|
|
|
|
@ -190,24 +205,27 @@ class Anthropic(LLM, _AnthropicCommon):
|
|
|
|
|
stop = self._get_anthropic_stop(stop)
|
|
|
|
|
params = {**self._default_params, **kwargs}
|
|
|
|
|
if self.streaming:
|
|
|
|
|
stream_resp = self.client.completion_stream(
|
|
|
|
|
stream_resp = self.client.completions.create(
|
|
|
|
|
prompt=self._wrap_prompt(prompt),
|
|
|
|
|
stop_sequences=stop,
|
|
|
|
|
stream=True,
|
|
|
|
|
**params,
|
|
|
|
|
)
|
|
|
|
|
current_completion = ""
|
|
|
|
|
for data in stream_resp:
|
|
|
|
|
delta = data["completion"][len(current_completion) :]
|
|
|
|
|
current_completion = data["completion"]
|
|
|
|
|
delta = data.completion
|
|
|
|
|
current_completion += delta
|
|
|
|
|
if run_manager:
|
|
|
|
|
run_manager.on_llm_new_token(delta, **data)
|
|
|
|
|
run_manager.on_llm_new_token(
|
|
|
|
|
delta,
|
|
|
|
|
)
|
|
|
|
|
return current_completion
|
|
|
|
|
response = self.client.completion(
|
|
|
|
|
response = self.client.completions.create(
|
|
|
|
|
prompt=self._wrap_prompt(prompt),
|
|
|
|
|
stop_sequences=stop,
|
|
|
|
|
**params,
|
|
|
|
|
)
|
|
|
|
|
return response["completion"]
|
|
|
|
|
return response.completion
|
|
|
|
|
|
|
|
|
|
async def _acall(
|
|
|
|
|
self,
|
|
|
|
@ -220,24 +238,25 @@ class Anthropic(LLM, _AnthropicCommon):
|
|
|
|
|
stop = self._get_anthropic_stop(stop)
|
|
|
|
|
params = {**self._default_params, **kwargs}
|
|
|
|
|
if self.streaming:
|
|
|
|
|
stream_resp = await self.client.acompletion_stream(
|
|
|
|
|
stream_resp = await self.async_client.completions.create(
|
|
|
|
|
prompt=self._wrap_prompt(prompt),
|
|
|
|
|
stop_sequences=stop,
|
|
|
|
|
stream=True,
|
|
|
|
|
**params,
|
|
|
|
|
)
|
|
|
|
|
current_completion = ""
|
|
|
|
|
async for data in stream_resp:
|
|
|
|
|
delta = data["completion"][len(current_completion) :]
|
|
|
|
|
current_completion = data["completion"]
|
|
|
|
|
delta = data.completion
|
|
|
|
|
current_completion += delta
|
|
|
|
|
if run_manager:
|
|
|
|
|
await run_manager.on_llm_new_token(delta, **data)
|
|
|
|
|
await run_manager.on_llm_new_token(delta)
|
|
|
|
|
return current_completion
|
|
|
|
|
response = await self.client.acompletion(
|
|
|
|
|
response = await self.async_client.completions.create(
|
|
|
|
|
prompt=self._wrap_prompt(prompt),
|
|
|
|
|
stop_sequences=stop,
|
|
|
|
|
**params,
|
|
|
|
|
)
|
|
|
|
|
return response["completion"]
|
|
|
|
|
return response.completion
|
|
|
|
|
|
|
|
|
|
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
|
|
|
|
r"""Call Anthropic completion_stream and return the resulting generator.
|
|
|
|
@ -263,9 +282,10 @@ class Anthropic(LLM, _AnthropicCommon):
|
|
|
|
|
yield token
|
|
|
|
|
"""
|
|
|
|
|
stop = self._get_anthropic_stop(stop)
|
|
|
|
|
return self.client.completion_stream(
|
|
|
|
|
return self.client.completions.create(
|
|
|
|
|
prompt=self._wrap_prompt(prompt),
|
|
|
|
|
stop_sequences=stop,
|
|
|
|
|
stream=True,
|
|
|
|
|
**self._default_params,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|