use agenerate

ankush/async-llm
Ankush Gola 1 year ago
parent 738bf977ab
commit 930edd8e77

@ -83,7 +83,7 @@ class BaseLLM(BaseModel, ABC):
"""Run the LLM on the given prompts."""
@abstractmethod
async def _async_generate(
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompts."""
@ -128,7 +128,7 @@ class BaseLLM(BaseModel, ABC):
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
async def async_generate(
async def agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
disregard_cache = self.cache is not None and not self.cache
@ -142,7 +142,7 @@ class BaseLLM(BaseModel, ABC):
{"name": self.__class__.__name__}, prompts, verbose=self.verbose
)
try:
output = await self._async_generate(prompts, stop=stop)
output = await self._agenerate(prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
@ -156,7 +156,7 @@ class BaseLLM(BaseModel, ABC):
{"name": self.__class__.__name__}, missing_prompts, verbose=self.verbose
)
try:
new_results = await self._async_generate(missing_prompts, stop=stop)
new_results = await self._agenerate(missing_prompts, stop=stop)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_llm_error(e, verbose=self.verbose)
raise e
@ -268,7 +268,7 @@ class LLM(BaseLLM):
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
async def _async_generate(
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
"""Run the LLM on the given prompt and input."""

@ -186,7 +186,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
token_usage[_key] += response["usage"][_key]
return self.create_llm_result(choices, prompts, token_usage)
async def _async_generate(
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
params = self._invocation_params

@ -33,7 +33,7 @@ class FakeLLM(BaseLLM, BaseModel):
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
async def _async_generate(
async def _agenerate(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])

@ -10,7 +10,7 @@ def generate_serially():
async def async_generate(llm):
resp = await llm.async_generate(["Hello, how are you?"])
resp = await llm.agenerate(["Hello, how are you?"])
# print(resp)

Loading…
Cancel
Save