|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
import asyncio
|
|
|
|
import itertools
|
|
|
|
import time
|
|
|
|
import uuid
|
|
|
|
from typing import AsyncIterator, List, Optional, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
|
|
|
|
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
|
|
|
|
from hivemind.p2p import P2P
|
|
|
|
from hivemind.proto import runtime_pb2
|
|
|
|
from hivemind.utils.tensor_descr import BatchTensorDescriptor
|
|
|
|
|
|
|
|
from petals.client.config import ClientConfig
|
|
|
|
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
|
|
|
|
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
|
|
|
|
from petals.server.handler import TransformerConnectionHandler
|
|
|
|
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
|
|
|
|
from petals.utils.packaging import pack_args_kwargs
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class _ServerInferenceSession:
|
|
|
|
"""
|
|
|
|
An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
|
|
|
|
|
|
|
|
:note: This class is *not* fault-tolerant out of the box.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
config: ClientConfig,
|
|
|
|
span: RemoteSpanInfo,
|
|
|
|
uid: ModuleUID,
|
|
|
|
rpc_info: RPCInfo,
|
|
|
|
inputs_queue: asyncio.Queue,
|
|
|
|
outputs_aiter: AsyncIterator,
|
|
|
|
*,
|
|
|
|
max_length: int,
|
|
|
|
**metadata,
|
|
|
|
):
|
|
|
|
self.config = config
|
|
|
|
self.span, self.uid, self.rpc_info = span, uid, rpc_info
|
|
|
|
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
|
|
|
|
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
|
|
|
|
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
|
|
|
|
self.session_id = str(uuid.uuid4())
|
|
|
|
self.session_metadata = dict(max_length=max_length, **metadata)
|
|
|
|
self.stepped = False
|
|
|
|
self.closed = False
|
|
|
|
|
|
|
|
self._position = 0
|
|
|
|
self.history = None # Used in case of server failures to regenerate attention caches on new servers
|
|
|
|
self.next_session = None
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
async def create(
|
|
|
|
cls,
|
|
|
|
config: ClientConfig,
|
|
|
|
p2p: P2P,
|
|
|
|
span: RemoteSpanInfo,
|
|
|
|
uid: ModuleUID,
|
|
|
|
rpc_info: RPCInfo,
|
|
|
|
**metadata,
|
|
|
|
) -> _ServerInferenceSession:
|
|
|
|
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
|
|
|
|
inputs_queue = asyncio.Queue()
|
|
|
|
outputs_stream = await asyncio.wait_for(
|
|
|
|
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
|
|
|
|
config.connect_timeout,
|
|
|
|
)
|
|
|
|
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
|
|
|
|
while True:
|
|
|
|
next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
|
|
|
|
yield next_input_message
|
|
|
|
if not next_input_message.uid and not next_input_message.tensors:
|
|
|
|
break # this message means "done sending"
|
|
|
|
|
|
|
|
def step(
|
|
|
|
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
|
|
|
|
) -> torch.Tensor:
|
|
|
|
"""
|
|
|
|
Inference step: send a chunk of input tensors and receive a chunk of outputs
|
|
|
|
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
|
|
|
|
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
|
|
|
|
"""
|
|
|
|
if self.closed:
|
|
|
|
raise Exception("Session is closed, cannot perform step")
|
|
|
|
|
|
|
|
n_input_tokens = inputs.shape[1]
|
|
|
|
if self.history is None:
|
|
|
|
self.history = inputs
|
|
|
|
elif self.history.shape[1] == self._position:
|
|
|
|
self.history = torch.cat([self.history, inputs[:, -n_input_tokens:]], dim=1)
|
|
|
|
assert self.history.shape[1] == self._position + n_input_tokens, (
|
|
|
|
f"Broken input cache: span={self.span} shape={self.history.shape} "
|
|
|
|
f"position={self._position} n_input_tokens={n_input_tokens}"
|
|
|
|
)
|
|
|
|
|
|
|
|
if not self.stepped:
|
|
|
|
inputs = self.history # Pass full inputs including prefix
|
|
|
|
else:
|
|
|
|
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
|
|
|
|
|
|
|
|
# serialize inputs and put them into the queue
|
|
|
|
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)
|
|
|
|
|
|
|
|
request_metadata = dict(session_id=self.session_id, step_id=step_id)
|
|
|
|
if not self.stepped:
|
|
|
|
request_metadata.update(self.session_metadata)
|
|
|
|
elif self.config.use_server_to_server:
|
|
|
|
next_servers = self._collect_next_servers()
|
|
|
|
if next_servers:
|
|
|
|
request_metadata["next_servers"] = next_servers
|
|
|
|
|
|
|
|
request_metadata["args_structure"] = args_structure
|
|
|
|
|
|
|
|
# TODO: make possible to use different compression method for different tensors
|
|
|
|
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
|
|
|
|
compression = server_side_inference_schema[0].compression
|
|
|
|
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
|
|
|
|
|
|
|
|
# TODO: create more explicit way to check servers schema and client's structure
|
|
|
|
assert len(input_tensors) >= len(
|
|
|
|
server_side_inference_schema
|
|
|
|
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
|
|
|
|
|
|
|
|
outputs_serialized = RemoteExpertWorker.run_coroutine(
|
|
|
|
self._step(
|
|
|
|
runtime_pb2.ExpertRequest(
|
|
|
|
uid=self.uid,
|
|
|
|
tensors=[
|
|
|
|
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
|
|
|
|
for tensor, proto in zip(input_tensors, inference_schema)
|
|
|
|
],
|
|
|
|
metadata=MSGPackSerializer.dumps(request_metadata),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
outputs = list(map(deserialize_torch_tensor, outputs_serialized.tensors))
|
|
|
|
assert (
|
|
|
|
outputs[0].shape == inputs.shape
|
|
|
|
), f"output activation shape is different from input shape: {outputs[0].shape} != {inputs.shape}"
|
|
|
|
|
|
|
|
self._position += n_input_tokens
|
|
|
|
|
|
|
|
return outputs[0]
|
|
|
|
|
|
|
|
def _collect_next_servers(self) -> List[Tuple[str, str, int, int]]:
|
|
|
|
next_servers = []
|
|
|
|
session = self.next_session
|
|
|
|
while session is not None and session.stepped:
|
|
|
|
next_servers.append(
|
|
|
|
(session.span.peer_id.to_base58(), session.session_id, session.span.start, session.span.end)
|
|
|
|
)
|
|
|
|
session = session.next_session
|
|
|
|
return next_servers
|
|
|
|
|
|
|
|
async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_pb2.ExpertResponse:
|
|
|
|
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
await self._inputs_queue.put(inputs_serialized)
|
|
|
|
self.stepped = True
|
|
|
|
return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
|
|
|
|
|
|
|
|
def close(self):
|
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
|
if self._outputs_stream is None:
|
|
|
|
return # already closed
|
|
|
|
RemoteExpertWorker.run_coroutine(self._aclose_stream())
|
|
|
|
self._outputs_stream = self._inputs_queue = None
|
|
|
|
self.closed = True
|
|
|
|
|
|
|
|
async def _aclose_stream(self):
|
|
|
|
"""Close the inference session. This code is meant to be run inside RemoteExpertWorker"""
|
|
|
|
if self._outputs_stream is None:
|
|
|
|
return # already closed
|
|
|
|
if self.stepped:
|
|
|
|
await self._inputs_queue.put(runtime_pb2.ExpertRequest()) # empty request will trigger end of session
|
|
|
|
try:
|
|
|
|
await anext(self._outputs_stream)
|
|
|
|
except StopAsyncIteration:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
self.close()
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
assert not self.closed
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, *exc_details):
|
|
|
|
self.close()
|
|
|
|
|
|
|
|
|
|
|
|
class InferenceSession:
|
|
|
|
"""
|
|
|
|
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
|
|
|
|
"""
|
|
|
|
|
Refactor RemoteSequenceManager (#309)
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
|
|
|
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
|
|
|
|
self._sequence_manager = sequence_manager
|
|
|
|
self._closed = False
|
|
|
|
self._server_sessions = []
|
|
|
|
self._position = 0
|
|
|
|
self._max_length = max_length
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
10 months ago
|
|
|
self.output_ids = None
|
|
|
|
|
|
|
|
@property
|
|
|
|
def num_blocks(self) -> int:
|
|
|
|
return len(self._sequence_manager)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def position(self) -> int:
|
|
|
|
return self._position
|
|
|
|
|
|
|
|
def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
|
|
|
|
server_sessions = []
|
|
|
|
try:
|
|
|
|
for span in chosen_spans:
|
|
|
|
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
|
|
|
|
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
|
|
|
|
session = RemoteExpertWorker.run_coroutine(
|
|
|
|
_ServerInferenceSession.create(
|
|
|
|
self._sequence_manager.config,
|
|
|
|
self._sequence_manager.state.p2p,
|
|
|
|
span,
|
|
|
|
span_uids,
|
|
|
|
rpc_info=self._sequence_manager.rpc_info,
|
|
|
|
max_length=self._max_length,
|
|
|
|
**metadata,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
server_sessions.append(session)
|
|
|
|
session.__enter__()
|
|
|
|
return server_sessions
|
|
|
|
except:
|
|
|
|
self._exit_server_sessions(server_sessions)
|
|
|
|
raise
|
|
|
|
|
|
|
|
def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
|
|
|
|
for session in reversed(server_sessions):
|
|
|
|
try:
|
|
|
|
session.__exit__(None, None, None)
|
|
|
|
except Exception:
|
|
|
|
logger.debug("Caught exception while closing connection to server:", exc_info=True)
|
|
|
|
|
|
|
|
def __enter__(self) -> "InferenceSession":
|
|
|
|
assert not self._closed and not self._server_sessions
|
|
|
|
return self
|
|
|
|
|
|
|
|
def step(
|
|
|
|
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
|
|
|
|
) -> torch.Tensor:
|
|
|
|
assert not self._closed
|
|
|
|
if torch.is_grad_enabled():
|
|
|
|
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
|
|
|
|
|
|
|
|
if prompts is None or is_dummy(prompts):
|
|
|
|
prompts = DUMMY
|
|
|
|
else:
|
|
|
|
assert prompts.ndim == 4, "deep prompts should have shape [num_blocks, batch_size, prefix_len, hid_size]"
|
|
|
|
assert prompts.shape[0] == self.num_blocks
|
|
|
|
assert prompts.shape[1] in (inputs.shape[0], 1)
|
|
|
|
assert prompts.shape[2] <= inputs.shape[1]
|
|
|
|
assert prompts.shape[3] == inputs.shape[2]
|
|
|
|
|
|
|
|
if hypo_ids is None or is_dummy(hypo_ids):
|
|
|
|
hypo_ids = DUMMY_INT64
|
|
|
|
else:
|
|
|
|
assert len(hypo_ids) == len(inputs)
|
|
|
|
assert hypo_ids.dtype == torch.int64
|
|
|
|
|
|
|
|
inputs_device = inputs.device
|
|
|
|
inputs_dtype = inputs.dtype
|
|
|
|
inputs = inputs.cpu()
|
|
|
|
prompts = prompts.cpu()
|
|
|
|
hypo_ids = hypo_ids.cpu()
|
|
|
|
step_id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
n_input_tokens = inputs.shape[1]
|
|
|
|
if self._position + n_input_tokens > self._max_length:
|
|
|
|
raise ValueError(
|
|
|
|
f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
|
|
|
|
)
|
|
|
|
|
|
|
|
server_idx = 0
|
|
|
|
block_idx = 0
|
|
|
|
while block_idx < self.num_blocks:
|
|
|
|
for attempt_no in itertools.count():
|
|
|
|
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
|
|
|
|
server_session = None
|
|
|
|
try:
|
|
|
|
if not self._server_sessions or attempt_no >= 1:
|
|
|
|
self._update_sequence(server_idx, block_idx, attempt_no)
|
|
|
|
|
|
|
|
server_session = self._server_sessions[server_idx]
|
|
|
|
inputs = server_session.step(
|
|
|
|
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
|
|
|
|
)
|
|
|
|
|
|
|
|
server_idx += 1
|
|
|
|
block_idx = server_session.span.end
|
|
|
|
self._sequence_manager.on_request_success(server_session.span.peer_id)
|
|
|
|
break
|
|
|
|
except Exception as e:
|
|
|
|
self._sequence_manager.on_request_failure(
|
|
|
|
server_session.span.peer_id if server_session is not None else None
|
|
|
|
)
|
Refactor RemoteSequenceManager (#309)
This PR:
1. **Extracts `SequenceManagerConfig` and `SequenceManagerState` subclasses.**
The config is provided by caller and never changed from inside `RemoteSequenceManager`. The state is a part of the `RemoteSequenceManager`'s state shared between the main manager and its slices. We fix some slicing bugs along the way.
2. **Removes `dht_prefix` and `p2p` arguments, makes `dht` argument optional.**
`dht_prefix` can always be overridden using `config.dht_prefix`. `p2p` actually needed only under the hood of `RemoteSequenceManager`, so it can extract it by itself without exposing this low-level class to callers. If strictly necessary, a caller can provide `p2p` as a part of `SequenceManagerState`. `dht` is also needed only by `RemoteSequenceManager`, so we can make it optional in the parent classes and create it automatically when it's not provided.
3. **Simplifies retry logic.**
Previously, we could have "nested" retry loops: one in `._update()`, another in inference/forward/backward steps. The loop in `._update()` could introduce issues to concurrent inference/forward/backward calls, since it blocks the entire class if its delay period becomes too high. Now this logic is simplified: `._update()` performs only one attempt to fetch the DHT info, any retries are triggered by the inference/forward/backward steps.
4. **Removes deprecated `RemoteTransformerBlock`.**
`RemoteTransformerBlock` was deprecated a long time ago, before Petals 1.0.0. Its removal is long due.
5. **Removes `dht_utils.get_remote_module()`, `dht_utils.get_remote_sequence()`.**
This functions duplicate the functionality of the `RemoteSequential` constructor.
6. (minor) **Removes `RemoteSequential.is_subsequence` flag.**
This flag worked incorrectly and was never used. I am removing it for the sake of simplicity.
1 year ago
|
|
|
if attempt_no + 1 == self._sequence_manager.config.max_retries:
|
|
|
|
raise
|
|
|
|
delay = self._sequence_manager.get_retry_delay(attempt_no)
|
|
|
|
logger.warning(
|
|
|
|
f"Caught exception when running inference via {server_session.span if server_session is not None else None} "
|
|
|
|
f"(retry in {delay:.0f} sec): {repr(e)}"
|
|
|
|
)
|
|
|
|
maybe_log_traceback(e)
|
|
|
|
time.sleep(delay)
|
|
|
|
|
|
|
|
self._position += n_input_tokens
|
|
|
|
outputs = inputs[:, -n_input_tokens:]
|
|
|
|
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
|
|
|
|
# If there is a failed server session, this code closes it
|
|
|
|
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
|
|
|
|
|
|
|
|
n_prev_spans = len(self._server_sessions)
|
|
|
|
update_end = self._server_sessions[server_idx].span.end if server_idx < n_prev_spans else self.num_blocks
|
|
|
|
if attempt_no >= 1:
|
|
|
|
logger.debug(
|
|
|
|
f"Due to a server failure, remote attention caches "
|
|
|
|
f"from block {block_idx} to {update_end} will be regenerated"
|
|
|
|
)
|
|
|
|
|
|
|
|
updated_spans = self._sequence_manager.make_sequence(
|
|
|
|
block_idx, update_end, mode="min_latency", cache_tokens_needed=self._max_length
|
|
|
|
)
|
|
|
|
# make_sequence() could return a longer sequence
|
|
|
|
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
|
|
|
|
updated_sessions = self._enter_server_sessions(updated_spans)
|
|
|
|
logger.debug(f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers")
|
|
|
|
|
|
|
|
# If there is a failed span, this code replaces it, otherwise it just adds new ones
|
|
|
|
if server_idx < n_prev_spans:
|
|
|
|
updated_sessions[0].history = self._server_sessions[server_idx].history
|
|
|
|
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
|
|
|
|
|
|
|
|
# Update links to the next server session for direct server-to-server communication via rpc_push()
|
|
|
|
for i in range(max(server_idx - 1, 0), min(server_idx + len(updated_spans), len(self._server_sessions) - 1)):
|
|
|
|
self._server_sessions[i].next_session = self._server_sessions[i + 1]
|
|
|
|
|
|
|
|
def close(self, *exc_details):
|
|
|
|
"""Finish a given inference session, close the underlying connection"""
|
|
|
|
if not self._closed:
|
|
|
|
self._exit_server_sessions(self._server_sessions)
|
|
|
|
self._server_sessions.clear()
|
|
|
|
self._closed = True
|
|
|
|
|
|
|
|
def __exit__(self, *exc_details):
|
|
|
|
self.close(*exc_details)
|
|
|
|
|
|
|
|
def __del__(self):
|
|
|
|
self.close()
|
Make client compatible with transformers' GenerationMixin (#464)
This PR drops custom generation codes and introduces compatibility with `transformers.GenerationMixin` instead. This includes support for more sampling options (`top_p`, `top_k`, `repetition_penalty` requested in #460) and beam search - all that is now identical to running model with transformers locally.
Most features (excluding beam search and other rarely used stuff) are also compatible with resuming existing sessions.
### Breaking changes
If `.generate()` or forward passes are being run inside an `.inference_session()` context, they now use the opened session by default. So, these snippets are now equivalent:
```python
# Using default session
with model.inference_session(max_length=100):
output_ids = model.generate(input_ids, max_new_tokens=3)
# Explicitly specifying a session
with model.inference_session(max_length=100) as sess:
output_ids = model.generate(input_ids, max_new_tokens=3, session=sess)
```
Earlier, the 1st snippet was creating a new session, which is not what most people expected (= such code was most likely to introduce a bug, which is now fixed).
10 months ago
|
|
|
|
|
|
|
@property
|
|
|
|
def last_token_id(self) -> Optional[torch.Tensor]: # Backward compatibility with Petals < 2.1.0
|
|
|
|
return self.output_ids[:, -1:] if self.output_ids is not None else None
|
|
|
|
|
|
|
|
@last_token_id.setter
|
|
|
|
def last_token_id(self, value: torch.Tensor): # Backward compatibility with Petals < 2.1.0
|
|
|
|
if self.output_ids is None:
|
|
|
|
raise RuntimeError("Can't override `last_token_id` since the session has not stepped yet")
|
|
|
|
self.output_ids[:, -1:] = value
|