add async support for anthropic (#2114)

should not be merged in before
https://github.com/anthropics/anthropic-sdk-python/pull/11 gets released
doc
Ankush Gola 1 year ago committed by GitHub
parent e2c26909f2
commit ccee1aedd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,7 @@
"\n", "\n",
"LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n", "LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n",
"\n", "\n",
"Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, only `OpenAI` and `PromptLayerOpenAI` are supported, but async support for other LLMs is on the roadmap.\n", "Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI` and `Anthropic` are supported, but async support for other LLMs is on the roadmap.\n",
"\n", "\n",
"You can use the `agenerate` method to call an OpenAI LLM asynchronously." "You can use the `agenerate` method to call an OpenAI LLM asynchronously."
] ]
@ -151,7 +151,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.10.9"
} }
}, },
"nbformat": 4, "nbformat": 4,

@ -170,6 +170,38 @@ class Anthropic(LLM, BaseModel):
) )
return response["completion"] return response["completion"]
async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Anthropic's completion endpoint asynchronously."""
stop = self._get_anthropic_stop(stop)
if self.streaming:
stream_resp = await self.client.acompletion_stream(
model=self.model,
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
stream=True,
**self._default_params,
)
current_completion = ""
async for data in stream_resp:
delta = data["completion"][len(current_completion) :]
current_completion = data["completion"]
if self.callback_manager.is_async:
await self.callback_manager.on_llm_new_token(
delta, verbose=self.verbose, **data
)
else:
self.callback_manager.on_llm_new_token(
delta, verbose=self.verbose, **data
)
return current_completion
response = await self.client.acompletion(
model=self.model,
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**self._default_params,
)
return response["completion"]
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
r"""Call Anthropic completion_stream and return the resulting generator. r"""Call Anthropic completion_stream and return the resulting generator.

17
poetry.lock generated

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. # This file is automatically @generated by Poetry and should not be changed by hand.
[[package]] [[package]]
name = "absl-py" name = "absl-py"
@ -287,17 +287,18 @@ types = ["mypy", "types-requests"]
[[package]] [[package]]
name = "anthropic" name = "anthropic"
version = "0.2.3" version = "0.2.4"
description = "Library for accessing the anthropic API" description = "Library for accessing the anthropic API"
category = "main" category = "main"
optional = true optional = true
python-versions = ">=3.8" python-versions = ">=3.8"
files = [ files = [
{file = "anthropic-0.2.3-py3-none-any.whl", hash = "sha256:51cc9e3c5c0fc39b62af64b0607fd0da1622c7815fed89d0a52d80ebe0e60f3a"}, {file = "anthropic-0.2.4-py3-none-any.whl", hash = "sha256:2f955435bfdecc94e5432d72492c38fd22a55647026de1518461b707dd3ae808"},
{file = "anthropic-0.2.3.tar.gz", hash = "sha256:3d4f8d21c54d23d476d5ef72510b50126108f9b0bdc45b9d5d2e2b34204d56ad"}, {file = "anthropic-0.2.4.tar.gz", hash = "sha256:2041546470a9a2e897d6317627fdd7ea585245a3ff1ed1b054a47f63b4daa893"},
] ]
[package.dependencies] [package.dependencies]
aiohttp = "*"
httpx = "*" httpx = "*"
requests = "*" requests = "*"
tokenizers = "*" tokenizers = "*"
@ -6868,7 +6869,7 @@ files = [
] ]
[package.dependencies] [package.dependencies]
greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and platform_machine == \"aarch64\" or python_version >= \"3\" and platform_machine == \"ppc64le\" or python_version >= \"3\" and platform_machine == \"x86_64\" or python_version >= \"3\" and platform_machine == \"amd64\" or python_version >= \"3\" and platform_machine == \"AMD64\" or python_version >= \"3\" and platform_machine == \"win32\" or python_version >= \"3\" and platform_machine == \"WIN32\""} greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"}
[package.extras] [package.extras]
aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"]
@ -8528,10 +8529,10 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker
testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
[extras] [extras]
all = ["aleph-alpha-client", "anthropic", "beautifulsoup4", "cohere", "deeplake", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "psycopg2-binary", "pypdf", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary"]
llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "fec488d52fc1a46ae34a643e5951e00251232d6d38c50c844556041df6067572" content-hash = "b8dd776b9c9bfda2413bdc0c58eb9cc8975a0ec770a830af5568cc80e4a1925d"

@ -34,7 +34,7 @@ pinecone-client = {version = "^2", optional = true}
weaviate-client = {version = "^3", optional = true} weaviate-client = {version = "^3", optional = true}
google-api-python-client = {version = "2.70.0", optional = true} google-api-python-client = {version = "2.70.0", optional = true}
wolframalpha = {version = "5.0.0", optional = true} wolframalpha = {version = "5.0.0", optional = true}
anthropic = {version = "^0.2.2", optional = true} anthropic = {version = "^0.2.4", optional = true}
qdrant-client = {version = "^1.0.4", optional = true, python = ">=3.8.1,<3.12"} qdrant-client = {version = "^1.0.4", optional = true, python = ">=3.8.1,<3.12"}
dataclasses-json = "^0.5.7" dataclasses-json = "^0.5.7"
tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"} tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"}

@ -1,9 +1,11 @@
"""Test Anthropic API wrapper.""" """Test Anthropic API wrapper."""
from typing import Generator from typing import Generator
import pytest
from langchain.callbacks.base import CallbackManager from langchain.callbacks.base import CallbackManager
from langchain.llms.anthropic import Anthropic from langchain.llms.anthropic import Anthropic
from langchain.schema import LLMResult
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -37,3 +39,27 @@ def test_anthropic_streaming_callback() -> None:
) )
llm("Write me a sentence with 100 words.") llm("Write me a sentence with 100 words.")
assert callback_handler.llm_streams > 1 assert callback_handler.llm_streams > 1
@pytest.mark.asyncio
async def test_anthropic_async_generate() -> None:
"""Test async generate."""
llm = Anthropic()
output = await llm.agenerate(["How many toes do dogs have?"])
assert isinstance(output, LLMResult)
@pytest.mark.asyncio
async def test_anthropic_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = Anthropic(
model="claude-v1",
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
result = await llm.agenerate(["How many toes do dogs have?"])
assert callback_handler.llm_streams > 1
assert isinstance(result, LLMResult)

Loading…
Cancel
Save