This commit is contained in:
Ankush Gola 2023-02-02 12:52:05 -08:00
parent 2611fdd03e
commit 496ee53c6c
2 changed files with 42 additions and 5 deletions

View File

@ -111,7 +111,7 @@ class Chain(BaseModel, ABC):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output.""" """Run the logic of this chain and return the output."""
async def _async_call(self, inputs: Dict[str, str]) -> Dict[str, str]: async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
"""Run the logic of this chain and return the output.""" """Run the logic of this chain and return the output."""
raise NotImplementedError("Async call not supported for this chain type.") raise NotImplementedError("Async call not supported for this chain type.")
@ -143,7 +143,7 @@ class Chain(BaseModel, ABC):
self.callback_manager.on_chain_end(outputs, verbose=self.verbose) self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
return self.prep_outputs(inputs, outputs, return_only_outputs) return self.prep_outputs(inputs, outputs, return_only_outputs)
async def async_call( async def acall(
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired. """Run the logic of this chain and add to output if desired.
@ -164,7 +164,7 @@ class Chain(BaseModel, ABC):
verbose=self.verbose, verbose=self.verbose,
) )
try: try:
outputs = await self._async_call(inputs) outputs = await self._acall(inputs)
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
self.callback_manager.on_chain_error(e, verbose=self.verbose) self.callback_manager.on_chain_error(e, verbose=self.verbose)
raise e raise e

View File

@ -56,6 +56,17 @@ class LLMChain(Chain, BaseModel):
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs.""" """Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
response = self.llm.generate(prompts, stop=stop)
return response
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list)
response = await self.llm.agenerate(prompts, stop=stop)
return response
def prep_prompts(self, input_list):
stop = None stop = None
if "stop" in input_list[0]: if "stop" in input_list[0]:
stop = input_list[0]["stop"] stop = input_list[0]["stop"]
@ -71,12 +82,19 @@ class LLMChain(Chain, BaseModel):
"If `stop` is present in any inputs, should be present in all." "If `stop` is present in any inputs, should be present in all."
) )
prompts.append(prompt) prompts.append(prompt)
response = self.llm.generate(prompts, stop=stop) return prompts, stop
return response
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains.""" """Utilize the LLM generate method for speed gains."""
response = self.generate(input_list) response = self.generate(input_list)
return self.create_outputs(response)
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
"""Utilize the LLM generate method for speed gains."""
response = await self.agenerate(input_list)
return self.create_outputs(response)
def create_outputs(self, response):
outputs = [] outputs = []
for generation in response.generations: for generation in response.generations:
# Get the text of the top generated string. # Get the text of the top generated string.
@ -87,6 +105,9 @@ class LLMChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0] return self.apply([inputs])[0]
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return (await self.aapply([inputs]))[0]
def predict(self, **kwargs: Any) -> str: def predict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM. """Format prompt with kwargs and pass to LLM.
@ -103,6 +124,22 @@ class LLMChain(Chain, BaseModel):
""" """
return self(kwargs)[self.output_key] return self(kwargs)[self.output_key]
async def apredict(self, **kwargs: Any) -> str:
"""Format prompt with kwargs and pass to LLM.
Args:
**kwargs: Keys to pass to prompt template.
Returns:
Completion from LLM.
Example:
.. code-block:: python
completion = llm.predict(adjective="funny")
"""
return (await self.acall(kwargs))[self.output_key]
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]: def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
"""Call predict and then parse the results.""" """Call predict and then parse the results."""
result = self.predict(**kwargs) result = self.predict(**kwargs)