Compare commits

...

12 Commits

Author SHA1 Message Date
Ankush Gola 496ee53c6c cr 1 year ago
Ankush Gola 2611fdd03e Merge branch 'ankush/async-llm' into ankush/async-llmchain 1 year ago
Ankush Gola bc559ee76b fix some lint 1 year ago
Ankush Gola 1b53bbf76c add integration test 1 year ago
Ankush Gola 930edd8e77 use agenerate 1 year ago
Ankush Gola 738bf977ab Merge branch 'ankush/retry-openai' into ankush/async-llm 1 year ago
Ankush Gola f4cb9ea42b lint 1 year ago
Ankush Gola 89e0cd5bb2 add retry logic for openai 1 year ago
Ankush Gola 103c3046e8 refactor chain 1 year ago
Ankush Gola a9af287297 generate async 1 year ago
Ankush Gola 54bf243e36 add async_generate 1 year ago
Ankush Gola 0b211f0394 add new package 1 year ago

@ -111,6 +111,10 @@ class Chain(BaseModel, ABC):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output.""" """Run the logic of this chain and return the output."""
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
raise NotImplementedError("Async call not supported for this chain type.")
def __call__( def __call__(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
@ -125,24 +129,7 @@ class Chain(BaseModel, ABC):
chain will be returned. Defaults to False. chain will be returned. Defaults to False.
""" """
if not isinstance(inputs, dict): inputs = self.prep_inputs(inputs)
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
self.callback_manager.on_chain_start( self.callback_manager.on_chain_start(
{"name": self.__class__.__name__}, {"name": self.__class__.__name__},
inputs, inputs,
@ -154,6 +141,37 @@ class Chain(BaseModel, ABC):
self.callback_manager.on_chain_error(e, verbose=self.verbose) self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose) self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
async def acall(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
"""
inputs = self.prep_inputs(inputs)
self.callback_manager.on_chain_start(
{"name": self.__class__.__name__},
inputs,
verbose=self.verbose,
)
try:
outputs = await self._acall(inputs)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
def prep_outputs(self, inputs, outputs, return_only_outputs):
self._validate_outputs(outputs) self._validate_outputs(outputs)
if self.memory is not None: if self.memory is not None:
self.memory.save_context(inputs, outputs) self.memory.save_context(inputs, outputs)
@ -162,6 +180,27 @@ class Chain(BaseModel, ABC):
else: else:
return {**inputs, **outputs} return {**inputs, **outputs}
def prep_inputs(self, inputs):
if not isinstance(inputs, dict):
_input_keys = set(self.input_keys)
if self.memory is not None:
# If there are multiple input keys, but some get set by memory so that
# only one is not set, we can still figure out which key it is.
_input_keys = _input_keys.difference(self.memory.memory_variables)
if len(_input_keys) != 1:
raise ValueError(
f"A single string input was passed in, but this chain expects "
f"multiple inputs ({_input_keys}). When a chain expects "
f"multiple inputs, please call it by passing in a dictionary, "
"eg `chain({'foo': 1, 'bar': 2})`"
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Call the chain on all inputs in the list.""" """Call the chain on all inputs in the list."""
return [self(inputs) for inputs in input_list] return [self(inputs) for inputs in input_list]

@ -56,6 +56,17 @@ class LLMChain(Chain, BaseModel):
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs.""" """Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
response = self.llm.generate(prompts, stop=stop)
return response
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
response = await self.llm.agenerate(prompts, stop=stop)
return response
def prep_prompts(self, input_list):
stop = None stop = None
if "stop" in input_list[0]: if "stop" in input_list[0]:
stop = input_list[0]["stop"] stop = input_list[0]["stop"]
@ -71,12 +82,19 @@ class LLMChain(Chain, BaseModel):
"If `stop` is present in any inputs, should be present in all." "If `stop` is present in any inputs, should be present in all."
) )
prompts.append(prompt) prompts.append(prompt)
response = self.llm.generate(prompts, stop=stop) return prompts, stop
return response
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains.""" """Utilize the LLM generate method for speed gains."""
response = self.generate(input_list) response = self.generate(input_list)
return self.create_outputs(response)
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = await self.agenerate(input_list)
return self.create_outputs(response)
def create_outputs(self, response):
outputs = [] outputs = []
for generation in response.generations: for generation in response.generations:
# Get the text of the top generated string. # Get the text of the top generated string.
@ -87,6 +105,9 @@ class LLMChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0] return self.apply([inputs])[0]
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return (await self.aapply([inputs]))[0]
def predict(self, **kwargs: Any) -> str: def predict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM. """Format prompt with kwargs and pass to LLM.
@ -103,6 +124,22 @@ class LLMChain(Chain, BaseModel):
""" """
return self(kwargs)[self.output_key] return self(kwargs)[self.output_key]
async def apredict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs))[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]: def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results.""" """Call predict and then parse the results."""
result = self.predict(**kwargs) result = self.predict(**kwargs)

