Make inference, forward, and backward fully fault-tolerant (#91)

pull/93/head
Alexander Borzunov 2 years ago committed by GitHub
parent 695df826c2
commit 11d6ba683c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,5 +4,5 @@ accelerate==0.10.0
huggingface-hub==0.7.0 huggingface-hub==0.7.0
transformers==4.21.3 transformers==4.21.3
protobuf>=3.12.2,<4.0.0 protobuf>=3.12.2,<4.0.0
git+https://github.com/learning-at-home/hivemind@94c985d2dc7a79a091e46c755e9f2f4469b164c7 git+https://github.com/learning-at-home/hivemind@8f258b4b3688f671208bf323359cb967b25d640a
humanfriendly humanfriendly

@ -1,4 +1,4 @@
from src.client.inference_session import RemoteSequentialInferenceSession, RemoteTransformerBlockInferenceSession from src.client.inference_session import InferenceSession
from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel from src.client.remote_model import DistributedBloomConfig, DistributedBloomForCausalLM, DistributedBloomModel
from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock from src.client.remote_sequential import RemoteSequential, RemoteTransformerBlock
from src.client.sequence_manager import RemoteSequenceManager from src.client.sequence_manager import RemoteSequenceManager

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib import itertools
import time
from typing import AsyncIterator, List, Optional from typing import AsyncIterator, List, Optional
import torch import torch
@ -13,26 +14,25 @@ from hivemind import (
get_logger, get_logger,
nested_flatten, nested_flatten,
serialize_torch_tensor, serialize_torch_tensor,
use_hivemind_log_handler,
) )
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import StubBase from hivemind.p2p import StubBase
from hivemind.proto import runtime_pb2 from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import aiter_with_timeout
from src.client.sequence_manager import RemoteSequenceManager from src.client.sequence_manager import RemoteSequenceManager
from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo from src.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from src.server.handler import TransformerConnectionHandler from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy from src.utils.misc import DUMMY, is_dummy
use_hivemind_log_handler("in_root_logger")
logger = get_logger(__file__) logger = get_logger(__file__)
class RemoteTransformerBlockInferenceSession: class _ServerInferenceSession:
""" """
An interface to a single multi-step *inference* session for a specific remote module on a specific server An interface to a single multi-step *inference* session for a a set of blocks on a specific server.
:note: this inference session is *not* fault-tolerant out of the box :note: This class is *not* fault-tolerant out of the box.
""" """
def __init__( def __init__(
@ -42,32 +42,35 @@ class RemoteTransformerBlockInferenceSession:
inputs_queue: asyncio.Queue, inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator, outputs_aiter: AsyncIterator,
*, *,
timeout: float,
max_length: int, max_length: int,
points: int = 0, points: int = 0,
): ):
self.uid, self.rpc_info = uid, rpc_info self.uid, self.rpc_info = uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1 self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
# warning: this code manages async objects that are only usable inside RemoteExpertWorker's background thread;
# using them in any other EventLoop may cause side-effects including, headaches, diarrhea, and loss of sleep
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.timeout = timeout
self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points)) self._serialized_metadata = MSGPackSerializer.dumps(dict(max_length=max_length, points=points))
self.stepped = False self.stepped = False
self.closed = False self.closed = False
@classmethod @classmethod
async def _create( async def create(
cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: Optional[float] = None, **metadata cls, stub: StubBase, uid: ModuleUID, rpc_info: RPCInfo, timeout: float, **metadata
) -> RemoteTransformerBlockInferenceSession: ) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker""" """Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
inputs_queue = asyncio.Queue() inputs_queue = asyncio.Queue()
outputs_stream = await stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue, timeout), timeout=timeout) outputs_stream = await asyncio.wait_for(
return cls(uid, rpc_info, inputs_queue, outputs_stream, **metadata) stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
timeout,
)
return cls(uid, rpc_info, inputs_queue, outputs_stream, timeout=timeout, **metadata)
@staticmethod @staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, timeout: Optional[float]) -> AsyncIterator: async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
while True: while True:
next_input_message = await asyncio.wait_for(queue.get(), timeout) next_input_message = await asyncio.wait_for(queue.get(), input_timeout)
yield next_input_message yield next_input_message
if not next_input_message.uid and not next_input_message.tensors: if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending" break # this message means "done sending"
@ -77,7 +80,7 @@ class RemoteTransformerBlockInferenceSession:
new_hidden_states: torch.Tensor, new_hidden_states: torch.Tensor,
prompts: Optional[torch.Tensor] = None, prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None, hypo_ids: Optional[torch.Tensor] = None,
): ) -> torch.Tensor:
""" """
Inference step: send a chunk of input tesors and receive a chunk of outputs Inference step: send a chunk of input tesors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs, :prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
@ -122,7 +125,7 @@ class RemoteTransformerBlockInferenceSession:
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker""" """Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized) await self._inputs_queue.put(inputs_serialized)
self.stepped = True self.stepped = True
return await anext(self._outputs_stream) return await asyncio.wait_for(anext(self._outputs_stream), self.timeout)
def close(self): def close(self):
"""Finish a given inference session, close the underlying connection""" """Finish a given inference session, close the underlying connection"""
@ -154,60 +157,163 @@ class RemoteTransformerBlockInferenceSession:
self.close() self.close()
class RemoteSequentialInferenceSession: class InferenceSession:
""" """
An interface to a multi-step *inference* session for a sequence of remote transformer blocks An interface to a multi-step *inference* session for a sequence of remote transformer blocks
""" """
def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, timeout: Optional[float] = None, **metadata): def __init__(self, sequence_manager: RemoteSequenceManager, p2p: P2P, max_length: int, **metadata):
self.sequence_manager = sequence_manager self._sequence_manager = sequence_manager
self.p2p = p2p self._p2p = p2p
self.closed = False self._closed = False
self.chosen_spans: List[RemoteSpanInfo] = [] self._chosen_spans = []
self.stack = contextlib.ExitStack() self._server_sessions = []
self.inference_sessions: List[RemoteTransformerBlockInferenceSession] = [] self._server_inputs = [] # Used in case of server failures to regenerate attention caches on new servers
self.metadata = metadata self._position = 0
self.timeout = timeout self._max_length = max_length
self._metadata = metadata
def __enter__(self): def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
assert not self.closed and not self.chosen_spans server_sessions = []
self.stack.__enter__() try:
# TODO(yozh) replace this code with a fault-tolerant chain that can be reconstructed if some peers fail for span in chosen_spans:
self.chosen_spans.extend(self.sequence_manager.make_sequence()) stub = TransformerConnectionHandler.get_stub(self._p2p, span.peer_id)
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
for chosen_span in self.chosen_spans: session = RemoteExpertWorker.run_coroutine(
stub = TransformerConnectionHandler.get_stub(self.p2p, chosen_span.peer_id) _ServerInferenceSession.create(
span_uids: str = CHAIN_DELIMITER.join(self.sequence_manager.block_uids[chosen_span.start : chosen_span.end]) stub,
inference_session = RemoteExpertWorker.run_coroutine( span_uids,
RemoteTransformerBlockInferenceSession._create( rpc_info=self._sequence_manager.rpc_info,
stub, span_uids, rpc_info=self.sequence_manager.rpc_info, timeout=self.timeout, **self.metadata timeout=self._sequence_manager.timeout,
max_length=self._max_length,
**self._metadata,
)
) )
) server_sessions.append(session)
self.inference_sessions.append(inference_session) session.__enter__()
self.stack.enter_context(inference_session) return server_sessions
except:
self._exit_server_sessions(server_sessions)
raise
def _exit_server_sessions(self, server_sessions: List[_ServerInferenceSession]) -> None:
for session in reversed(server_sessions):
try:
session.__exit__(None, None, None)
except Exception:
logger.debug("Caught exception while closing connection to server:", exc_info=True)
def __enter__(self) -> "InferenceSession":
assert not self._closed and not self._chosen_spans
return self return self
def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs): def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
assert not self.closed assert not self._closed
if torch.is_grad_enabled(): if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.") logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
n_blocks = len(self._sequence_manager)
if prompts is None or is_dummy(prompts): if prompts is None or is_dummy(prompts):
prompts = DUMMY prompts = DUMMY
else: else:
assert prompts.ndim == 4 and prompts.shape[0] == len(self.sequence_manager) assert prompts.ndim == 4 and prompts.shape[0] == n_blocks
for session in self.inference_sessions:
outputs = session.step(inputs, prompts[self.chosen_spans[0].start : self.chosen_spans[0].end], **kwargs) n_input_tokens = inputs.shape[1]
assert outputs.shape == inputs.shape, f"expected {inputs.shape}, got {outputs.shape}" if self._position + n_input_tokens > self._max_length:
inputs = outputs raise ValueError(
f"Maximum length exceeded: prefix {self._position} + current {n_input_tokens} exceeds pre-allocated maximum {self._max_length}"
)
server_idx = 0
block_idx = 0
recovery_until = -1 # Recovery mode is disabled until a failure happens
while block_idx < n_blocks:
for attempt_no in itertools.count():
logger.debug(f"Inference: block {block_idx}, attempt {attempt_no}")
try:
if attempt_no >= 1:
self._sequence_manager.update_()
if not self._chosen_spans or not self._server_sessions or attempt_no >= 1:
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])
n_prev_spans = len(self._chosen_spans)
update_end = self._chosen_spans[server_idx].end if server_idx < n_prev_spans else n_blocks
if attempt_no >= 1 and update_end > recovery_until:
logger.info(
f"Due to a server failure, remote attention caches "
f"from block {block_idx} to {update_end} will be regenerated"
)
recovery_until = max(recovery_until, update_end)
updated_spans = self._sequence_manager.make_sequence(block_idx, update_end)
# make_sequence() could return a longer sequence
updated_spans[-1].end = min(updated_spans[-1].end, update_end)
updated_sessions = self._enter_server_sessions(updated_spans)
logger.debug(
f"Found path from block {block_idx} to {update_end} via {len(updated_spans)} servers"
)
# If there is a failed span, this code replaces it, otherwise it just adds new ones
self._chosen_spans[server_idx : server_idx + 1] = updated_spans
self._server_sessions[server_idx : server_idx + 1] = updated_sessions
recovery_inputs = self._server_inputs[server_idx] if server_idx < n_prev_spans else None
self._server_inputs[server_idx : server_idx + 1] = [recovery_inputs] + [None] * (
len(updated_spans) - 1
)
assert len(self._chosen_spans) == len(self._server_sessions) == len(self._server_inputs), (
f"Broken state: {len(self._chosen_spans)} spans, {len(self._server_sessions)} sessions, "
f"{len(self._server_inputs)} inputs"
)
session = self._server_sessions[server_idx]
span = self._chosen_spans[server_idx]
if self._server_inputs[server_idx] is None:
self._server_inputs[server_idx] = inputs
elif self._server_inputs[server_idx].shape[1] == self._position:
self._server_inputs[server_idx] = torch.cat(
[self._server_inputs[server_idx], inputs[:, -n_input_tokens:]], dim=1
)
assert self._server_inputs[server_idx].shape[1] == self._position + n_input_tokens, (
f"Broken input cache: server_idx={server_idx} shape={self._server_inputs[server_idx].shape} "
f"position={self._position} n_input_tokens={n_input_tokens}"
)
if not session.stepped:
inputs = self._server_inputs[server_idx] # Pass full inputs including prefix
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
outputs = session.step(inputs, prompts[span.start : span.end], **kwargs)
assert (
inputs.shape == outputs.shape
), f"Shape mismatch: inputs.shape={inputs.shape}, outputs.shape={outputs.shape})"
inputs = outputs
server_idx += 1
block_idx = span.end
break
except Exception as e:
delay = self._sequence_manager.get_retry_delay(attempt_no)
logger.warning(
f"Caught exception when running inference from block {block_idx} "
f"(retry in {delay:.0f} sec): {repr(e)}"
)
logger.debug("See detailed traceback below:", exc_info=True)
time.sleep(delay)
self._position += n_input_tokens
return inputs return inputs
def close(self, *exc_details): def close(self, *exc_details):
"""Finish a given inference session, close the underlying connection""" """Finish a given inference session, close the underlying connection"""
if not self.closed: if not self._closed:
self.stack.__exit__(*exc_details or (None, None, None)) self._server_inputs.clear()
self.inference_sessions.clear() self._exit_server_sessions(self._server_sessions)
self.closed = True self._server_sessions.clear()
self._chosen_spans.clear()
self._closed = True
def __exit__(self, *exc_details): def __exit__(self, *exc_details):
self.close(*exc_details) self.close(*exc_details)

