fix some lint

ankush/async-llm
Ankush Gola 1 year ago
parent 1b53bbf76c
commit bc559ee76b

@ -2,7 +2,7 @@
import json import json
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union from typing import Any, Dict, List, Mapping, Optional, Union, Tuple
import yaml import yaml
from pydantic import BaseModel, Extra, Field, validator from pydantic import BaseModel, Extra, Field, validator
@ -17,7 +17,8 @@ def _get_verbosity() -> bool:
return langchain.verbose return langchain.verbose
def get_prompts(params, prompts): def get_prompts(params: Dict[str, Any], prompts: List[str]) -> tuple[Dict[int, list], str, list[int], list[str]]:
"""Get prompts that are already cached."""
llm_string = str(sorted([(k, v) for k, v in params.items()])) llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = [] missing_prompts = []
missing_prompt_idxs = [] missing_prompt_idxs = []
@ -32,7 +33,10 @@ def get_prompts(params, prompts):
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
def get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts): def get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
):
"""Get the LLM output."""
for i, result in enumerate(new_results.generations): for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]] prompt = prompts[missing_prompt_idxs[i]]
@ -111,7 +115,12 @@ class BaseLLM(BaseModel, ABC):
return output return output
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
existing_prompts, llm_string, missing_prompt_idxs, missing_prompts = get_prompts(params, prompts) (
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
self.callback_manager.on_llm_start( self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
@ -122,7 +131,9 @@ class BaseLLM(BaseModel, ABC):
self.callback_manager.on_llm_error(e, verbose=self.verbose) self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose) self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts) llm_output = get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
else: else:
llm_output = {} llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))] generations = [existing_prompts[i] for i in range(len(prompts))]
@ -150,7 +161,12 @@ class BaseLLM(BaseModel, ABC):
return output return output
params = self.dict() params = self.dict()
params["stop"] = stop params["stop"] = stop
existing_prompts, llm_string, missing_prompt_idxs, missing_prompts = get_prompts(params, prompts) (
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
self.callback_manager.on_llm_start( self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
@ -161,7 +177,9 @@ class BaseLLM(BaseModel, ABC):
self.callback_manager.on_llm_error(e, verbose=self.verbose) self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose) self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts) llm_output = get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
else: else:
llm_output = {} llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))] generations = [existing_prompts[i] for i in range(len(prompts))]

@ -1,7 +1,7 @@
"""Wrapper around OpenAI APIs.""" """Wrapper around OpenAI APIs."""
import logging import logging
import sys import sys
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union, Set
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import ( from tenacity import (
@ -19,6 +19,16 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def update_token_usage(keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response["usage"])
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
class BaseOpenAI(BaseLLM, BaseModel): class BaseOpenAI(BaseLLM, BaseModel):
"""Wrapper around OpenAI large language models. """Wrapper around OpenAI large language models.
@ -178,17 +188,13 @@ class BaseOpenAI(BaseLLM, BaseModel):
for _prompts in sub_prompts: for _prompts in sub_prompts:
response = self.completion_with_retry(prompt=_prompts, **params) response = self.completion_with_retry(prompt=_prompts, **params)
choices.extend(response["choices"]) choices.extend(response["choices"])
_keys_to_use = _keys.intersection(response["usage"]) update_token_usage(_keys, response, token_usage)
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
async def _agenerate( async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult: ) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params params = self._invocation_params
sub_prompts = self.get_sub_prompts(params, prompts, stop) sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = [] choices = []
@ -199,15 +205,11 @@ class BaseOpenAI(BaseLLM, BaseModel):
for _prompts in sub_prompts: for _prompts in sub_prompts:
response = await self.client.acreate(prompt=_prompts, **params) response = await self.client.acreate(prompt=_prompts, **params)
choices.extend(response["choices"]) choices.extend(response["choices"])
_keys_to_use = _keys.intersection(response["usage"]) update_token_usage(_keys, response, token_usage)
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
return self.create_llm_result(choices, prompts, token_usage) return self.create_llm_result(choices, prompts, token_usage)
def get_sub_prompts(self, params, prompts, stop): def get_sub_prompts(self, params, prompts, stop):
"""Get the sub prompts for llm call."""
if stop is not None: if stop is not None:
if "stop" in params: if "stop" in params:
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")

@ -1,6 +1,7 @@
from langchain.llms import OpenAI
import asyncio import asyncio
from langchain.llms import OpenAI
def generate_serially(): def generate_serially():
llm = OpenAI(temperature=0) llm = OpenAI(temperature=0)
@ -22,6 +23,7 @@ async def generate_concurrently():
if __name__ == "__main__": if __name__ == "__main__":
import time import time
s = time.perf_counter() s = time.perf_counter()
asyncio.run(generate_concurrently()) asyncio.run(generate_concurrently())
elapsed = time.perf_counter() - s elapsed = time.perf_counter() - s

Loading…
Cancel
Save