async output parser (#8894)

Co-authored-by: Nuno Campos <nuno@boringbits.io>
pull/8954/head
Harrison Chase 1 year ago committed by GitHub
parent 3c6eccd701
commit 4d72288487
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -475,7 +475,8 @@ class Agent(BaseSingleActionAgent):
"""
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
full_output = await self.llm_chain.apredict(callbacks=callbacks, **full_inputs)
return self.output_parser.parse(full_output)
agent_output = await self.output_parser.aparse(full_output)
return agent_output
def get_full_inputs(
self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
@ -27,6 +28,20 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
Structured output.
"""
async def aparse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.parse_result, result
)
class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
@ -51,6 +66,26 @@ class BaseGenerationOutputParser(
run_type="parser",
)
async def ainvoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
) -> T:
if isinstance(input, BaseMessage):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]):
"""Base class to parse the output of an LLM call.
@ -99,6 +134,26 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
run_type="parser",
)
async def ainvoke(
self, input: str | BaseMessage, config: RunnableConfig | None = None
) -> T:
if isinstance(input, BaseMessage):
return await self._acall_with_config(
lambda inner_input: self.aparse_result(
[ChatGeneration(message=inner_input)]
),
input,
config,
run_type="parser",
)
else:
return await self._acall_with_config(
lambda inner_input: self.aparse_result([Generation(text=inner_input)]),
input,
config,
run_type="parser",
)
def parse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
@ -125,6 +180,32 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
Structured output.
"""
async def aparse_result(self, result: List[Generation]) -> T:
"""Parse a list of candidate model Generations into a specific format.
The return value is parsed from only the first Generation in the result, which
is assumed to be the highest-likelihood Generation.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
return await self.aparse(result[0].text)
async def aparse(self, text: str) -> T:
"""Parse a single string model output into some structure.
Args:
text: String output of a language model.
Returns:
Structured output.
"""
return await asyncio.get_running_loop().run_in_executor(None, self.parse, text)
# TODO: rename 'completion' -> 'text'.
def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
"""Parse the output of an LLM call with the input prompt for context.

@ -6,6 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Dict,
@ -192,6 +193,37 @@ class Runnable(Generic[Input, Output], ABC):
)
return output
async def _acall_with_config(
self,
func: Callable[[Input], Awaitable[Output]],
input: Input,
config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output:
from langchain.callbacks.manager import AsyncCallbackManager
config = config or {}
callback_manager = AsyncCallbackManager.configure(
inheritable_callbacks=config.get("callbacks"),
inheritable_tags=config.get("tags"),
inheritable_metadata=config.get("metadata"),
)
run_manager = await callback_manager.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
run_type=run_type,
)
try:
output = await func(input)
except Exception as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(
output if isinstance(output, dict) else {"output": output}
)
return output
def with_fallbacks(
self,
fallbacks: Sequence[Runnable[Input, Output]],

Loading…
Cancel
Save