diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 5d64e3a1..f6a31fb4 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -2,7 +2,7 @@ import json from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union, Tuple import yaml from pydantic import BaseModel, Extra, Field, validator @@ -17,7 +17,8 @@ def _get_verbosity() -> bool: return langchain.verbose -def get_prompts(params, prompts): +def get_prompts(params: Dict[str, Any], prompts: List[str]) -> tuple[Dict[int, list], str, list[int], list[str]]: + """Get prompts that are already cached.""" llm_string = str(sorted([(k, v) for k, v in params.items()])) missing_prompts = [] missing_prompt_idxs = [] @@ -32,7 +33,10 @@ def get_prompts(params, prompts): return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts -def get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts): +def get_llm_output( + existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts +): + """Get the LLM output.""" for i, result in enumerate(new_results.generations): existing_prompts[missing_prompt_idxs[i]] = result prompt = prompts[missing_prompt_idxs[i]] @@ -111,7 +115,12 @@ class BaseLLM(BaseModel, ABC): return output params = self.dict() params["stop"] = stop - existing_prompts, llm_string, missing_prompt_idxs, missing_prompts = get_prompts(params, prompts) + ( + existing_prompts, + llm_string, + missing_prompt_idxs, + missing_prompts, + ) = get_prompts(params, prompts) if len(missing_prompts) > 0: self.callback_manager.on_llm_start( {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose @@ -122,7 +131,9 @@ class BaseLLM(BaseModel, ABC): self.callback_manager.on_llm_error(e, verbose=self.verbose) raise e self.callback_manager.on_llm_end(new_results, verbose=self.verbose) - llm_output = get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts) + llm_output = get_llm_output( + existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts + ) else: llm_output = {} generations = [existing_prompts[i] for i in range(len(prompts))] @@ -150,7 +161,12 @@ class BaseLLM(BaseModel, ABC): return output params = self.dict() params["stop"] = stop - existing_prompts, llm_string, missing_prompt_idxs, missing_prompts = get_prompts(params, prompts) + ( + existing_prompts, + llm_string, + missing_prompt_idxs, + missing_prompts, + ) = get_prompts(params, prompts) if len(missing_prompts) > 0: self.callback_manager.on_llm_start( {"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose @@ -161,7 +177,9 @@ class BaseLLM(BaseModel, ABC): self.callback_manager.on_llm_error(e, verbose=self.verbose) raise e self.callback_manager.on_llm_end(new_results, verbose=self.verbose) - llm_output = get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts) + llm_output = get_llm_output( + existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts + ) else: llm_output = {} generations = [existing_prompts[i] for i in range(len(prompts))] diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index e5503c0b..ca640edc 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -1,7 +1,7 @@ """Wrapper around OpenAI APIs.""" import logging import sys -from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union, Set from pydantic import BaseModel, Extra, Field, root_validator from tenacity import ( @@ -19,6 +19,16 @@ from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__name__) +def update_token_usage(keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]) -> None: + """Update token usage.""" + _keys_to_use = keys.intersection(response["usage"]) + for _key in _keys_to_use: + if _key not in token_usage: + token_usage[_key] = response["usage"][_key] + else: + token_usage[_key] += response["usage"][_key] + + class BaseOpenAI(BaseLLM, BaseModel): """Wrapper around OpenAI large language models. @@ -178,17 +188,13 @@ class BaseOpenAI(BaseLLM, BaseModel): for _prompts in sub_prompts: response = self.completion_with_retry(prompt=_prompts, **params) choices.extend(response["choices"]) - _keys_to_use = _keys.intersection(response["usage"]) - for _key in _keys_to_use: - if _key not in token_usage: - token_usage[_key] = response["usage"][_key] - else: - token_usage[_key] += response["usage"][_key] + update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) async def _agenerate( self, prompts: List[str], stop: Optional[List[str]] = None ) -> LLMResult: + """Call out to OpenAI's endpoint async with k unique prompts.""" params = self._invocation_params sub_prompts = self.get_sub_prompts(params, prompts, stop) choices = [] @@ -199,15 +205,11 @@ class BaseOpenAI(BaseLLM, BaseModel): for _prompts in sub_prompts: response = await self.client.acreate(prompt=_prompts, **params) choices.extend(response["choices"]) - _keys_to_use = _keys.intersection(response["usage"]) - for _key in _keys_to_use: - if _key not in token_usage: - token_usage[_key] = response["usage"][_key] - else: - token_usage[_key] += response["usage"][_key] + update_token_usage(_keys, response, token_usage) return self.create_llm_result(choices, prompts, token_usage) def get_sub_prompts(self, params, prompts, stop): + """Get the sub prompts for llm call.""" if stop is not None: if "stop" in params: raise ValueError("`stop` found in both the input and default params.") diff --git a/tests/unit_tests/llms/llm_test.py b/tests/unit_tests/llms/llm_test.py index 6816fc45..c610f828 100644 --- a/tests/unit_tests/llms/llm_test.py +++ b/tests/unit_tests/llms/llm_test.py @@ -1,6 +1,7 @@ -from langchain.llms import OpenAI import asyncio +from langchain.llms import OpenAI + def generate_serially(): llm = OpenAI(temperature=0) @@ -22,6 +23,7 @@ async def generate_concurrently(): if __name__ == "__main__": import time + s = time.perf_counter() asyncio.run(generate_concurrently()) elapsed = time.perf_counter() - s