Implement async support for Cohere (#8237)

This PR introduces async API support for Cohere, both LLM and
embeddings. It requires updating `cohere` package to `^4`.

Tagging @hwchase17, @baskaryan, @agola11

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/8330/head
Kacper Łukawski 1 year ago committed by GitHub
parent bf1357f584
commit c5988c1d4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,7 +9,7 @@
"\n",
"LangChain provides async support for LLMs by leveraging the [asyncio](https://docs.python.org/3/library/asyncio.html) library.\n",
"\n",
"Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI` and `Anthropic` are supported, but async support for other LLMs is on the roadmap.\n",
"Async support is particularly useful for calling multiple LLMs concurrently, as these calls are network-bound. Currently, `OpenAI`, `PromptLayerOpenAI`, `ChatOpenAI`, `Anthropic` and `Cohere` are supported, but async support for other LLMs is on the roadmap.\n",
"\n",
"You can use the `agenerate` method to call an OpenAI LLM asynchronously."
]
@ -56,7 +56,7 @@
"\n",
"\n",
"I'm doing well, thank you. How about you?\n",
"\u001b[1mConcurrent executed in 1.39 seconds.\u001b[0m\n",
"\u001B[1mConcurrent executed in 1.39 seconds.\u001B[0m\n",
"\n",
"\n",
"I'm doing well, thank you. How about you?\n",
@ -86,7 +86,7 @@
"\n",
"\n",
"I'm doing well, thanks for asking. How about you?\n",
"\u001b[1mSerial executed in 5.77 seconds.\u001b[0m\n"
"\u001B[1mSerial executed in 5.77 seconds.\u001B[0m\n"
]
}
],

@ -24,6 +24,8 @@ class CohereEmbeddings(BaseModel, Embeddings):
client: Any #: :meta private:
"""Cohere client."""
async_client: Any #: :meta private:
"""Cohere async client."""
model: str = "embed-english-v2.0"
"""Model name to use."""
@ -47,6 +49,7 @@ class CohereEmbeddings(BaseModel, Embeddings):
import cohere
values["client"] = cohere.Client(cohere_api_key)
values["async_client"] = cohere.AsyncClient(cohere_api_key)
except ImportError:
raise ValueError(
"Could not import cohere python package. "
@ -68,6 +71,20 @@ class CohereEmbeddings(BaseModel, Embeddings):
).embeddings
return [list(map(float, e)) for e in embeddings]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Async call out to Cohere's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = await self.async_client.embed(
model=self.model, texts=texts, truncate=self.truncate
)
return [list(map(float, e)) for e in embeddings.embeddings]
def embed_query(self, text: str) -> List[float]:
"""Call out to Cohere's embedding endpoint.
@ -77,7 +94,16 @@ class CohereEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
embedding = self.client.embed(
model=self.model, texts=[text], truncate=self.truncate
).embeddings[0]
return list(map(float, embedding))
return self.embed_documents([text])[0]
async def aembed_query(self, text: str) -> List[float]:
"""Async call out to Cohere's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
embeddings = await self.aembed_documents([text])
return embeddings[0]