@ -2,7 +2,7 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union from typing import Any, Dict, List, Mapping, Optional, Union, Tuple
import yaml import yaml
from pydantic import BaseModel, Extra, Field, validator from pydantic import BaseModel, Extra, Field, validator
@ -17,6 +17,34 @@ def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
def get_prompts(params: Dict[str, Any], prompts: List[str]) -> tuple[Dict[int, list], str, list[int], list[str]]:
"""Get prompts that are already cached."""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
for i, prompt in enumerate(prompts):
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
def get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
):
"""Get the LLM output."""
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
langchain.llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
class BaseLLM(BaseModel, ABC): class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string.""" """LLM wrapper should take in a prompt and return a string."""
@ -58,6 +86,12 @@ class BaseLLM(BaseModel, ABC):
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
@abstractmethod
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompts."""
def generate( def generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
@ -81,17 +115,12 @@ class BaseLLM(BaseModel, ABC):
return output return output
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()])) (
missing_prompts = [] existing_prompts,
missing_prompt_idxs = [] llm_string,
existing_prompts = {} missing_prompt_idxs,
for i, prompt in enumerate(prompts): missing_prompts,
cache_val = langchain.llm_cache.lookup(prompt, llm_string) ) = get_prompts(params, prompts)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
self.callback_manager.on_llm_start( self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
@ -102,11 +131,55 @@ class BaseLLM(BaseModel, ABC):
self.callback_manager.on_llm_error(e, verbose=self.verbose) self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose) self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
for i, result in enumerate(new_results.generations): llm_output = get_llm_output(
existing_prompts[missing_prompt_idxs[i]] = result existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
prompt = prompts[missing_prompt_idxs[i]] )
langchain.llm_cache.update(prompt, llm_string, result) else:
llm_output = new_results.llm_output llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = await self._agenerate(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
params = self.dict()
params["stop"] = stop
(
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
)
try:
new_results = await self._agenerate(missing_prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
else: else:
llm_output = {} llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))] generations = [existing_prompts[i] for i in range(len(prompts))]
@ -212,3 +285,9 @@ class LLM(BaseLLM):
text = self._call(prompt, stop=stop) text = self._call(prompt, stop=stop)
generations.append([Generation(text=text)]) generations.append([Generation(text=text)])
return LLMResult(generations=generations) return LLMResult(generations=generations)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.")

