|
|
|
@ -56,6 +56,17 @@ class LLMChain(Chain, BaseModel):
|
|
|
|
|
|
|
|
|
|
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
|
|
|
|
"""Generate LLM result from inputs."""
|
|
|
|
|
prompts, stop = self.prep_prompts(input_list)
|
|
|
|
|
response = self.llm.generate(prompts, stop=stop)
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
|
|
|
|
"""Generate LLM result from inputs."""
|
|
|
|
|
prompts, stop = self.prep_prompts(input_list)
|
|
|
|
|
response = await self.llm.agenerate(prompts, stop=stop)
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
def prep_prompts(self, input_list):
|
|
|
|
|
stop = None
|
|
|
|
|
if "stop" in input_list[0]:
|
|
|
|
|
stop = input_list[0]["stop"]
|
|
|
|
@ -71,12 +82,19 @@ class LLMChain(Chain, BaseModel):
|
|
|
|
|
"If `stop` is present in any inputs, should be present in all."
|
|
|
|
|
)
|
|
|
|
|
prompts.append(prompt)
|
|
|
|
|
response = self.llm.generate(prompts, stop=stop)
|
|
|
|
|
return response
|
|
|
|
|
return prompts, stop
|
|
|
|
|
|
|
|
|
|
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
|
|
|
|
"""Utilize the LLM generate method for speed gains."""
|
|
|
|
|
response = self.generate(input_list)
|
|
|
|
|
return self.create_outputs(response)
|
|
|
|
|
|
|
|
|
|
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
|
|
|
|
"""Utilize the LLM generate method for speed gains."""
|
|
|
|
|
response = await self.agenerate(input_list)
|
|
|
|
|
return self.create_outputs(response)
|
|
|
|
|
|
|
|
|
|
def create_outputs(self, response):
|
|
|
|
|
outputs = []
|
|
|
|
|
for generation in response.generations:
|
|
|
|
|
# Get the text of the top generated string.
|
|
|
|
@ -87,6 +105,9 @@ class LLMChain(Chain, BaseModel):
|
|
|
|
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
|
|
|
return self.apply([inputs])[0]
|
|
|
|
|
|
|
|
|
|
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
|
|
|
|
return (await self.aapply([inputs]))[0]
|
|
|
|
|
|
|
|
|
|
def predict(self, **kwargs: Any) -> str:
|
|
|
|
|
"""Format prompt with kwargs and pass to LLM.
|
|
|
|
|
|
|
|
|
@ -103,6 +124,22 @@ class LLMChain(Chain, BaseModel):
|
|
|
|
|
"""
|
|
|
|
|
return self(kwargs)[self.output_key]
|
|
|
|
|
|
|
|
|
|
async def apredict(self, **kwargs: Any) -> str:
|
|
|
|
|
"""Format prompt with kwargs and pass to LLM.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
**kwargs: Keys to pass to prompt template.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Completion from LLM.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
completion = llm.predict(adjective="funny")
|
|
|
|
|
"""
|
|
|
|
|
return (await self.acall(kwargs))[self.output_key]
|
|
|
|
|
|
|
|
|
|
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
|
|
|
|
"""Call predict and then parse the results."""
|
|
|
|
|
result = self.predict(**kwargs)
|
|
|
|
|