|
|
|
@ -337,11 +337,17 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
|
|
|
|
|
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|
|
|
|
"""Base class for an output parser that can handle streaming input."""
|
|
|
|
|
|
|
|
|
|
diff: bool = False
|
|
|
|
|
|
|
|
|
|
def _diff(self, prev: Optional[T], next: T) -> T:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
|
|
|
|
|
prev_parsed = None
|
|
|
|
|
acc_gen = None
|
|
|
|
|
for chunk in input:
|
|
|
|
|
if isinstance(chunk, BaseMessageChunk):
|
|
|
|
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
|
|
|
|
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
|
|
|
|
elif isinstance(chunk, BaseMessage):
|
|
|
|
|
chunk_gen = ChatGenerationChunk(
|
|
|
|
|
message=BaseMessageChunk(**chunk.dict())
|
|
|
|
@ -355,16 +361,21 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|
|
|
|
acc_gen += chunk_gen
|
|
|
|
|
|
|
|
|
|
parsed = self.parse_result([acc_gen])
|
|
|
|
|
if parsed is not None:
|
|
|
|
|
yield parsed
|
|
|
|
|
if parsed is not None and parsed != prev_parsed:
|
|
|
|
|
if self.diff:
|
|
|
|
|
yield self._diff(prev_parsed, parsed)
|
|
|
|
|
else:
|
|
|
|
|
yield parsed
|
|
|
|
|
prev_parsed = parsed
|
|
|
|
|
|
|
|
|
|
async def _atransform(
|
|
|
|
|
self, input: AsyncIterator[Union[str, BaseMessage]]
|
|
|
|
|
) -> AsyncIterator[T]:
|
|
|
|
|
prev_parsed = None
|
|
|
|
|
acc_gen = None
|
|
|
|
|
for chunk in input:
|
|
|
|
|
async for chunk in input:
|
|
|
|
|
if isinstance(chunk, BaseMessageChunk):
|
|
|
|
|
chunk_gen = ChatGenerationChunk(message=chunk)
|
|
|
|
|
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
|
|
|
|
|
elif isinstance(chunk, BaseMessage):
|
|
|
|
|
chunk_gen = ChatGenerationChunk(
|
|
|
|
|
message=BaseMessageChunk(**chunk.dict())
|
|
|
|
@ -378,8 +389,12 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
|
|
|
|
|
acc_gen += chunk_gen
|
|
|
|
|
|
|
|
|
|
parsed = self.parse_result([acc_gen])
|
|
|
|
|
if parsed is not None:
|
|
|
|
|
yield parsed
|
|
|
|
|
if parsed is not None and parsed != prev_parsed:
|
|
|
|
|
if self.diff:
|
|
|
|
|
yield self._diff(prev_parsed, parsed)
|
|
|
|
|
else:
|
|
|
|
|
yield parsed
|
|
|
|
|
prev_parsed = parsed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StrOutputParser(BaseTransformOutputParser[str]):
|
|
|
|
|