|
|
|
@ -1,11 +1,12 @@
|
|
|
|
|
import contextlib
|
|
|
|
|
import dataclasses
|
|
|
|
|
from typing import ContextManager, Optional
|
|
|
|
|
from typing import ContextManager, List, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
|
|
|
|
|
from petals.client.inference_session import InferenceSession
|
|
|
|
|
from petals.utils.misc import DUMMY
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
@ -14,6 +15,9 @@ logger = get_logger(__name__)
|
|
|
|
|
class RemotePastKeyValues:
|
|
|
|
|
hypo_ids: Optional[torch.LongTensor] = None
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, _index: int) -> List[torch.Tensor]:
|
|
|
|
|
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RemoteGenerationMixin:
|
|
|
|
|
"""
|
|
|
|
|