add async_generate

ankush/async-llmchain
Ankush Gola 1 year ago
parent 0b211f0394
commit 54bf243e36

@ -17,6 +17,30 @@ def _get_verbosity() -> bool:
return langchain.verbose
def get_prompts(params, prompts):
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
for i, prompt in enumerate(prompts):
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts
def get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts):
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
langchain.llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
return llm_output
class BaseLLM(BaseModel, ABC):
"""LLM wrapper should take in a prompt and return a string."""
@ -58,6 +82,12 @@ class BaseLLM(BaseModel, ABC):
) -> LLMResult:
"""Run the LLM on the given prompts."""
@abstractmethod
async def _async_generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompts."""
def generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
@ -81,17 +111,7 @@ class BaseLLM(BaseModel, ABC):
return output
params = self.dict()
params["stop"] = stop
llm_string = str(sorted([(k, v) for k, v in params.items()]))
missing_prompts = []
missing_prompt_idxs = []
existing_prompts = {}
for i, prompt in enumerate(prompts):
cache_val = langchain.llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
existing_prompts[i] = cache_val
else:
missing_prompts.append(prompt)
missing_prompt_idxs.append(i)
existing_prompts, llm_string, missing_prompt_idxs, missing_prompts = get_prompts(params, prompts)
if len(missing_prompts) > 0:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
@ -102,11 +122,46 @@ class BaseLLM(BaseModel, ABC):
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
for i, result in enumerate(new_results.generations):
existing_prompts[missing_prompt_idxs[i]] = result
prompt = prompts[missing_prompt_idxs[i]]
langchain.llm_cache.update(prompt, llm_string, result)
llm_output = new_results.llm_output
llm_output = get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts)
else:
llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
async def async_generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
disregard_cache = self.cache is not None and not self.cache
if langchain.llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = await self._async_generate(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(output, verbose=self.verbose)
return output
params = self.dict()
params["stop"] = stop
existing_prompts, llm_string, missing_prompt_idxs, missing_prompts = get_prompts(params, prompts)
if len(missing_prompts) > 0:
self.callback_manager.on_llm_start(
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
)
try:
new_results = await self._async_generate(missing_prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
self.callback_manager.on_llm_end(new_results, verbose=self.verbose)
llm_output = get_llm_output(existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts)
else:
llm_output = {}
generations = [existing_prompts[i] for i in range(len(prompts))]
@ -212,3 +267,9 @@ class LLM(BaseLLM):
text = self._call(prompt, stop=stop)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
async def _async_generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.")

@ -115,6 +115,11 @@ class BaseOpenAI(BaseLLM, BaseModel):
}
return {**normal_params, **self.model_kwargs}
async def _async_generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
raise NotImplementedError("Async generation not implemented for OpenAI.")
def _generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:

@ -33,6 +33,11 @@ class FakeLLM(BaseLLM, BaseModel):
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
async def _async_generate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@property
def _llm_type(self) -> str:
"""Return type of llm."""

Loading…
Cancel
Save