forked from Archives/langchain
cr
This commit is contained in:
parent
2611fdd03e
commit
496ee53c6c
@ -111,7 +111,7 @@ class Chain(BaseModel, ABC):
|
|||||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
"""Run the logic of this chain and return the output."""
|
"""Run the logic of this chain and return the output."""
|
||||||
|
|
||||||
async def _async_call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
async def _acall(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
"""Run the logic of this chain and return the output."""
|
"""Run the logic of this chain and return the output."""
|
||||||
raise NotImplementedError("Async call not supported for this chain type.")
|
raise NotImplementedError("Async call not supported for this chain type.")
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ class Chain(BaseModel, ABC):
|
|||||||
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
self.callback_manager.on_chain_end(outputs, verbose=self.verbose)
|
||||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||||
|
|
||||||
async def async_call(
|
async def acall(
|
||||||
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Run the logic of this chain and add to output if desired.
|
"""Run the logic of this chain and add to output if desired.
|
||||||
@ -164,7 +164,7 @@ class Chain(BaseModel, ABC):
|
|||||||
verbose=self.verbose,
|
verbose=self.verbose,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
outputs = await self._async_call(inputs)
|
outputs = await self._acall(inputs)
|
||||||
except (KeyboardInterrupt, Exception) as e:
|
except (KeyboardInterrupt, Exception) as e:
|
||||||
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
self.callback_manager.on_chain_error(e, verbose=self.verbose)
|
||||||
raise e
|
raise e
|
||||||
|
@ -56,6 +56,17 @@ class LLMChain(Chain, BaseModel):
|
|||||||
|
|
||||||
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
def generate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||||
"""Generate LLM result from inputs."""
|
"""Generate LLM result from inputs."""
|
||||||
|
prompts, stop = self.prep_prompts(input_list)
|
||||||
|
response = self.llm.generate(prompts, stop=stop)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def agenerate(self, input_list: List[Dict[str, Any]]) -> LLMResult:
|
||||||
|
"""Generate LLM result from inputs."""
|
||||||
|
prompts, stop = self.prep_prompts(input_list)
|
||||||
|
response = await self.llm.agenerate(prompts, stop=stop)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def prep_prompts(self, input_list):
|
||||||
stop = None
|
stop = None
|
||||||
if "stop" in input_list[0]:
|
if "stop" in input_list[0]:
|
||||||
stop = input_list[0]["stop"]
|
stop = input_list[0]["stop"]
|
||||||
@ -71,12 +82,19 @@ class LLMChain(Chain, BaseModel):
|
|||||||
"If `stop` is present in any inputs, should be present in all."
|
"If `stop` is present in any inputs, should be present in all."
|
||||||
)
|
)
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
response = self.llm.generate(prompts, stop=stop)
|
return prompts, stop
|
||||||
return response
|
|
||||||
|
|
||||||
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
def apply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||||
"""Utilize the LLM generate method for speed gains."""
|
"""Utilize the LLM generate method for speed gains."""
|
||||||
response = self.generate(input_list)
|
response = self.generate(input_list)
|
||||||
|
return self.create_outputs(response)
|
||||||
|
|
||||||
|
async def aapply(self, input_list: List[Dict[str, Any]]) -> List[Dict[str, str]]:
|
||||||
|
"""Utilize the LLM generate method for speed gains."""
|
||||||
|
response = await self.agenerate(input_list)
|
||||||
|
return self.create_outputs(response)
|
||||||
|
|
||||||
|
def create_outputs(self, response):
|
||||||
outputs = []
|
outputs = []
|
||||||
for generation in response.generations:
|
for generation in response.generations:
|
||||||
# Get the text of the top generated string.
|
# Get the text of the top generated string.
|
||||||
@ -87,6 +105,9 @@ class LLMChain(Chain, BaseModel):
|
|||||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
return self.apply([inputs])[0]
|
return self.apply([inputs])[0]
|
||||||
|
|
||||||
|
async def _acall(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||||
|
return (await self.aapply([inputs]))[0]
|
||||||
|
|
||||||
def predict(self, **kwargs: Any) -> str:
|
def predict(self, **kwargs: Any) -> str:
|
||||||
"""Format prompt with kwargs and pass to LLM.
|
"""Format prompt with kwargs and pass to LLM.
|
||||||
|
|
||||||
@ -103,6 +124,22 @@ class LLMChain(Chain, BaseModel):
|
|||||||
"""
|
"""
|
||||||
return self(kwargs)[self.output_key]
|
return self(kwargs)[self.output_key]
|
||||||
|
|
||||||
|
async def apredict(self, **kwargs: Any) -> str:
|
||||||
|
"""Format prompt with kwargs and pass to LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Keys to pass to prompt template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Completion from LLM.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
completion = llm.predict(adjective="funny")
|
||||||
|
"""
|
||||||
|
return (await self.acall(kwargs))[self.output_key]
|
||||||
|
|
||||||
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
def predict_and_parse(self, **kwargs: Any) -> Union[str, List[str], Dict[str, str]]:
|
||||||
"""Call predict and then parse the results."""
|
"""Call predict and then parse the results."""
|
||||||
result = self.predict(**kwargs)
|
result = self.predict(**kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user