@ -12,7 +12,10 @@ from tenacity import (
wait_exponential,
)
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
@ -47,6 +50,17 @@ def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
return _completion_with_retry(**kwargs)
def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
return await llm.async_client.generate(**kwargs)
return _completion_with_retry(**kwargs)
class Cohere(LLM):
"""Cohere large language models.
@ -62,6 +76,7 @@ class Cohere(LLM):
"""
client: Any #: :meta private:
async_client: Any #: :meta private:
model: Optional[str] = None
"""Model name to use."""
@ -109,6 +124,7 @@ class Cohere(LLM):
import cohere
values["client"] = cohere.Client(cohere_api_key)
values["async_client"] = cohere.AsyncClient(cohere_api_key)
except ImportError:
raise ImportError(
"Could not import cohere python package. "
@ -139,6 +155,24 @@ class Cohere(LLM):
"""Return type of llm."""
return "cohere"
def _invocation_params(self, stop: Optional[List[str]], **kwargs: Any) -> dict:
params = self._default_params
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop_sequences"] = self.stop
else:
params["stop_sequences"] = stop
return {**params, **kwargs}
def _process_response(self, response: Any, stop: Optional[List[str]]) -> str:
text = response.generations[0].text
# If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop:
text = enforce_stop_tokens(text, stop)
return text
def _call(
self,
prompt: str,
@ -160,20 +194,37 @@ class Cohere(LLM):
response = cohere("Tell me a joke.")
"""
params = self._default_params
if self.stop is not None and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop is not None:
params["stop_sequences"] = self.stop
else:
params["stop_sequences"] = stop
params = {**params, **kwargs}
params = self._invocation_params(stop, **kwargs)
response = completion_with_retry(
self, model=self.model, prompt=prompt, **params
)
text = response.generations[0].text
# If stop tokens are provided, Cohere's endpoint returns them.
# In order to make this consistent with other endpoints, we strip them.
if stop is not None or self.stop is not None:
text = enforce_stop_tokens(text, params["stop_sequences"])
return text
_stop = params.get("stop_sequences")
return self._process_response(response, _stop)
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Async call out to Cohere's generate endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = await cohere("Tell me a joke.")
"""
params = self._invocation_params(stop, **kwargs)
response = await acompletion_with_retry(
self, model=self.model, prompt=prompt, **params
)
_stop = params.get("stop_sequences")
return self._process_response(response, _stop)

@ -1863,18 +1863,23 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency
[[package]]
name = "cohere"
version = "3.10.0"
description = "A Python library for the Cohere API"
version = "4.18.0"
description = ""
category = "main"
optional = true
python-versions = ">=3.6"
python-versions = ">=3.7,<4.0"
files = [
{file = "cohere-3.10.0.tar.gz", hash = "sha256:8c06a87a47aa9521051eeba130ce391d84ab578148c4ea5b62f6dcc41bd3a274"},
{file = "cohere-4.18.0-py3-none-any.whl", hash = "sha256:26b5be3f93c0046be7fd89b2e724190e10f9fceac8bcf8f22581368a1f3af2e4"},
{file = "cohere-4.18.0.tar.gz", hash = "sha256:ed3d5703384412312fd827e669364b2f0eb3678a1206987cb3e1d98b88409c31"},
]
[package.dependencies]
requests = "*"
urllib3 = ">=1.26,<2.0"
aiohttp = ">=3.0,<4.0"
backoff = ">=2.0,<3.0"
fastavro = "1.7.4"
importlib_metadata = ">=6.0,<7.0"
requests = ">=2.25.0,<3.0.0"
urllib3 = ">=1.26,<3"
[[package]]
name = "colorama"
@ -2689,6 +2694,53 @@ dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>
doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"]
test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"]
[[package]]
name = "fastavro"
version = "1.7.4"
description = "Fast read/write of AVRO files"
category = "main"
optional = true
python-versions = ">=3.7"
files = [
{file = "fastavro-1.7.4-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:7568e621b94e061974b2a96d70670d09910e0a71482dd8610b153c07bd768497"},
{file = "fastavro-1.7.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4ec994faf64b743647f0027fcc56b01dc15d46c0e48fa15828277cb02dbdcd6"},
{file = "fastavro-1.7.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:727fdc1ddd12fcc6addab0b6df12ef999a6babe4b753db891f78aa2ee33edc77"},
{file = "fastavro-1.7.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b2f0cb3f7795fcb0042e0bbbe51204c28338a455986d68409b26dcbde64dd69a"},
{file = "fastavro-1.7.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bb0a8b5016a99be4b8ce3550889a1bd968c0fb3f521bcfbae24210c6342aee0c"},
{file = "fastavro-1.7.4-cp310-cp310-win_amd64.whl", hash = "sha256:1d2040b2bf3dc1a75170ea44d1e7e09f84fb77f40ef2e6c6b9f2eaf710557083"},
{file = "fastavro-1.7.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5542423f46bb7fc9699c467cbf151c2713aa6976ef14f4f5ec3532d80d0bb616"},
{file = "fastavro-1.7.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec396e6ab6b272708c8b9a0142df01fff4c7a1f168050f292ab92fdaee0b0257"},
{file = "fastavro-1.7.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b10d68c03371b79f461feca1c6c7e9d3f6aea2e9c7472b25cd749c57562aa1"},
{file = "fastavro-1.7.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:f94d5168ec72f3cfcf2181df1c46ad240dc1fcf361717447d2c5237121b9df55"},
{file = "fastavro-1.7.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:bad3dc279ed4ce747989259035cb3607f189ef7aff40339202f9321ca7f83d0b"},
{file = "fastavro-1.7.4-cp311-cp311-win_amd64.whl", hash = "sha256:8480ff444d9c7abd0bf121dd68656bd2115caca8ed28e71936eff348fde706e0"},
{file = "fastavro-1.7.4-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:bd3d669f4ec6915c88bb80b7c14e01d2c3ceb93a61de5dcf33ff13972bba505e"},
{file = "fastavro-1.7.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1a312b128536b81bdb79f27076f513b998abe7d13ee6fe52e99bc01f7ad9b06a"},
{file = "fastavro-1.7.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:487054d1419f1bfa41e7f19c718cbdbbb254319d3fd5b9ac411054d6432b9d40"},
{file = "fastavro-1.7.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d2897fe7d1d5b27dcd33c43d68480de36e55a0e651d7731004a36162cd3eed9e"},
{file = "fastavro-1.7.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6d318b49fd648a1fd93394411fe23761b486ac65dadea7c52dbeb0d0bef30221"},
{file = "fastavro-1.7.4-cp37-cp37m-win_amd64.whl", hash = "sha256:a117c3b122a8110c6ab99b3e66736790b4be19ceefb1edf0e732c33b3dc411c8"},
{file = "fastavro-1.7.4-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:0cca15e1a1f829e40524004342e425acfb594cefbd3388b0a5d13542750623ac"},
{file = "fastavro-1.7.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9211ec7a18a46a2aee01a2a979fd79f05f36b11fdb1bc469c9d9fd8cec32579"},
{file = "fastavro-1.7.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f16bde6b5fb51e15233bfcee0378f48d4221201ba45e497a8063f6d216b7aad7"},
{file = "fastavro-1.7.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:aeca55c905ff4c667f2158564654a778918988811ae3eb28592767edcf5f5c4a"},
{file = "fastavro-1.7.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b244f3abc024fc043d6637284ba2ffee5a1291c08a0f361ea1af4d829f66f303"},
{file = "fastavro-1.7.4-cp38-cp38-win_amd64.whl", hash = "sha256:b64e394c87cb99d0681727e1ae5d3633906a72abeab5ea0c692394aeb5a56607"},
{file = "fastavro-1.7.4-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:8c8115bdb1c862354d9abd0ea23eab85793bbff139087f2607bd4b83e8ae07ab"},
{file = "fastavro-1.7.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b27dd08f2338a478185c6ba23308002f334642ce83a6aeaf8308271efef88062"},
{file = "fastavro-1.7.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f087c246afab8bac08d86ef21be87cbf4f3779348fb960c081863fc3d570412c"},
{file = "fastavro-1.7.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b4077e17a2bab37af96e5ca52e61b6f2b85e4577e7a2903f6814642eb6a834f7"},
{file = "fastavro-1.7.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:776511cecf2ea9da4edd0de5015c1562cd9063683cf94f79bc9e20bab8f06923"},
{file = "fastavro-1.7.4-cp39-cp39-win_amd64.whl", hash = "sha256:a7ea5565fe2c145e074ce9ba75fafd5479a86b34a8dbd00dd1835cf192290e14"},
{file = "fastavro-1.7.4.tar.gz", hash = "sha256:6450f47ac4db95ec3a9e6434fec1f8a3c4c8c941de16205832ca8c67dd23d0d2"},
]
[package.extras]
codecs = ["lz4", "python-snappy", "zstandard"]
lz4 = ["lz4"]
snappy = ["python-snappy"]
zstandard = ["zstandard"]
[[package]]
name = "fastjsonschema"
version = "2.17.1"
@ -12500,4 +12552,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "7a8847de4dd88e71b423ff148823523220a5649340178e8ab1f7bafb03a290d2"
content-hash = "4f5d91f450555bb3a039c3aef4a7996d1322f25608ec17a7b0c1ad92813d6a63"

@ -47,7 +47,7 @@ qdrant-client = {version = "^1.3.1", optional = true, python = ">=3.8.1,<3.12"}
dataclasses-json = "^0.5.7"
tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"}
tenacity = "^8.1.0"
cohere = {version = "^3", optional = true}
cohere = {version = "^4", optional = true}
openai = {version = "^0", optional = true}
nlpcloud = {version = "^1", optional = true}
nomic = {version = "^1.0.43", optional = true}

Loading…
Cancel
Save