Implement diff

pull/11222/head
Nuno Campos 1 year ago
parent 5cbe2b7b6a
commit 4e28a7a513

@ -337,11 +337,17 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
"""Base class for an output parser that can handle streaming input."""
diff: bool = False
def _diff(self, prev: Optional[T], next: T) -> T:
raise NotImplementedError()
def _transform(self, input: Iterator[Union[str, BaseMessage]]) -> Iterator[Any]:
prev_parsed = None
acc_gen = None
for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk)
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
@ -355,16 +361,21 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen])
if parsed is not None:
yield parsed
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[T]:
prev_parsed = None
acc_gen = None
for chunk in input:
async for chunk in input:
if isinstance(chunk, BaseMessageChunk):
chunk_gen = ChatGenerationChunk(message=chunk)
chunk_gen: Generation = ChatGenerationChunk(message=chunk)
elif isinstance(chunk, BaseMessage):
chunk_gen = ChatGenerationChunk(
message=BaseMessageChunk(**chunk.dict())
@ -378,8 +389,12 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
acc_gen += chunk_gen
parsed = self.parse_result([acc_gen])
if parsed is not None:
yield parsed
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed
class StrOutputParser(BaseTransformOutputParser[str]):

Loading…
Cancel
Save