Harrison/update anthropic (#7237)

Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
Harrison Chase 2023-07-05 21:02:35 -04:00 committed by GitHub
parent 695e7027e6
commit 52b016920c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 83 additions and 50 deletions

View File

@ -52,7 +52,7 @@
{
"data": {
"text/plain": [
"AIMessage(content=\" J'aime programmer. \", additional_kwargs={})"
"AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 3,
@ -101,7 +101,7 @@
{
"data": {
"text/plain": [
"LLMResult(generations=[[ChatGeneration(text=\" J'aime la programmation.\", generation_info=None, message=AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}))]], llm_output={})"
"LLMResult(generations=[[ChatGeneration(text=\" J'aime programmer.\", generation_info=None, message=AIMessage(content=\" J'aime programmer.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('8cc8fb68-1c35-439c-96a0-695036a93652'))])"
]
},
"execution_count": 5,
@ -125,13 +125,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
" J'adore programmer."
" J'aime la programmation."
]
},
{
"data": {
"text/plain": [
"AIMessage(content=\" J'adore programmer.\", additional_kwargs={})"
"AIMessage(content=\" J'aime la programmation.\", additional_kwargs={}, example=False)"
]
},
"execution_count": 6,
@ -151,7 +151,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "df45f59f",
"id": "c253883f",
"metadata": {},
"outputs": [],
"source": []
@ -173,7 +173,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.9.1"
}
},
"nbformat": 4,

View File

@ -104,17 +104,17 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
if self.streaming:
completion = ""
stream_resp = self.client.completion_stream(**params)
stream_resp = self.client.completions.create(**params, stream=True)
for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
delta = data.completion
completion += delta
if run_manager:
run_manager.on_llm_new_token(
delta,
)
else:
response = self.client.completion(**params)
completion = response["completion"]
response = self.client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])
@ -132,17 +132,19 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
if self.streaming:
completion = ""
stream_resp = await self.client.acompletion_stream(**params)
stream_resp = await self.async_client.completions.create(
**params, stream=True
)
async for data in stream_resp:
delta = data["completion"][len(completion) :]
completion = data["completion"]
delta = data.completion
completion += delta
if run_manager:
await run_manager.on_llm_new_token(
delta,
)
else:
response = await self.client.acompletion(**params)
completion = response["completion"]
response = await self.async_client.completions.create(**params)
completion = response.completion
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])

View File

