diff --git a/langchain/chains/base.py b/langchain/chains/base.py index cc1ae752..eacd395a 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -111,7 +111,7 @@ class Chain(BaseModel, ABC): def _call(self, inputs: Dict[str, str]) -> Dict[str, str]: """Run the logic of this chain and return the output.""" - async def _async_call(self, inputs: Dict[str, str]) -> Dict[str, str]: + async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]: """Run the logic of this chain and return the output.""" raise NotImplementedError("Async call not supported for this chain type.") @@ -143,7 +143,7 @@ class Chain(BaseModel, ABC): self.callback_manager.on_chain_end(outputs, verbose=self.verbose) return self.prep_outputs(inputs, outputs, return_only_outputs) - async def async_call( + async def acall( self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -164,7 +164,7 @@ class Chain(BaseModel, ABC): verbose=self.verbose, ) try: - outputs = await self._async_call(inputs) + outputs = await self._acall(inputs) except (KeyboardInterrupt, Exception) as e: self.callback_manager.on_chain_error(e, verbose=self.verbose) raise e diff --git a/langchain/chains/llm.py b/langchain/chains/llm.py index 33174bf4..f6fca3dd 100644 --- a/langchain/chains/llm.py +++ b/langchain/chains/llm.py @@ -56,6 +56,17 @@ class LLMChain(Chain, BaseModel): def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult: """Generate LLM result from inputs.""" + prompts, stop = self.prep_prompts(input_list) + response = self.llm.generate(prompts, stop=stop) + return response + + async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult: + """Generate LLM result from inputs.""" + prompts, stop = self.prep_prompts(input_list) + response = await self.llm.agenerate(prompts, stop=stop) + return response + + def prep_prompts(self, input_list): stop = None if "stop" in input_list[0]: stop = input_list[0]["stop"] @@ -71,12 +82,19 @@ class LLMChain(Chain, BaseModel): "If `stop` is present in any inputs, should be present in all." ) prompts.append(prompt) - response = self.llm.generate(prompts, stop=stop) - return response + return prompts, stop def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: """Utilize the LLM generate method for speed gains.""" response = self.generate(input_list) + return self.create_outputs(response) + + async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]: + """Utilize the LLM generate method for speed gains.""" + response = await self.agenerate(input_list) + return self.create_outputs(response) + + def create_outputs(self, response): outputs = [] for generation in response.generations: # Get the text of the top generated string. @@ -87,6 +105,9 @@ class LLMChain(Chain, BaseModel): def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]: return self.apply([inputs])[0] + async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]: + return (await self.aapply([inputs]))[0] + def predict(self, **kwargs: Any) -> str: """Format prompt with kwargs and pass to LLM. @@ -103,6 +124,22 @@ class LLMChain(Chain, BaseModel): """ return self(kwargs)[self.output_key] + async def apredict(self, **kwargs: Any) -> str: + """Format prompt with kwargs and pass to LLM. + + Args: + **kwargs: Keys to pass to prompt template. + + Returns: + Completion from LLM. + + Example: + .. code-block:: python + + completion = llm.predict(adjective="funny") + """ + return (await self.acall(kwargs))[self.output_key] + def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]: """Call predict and then parse the results.""" result = self.predict(**kwargs)