forked from Archives/langchain
cr
This commit is contained in:
parent
2611fdd03e
commit
496ee53c6c
@ -111,7 +111,7 @@ class Chain(BaseModel, ABC):
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
|
||||
async def _async_call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Run the logic of this chain and return the output."""
|
||||
raise NotImplementedError("Async call not supported for this chain type.")
|
||||
|
||||
@ -143,7 +143,7 @@ class Chain(BaseModel, ABC):
|
||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
|
||||
async def async_call(
|
||||
async def acall(
|
||||
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
@ -164,7 +164,7 @@ class Chain(BaseModel, ABC):
|
||||
verbose=self.verbose,
|
||||
)
|
||||
try:
|
||||
outputs = await self._async_call(inputs)
|
||||
outputs = await self._acall(inputs)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
|
@ -56,6 +56,17 @@ class LLMChain(Chain, BaseModel):
|
||||
|
||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list)
|
||||
response = self.llm.generate(prompts, stop=stop)
|
||||
return response
|
||||
|
||||
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list)
|
||||
response = await self.llm.agenerate(prompts, stop=stop)
|
||||
return response
|
||||
|
||||
def prep_prompts(self, input_list):
|
||||
stop = None
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
@ -71,12 +82,19 @@ class LLMChain(Chain, BaseModel):
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
response = self.llm.generate(prompts, stop=stop)
|
||||
return response
|
||||
return prompts, stop
|
||||
|
||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = self.generate(input_list)
|
||||
return self.create_outputs(response)
|
||||
|
||||
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
response = await self.agenerate(input_list)
|
||||
return self.create_outputs(response)
|
||||
|
||||
def create_outputs(self, response):
|
||||
outputs = []
|
||||
for generation in response.generations:
|
||||
# Get the text of the top generated string.
|
||||
@ -87,6 +105,9 @@ class LLMChain(Chain, BaseModel):
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return self.apply([inputs])[0]
|
||||
|
||||
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
return (await self.aapply([inputs]))[0]
|
||||
|
||||
def predict(self, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
@ -103,6 +124,22 @@ class LLMChain(Chain, BaseModel):
|
||||
"""
|
||||
return self(kwargs)[self.output_key]
|
||||
|
||||
async def apredict(self, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return (await self.acall(kwargs))[self.output_key]
|
||||
|
||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(**kwargs)
|
||||
|
Loading…
Reference in New Issue
Block a user