From c5988c1d4b17fcc1a89ee0bfed6e1c44edb92efe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20=C5=81ukawski?= Date: Thu, 27 Jul 2023 00:51:18 +0200 Subject: [PATCH] Implement async support for Cohere (#8237) This PR introduces async API support for Cohere, both LLM and embeddings. It requires updating `cohere` package to `^4`. Tagging @hwchase17, @baskaryan, @agola11 --------- Co-authored-by: Bagatur --- .../model_io/models/llms/async_llm.ipynb | 6 +- libs/langchain/langchain/embeddings/cohere.py | 34 +++++++- libs/langchain/langchain/llms/cohere.py | 81 +++++++++++++++---- libs/langchain/poetry.lock | 66 +++++++++++++-- libs/langchain/pyproject.toml | 2 +- 5 files changed, 159 insertions(+), 30 deletions(-) diff --git a/docs/extras/modules/model_io/models/llms/async_llm.ipynb b/docs/extras/modules/model_io/models/llms/async_llm.ipynb index 56972c9481..585fe26cd1 100644 --- a/docs/extras/modules/model_io/models/llms/async_llm.ipynb +++ b/docs/extras/modules/model_io/models/llms/async_llm.ipynb @@ -9,7 +9,7 @@ "\n", "LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n", "\n", - "Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI` and `Anthropic` are supported, but async support for other LLMs is on the roadmap.\n", + "Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI`, `Anthropic` and `Cohere` are supported, but async support for other LLMs is on the roadmap.\n", "\n", "You can use the `agenerate` method to call an OpenAI LLM asynchronously." ] @@ -56,7 +56,7 @@ "\n", "\n", "I'm doing well, thank you. How about you?\n", - "\u001b[1mConcurrent executed in 1.39 seconds.\u001b[0m\n", + "\u001B[1mConcurrent executed in 1.39 seconds.\u001B[0m\n", "\n", "\n", "I'm doing well, thank you. How about you?\n", @@ -86,7 +86,7 @@ "\n", "\n", "I'm doing well, thanks for asking. How about you?\n", - "\u001b[1mSerial executed in 5.77 seconds.\u001b[0m\n" + "\u001B[1mSerial executed in 5.77 seconds.\u001B[0m\n" ] } ], diff --git a/libs/langchain/langchain/embeddings/cohere.py b/libs/langchain/langchain/embeddings/cohere.py index ba4b472635..28a3b6cc2a 100644 --- a/libs/langchain/langchain/embeddings/cohere.py +++ b/libs/langchain/langchain/embeddings/cohere.py @@ -24,6 +24,8 @@ class CohereEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: """Cohere client.""" + async_client: Any #: :meta private: + """Cohere async client.""" model: str = "embed-english-v2.0" """Model name to use.""" @@ -47,6 +49,7 @@ class CohereEmbeddings(BaseModel, Embeddings): import cohere values["client"] = cohere.Client(cohere_api_key) + values["async_client"] = cohere.AsyncClient(cohere_api_key) except ImportError: raise ValueError( "Could not import cohere python package. " @@ -68,6 +71,20 @@ class CohereEmbeddings(BaseModel, Embeddings): ).embeddings return [list(map(float, e)) for e in embeddings] + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Async call out to Cohere's embedding endpoint. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + embeddings = await self.async_client.embed( + model=self.model, texts=texts, truncate=self.truncate + ) + return [list(map(float, e)) for e in embeddings.embeddings] + def embed_query(self, text: str) -> List[float]: """Call out to Cohere's embedding endpoint. @@ -77,7 +94,16 @@ class CohereEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - embedding = self.client.embed( - model=self.model, texts=[text], truncate=self.truncate - ).embeddings[0] - return list(map(float, embedding)) + return self.embed_documents([text])[0] + + async def aembed_query(self, text: str) -> List[float]: + """Async call out to Cohere's embedding endpoint. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + embeddings = await self.aembed_documents([text]) + return embeddings[0] diff --git a/libs/langchain/langchain/llms/cohere.py b/libs/langchain/langchain/llms/cohere.py index 2e37697990..bad5a15651 100644 --- a/libs/langchain/langchain/llms/cohere.py +++ b/libs/langchain/langchain/llms/cohere.py @@ -12,7 +12,10 @@ from tenacity import ( wait_exponential, ) -from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain.llms.base import LLM from langchain.llms.utils import enforce_stop_tokens from langchain.utils import get_from_dict_or_env @@ -47,6 +50,17 @@ def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any: return _completion_with_retry(**kwargs) +def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = _create_retry_decorator(llm) + + @retry_decorator + async def _completion_with_retry(**kwargs: Any) -> Any: + return await llm.async_client.generate(**kwargs) + + return _completion_with_retry(**kwargs) + + class Cohere(LLM): """Cohere large language models. @@ -62,6 +76,7 @@ class Cohere(LLM): """ client: Any #: :meta private: + async_client: Any #: :meta private: model: Optional[str] = None """Model name to use.""" @@ -109,6 +124,7 @@ class Cohere(LLM): import cohere values["client"] = cohere.Client(cohere_api_key) + values["async_client"] = cohere.AsyncClient(cohere_api_key) except ImportError: raise ImportError( "Could not import cohere python package. " @@ -139,6 +155,24 @@ class Cohere(LLM): """Return type of llm.""" return "cohere" + def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict: + params = self._default_params + if self.stop is not None and stop is not None: + raise ValueError("`stop` found in both the input and default params.") + elif self.stop is not None: + params["stop_sequences"] = self.stop + else: + params["stop_sequences"] = stop + return {**params, **kwargs} + + def _process_response(self, response: Any, stop: Optional[List[str]]) -> str: + text = response.generations[0].text + # If stop tokens are provided, Cohere's endpoint returns them. + # In order to make this consistent with other endpoints, we strip them. + if stop: + text = enforce_stop_tokens(text, stop) + return text + def _call( self, prompt: str, @@ -160,20 +194,37 @@ class Cohere(LLM): response = cohere("Tell me a joke.") """ - params = self._default_params - if self.stop is not None and stop is not None: - raise ValueError("`stop` found in both the input and default params.") - elif self.stop is not None: - params["stop_sequences"] = self.stop - else: - params["stop_sequences"] = stop - params = {**params, **kwargs} + params = self._invocation_params(stop, **kwargs) response = completion_with_retry( self, model=self.model, prompt=prompt, **params ) - text = response.generations[0].text - # If stop tokens are provided, Cohere's endpoint returns them. - # In order to make this consistent with other endpoints, we strip them. - if stop is not None or self.stop is not None: - text = enforce_stop_tokens(text, params["stop_sequences"]) - return text + _stop = params.get("stop_sequences") + return self._process_response(response, _stop) + + async def _acall( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Async call out to Cohere's generate endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = await cohere("Tell me a joke.") + """ + params = self._invocation_params(stop, **kwargs) + response = await acompletion_with_retry( + self, model=self.model, prompt=prompt, **params + ) + _stop = params.get("stop_sequences") + return self._process_response(response, _stop) diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 88615250fb..cd3d220df5 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -1863,18 +1863,23 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency [[package]] name = "cohere" -version = "3.10.0" -description = "A Python library for the Cohere API" +version = "4.18.0" +description = "" category = "main" optional = true -python-versions = ">=3.6" +python-versions = ">=3.7,<4.0" files = [ - {file = "cohere-3.10.0.tar.gz", hash = "sha256:8c06a87a47aa9521051eeba130ce391d84ab578148c4ea5b62f6dcc41bd3a274"}, + {file = "cohere-4.18.0-py3-none-any.whl", hash = "sha256:26b5be3f93c0046be7fd89b2e724190e10f9fceac8bcf8f22581368a1f3af2e4"}, + {file = "cohere-4.18.0.tar.gz", hash = "sha256:ed3d5703384412312fd827e669364b2f0eb3678a1206987cb3e1d98b88409c31"}, ] [package.dependencies] -requests = "*" -urllib3 = ">=1.26,<2.0" +aiohttp = ">=3.0,<4.0" +backoff = ">=2.0,<3.0" +fastavro = "1.7.4" +importlib_metadata = ">=6.0,<7.0" +requests = ">=2.25.0,<3.0.0" +urllib3 = ">=1.26,<3" [[package]] name = "colorama" @@ -2689,6 +2694,53 @@ dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (> doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"] test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] +[[package]] +name = "fastavro" +version = "1.7.4" +description = "Fast read/write of AVRO files" +category = "main" +optional = true +python-versions = ">=3.7" +files = [ + {file = "fastavro-1.7.4-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7568e621b94e061974b2a96d70670d09910e0a71482dd8610b153c07bd768497"}, + {file = "fastavro-1.7.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4ec994faf64b743647f0027fcc56b01dc15d46c0e48fa15828277cb02dbdcd6"}, + {file = "fastavro-1.7.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:727fdc1ddd12fcc6addab0b6df12ef999a6babe4b753db891f78aa2ee33edc77"}, + {file = "fastavro-1.7.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b2f0cb3f7795fcb0042e0bbbe51204c28338a455986d68409b26dcbde64dd69a"}, + {file = "fastavro-1.7.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bb0a8b5016a99be4b8ce3550889a1bd968c0fb3f521bcfbae24210c6342aee0c"}, + {file = "fastavro-1.7.4-cp310-cp310-win_amd64.whl", hash = "sha256:1d2040b2bf3dc1a75170ea44d1e7e09f84fb77f40ef2e6c6b9f2eaf710557083"}, + {file = "fastavro-1.7.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5542423f46bb7fc9699c467cbf151c2713aa6976ef14f4f5ec3532d80d0bb616"}, + {file = "fastavro-1.7.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec396e6ab6b272708c8b9a0142df01fff4c7a1f168050f292ab92fdaee0b0257"}, + {file = "fastavro-1.7.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b10d68c03371b79f461feca1c6c7e9d3f6aea2e9c7472b25cd749c57562aa1"}, + {file = "fastavro-1.7.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f94d5168ec72f3cfcf2181df1c46ad240dc1fcf361717447d2c5237121b9df55"}, + {file = "fastavro-1.7.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bad3dc279ed4ce747989259035cb3607f189ef7aff40339202f9321ca7f83d0b"}, + {file = "fastavro-1.7.4-cp311-cp311-win_amd64.whl", hash = "sha256:8480ff444d9c7abd0bf121dd68656bd2115caca8ed28e71936eff348fde706e0"}, + {file = "fastavro-1.7.4-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:bd3d669f4ec6915c88bb80b7c14e01d2c3ceb93a61de5dcf33ff13972bba505e"}, + {file = "fastavro-1.7.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a312b128536b81bdb79f27076f513b998abe7d13ee6fe52e99bc01f7ad9b06a"}, + {file = "fastavro-1.7.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:487054d1419f1bfa41e7f19c718cbdbbb254319d3fd5b9ac411054d6432b9d40"}, + {file = "fastavro-1.7.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d2897fe7d1d5b27dcd33c43d68480de36e55a0e651d7731004a36162cd3eed9e"}, + {file = "fastavro-1.7.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6d318b49fd648a1fd93394411fe23761b486ac65dadea7c52dbeb0d0bef30221"}, + {file = "fastavro-1.7.4-cp37-cp37m-win_amd64.whl", hash = "sha256:a117c3b122a8110c6ab99b3e66736790b4be19ceefb1edf0e732c33b3dc411c8"}, + {file = "fastavro-1.7.4-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:0cca15e1a1f829e40524004342e425acfb594cefbd3388b0a5d13542750623ac"}, + {file = "fastavro-1.7.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9211ec7a18a46a2aee01a2a979fd79f05f36b11fdb1bc469c9d9fd8cec32579"}, + {file = "fastavro-1.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f16bde6b5fb51e15233bfcee0378f48d4221201ba45e497a8063f6d216b7aad7"}, + {file = "fastavro-1.7.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aeca55c905ff4c667f2158564654a778918988811ae3eb28592767edcf5f5c4a"}, + {file = "fastavro-1.7.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b244f3abc024fc043d6637284ba2ffee5a1291c08a0f361ea1af4d829f66f303"}, + {file = "fastavro-1.7.4-cp38-cp38-win_amd64.whl", hash = "sha256:b64e394c87cb99d0681727e1ae5d3633906a72abeab5ea0c692394aeb5a56607"}, + {file = "fastavro-1.7.4-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:8c8115bdb1c862354d9abd0ea23eab85793bbff139087f2607bd4b83e8ae07ab"}, + {file = "fastavro-1.7.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b27dd08f2338a478185c6ba23308002f334642ce83a6aeaf8308271efef88062"}, + {file = "fastavro-1.7.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f087c246afab8bac08d86ef21be87cbf4f3779348fb960c081863fc3d570412c"}, + {file = "fastavro-1.7.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b4077e17a2bab37af96e5ca52e61b6f2b85e4577e7a2903f6814642eb6a834f7"}, + {file = "fastavro-1.7.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:776511cecf2ea9da4edd0de5015c1562cd9063683cf94f79bc9e20bab8f06923"}, + {file = "fastavro-1.7.4-cp39-cp39-win_amd64.whl", hash = "sha256:a7ea5565fe2c145e074ce9ba75fafd5479a86b34a8dbd00dd1835cf192290e14"}, + {file = "fastavro-1.7.4.tar.gz", hash = "sha256:6450f47ac4db95ec3a9e6434fec1f8a3c4c8c941de16205832ca8c67dd23d0d2"}, +] + +[package.extras] +codecs = ["lz4", "python-snappy", "zstandard"] +lz4 = ["lz4"] +snappy = ["python-snappy"] +zstandard = ["zstandard"] + [[package]] name = "fastjsonschema" version = "2.17.1" @@ -12500,4 +12552,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "7a8847de4dd88e71b423ff148823523220a5649340178e8ab1f7bafb03a290d2" +content-hash = "4f5d91f450555bb3a039c3aef4a7996d1322f25608ec17a7b0c1ad92813d6a63" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index 513fc0e2ff..fa6c220492 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -47,7 +47,7 @@ qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"} dataclasses-json = "^0.5.7" tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"} tenacity = "^8.1.0" -cohere = {version = "^3", optional = true} +cohere = {version = "^4", optional = true} openai = {version = "^0", optional = true} nlpcloud = {version = "^1", optional = true} nomic = {version = "^1.0.43", optional = true}