fix some lint

ankush/async-llm
Ankush Gola 1 year ago
parent 1b53bbf76c
commit bc559ee76b

@ -2,7 +2,7 @@
import json
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Union, Tuple
import yaml
from pydantic import BaseModel, Extra, Field, validator
@ -17,7 +17,8 @@ def _get_verbosity() -> bool:
return langchain.verbose
def get_prompts(params, prompts):
def get_prompts(params: Dict[str, Any], prompts: List[str]) -> tuple[Dict[int, list], str, list[int], list[str]]:
"""Get prompts that are already cached."""
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
@ -32,7 +33,10 @@ def get_prompts(params, prompts):
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
def get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts):
def get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
):
"""Get the LLM output."""
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
@ -111,7 +115,12 @@ class BaseLLM(BaseModel, ABC):
return output
params = self.dict()
params["stop"] = stop
existing_prompts, llm_string, missing_prompt_idxs, missing_prompts = get_prompts(params, prompts)
(
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
@ -122,7 +131,9 @@ class BaseLLM(BaseModel, ABC):
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts)
llm_output = get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
else:
llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))]
@ -150,7 +161,12 @@ class BaseLLM(BaseModel, ABC):
return output
params = self.dict()
params["stop"] = stop
existing_prompts, llm_string, missing_prompt_idxs, missing_prompts = get_prompts(params, prompts)
(
existing_prompts,
llm_string,
missing_prompt_idxs,
missing_prompts,
) = get_prompts(params, prompts)
if len(missing_prompts) > 0:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
@ -161,7 +177,9 @@ class BaseLLM(BaseModel, ABC):
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts)
llm_output = get_llm_output(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
else:
llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))]

@ -1,7 +1,7 @@
"""Wrapper around OpenAI APIs."""
import logging
import sys
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union, Set
from pydantic import BaseModel, Extra, Field, root_validator
from tenacity import (
@ -19,6 +19,16 @@ from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
def update_token_usage(keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]) -> None:
"""Update token usage."""
_keys_to_use = keys.intersection(response["usage"])
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
class BaseOpenAI(BaseLLM, BaseModel):
"""Wrapper around OpenAI large language models.
@ -178,17 +188,13 @@ class BaseOpenAI(BaseLLM, BaseModel):
for _prompts in sub_prompts:
response = self.completion_with_retry(prompt=_prompts, **params)
choices.extend(response["choices"])
_keys_to_use = _keys.intersection(response["usage"])
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
@ -199,15 +205,11 @@ class BaseOpenAI(BaseLLM, BaseModel):
for _prompts in sub_prompts:
response = await self.client.acreate(prompt=_prompts, **params)
choices.extend(response["choices"])
_keys_to_use = _keys.intersection(response["usage"])
for _key in _keys_to_use:
if _key not in token_usage:
token_usage[_key] = response["usage"][_key]
else:
token_usage[_key] += response["usage"][_key]
update_token_usage(_keys, response, token_usage)
return self.create_llm_result(choices, prompts, token_usage)
def get_sub_prompts(self, params, prompts, stop):
"""Get the sub prompts for llm call."""
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")

@ -1,6 +1,7 @@
from langchain.llms import OpenAI
import asyncio
from langchain.llms import OpenAI
def generate_serially():
llm = OpenAI(temperature=0)
@ -22,6 +23,7 @@ async def generate_concurrently():
if __name__ == "__main__":
import time
s = time.perf_counter()
asyncio.run(generate_concurrently())
elapsed = time.perf_counter() - s

Loading…
Cancel
Save