Fix bloom

pull/464/head
Aleksandr Borzunov 10 months ago
parent 735db64994
commit 813090e4fa

@ -1,11 +1,12 @@
import contextlib
import dataclasses
from typing import ContextManager, Optional
from typing import ContextManager, List, Optional
import torch
from hivemind.utils.logging import get_logger
from petals.client.inference_session import InferenceSession
from petals.utils.misc import DUMMY
logger = get_logger(__name__)
@ -14,6 +15,9 @@ logger = get_logger(__name__)
class RemotePastKeyValues:
hypo_ids: Optional[torch.LongTensor] = None
def __getitem__(self, _index: int) -> List[torch.Tensor]:
return [DUMMY] # For compatibility with BloomForCausalLM.prepare_inputs_for_generation()
class RemoteGenerationMixin:
"""

Loading…
Cancel
Save