This commit is contained in:
Ankush Gola 2023-02-02 12:52:05 -08:00
parent 2611fdd03e
commit 496ee53c6c
2 changed files with 42 additions and 5 deletions

View File

@ -111,7 +111,7 @@ class Chain(BaseModel, ABC):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
async def _async_call(self, inputs: Dict[str, str]) -> Dict[str, str]:
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output."""
raise NotImplementedError("Async call not supported for this chain type.")
@ -143,7 +143,7 @@ class Chain(BaseModel, ABC):
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs)
async def async_call(
async def acall(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -164,7 +164,7 @@ class Chain(BaseModel, ABC):
verbose=self.verbose,
)
try:
outputs = await self._async_call(inputs)
outputs = await self._acall(inputs)
except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e

View File

@ -56,6 +56,17 @@ class LLMChain(Chain, BaseModel):
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
response = self.llm.generate(prompts, stop=stop)
return response
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
response = await self.llm.agenerate(prompts, stop=stop)
return response
def prep_prompts(self, input_list):
stop = None
if "stop" in input_list[0]:
stop = input_list[0]["stop"]
@ -71,12 +82,19 @@ class LLMChain(Chain, BaseModel):
"If `stop` is present in any inputs, should be present in all."
)
prompts.append(prompt)
response = self.llm.generate(prompts, stop=stop)
return response
return prompts, stop
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = self.generate(input_list)
return self.create_outputs(response)
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = await self.agenerate(input_list)
return self.create_outputs(response)
def create_outputs(self, response):
outputs = []
for generation in response.generations:
# Get the text of the top generated string.
@ -87,6 +105,9 @@ class LLMChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return (await self.aapply([inputs]))[0]
def predict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
@ -103,6 +124,22 @@ class LLMChain(Chain, BaseModel):
"""
return self(kwargs)[self.output_key]
async def apredict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs))[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results."""
result = self.predict(**kwargs)