@ -1,8 +1,10 @@
"""Wrapper around Anthropic APIs."""
import re
import warnings
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union
from importlib.metadata import version
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional
import packaging
from pydantic import BaseModel, root_validator
from langchain.callbacks.manager import (
@ -15,6 +17,7 @@ from langchain.utils import get_from_dict_or_env
class _AnthropicCommon(BaseModel):
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model: str = "claude-v1"
"""Model name to use."""
@ -33,7 +36,7 @@ class _AnthropicCommon(BaseModel):
streaming: bool = False
"""Whether to stream the results."""
default_request_timeout: Optional[Union[float, Tuple[float, float]]] = None
default_request_timeout: Optional[float] = None
"""Timeout for requests to Anthropic Completion API. Default is 600 seconds."""
anthropic_api_url: Optional[str] = None
@ -50,7 +53,7 @@ class _AnthropicCommon(BaseModel):
anthropic_api_key = get_from_dict_or_env(
values, "anthropic_api_key", "ANTHROPIC_API_KEY"
)
"""Get custom api url from environment."""
# Get custom api url from environment.
anthropic_api_url = get_from_dict_or_env(
values,
"anthropic_api_url",
@ -61,14 +64,26 @@ class _AnthropicCommon(BaseModel):
try:
import anthropic
values["client"] = anthropic.Client(
api_url=anthropic_api_url,
anthropic_version = packaging.version.parse(version("anthropic"))
if anthropic_version < packaging.version.parse("0.3"):
raise ValueError(
f"Anthropic client version must be > 0.3, got {anthropic_version}. "
f"To update the client, please run "
f"`pip install -U anthropic`"
)
values["client"] = anthropic.Anthropic(
base_url=anthropic_api_url,
api_key=anthropic_api_key,
default_request_timeout=values["default_request_timeout"],
timeout=values["default_request_timeout"],
)
values["async_client"] = anthropic.AsyncAnthropic(
base_url=anthropic_api_url,
api_key=anthropic_api_key,
timeout=values["default_request_timeout"],
)
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = anthropic.count_tokens
values["count_tokens"] = values["client"].count_tokens
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
@ -190,24 +205,27 @@ class Anthropic(LLM, _AnthropicCommon):
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming:
stream_resp = self.client.completion_stream(
stream_resp = self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
)
current_completion = ""
for data in stream_resp:
delta = data["completion"][len(current_completion) :]
current_completion = data["completion"]
delta = data.completion
current_completion += delta
if run_manager:
run_manager.on_llm_new_token(delta, **data)
run_manager.on_llm_new_token(
delta,
)
return current_completion
response = self.client.completion(
response = self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**params,
)
return response["completion"]
return response.completion
async def _acall(
self,
@ -220,24 +238,25 @@ class Anthropic(LLM, _AnthropicCommon):
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming:
stream_resp = await self.client.acompletion_stream(
stream_resp = await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**params,
)
current_completion = ""
async for data in stream_resp:
delta = data["completion"][len(current_completion) :]
current_completion = data["completion"]
delta = data.completion
current_completion += delta
if run_manager:
await run_manager.on_llm_new_token(delta, **data)
await run_manager.on_llm_new_token(delta)
return current_completion
response = await self.client.acompletion(
response = await self.async_client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**params,
)
return response["completion"]
return response.completion
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
r"""Call Anthropic completion_stream and return the resulting generator.
@ -263,9 +282,10 @@ class Anthropic(LLM, _AnthropicCommon):
yield token
"""
stop = self._get_anthropic_stop(stop)
return self.client.completion_stream(
return self.client.completions.create(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**self._default_params,
)

37
poetry.lock generated
View File

@ -341,24 +341,23 @@ dev = ["black", "docutils", "flake8", "ipython", "m2r", "mistune (<2.0.0)", "pyt
[[package]]
name = "anthropic"
version = "0.2.10"
description = "Library for accessing the anthropic API"
version = "0.3.2"
description = "Client library for the anthropic API"
category = "main"
optional = true
python-versions = ">=3.8"
python-versions = ">=3.7,<4.0"
files = [
{file = "anthropic-0.2.10-py3-none-any.whl", hash = "sha256:a007496207fd186b0bcb9592b00ca130069d2a427f3d6f602a61dbbd1ac6316e"},
{file = "anthropic-0.2.10.tar.gz", hash = "sha256:e4da061a86d8ffb86072c0b0feaf219a3a4f7dfddd4224df9ba769e469498c19"},
{file = "anthropic-0.3.2-py3-none-any.whl", hash = "sha256:43ad86df406bf91419e3c651e20dcc69ae273c932c92c26973a1621a72ff1d86"},
{file = "anthropic-0.3.2.tar.gz", hash = "sha256:f968e970bb0dfa38b1ec59db7bb4162fd1e0f2bef95c3203e926effe62bfcf38"},
]
[package.dependencies]
aiohttp = "*"
httpx = "*"
requests = "*"
tokenizers = "*"
[package.extras]
dev = ["black (>=22.3.0)", "pytest"]
anyio = ">=3.5.0"
distro = ">=1.7.0"
httpx = ">=0.23.0"
pydantic = ">=1.9.0,<2.0.0"
tokenizers = ">=0.13.0"
typing-extensions = ">=4.1.1"
[[package]]
name = "anyio"
@ -2285,6 +2284,18 @@ files = [
[package.extras]
graph = ["objgraph (>=1.7.2)"]
[[package]]
name = "distro"
version = "1.8.0"
description = "Distro - an OS platform information API"
category = "main"
optional = true
python-versions = ">=3.6"
files = [
{file = "distro-1.8.0-py3-none-any.whl", hash = "sha256:99522ca3e365cac527b44bde033f64c6945d90eb9f769703caaec52b09bbd3ff"},
{file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"},
]
[[package]]
name = "dnspython"
version = "2.3.0"
@ -12409,4 +12420,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "ce9bfa2954a3b468d925410fe836c6db92040e95cb9720227e12abef1f4c11ca"
content-hash = "e1da6f8f88f3410d6ac4b1babb2a8aa9e01263deb52da423419f5362c5ddfc1f"

View File

@ -42,7 +42,7 @@ marqo = {version = "^0.9.1", optional=true}
google-api-python-client = {version = "2.70.0", optional = true}
google-auth = {version = "^2.18.1", optional = true}
wolframalpha = {version = "5.0.0", optional = true}
anthropic = {version = "^0.2.6", optional = true}
anthropic = {version = "^0.3", optional = true}
qdrant-client = {version = "^1.1.2", optional = true, python = ">=3.8.1,<3.12"}
dataclasses-json = "^0.5.7"
tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"}