Compare commits

...

9 Commits

Author SHA1 Message Date
Ankush Gola bc559ee76b fix some lint 1 year ago
Ankush Gola 1b53bbf76c add integration test 1 year ago
Ankush Gola 930edd8e77 use agenerate 1 year ago
Ankush Gola 738bf977ab Merge branch 'ankush/retry-openai' into ankush/async-llm 1 year ago
Ankush Gola f4cb9ea42b lint 1 year ago
Ankush Gola 89e0cd5bb2 add retry logic for openai 1 year ago
Ankush Gola a9af287297 generate async 1 year ago
Ankush Gola 54bf243e36 add async_generate 1 year ago
Ankush Gola 0b211f0394 add new package 1 year ago

@ -2,7 +2,7 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union, Tuple
import yaml
from pydantic import BaseModel, Extra, Field, validator
@ -17,6 +17,34 @@ def _get_verbosity() -> bool:
return langchain.verbose
def get_prompts(params: Dict[str, Any], prompts: List[str]) -> tuple[Dict[int, list], str, list[int], list[str]]:
"""Get prompts that are already cached."""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
for i, prompt in enumerate(prompts):
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
def get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
):
"""Get the LLM output."""
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
langchain.llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
@ -58,6 +86,12 @@ class BaseLLM(BaseModel, ABC):
) -> LLMResult:
"""Run the LLM on the given prompts."""
@abstractmethod
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompts."""
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
@ -81,17 +115,12 @@ class BaseLLM(BaseModel, ABC):
return output
params = self.dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
for i, prompt in enumerate(prompts):
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
(
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
@ -102,11 +131,55 @@ class BaseLLM(BaseModel, ABC):
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
langchain.llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
llm_output = get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
else:
llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = await self._agenerate(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
params = self.dict()
params["stop"] = stop
(
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
)
try:
new_results = await self._agenerate(missing_prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
else:
llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))]
@ -212,3 +285,9 @@ class LLM(BaseLLM):
text = self._call(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.")

@ -1,9 +1,16 @@
"""Wrapper around OpenAI APIs."""
import logging
import sys
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union, Set
from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult
@ -12,6 +19,16 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def update_token_usage(keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response["usage"])
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
class BaseOpenAI(BaseLLM, BaseModel):
"""Wrapper around OpenAI large language models.
@ -56,6 +73,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
"""Adjust the probability of specific tokens being generated."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
class Config:
"""Configuration for this pydantic object."""
@ -115,6 +134,32 @@ class BaseOpenAI(BaseLLM, BaseModel):
}
return {**normal_params, **self.model_kwargs}
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
@retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
),
after=after_log(logger, logging.DEBUG),
)
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
@ -134,11 +179,41 @@ class BaseOpenAI(BaseLLM, BaseModel):
"""
# TODO: write a unit test for this
params = self._invocation_params
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = self.completion_with_retry(prompt=_prompts, **params)
choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = await self.client.acreate(prompt=_prompts, **params)
choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
def get_sub_prompts(self, params, prompts, stop):
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
if params["max_tokens"] == -1:
if len(prompts) != 1:
raise ValueError(
@ -146,26 +221,15 @@ class BaseOpenAI(BaseLLM, BaseModel):
)
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
sub_prompts = [
prompts[i : i + self.batch_size]
prompts[i: i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
choices = []
token_usage = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = self.client.create(prompt=_prompts, **params)
choices.extend(response["choices"])
_keys_to_use = _keys.intersection(response["usage"])
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
return sub_prompts
def create_llm_result(self, choices, prompts, token_usage):
generations = []
for i, prompt in enumerate(prompts):
sub_choices = choices[i * self.n : (i + 1) * self.n]
sub_choices = choices[i * self.n: (i + 1) * self.n]
generations.append(
[
Generation(

43
poetry.lock generated

@ -851,7 +851,6 @@ files = [
{file = "debugpy-1.6.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b5d1b13d7c7bf5d7cf700e33c0b8ddb7baf030fcf502f76fc061ddd9405d16c"},
{file = "debugpy-1.6.6-cp38-cp38-win32.whl", hash = "sha256:70ab53918fd907a3ade01909b3ed783287ede362c80c75f41e79596d5ccacd32"},
{file = "debugpy-1.6.6-cp38-cp38-win_amd64.whl", hash = "sha256:c05349890804d846eca32ce0623ab66c06f8800db881af7a876dc073ac1c2225"},
{file = "debugpy-1.6.6-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:11a0f3a106f69901e4a9a5683ce943a7a5605696024134b522aa1bfda25b5fec"},
{file = "debugpy-1.6.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a771739902b1ae22a120dbbb6bd91b2cae6696c0e318b5007c5348519a4211c6"},
{file = "debugpy-1.6.6-cp39-cp39-win32.whl", hash = "sha256:549ae0cb2d34fc09d1675f9b01942499751d174381b6082279cf19cdb3c47cbe"},
{file = "debugpy-1.6.6-cp39-cp39-win_amd64.whl", hash = "sha256:de4a045fbf388e120bb6ec66501458d3134f4729faed26ff95de52a754abddb1"},
@ -2411,6 +2410,7 @@ files = [
{file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"},
{file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"},
{file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"},
{file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"},
{file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"},
{file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"},
{file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"},
@ -2420,6 +2420,7 @@ files = [
{file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"},
{file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"},
{file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"},
{file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"},
{file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"},
{file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"},
{file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"},
@ -3899,6 +3900,25 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.20.3"
description = "Pytest support for asyncio"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-asyncio-0.20.3.tar.gz", hash = "sha256:83cbf01169ce3e8eb71c6c278ccb0574d1a7a3bb8eaaf5e50e0ad342afb33b36"},
{file = "pytest_asyncio-0.20.3-py3-none-any.whl", hash = "sha256:f129998b209d04fcc65c96fc85c11e5316738358909a8399e93be553d7656442"},
]
[package.dependencies]
pytest = ">=6.1.0"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]]
name = "pytest-cov"
version = "4.0.0"
@ -5092,6 +5112,21 @@ files = [
[package.extras]
widechars = ["wcwidth"]
[[package]]
name = "tenacity"
version = "8.1.0"
description = "Retry code until it succeeds"
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "tenacity-8.1.0-py3-none-any.whl", hash = "sha256:35525cd47f82830069f0d6b73f7eb83bc5b73ee2fff0437952cedf98b27653ac"},
{file = "tenacity-8.1.0.tar.gz", hash = "sha256:e48c437fdf9340f5666b92cd7990e96bc5fc955e1298baf4a907e3972067a445"},
]
[package.extras]
doc = ["reno", "sphinx", "tornado (>=4.5)"]
[[package]]
name = "tensorboard"
version = "2.11.2"
@ -5505,11 +5540,14 @@ files = [
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"},
{file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"},
{file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"},
{file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"},
{file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"},
{file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"},
{file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"},
{file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"},
@ -5518,6 +5556,7 @@ files = [
{file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"},
{file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"},
{file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"},
{file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"},
@ -6254,4 +6293,4 @@ llms = ["manifest-ml", "torch", "transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "fdf51cf2f138653a8e32f1c43a2dcbbcca76f74534afc9dc4a7879c58282100a"
content-hash = "b4470de82ffcc2fab1aa0bdb6bdadd1b647cebe9194f7f47429ff257616e607f"

@ -36,6 +36,7 @@ wolframalpha = {version = "5.0.0", optional = true}
qdrant-client = {version = "^0.11.7", optional = true}
dataclasses-json = "^0.5.7"
tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"}
tenacity = "^8.1.0"
[tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0"
@ -59,6 +60,7 @@ duckdb-engine = "^0.6.6"
pytest-watcher = "^0.2.6"
freezegun = "^1.2.2"
responses = "^0.22.0"
pytest-asyncio = "^0.20.3"
[tool.poetry.group.lint.dependencies]
flake8-docstrings = "^1.6.0"

@ -7,6 +7,7 @@ import pytest
from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI
from langchain.schema import LLMResult
def test_openai_call() -> None:
@ -74,3 +75,11 @@ def test_openai_streaming_error() -> None:
llm = OpenAI(best_of=2)
with pytest.raises(ValueError):
llm.stream("I'm Pickle Rick")
@pytest.mark.asyncio
async def test_openai_async_generate() -> None:
"""Test async generation."""
llm = OpenAI(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)

@ -33,6 +33,11 @@ class FakeLLM(BaseLLM, BaseModel):
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@property
def _llm_type(self) -> str:
"""Return type of llm."""

@ -0,0 +1,35 @@
import asyncio
from langchain.llms import OpenAI
def generate_serially():
llm = OpenAI(temperature=0)
for _ in range(10):
resp = llm.generate(["Hello, how are you?"])
# print(resp)
async def async_generate(llm):
resp = await llm.agenerate(["Hello, how are you?"])
# print(resp)
async def generate_concurrently():
llm = OpenAI(temperature=0)
tasks = [async_generate(llm) for _ in range(10)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
import time
s = time.perf_counter()
asyncio.run(generate_concurrently())
elapsed = time.perf_counter() - s
print(f"Concurrent executed in {elapsed:0.2f} seconds.")
s = time.perf_counter()
generate_serially()
elapsed = time.perf_counter() - s
print(f"Serial executed in {elapsed:0.2f} seconds.")
Loading…
Cancel
Save