From ccee1aedd2c846e8c181e00befbcc004c8df405e Mon Sep 17 00:00:00 2001 From: Ankush Gola <9536492+agola11@users.noreply.github.com> Date: Tue, 28 Mar 2023 22:49:14 -0400 Subject: [PATCH] add async support for anthropic (#2114) should not be merged in before https://github.com/anthropics/anthropic-sdk-python/pull/11 gets released --- .../models/llms/examples/async_llm.ipynb | 4 +-- langchain/llms/anthropic.py | 32 +++++++++++++++++++ poetry.lock | 17 +++++----- pyproject.toml | 2 +- .../integration_tests/llms/test_anthropic.py | 28 +++++++++++++++- 5 files changed, 71 insertions(+), 12 deletions(-) diff --git a/docs/modules/models/llms/examples/async_llm.ipynb b/docs/modules/models/llms/examples/async_llm.ipynb index f3ccfc60..dad68d5a 100644 --- a/docs/modules/models/llms/examples/async_llm.ipynb +++ b/docs/modules/models/llms/examples/async_llm.ipynb @@ -9,7 +9,7 @@ "\n", "LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n", "\n", - "Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, only `OpenAI` and `PromptLayerOpenAI` are supported, but async support for other LLMs is on the roadmap.\n", + "Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI` and `Anthropic` are supported, but async support for other LLMs is on the roadmap.\n", "\n", "You can use the `agenerate` method to call an OpenAI LLM asynchronously." ] @@ -151,7 +151,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.10.9" } }, "nbformat": 4, diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py index 9927af91..d877da52 100644 --- a/langchain/llms/anthropic.py +++ b/langchain/llms/anthropic.py @@ -170,6 +170,38 @@ class Anthropic(LLM, BaseModel): ) return response["completion"] + async def _acall(self, prompt: str, stop: Optional[List[str]] = None) -> str: + """Call out to Anthropic's completion endpoint asynchronously.""" + stop = self._get_anthropic_stop(stop) + if self.streaming: + stream_resp = await self.client.acompletion_stream( + model=self.model, + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + stream=True, + **self._default_params, + ) + current_completion = "" + async for data in stream_resp: + delta = data["completion"][len(current_completion) :] + current_completion = data["completion"] + if self.callback_manager.is_async: + await self.callback_manager.on_llm_new_token( + delta, verbose=self.verbose, **data + ) + else: + self.callback_manager.on_llm_new_token( + delta, verbose=self.verbose, **data + ) + return current_completion + response = await self.client.acompletion( + model=self.model, + prompt=self._wrap_prompt(prompt), + stop_sequences=stop, + **self._default_params, + ) + return response["completion"] + def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator: r"""Call Anthropic completion_stream and return the resulting generator. diff --git a/poetry.lock b/poetry.lock index eddd25e3..6bc3a0aa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "absl-py" @@ -287,17 +287,18 @@ types = ["mypy", "types-requests"] [[package]] name = "anthropic" -version = "0.2.3" +version = "0.2.4" description = "Library for accessing the anthropic API" category = "main" optional = true python-versions = ">=3.8" files = [ - {file = "anthropic-0.2.3-py3-none-any.whl", hash = "sha256:51cc9e3c5c0fc39b62af64b0607fd0da1622c7815fed89d0a52d80ebe0e60f3a"}, - {file = "anthropic-0.2.3.tar.gz", hash = "sha256:3d4f8d21c54d23d476d5ef72510b50126108f9b0bdc45b9d5d2e2b34204d56ad"}, + {file = "anthropic-0.2.4-py3-none-any.whl", hash = "sha256:2f955435bfdecc94e5432d72492c38fd22a55647026de1518461b707dd3ae808"}, + {file = "anthropic-0.2.4.tar.gz", hash = "sha256:2041546470a9a2e897d6317627fdd7ea585245a3ff1ed1b054a47f63b4daa893"}, ] [package.dependencies] +aiohttp = "*" httpx = "*" requests = "*" tokenizers = "*" @@ -6868,7 +6869,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and platform_machine == \"aarch64\" or python_version >= \"3\" and platform_machine == \"ppc64le\" or python_version >= \"3\" and platform_machine == \"x86_64\" or python_version >= \"3\" and platform_machine == \"amd64\" or python_version >= \"3\" and platform_machine == \"AMD64\" or python_version >= \"3\" and platform_machine == \"win32\" or python_version >= \"3\" and platform_machine == \"WIN32\""} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] @@ -8528,10 +8529,10 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -all = ["aleph-alpha-client", "anthropic", "beautifulsoup4", "cohere", "deeplake", "elasticsearch", "faiss-cpu", "google-api-python-client", "google-search-results", "huggingface_hub", "jina", "jinja2", "manifest-ml", "networkx", "nlpcloud", "nltk", "nomic", "openai", "opensearch-py", "pgvector", "pinecone-client", "psycopg2-binary", "pypdf", "qdrant-client", "redis", "sentence-transformers", "spacy", "tensorflow-text", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"] -llms = ["anthropic", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "torch", "transformers"] +all = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "weaviate-client", "redis", "google-api-python-client", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary"] +llms = ["anthropic", "cohere", "openai", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "fec488d52fc1a46ae34a643e5951e00251232d6d38c50c844556041df6067572" +content-hash = "b8dd776b9c9bfda2413bdc0c58eb9cc8975a0ec770a830af5568cc80e4a1925d" diff --git a/pyproject.toml b/pyproject.toml index 44577f9a..1374f5ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ pinecone-client = {version = "^2", optional = true} weaviate-client = {version = "^3", optional = true} google-api-python-client = {version = "2.70.0", optional = true} wolframalpha = {version = "5.0.0", optional = true} -anthropic = {version = "^0.2.2", optional = true} +anthropic = {version = "^0.2.4", optional = true} qdrant-client = {version = "^1.0.4", optional = true, python = ">=3.8.1,<3.12"} dataclasses-json = "^0.5.7" tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"} diff --git a/tests/integration_tests/llms/test_anthropic.py b/tests/integration_tests/llms/test_anthropic.py index 325a098d..eaa509bf 100644 --- a/tests/integration_tests/llms/test_anthropic.py +++ b/tests/integration_tests/llms/test_anthropic.py @@ -1,9 +1,11 @@ """Test Anthropic API wrapper.""" - from typing import Generator +import pytest + from langchain.callbacks.base import CallbackManager from langchain.llms.anthropic import Anthropic +from langchain.schema import LLMResult from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -37,3 +39,27 @@ def test_anthropic_streaming_callback() -> None: ) llm("Write me a sentence with 100 words.") assert callback_handler.llm_streams > 1 + + +@pytest.mark.asyncio +async def test_anthropic_async_generate() -> None: + """Test async generate.""" + llm = Anthropic() + output = await llm.agenerate(["How many toes do dogs have?"]) + assert isinstance(output, LLMResult) + + +@pytest.mark.asyncio +async def test_anthropic_async_streaming_callback() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + llm = Anthropic( + model="claude-v1", + streaming=True, + callback_manager=callback_manager, + verbose=True, + ) + result = await llm.agenerate(["How many toes do dogs have?"]) + assert callback_handler.llm_streams > 1 + assert isinstance(result, LLMResult)