|
|
|
@ -5,6 +5,7 @@ from typing import ContextManager, List, Optional
|
|
|
|
|
import torch
|
|
|
|
|
import transformers
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from transformers.generation.utils import ModelOutput
|
|
|
|
|
|
|
|
|
|
from petals.client.inference_session import InferenceSession
|
|
|
|
|
from petals.client.remote_sequential import RemoteSequential
|
|
|
|
@ -86,11 +87,16 @@ class RemoteGenerationMixin:
|
|
|
|
|
|
|
|
|
|
result = super().generate(inputs, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
sequences = result.sequences if isinstance(result, ModelOutput) else result
|
|
|
|
|
# Crop the last tokens from the previous call
|
|
|
|
|
if session.last_token_id is not None:
|
|
|
|
|
result = result[:, 1:]
|
|
|
|
|
sequences = sequences[:, 1:]
|
|
|
|
|
if isinstance(result, ModelOutput):
|
|
|
|
|
result.sequences = sequences
|
|
|
|
|
else:
|
|
|
|
|
result = sequences
|
|
|
|
|
# Save the last tokens from this call
|
|
|
|
|
session.last_token_id = result[:, -1:]
|
|
|
|
|
session.last_token_id = sequences[:, -1:]
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|