Support .generate(..., return_dict_in_generate=True)

pull/464/head
Aleksandr Borzunov 10 months ago
parent bacdca0f5c
commit 4d1c228dd6

@ -5,6 +5,7 @@ from typing import ContextManager, List, Optional
import torch
import transformers
from hivemind.utils.logging import get_logger
from transformers.generation.utils import ModelOutput
from petals.client.inference_session import InferenceSession
from petals.client.remote_sequential import RemoteSequential
@ -86,11 +87,16 @@ class RemoteGenerationMixin:
result = super().generate(inputs, *args, **kwargs)
sequences = result.sequences if isinstance(result, ModelOutput) else result
# Crop the last tokens from the previous call
if session.last_token_id is not None:
result = result[:, 1:]
sequences = sequences[:, 1:]
if isinstance(result, ModelOutput):
result.sequences = sequences
else:
result = sequences
# Save the last tokens from this call
session.last_token_id = result[:, -1:]
session.last_token_id = sequences[:, -1:]
return result

Loading…
Cancel
Save