@ -8,7 +8,7 @@ from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from torch import nn from torch import nn
import src import src
from src.client.inference_session import RemoteSequentialInferenceSession from src.client.inference_session import InferenceSession
from src.client.sequence_manager import RemoteSequenceManager from src.client.sequence_manager import RemoteSequenceManager
from src.client.sequential_autograd import _RemoteSequentialAutogradFunction from src.client.sequential_autograd import _RemoteSequentialAutogradFunction
from src.data_structures import UID_DELIMITER from src.data_structures import UID_DELIMITER
@ -80,9 +80,9 @@ class RemoteSequential(nn.Module):
def __len__(self): def __len__(self):
return len(self.sequence_manager) return len(self.sequence_manager)
def inference_session(self, **kwargs) -> RemoteSequentialInferenceSession: def inference_session(self, **kwargs) -> InferenceSession:
self.sequence_manager.update_() self.sequence_manager.update_()
return RemoteSequentialInferenceSession(self.sequence_manager, self.p2p, **kwargs) return InferenceSession(self.sequence_manager, self.p2p, **kwargs)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}" return f"modules={self.sequence_manager.block_uids[0]}..{self.sequence_manager.block_uids[-1]}"

@ -160,3 +160,8 @@ class RemoteSequenceManager:
else: else:
logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True) logger.warning(f"Tried to call rpc_info, but caught {repr(e)}", exc_info=True)
return self._rpc_info return self._rpc_info
def get_retry_delay(self, attempt_no: int) -> float:
if attempt_no == 0:
return 0
return self.min_backoff * 2 ** (attempt_no - 1)

