add async support for anthropic (#2114)

should not be merged in before
https://github.com/anthropics/anthropic-sdk-python/pull/11 gets released
This commit is contained in:
Ankush Gola 2023-03-28 22:49:14 -04:00 committed by GitHub
parent e2c26909f2
commit ccee1aedd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 71 additions and 12 deletions

View File

@ -9,7 +9,7 @@
"\n",
"LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n",
"\n",
"Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, only `OpenAI` and `PromptLayerOpenAI` are supported, but async support for other LLMs is on the roadmap.\n",
"Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI` and `Anthropic` are supported, but async support for other LLMs is on the roadmap.\n",
"\n",
"You can use the `agenerate` method to call an OpenAI LLM asynchronously."
]
@ -151,7 +151,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.10.9"
}
},
"nbformat": 4,

View File

@ -170,6 +170,38 @@ class Anthropic(LLM, BaseModel):
)
return response["completion"]
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Anthropic's completion endpoint asynchronously."""
stop = self._get_anthropic_stop(stop)
if self.streaming:
stream_resp = await self.client.acompletion_stream(
model=self.model,
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**self._default_params,
)
current_completion = ""
async for data in stream_resp:
delta = data["completion"][len(current_completion) :]
current_completion = data["completion"]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
delta, verbose=self.verbose, **data
)
else:
self.callback_manager.on_llm_new_token(
delta, verbose=self.verbose, **data
)
return current_completion
response = await self.client.acompletion(
model=self.model,
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**self._default_params,
)
return response["completion"]
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
r"""Call Anthropic completion_stream and return the resulting generator.

17
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand.
# This file is automatically @generated by Poetry and should not be changed by hand.
[[package]]
name = "absl-py"
@ -287,17 +287,18 @@ types = ["mypy", "types-requests"]
[[package]]
name = "anthropic"
version = "0.2.3"
version = "0.2.4"
description = "Library for accessing the anthropic API"
category = "main"
optional = true
python-versions = ">=3.8"
files = [
{file = "anthropic-0.2.3-py3-none-any.whl", hash = "sha256:51cc9e3c5c0fc39b62af64b0607fd0da1622c7815fed89d0a52d80ebe0e60f3a"},
{file = "anthropic-0.2.3.tar.gz", hash = "sha256:3d4f8d21c54d23d476d5ef72510b50126108f9b0bdc45b9d5d2e2b34204d56ad"},
{file = "anthropic-0.2.4-py3-none-any.whl", hash = "sha256:2f955435bfdecc94e5432d72492c38fd22a55647026de1518461b707dd3ae808"},
{file = "anthropic-0.2.4.tar.gz", hash = "sha256:2041546470a9a2e897d6317627fdd7ea585245a3ff1ed1b054a47f63b4daa893"},
]
[package.dependencies]
aiohttp = "*"
httpx = "*"
requests = "*"
tokenizers = "*"
@ -6868,7 +6869,7 @@ files = [
]
[package.dependencies]
greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and platform_machine == \"aarch64\" or python_version >= \"3\" and platform_machine == \"ppc64le\" or python_version >= \"3\" and platform_machine == \"x86_64\" or python_version >= \"3\" and platform_machine == \"amd64\" or python_version >= \"3\" and platform_machine == \"AMD64\" or python_version >= \"3\" and platform_machine == \"win32\" or python_version >= \"3\" and platform_machine == \"WIN32\""}
greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"}
[package.extras]
aiomysql = ["aiomysql", "greenlet (!=0.4.17)"]
@ -8528,10 +8529,10 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[extras]
all = ["aleph-alpha-client", "anthropic", "beautifulsoup4", "cohere", "deeplake", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "psycopg2-binary", "pypdf", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"]
all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary"]
llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "fec488d52fc1a46ae34a643e5951e00251232d6d38c50c844556041df6067572"
content-hash = "b8dd776b9c9bfda2413bdc0c58eb9cc8975a0ec770a830af5568cc80e4a1925d"

View File

@ -34,7 +34,7 @@ pinecone-client = {version = "^2", optional = true}
weaviate-client = {version = "^3", optional = true}
google-api-python-client = {version = "2.70.0", optional = true}
wolframalpha = {version = "5.0.0", optional = true}
anthropic = {version = "^0.2.2", optional = true}
anthropic = {version = "^0.2.4", optional = true}
qdrant-client = {version = "^1.0.4", optional = true, python = ">=3.8.1,<3.12"}
dataclasses-json = "^0.5.7"
tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"}

View File

@ -1,9 +1,11 @@
"""Test Anthropic API wrapper."""
from typing import Generator
import pytest
from langchain.callbacks.base import CallbackManager
from langchain.llms.anthropic import Anthropic
from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -37,3 +39,27 @@ def test_anthropic_streaming_callback() -> None:
)
llm("Write me a sentence with 100 words.")
assert callback_handler.llm_streams > 1
@pytest.mark.asyncio
async def test_anthropic_async_generate() -> None:
"""Test async generate."""
llm = Anthropic()
output = await llm.agenerate(["How many toes do dogs have?"])
assert isinstance(output, LLMResult)
@pytest.mark.asyncio
async def test_anthropic_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Anthropic(
model="claude-v1",
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
result = await llm.agenerate(["How many toes do dogs have?"])
assert callback_handler.llm_streams > 1
assert isinstance(result, LLMResult)