Add option to rollback inference for a certain number of steps (#588)

* fix

* fix

* fix

* fix

* fix

* fix

* style
pull/557/head^2
Anton Sinitsin 3 months ago committed by GitHub
parent 68585864ae
commit c0a4d2e3d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -84,7 +84,13 @@ class _ServerInferenceSession:
break # this message means "done sending"
def step(
self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, step_id: str
self,
inputs: torch.Tensor,
prompts: torch.Tensor,
hypo_ids: torch.LongTensor,
*,
step_id: str,
start_from_position: int,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
@ -94,6 +100,12 @@ class _ServerInferenceSession:
if self.closed:
raise Exception("Session is closed, cannot perform step")
if start_from_position is not None:
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None
n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
@ -115,6 +127,8 @@ class _ServerInferenceSession:
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
if start_from_position is not None:
request_metadata["start_from_position"] = start_from_position
elif self.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
@ -257,8 +271,16 @@ class InferenceSession:
return self
def step(
self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
start_from_position: Optional[int] = None,
) -> torch.Tensor:
if start_from_position is not None:
self._position = start_from_position
assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
@ -303,7 +325,11 @@ class InferenceSession:
server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, step_id=step_id
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids,
step_id=step_id,
start_from_position=start_from_position,
)
server_idx += 1

@ -160,6 +160,13 @@ async def iterate_rpc_inference(
point_per_piece = points / max_length if max_length > 0 else 0.0
async for request, step_metadata in input_iterator:
if "start_from_position" in step_metadata:
start_from_position = step_metadata["start_from_position"]
assert (
prefix_length >= start_from_position,
), f"prefix_length={prefix_length}, start_from_position={start_from_position}"
prefix_length = start_from_position
flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)
if args_structure is not None:
# TODO: kwargs currently is unused, it can be used later for peft-like adaptation

@ -0,0 +1,35 @@
import random
import pytest
import torch
from petals import AutoDistributedConfig, RemoteSequential
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
@pytest.mark.forked
def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, atol_inference=1e-3):
config = AutoDistributedConfig.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
remote_sequential = RemoteSequential(config)
block_index = random.randint(0, config.num_hidden_layers - 1)
remote_block = remote_sequential[block_index]
inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
short_inputs = torch.randn(1, MAX_SHORT_INFERENCE_TOKENS - 50, config.hidden_size)
short_inputs[:, :2, :] = inputs[:, :2, :]
initial_outputs_inference = None
secondary_outputs_inference = None
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
initial_outputs_inference = sess.step(inputs)
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)
ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(short_inputs)
assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)
Loading…
Cancel
Save