@ -3,11 +3,12 @@ A PyTorch autograd function that runs forward/backward on a sequence of remote s
""" """
import asyncio import asyncio
import itertools import itertools
import logging from collections import deque
from typing import List, Optional, Sequence, Tuple from typing import List, Optional, Sequence, Tuple
import torch import torch
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.utils.logging import get_logger
from src.client.remote_forward_backward import run_remote_backward, run_remote_forward from src.client.remote_forward_backward import run_remote_backward, run_remote_forward
from src.client.sequence_manager import RemoteSequenceManager from src.client.sequence_manager import RemoteSequenceManager
@ -15,6 +16,8 @@ from src.data_structures import CHAIN_DELIMITER, RemoteSpanInfo
from src.server.handler import TransformerConnectionHandler from src.server.handler import TransformerConnectionHandler
from src.utils.misc import DUMMY, is_dummy from src.utils.misc import DUMMY, is_dummy
logger = get_logger(__file__)
MAX_TOKENS_IN_BATCH = 1024 MAX_TOKENS_IN_BATCH = 1024
@ -39,19 +42,30 @@ async def sequential_forward(
sequence_manager.block_uids sequence_manager.block_uids
) # should be n_layers - 1 but add extra prompts for convenience ) # should be n_layers - 1 but add extra prompts for convenience
sequences = sequence_manager.make_sequence(start_index, end_index) sequences = deque()
intermediate_inputs = [] intermediate_inputs = []
done_sequences = [] done_sequences = []
outputs = inputs outputs = inputs
while len(sequences) > 0: block_idx = start_index
while block_idx < end_index:
for attempt_no in itertools.count(): for attempt_no in itertools.count():
span = sequences.pop(0) logger.debug(f"Forward: block {block_idx}, attempt {attempt_no}")
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try: try:
if attempt_no >= 1:
sequence_manager.update_()
if not sequences or attempt_no >= 1:
sequences = deque(sequence_manager.make_sequence(block_idx, end_index))
# make_sequence() could return a longer sequence
sequences[-1].end = min(sequences[-1].end, end_index)
logger.debug(f"Found path from block {block_idx} to {end_index} via {len(sequences)} servers")
span = sequences.popleft()
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
inputs_and_prompts = [inputs, prompts[span.start : span.end]] inputs_and_prompts = [inputs, prompts[span.start : span.end]]
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
(outputs,) = await run_remote_forward( (outputs,) = await run_remote_forward(
span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout span_uids, stub, sequence_manager.rpc_info, *inputs_and_prompts, timeout=sequence_manager.timeout
) )
@ -64,14 +78,16 @@ async def sequential_forward(
done_sequences.append(span) done_sequences.append(span)
inputs = outputs inputs = outputs
block_idx = span.end
break break
except Exception as e: except Exception as e:
logging.warning(f"Caught {e} when running forward for chain {span.start}-{span.end}", exc_info=True) delay = sequence_manager.get_retry_delay(attempt_no)
await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no) logger.warning(
f"Caught exception when running forward from block {block_idx} "
backup_sequences = sequence_manager.make_sequence(span.start) f"(retry in {delay:.0f} sec): {repr(e)}"
assert backup_sequences[0].start == span.start )
sequences = backup_sequences logger.debug("See detailed traceback below:", exc_info=True)
await asyncio.sleep(delay)
return outputs, intermediate_inputs, done_sequences return outputs, intermediate_inputs, done_sequences
@ -91,11 +107,26 @@ async def sequential_backward(
grad_prompts_reversed = [] grad_prompts_reversed = []
while len(forward_sequences) > 0 and len(intermediate_inputs) > 0: while len(forward_sequences) > 0 and len(intermediate_inputs) > 0:
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
for attempt_no in itertools.count(): for attempt_no in itertools.count():
inputs = intermediate_inputs.pop(-1) logger.debug(f"Backward: block {span.end - 1}, attempt {attempt_no}")
span = forward_sequences.pop(-1)
span_uids: str = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
try: try:
if attempt_no >= 1:
sequence_manager.update_()
_, backup_inputs, backup_sequences = await sequential_forward(
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
)
assert len(backup_inputs) == len(backup_sequences)
assert backup_sequences[0].start == span.start
assert backup_sequences[-1].end == span.end
intermediate_inputs.extend(backup_inputs)
forward_sequences.extend(backup_sequences)
inputs = intermediate_inputs.pop()
span = forward_sequences.pop()
span_uids = CHAIN_DELIMITER.join(sequence_manager.block_uids[span.start : span.end])
stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id) stub = TransformerConnectionHandler.get_stub(sequence_manager.p2p, span.peer_id)
grad_outputs, *span_grad_prompts = await run_remote_backward( grad_outputs, *span_grad_prompts = await run_remote_backward(
span_uids, span_uids,
@ -110,18 +141,13 @@ async def sequential_backward(
grad_prompts_reversed.extend(span_grad_prompts) grad_prompts_reversed.extend(span_grad_prompts)
break break
except Exception as e: except Exception as e:
logging.warning(f"Caught {e} when running backward for chain {span.start}-{span.end}", exc_info=True) delay = sequence_manager.get_retry_delay(attempt_no)
await asyncio.sleep(sequence_manager.min_backoff * 2**attempt_no) logger.warning(
f"Caught exception when running backward between blocks {span.start}-{span.end} "
_, backup_intermediate_inputs, backup_forward_sequences = await sequential_forward( f"(retry in {delay:.0f} sec): {repr(e)}"
inputs, prompts, sequence_manager, start_index=span.start, end_index=span.end
) )
assert len(intermediate_inputs) == len(forward_sequences) logger.debug("See detailed traceback below:", exc_info=True)
assert backup_forward_sequences[0].start == span.start await asyncio.sleep(delay)
assert backup_forward_sequences[-1].end == span.end
forward_sequences.extend(backup_forward_sequences)
intermediate_inputs.extend(backup_intermediate_inputs)
# For now, we do not support mixed dummy and grad prompts # For now, we do not support mixed dummy and grad prompts
# Concat in num_layer dimension # Concat in num_layer dimension

@ -33,7 +33,7 @@ def test_remote_block_exact_match(atol_forward=1e-5, atol_inference=1e-3):
outputs_inference.append(sess.step(inputs[:, i : i + 1, :])) outputs_inference.append(sess.step(inputs[:, i : i + 1, :]))
# test that max length is respected # test that max length is respected
with pytest.raises(P2PHandlerError) as exc_info: with pytest.raises(ValueError, match=r"Maximum length exceeded") as exc_info:
sess.step(inputs[:, -1:, :]) sess.step(inputs[:, -1:, :])
assert "Maximum length exceeded" in repr(exc_info.value) assert "Maximum length exceeded" in repr(exc_info.value)

Loading…
Cancel
Save