@ -1,9 +1,16 @@
"""Wrapper around OpenAI APIs.""" """Wrapper around OpenAI APIs."""
import logging import logging
import sys import sys
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union, Set
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.schema import Generation, LLMResult from langchain.schema import Generation, LLMResult
@ -12,6 +19,16 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def update_token_usage(keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response["usage"])
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
class BaseOpenAI(BaseLLM, BaseModel): class BaseOpenAI(BaseLLM, BaseModel):
"""Wrapper around OpenAI large language models. """Wrapper around OpenAI large language models.
@ -56,6 +73,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
"""Timeout for requests to OpenAI completion API. Default is 600 seconds.""" """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict) logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
"""Adjust the probability of specific tokens being generated.""" """Adjust the probability of specific tokens being generated."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -115,6 +134,32 @@ class BaseOpenAI(BaseLLM, BaseModel):
} }
return {**normal_params, **self.model_kwargs} return {**normal_params, **self.model_kwargs}
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
import openai
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
@retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
),
after=after_log(logger, logging.DEBUG),
)
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _generate( def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
@ -134,11 +179,41 @@ class BaseOpenAI(BaseLLM, BaseModel):
""" """
# TODO: write a unit test for this # TODO: write a unit test for this
params = self._invocation_params params = self._invocation_params
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = self.completion_with_retry(prompt=_prompts, **params)
choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = await self.client.acreate(prompt=_prompts, **params)
choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
def get_sub_prompts(self, params, prompts, stop):
"""Get the sub prompts for llm call."""
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
if params["max_tokens"] == -1: if params["max_tokens"] == -1:
if len(prompts) != 1: if len(prompts) != 1:
raise ValueError( raise ValueError(
@ -146,26 +221,15 @@ class BaseOpenAI(BaseLLM, BaseModel):
) )
params["max_tokens"] = self.max_tokens_for_prompt(prompts[0]) params["max_tokens"] = self.max_tokens_for_prompt(prompts[0])
sub_prompts = [ sub_prompts = [
prompts[i : i + self.batch_size] prompts[i: i + self.batch_size]
for i in range(0, len(prompts), self.batch_size) for i in range(0, len(prompts), self.batch_size)
] ]
choices = [] return sub_prompts
token_usage = {}
# Get the token usage from the response. def create_llm_result(self, choices, prompts, token_usage):
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for _prompts in sub_prompts:
response = self.client.create(prompt=_prompts, **params)
choices.extend(response["choices"])
_keys_to_use = _keys.intersection(response["usage"])
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
generations = [] generations = []
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
sub_choices = choices[i * self.n : (i + 1) * self.n] sub_choices = choices[i * self.n: (i + 1) * self.n]
generations.append( generations.append(
[ [
Generation( Generation(

43
poetry.lock generated

@ -851,7 +851,6 @@ files = [
{file = "debugpy-1.6.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b5d1b13d7c7bf5d7cf700e33c0b8ddb7baf030fcf502f76fc061ddd9405d16c"}, {file = "debugpy-1.6.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9b5d1b13d7c7bf5d7cf700e33c0b8ddb7baf030fcf502f76fc061ddd9405d16c"},
{file = "debugpy-1.6.6-cp38-cp38-win32.whl", hash = "sha256:70ab53918fd907a3ade01909b3ed783287ede362c80c75f41e79596d5ccacd32"}, {file = "debugpy-1.6.6-cp38-cp38-win32.whl", hash = "sha256:70ab53918fd907a3ade01909b3ed783287ede362c80c75f41e79596d5ccacd32"},
{file = "debugpy-1.6.6-cp38-cp38-win_amd64.whl", hash = "sha256:c05349890804d846eca32ce0623ab66c06f8800db881af7a876dc073ac1c2225"}, {file = "debugpy-1.6.6-cp38-cp38-win_amd64.whl", hash = "sha256:c05349890804d846eca32ce0623ab66c06f8800db881af7a876dc073ac1c2225"},
{file = "debugpy-1.6.6-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:11a0f3a106f69901e4a9a5683ce943a7a5605696024134b522aa1bfda25b5fec"},
{file = "debugpy-1.6.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a771739902b1ae22a120dbbb6bd91b2cae6696c0e318b5007c5348519a4211c6"}, {file = "debugpy-1.6.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a771739902b1ae22a120dbbb6bd91b2cae6696c0e318b5007c5348519a4211c6"},
{file = "debugpy-1.6.6-cp39-cp39-win32.whl", hash = "sha256:549ae0cb2d34fc09d1675f9b01942499751d174381b6082279cf19cdb3c47cbe"}, {file = "debugpy-1.6.6-cp39-cp39-win32.whl", hash = "sha256:549ae0cb2d34fc09d1675f9b01942499751d174381b6082279cf19cdb3c47cbe"},
{file = "debugpy-1.6.6-cp39-cp39-win_amd64.whl", hash = "sha256:de4a045fbf388e120bb6ec66501458d3134f4729faed26ff95de52a754abddb1"}, {file = "debugpy-1.6.6-cp39-cp39-win_amd64.whl", hash = "sha256:de4a045fbf388e120bb6ec66501458d3134f4729faed26ff95de52a754abddb1"},
@ -2411,6 +2410,7 @@ files = [
{file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"}, {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ca989b91cf3a3ba28930a9fc1e9aeafc2a395448641df1f387a2d394638943b0"},
{file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"}, {file = "lxml-4.9.2-cp35-cp35m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:822068f85e12a6e292803e112ab876bc03ed1f03dddb80154c395f891ca6b31e"},
{file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"}, {file = "lxml-4.9.2-cp35-cp35m-win32.whl", hash = "sha256:be7292c55101e22f2a3d4d8913944cbea71eea90792bf914add27454a13905df"},
{file = "lxml-4.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:998c7c41910666d2976928c38ea96a70d1aa43be6fe502f21a651e17483a43c5"},
{file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"}, {file = "lxml-4.9.2-cp36-cp36m-macosx_10_15_x86_64.whl", hash = "sha256:b26a29f0b7fc6f0897f043ca366142d2b609dc60756ee6e4e90b5f762c6adc53"},
{file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:ab323679b8b3030000f2be63e22cdeea5b47ee0abd2d6a1dc0c8103ddaa56cd7"},
{file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"}, {file = "lxml-4.9.2-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:689bb688a1db722485e4610a503e3e9210dcc20c520b45ac8f7533c837be76fe"},
@ -2420,6 +2420,7 @@ files = [
{file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"}, {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:58bfa3aa19ca4c0f28c5dde0ff56c520fbac6f0daf4fac66ed4c8d2fb7f22e74"},
{file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"}, {file = "lxml-4.9.2-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:bc718cd47b765e790eecb74d044cc8d37d58562f6c314ee9484df26276d36a38"},
{file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"}, {file = "lxml-4.9.2-cp36-cp36m-win32.whl", hash = "sha256:d5bf6545cd27aaa8a13033ce56354ed9e25ab0e4ac3b5392b763d8d04b08e0c5"},
{file = "lxml-4.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:3ab9fa9d6dc2a7f29d7affdf3edebf6ece6fb28a6d80b14c3b2fb9d39b9322c3"},
{file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"}, {file = "lxml-4.9.2-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:05ca3f6abf5cf78fe053da9b1166e062ade3fa5d4f92b4ed688127ea7d7b1d03"},
{file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_24_i686.whl", hash = "sha256:a5da296eb617d18e497bcf0a5c528f5d3b18dadb3619fbdadf4ed2356ef8d941"},
{file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"}, {file = "lxml-4.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:04876580c050a8c5341d706dd464ff04fd597095cc8c023252566a8826505726"},
@ -3899,6 +3900,25 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras] [package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"] testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
[[package]]
name = "pytest-asyncio"
version = "0.20.3"
description = "Pytest support for asyncio"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pytest-asyncio-0.20.3.tar.gz", hash = "sha256:83cbf01169ce3e8eb71c6c278ccb0574d1a7a3bb8eaaf5e50e0ad342afb33b36"},
{file = "pytest_asyncio-0.20.3-py3-none-any.whl", hash = "sha256:f129998b209d04fcc65c96fc85c11e5316738358909a8399e93be553d7656442"},
]
[package.dependencies]
pytest = ">=6.1.0"
[package.extras]
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (>=0.931)", "pytest-trio (>=0.7.0)"]
[[package]] [[package]]
name = "pytest-cov" name = "pytest-cov"
version = "4.0.0" version = "4.0.0"
@ -5092,6 +5112,21 @@ files = [
[package.extras] [package.extras]
widechars = ["wcwidth"] widechars = ["wcwidth"]
[[package]]
name = "tenacity"
version = "8.1.0"
description = "Retry code until it succeeds"
category = "main"
optional = false
python-versions = ">=3.6"
files = [
{file = "tenacity-8.1.0-py3-none-any.whl", hash = "sha256:35525cd47f82830069f0d6b73f7eb83bc5b73ee2fff0437952cedf98b27653ac"},
{file = "tenacity-8.1.0.tar.gz", hash = "sha256:e48c437fdf9340f5666b92cd7990e96bc5fc955e1298baf4a907e3972067a445"},
]
[package.extras]
doc = ["reno", "sphinx", "tornado (>=4.5)"]
[[package]] [[package]]
name = "tensorboard" name = "tensorboard"
version = "2.11.2" version = "2.11.2"
@ -5505,11 +5540,14 @@ files = [
{file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"}, {file = "tokenizers-0.13.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47ef745dbf9f49281e900e9e72915356d69de3a4e4d8a475bda26bfdb5047736"},
{file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"}, {file = "tokenizers-0.13.2-cp310-cp310-win32.whl", hash = "sha256:96cedf83864bcc15a3ffd088a6f81a8a8f55b8b188eabd7a7f2a4469477036df"},
{file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"}, {file = "tokenizers-0.13.2-cp310-cp310-win_amd64.whl", hash = "sha256:eda77de40a0262690c666134baf19ec5c4f5b8bde213055911d9f5a718c506e1"},
{file = "tokenizers-0.13.2-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:9eee037bb5aa14daeb56b4c39956164b2bebbe6ab4ca7779d88aa16b79bd4e17"},
{file = "tokenizers-0.13.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:d1b079c4c9332048fec4cb9c2055c2373c74fbb336716a5524c9a720206d787e"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a689654fc745135cce4eea3b15e29c372c3e0b01717c6978b563de5c38af9811"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3606528c07cda0566cff6cbfbda2b167f923661be595feac95701ffcdcbdbb21"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:41291d0160946084cbd53c8ec3d029df3dc2af2673d46b25ff1a7f31a9d55d51"},
{file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"}, {file = "tokenizers-0.13.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7892325f9ca1cc5fca0333d5bfd96a19044ce9b092ce2df625652109a3de16b8"},
{file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"}, {file = "tokenizers-0.13.2-cp311-cp311-win32.whl", hash = "sha256:93714958d4ebe5362d3de7a6bd73dc86c36b5af5941ebef6c325ac900fa58865"},
{file = "tokenizers-0.13.2-cp311-cp311-win_amd64.whl", hash = "sha256:fa7ef7ee380b1f49211bbcfac8a006b1a3fa2fa4c7f4ee134ae384eb4ea5e453"},
{file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"}, {file = "tokenizers-0.13.2-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:da521bfa94df6a08a6254bb8214ea04854bb9044d61063ae2529361688b5440a"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a739d4d973d422e1073989769723f3b6ad8b11e59e635a63de99aea4b2208188"},
{file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"}, {file = "tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cac01fc0b868e4d0a3aa7c5c53396da0a0a63136e81475d32fcf5c348fcb2866"},
@ -5518,6 +5556,7 @@ files = [
{file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"}, {file = "tokenizers-0.13.2-cp37-cp37m-win32.whl", hash = "sha256:a537061ee18ba104b7f3daa735060c39db3a22c8a9595845c55b6c01d36c5e87"},
{file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"}, {file = "tokenizers-0.13.2-cp37-cp37m-win_amd64.whl", hash = "sha256:c82fb87b1cbfa984d8f05b2b3c3c73e428b216c1d4f0e286d0a3b27f521b32eb"},
{file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"}, {file = "tokenizers-0.13.2-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:ce298605a833ac7f81b8062d3102a42dcd9fa890493e8f756112c346339fe5c5"},
{file = "tokenizers-0.13.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f44d59bafe3d61e8a56b9e0a963075187c0f0091023120b13fbe37a87936f171"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51b93932daba12ed07060935978a6779593a59709deab04a0d10e6fd5c29e60"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6969e5ea7ccb909ce7d6d4dfd009115dc72799b0362a2ea353267168667408c4"},
{file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"}, {file = "tokenizers-0.13.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:92f040c4d938ea64683526b45dfc81c580e3b35aaebe847e7eec374961231734"},
@ -6254,4 +6293,4 @@ llms = ["manifest-ml", "torch", "transformers"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.8.1,<4.0" python-versions = ">=3.8.1,<4.0"
content-hash = "fdf51cf2f138653a8e32f1c43a2dcbbcca76f74534afc9dc4a7879c58282100a" content-hash = "b4470de82ffcc2fab1aa0bdb6bdadd1b647cebe9194f7f47429ff257616e607f"

@ -36,6 +36,7 @@ wolframalpha = {version = "5.0.0", optional = true}
qdrant-client = {version = "^0.11.7", optional = true} qdrant-client = {version = "^0.11.7", optional = true}
dataclasses-json = "^0.5.7" dataclasses-json = "^0.5.7"
tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"} tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"}
tenacity = "^8.1.0"
[tool.poetry.group.docs.dependencies] [tool.poetry.group.docs.dependencies]
autodoc_pydantic = "^1.8.0" autodoc_pydantic = "^1.8.0"
@ -59,6 +60,7 @@ duckdb-engine = "^0.6.6"
pytest-watcher = "^0.2.6" pytest-watcher = "^0.2.6"
freezegun = "^1.2.2" freezegun = "^1.2.2"
responses = "^0.22.0" responses = "^0.22.0"
pytest-asyncio = "^0.20.3"
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
flake8-docstrings = "^1.6.0" flake8-docstrings = "^1.6.0"

@ -7,6 +7,7 @@ import pytest
from langchain.llms.loading import load_llm from langchain.llms.loading import load_llm
from langchain.llms.openai import OpenAI from langchain.llms.openai import OpenAI
from langchain.schema import LLMResult
def test_openai_call() -> None: def test_openai_call() -> None:
@ -74,3 +75,11 @@ def test_openai_streaming_error() -> None:
llm = OpenAI(best_of=2) llm = OpenAI(best_of=2)
with pytest.raises(ValueError): with pytest.raises(ValueError):
llm.stream("I'm Pickle Rick") llm.stream("I'm Pickle Rick")
@pytest.mark.asyncio
async def test_openai_async_generate() -> None:
"""Test async generation."""
llm = OpenAI(max_tokens=10)
output = await llm.agenerate(["Hello, how are you?"])
assert isinstance(output, LLMResult)

@ -33,6 +33,11 @@ class FakeLLM(BaseLLM, BaseModel):
) -> LLMResult: ) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of llm.""" """Return type of llm."""

@ -0,0 +1,35 @@
import asyncio
from langchain.llms import OpenAI
def generate_serially():
llm = OpenAI(temperature=0)
for _ in range(10):
resp = llm.generate(["Hello, how are you?"])
# print(resp)
async def async_generate(llm):
resp = await llm.agenerate(["Hello, how are you?"])
# print(resp)
async def generate_concurrently():
llm = OpenAI(temperature=0)
tasks = [async_generate(llm) for _ in range(10)]
await asyncio.gather(*tasks)
if __name__ == "__main__":
import time
s = time.perf_counter()
asyncio.run(generate_concurrently())
elapsed = time.perf_counter() - s
print(f"Concurrent executed in {elapsed:0.2f} seconds.")
s = time.perf_counter()
generate_serially()
elapsed = time.perf_counter() - s
print(f"Serial executed in {elapsed:0.2f} seconds.")
Loading…
Cancel
Save