Add various server timeouts, lower --max_batch_size and --inference_max_length defaults (#97)

fix-ptune
Alexander Borzunov 1 year ago committed by GitHub
parent d8ef09146e
commit c6e1b5a8e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -39,13 +39,13 @@ def main():
help='server will use this many processes to handle incoming requests')
parser.add_argument('--min_batch_size', type=int, default=1,
help='Minimum required batch size for all operations (in total tokens)')
parser.add_argument('--max_batch_size', type=int, default=16384,
parser.add_argument('--max_batch_size', type=int, default=2048,
help='The total number of tokens in the same batch will not exceed this value')
parser.add_argument('--prefetch_batches', type=int, default=1, required=False,
help='Pre-form this many subsequent batches while GPU is processing the current one')
parser.add_argument('--sender_threads', type=int, default=1, required=False,
help='Use this many threads to pass results/exceptions from Runtime to Pools')
parser.add_argument('--inference_max_length', type=int, default=16384,
parser.add_argument('--inference_max_length', type=int, default=2048,
help='Maximum total sequence length permitted per inference, defaults to 16384 tokens')
parser.add_argument('--cache_dir', type=str, default=None,
help='Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used.')
@ -57,6 +57,9 @@ def main():
parser.add_argument('--attn_cache_size', type=str, default=None,
help='The size of GPU memory allocated for storing past attention keys/values between inference'
' steps; examples: 500MB or 1.2GB or 1073741824 (bytes); be warned: 1KB != 1KiB')
parser.add_argument('--alloc_timeout', type=float, default=60,
help='If the cache is full, the server will wait for this number of seconds hoping that some memory will be freed '
'before rejecting the request')
parser.add_argument('--revision', type=str, default='main',
help="The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models"
"and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.")
@ -72,6 +75,12 @@ def main():
help='Server will report blocks to DHT once in this many seconds')
parser.add_argument('--expiration', type=float, required=False, default=None,
help='DHT entries will expire after this many seconds')
parser.add_argument('--request_timeout', type=float, required=False, default=3 * 60,
help='Timeout for the whole rpc_forward/rpc_backward/rpc_forward_stream/rpc_backward_stream request')
parser.add_argument('--session_timeout', type=float, required=False, default=30 * 60,
help='Timeout for the whole inference session')
parser.add_argument('--step_timeout', type=float, required=False, default=5 * 60,
help="Timeout for waiting the next step's inputs inside an inference session")
group = parser.add_mutually_exclusive_group()
group.add_argument('--initial_peers', type=str, nargs='*', required=False, default=PUBLIC_INITIAL_PEERS,

@ -6,3 +6,4 @@ transformers==4.21.3
protobuf>=3.20.3,<4.0dev
git+https://github.com/learning-at-home/hivemind@be88b4280cdd87432168e1da238e532f1364078b
humanfriendly
async-timeout>=4.0.2

@ -26,8 +26,9 @@ Handle = int
class MemoryCache:
"""A shared cache for storing tensors that persist across calls. Main use case: storing past attention KVs"""
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int]):
def __init__(self, device: Union[str, torch.device], max_size_bytes: Optional[int], alloc_timeout: float):
self.max_size_bytes = max_size_bytes if max_size_bytes is not None else (2**64 - 1)
self.alloc_timeout = alloc_timeout
self.device = device
self._lock_metadata, self.size_decreased_event = mp.Lock(), mp.Event()
self._current_size = mp.Value(ctypes.c_int64, 0, lock=False)
@ -75,7 +76,9 @@ class MemoryCache:
try:
async with hivemind.utils.enter_asynchronously(self._lock_acquire_memory):
if self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
await loop.run_in_executor(None, self._wait_until_available, allocated_size_bytes)
await loop.run_in_executor(
None, self._wait_until_available, allocated_size_bytes, timeout=self.alloc_timeout
)
async with hivemind.utils.enter_asynchronously(self._lock_metadata):
allocated_handle = int(self.handle_counter)
self.current_size_bytes += allocated_size_bytes
@ -92,17 +95,19 @@ class MemoryCache:
self.current_size_bytes -= allocated_size_bytes
self._memory_freed_event.set()
def _wait_until_available(self, allocated_size_bytes: int, timeout: Optional[float] = None):
def _wait_until_available(self, allocated_size: int, timeout: Optional[float] = None):
# note: this function should only be called inside _lock_acquire_memory!
if allocated_size_bytes > self.max_size_bytes:
if allocated_size > self.max_size_bytes:
raise AllocationFailed(
f"Could not allocate {allocated_size_bytes} bytes, max cache size = {self.max_size_bytes} bytes"
f"Could not allocate {allocated_size} bytes, max cache size = {self.max_size_bytes} bytes"
)
deadline = None if timeout is None else time.perf_counter() + timeout
while self.current_size_bytes + allocated_size_bytes > self.max_size_bytes:
while self.current_size_bytes + allocated_size > self.max_size_bytes:
remaining_time = deadline - time.perf_counter() if timeout is not None else None
if not self._memory_freed_event.wait(remaining_time):
raise AllocationFailed(f"Could not allocate {allocated_size_bytes} bytes in {timeout} seconds")
raise AllocationFailed(
f"Server's attention cache is full, failed to allocate {allocated_size} bytes in {timeout} seconds"
)
self._memory_freed_event.clear()
@contextlib.contextmanager

@ -1,7 +1,9 @@
import asyncio
import contextlib
from typing import AsyncIterator, Dict, Iterable, List, Sequence, Tuple, Union
import torch
from async_timeout import timeout
from hivemind import (
DHT,
MSGPackSerializer,
@ -37,13 +39,19 @@ class TransformerConnectionHandler(ConnectionHandler):
self,
dht: DHT,
module_backends: Dict[str, TransformerBackend],
*,
inference_max_length: int,
request_timeout: float,
session_timeout: float,
step_timeout: float,
task_prioritizer: TaskPrioritizerBase = DummyTaskPrioritizer(),
):
super().__init__(dht, module_backends)
for module_backend in self.module_backends.values():
assert isinstance(module_backend, TransformerBackend)
self.inference_max_length = inference_max_length
self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self._prioritizer = task_prioritizer
async def _gather_inputs(
@ -76,227 +84,240 @@ class TransformerConnectionHandler(ConnectionHandler):
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
"""Compute a single step of inference using attention cache; update attention cache accordingly."""
request = await anext(requests)
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_inference.open", requested_uids, context)
try:
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
points = metadata.get("points", 0)
async with timeout(self.session_timeout):
request = await asyncio.wait_for(anext(requests), self.step_timeout)
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_inference.open", requested_uids, context)
try:
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
max_length = metadata.get("max_length")
points = metadata.get("points", 0)
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(
max_length, int
), f"rpc_inference metadata must contain int max_length, got {max_length}"
assert isinstance(
points, (float, int)
), f"rpc_inference should have number of points as a number or None, got {points}"
if not 0 <= max_length <= self.inference_max_length:
raise ValueError(
f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}"
)
if not requested_uids:
raise ValueError("User must specify at least one block for inference, but got none")
assert isinstance(max_length, int), f"rpc_inference metadata must contain int max_length, got {max_length}"
assert isinstance(
points, (float, int)
), f"rpc_inference should have number of points as a number or None, got {points}"
if not 0 <= max_length <= self.inference_max_length:
raise ValueError(f"Cannot allocate KV cache for {max_length} tokens, max = {self.inference_max_length}")
point_per_piece = points / max_length if max_length > 0 else 0.0
batch_size = request.tensors[0].size[0] if request.tensors else 1
cache_metadata = torch.tensor(
[[-1, -1] for _ in range(batch_size)], dtype=torch.int64
) # [cache_handle, prefix_length]
prefix_length = 0
async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
while request.tensors: # iterate while user is willing to supply tensors
hidden_states, prompts, hypo_ids = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
if prompts is None or is_dummy(prompts) or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
point_per_piece = points / max_length if max_length > 0 else 0.0
batch_size = request.tensors[0].size[0] if request.tensors else 1
# run request tensors through all requested modules, update caches
for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
assert isinstance(
backend.inference_pool, PrioritizedTaskPool
), "petals support only prioritized pools"
priority = self._prioritizer.prioritize(
cache_metadata,
hidden_states,
hypo_ids,
points=point_per_piece / len(requested_backends),
backend=backend,
type="inference",
)
(hidden_states,) = await backend.inference_pool.submit_task(
cache_metadata, hidden_states, hypo_ids, priority=priority
)
cache_metadata = torch.tensor(
[[-1, -1] for _ in range(batch_size)], dtype=torch.int64
) # [cache_handle, prefix_length]
prefix_length = 0
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
async with self._allocate_caches(requested_backends, batch_size, max_length) as cache_handles:
assert len(cache_handles) == len(requested_backends)
while request.tensors: # iterate while user is willing to supply tensors
hidden_states, prompts, hypo_ids = [
deserialize_torch_tensor(tensor) for tensor in request.tensors
]
)
# prepare for next step
prefix_length += hidden_states.shape[1]
request = await (anext(requests))
finally:
self._log_request("rpc_inference.close", requested_uids, context)
# Cast inputs to backend dtype
hidden_states = hidden_states.to(requested_backends[0].dtype)
assert hypo_ids.dtype == torch.int64, f"hypo ids must be int64, got {hypo_ids.dtype}"
# parse deep prompts (optional argument)
if prompts is None or is_dummy(prompts) or is_dummy(prompts):
prompts = [DUMMY] * len(requested_backends)
else:
prompts = [p.squeeze(0) for p in prompts.to(requested_backends[0].dtype).split(1, dim=0)]
if not (len(requested_backends) == len(prompts)):
raise ValueError(f"Received {len(prompts)} prompts for {len(requested_backends)} backends")
length_increment = hidden_states.shape[1] # how many tokens are added this step (in each seq)
if prefix_length + length_increment > max_length:
raise ValueError(
f"Maximum length exceeded: prefix {prefix_length} + current {length_increment}"
f" exceeds pre-allocated maximum {max_length}"
)
# run request tensors through all requested modules, update caches
for backend, prompt, cache_handle in zip(requested_backends, prompts, cache_handles):
if not is_dummy(prompt):
hidden_states[:, : prompt.shape[1]] += prompt
cache_metadata[:, 0], cache_metadata[:, 1] = cache_handle, prefix_length
assert isinstance(
hidden_states, torch.Tensor
), f"hidden states must be tensor, got {type(hidden_states)}"
assert (
hidden_states.ndim == 3
), f"inputs to {type(backend)} must be a list with a single 3d tensor of hidden states"
assert isinstance(
backend.inference_pool, PrioritizedTaskPool
), "petals support only prioritized pools"
priority = self._prioritizer.prioritize(
cache_metadata,
hidden_states,
hypo_ids,
points=point_per_piece / len(requested_backends),
backend=backend,
type="inference",
)
(hidden_states,) = await backend.inference_pool.submit_task(
cache_metadata, hidden_states, hypo_ids, priority=priority
)
# serialize and send last layer outputs
yield runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(
(hidden_states,), nested_flatten(requested_backends[-1].outputs_schema)
)
]
)
# prepare for next step
prefix_length += hidden_states.shape[1]
request = await asyncio.wait_for(anext(requests), self.step_timeout)
finally:
self._log_request("rpc_inference.close", requested_uids, context)
async def rpc_forward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
# Parse request and prepare backends
flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_forward", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
async with timeout(self.request_timeout):
# Parse request and prepare backends
flat_inputs = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_forward", requested_uids, context)
# Serialize output and respond to client
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_forward should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
# Serialize output and respond to client
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
)
async def rpc_forward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertRequest]:
# Parse requests and prepare backends
uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uid_str)
self._log_request("rpc_forward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
assert isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3, "hidden_states must be a 3d tensor"
async with timeout(self.request_timeout):
# Parse requests and prepare backends
uid_str, flat_inputs, metadata = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uid_str)
self._log_request("rpc_forward_stream", requested_uids, context)
# Serialize the overall output
serialized_output = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_forward_stream should have number of points as number or None, got {points}"
# Split the serialized_output for streaming and respond to client
output_split = [
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
hidden_states = await _rpc_forward(
*flat_inputs, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
assert (
isinstance(hidden_states, torch.Tensor) and hidden_states.ndim == 3
), "hidden_states must be a 3d tensor"
# Serialize the overall output
serialized_output = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip((hidden_states,), nested_flatten(requested_backends[-1].outputs_schema))
]
# Split the serialized_output for streaming and respond to client
output_split = [
part for tensor in serialized_output for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
async def rpc_backward(self, request: runtime_pb2.ExpertRequest, context: P2PContext) -> runtime_pb2.ExpertResponse:
# Parse requests and prepare backends
flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_backward", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
async with timeout(self.request_timeout):
# Parse requests and prepare backends
flat_tensors = [deserialize_torch_tensor(tensor) for tensor in request.tensors]
requested_uids = self._check_uids(request.uid)
self._log_request("rpc_backward", requested_uids, context)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
metadata = MSGPackSerializer.loads(request.metadata) if request.metadata else {}
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_backward should have number of points as number or None, got {points}"
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
) # TODO generalize
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
# Serialize the overall grad_input and respond
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
) # TODO generalize
# Serialize the overall grad_input and respond
return runtime_pb2.ExpertResponse(
tensors=[
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
)
async def rpc_backward_stream(
self, requests: AsyncIterator[runtime_pb2.ExpertRequest], context: P2PContext
) -> AsyncIterator[runtime_pb2.ExpertResponse]:
uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uids_header)
self._log_request("rpc_backward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
async with timeout(self.request_timeout):
uids_header, flat_tensors, metadata = await self._gather_inputs(requests, context)
requested_uids = self._check_uids(uids_header)
self._log_request("rpc_backward_stream", requested_uids, context)
requested_backends = tuple(self.module_backends[uid] for uid in requested_uids)
points = metadata.get("points", 0)
assert isinstance(
points, (float, int)
), f"rpc_backward_stream should have number of points as number or None, got {points}"
grads = await _rpc_backward(
*flat_tensors, requested_backends=requested_backends, prioritizer=self._prioritizer, points=points
)
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
) # TODO generalize
# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
# Modify grad_inputs_schema to support grad_prompts
assert len(requested_backends[0].args_schema) == 1 and len(grads) in (1, 2) # TODO generalize
grad_inputs_schema_with_prompts = (
requested_backends[0].args_schema * len(grads),
requested_backends[0].kwargs_schema,
) # TODO generalize
# Serialize the overall grad_inputs
serialized_grad_inputs = [
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
for result, proto in zip(grads, nested_flatten(grad_inputs_schema_with_prompts))
]
# Split the serialized_grad_inputs for streaming and respond
output_split = [
part for tensor in serialized_grad_inputs for part in split_for_streaming(tensor, DEFAULT_MAX_MSG_SIZE)
]
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
async for part in as_aiter(*output_split):
yield runtime_pb2.ExpertResponse(tensors=[part])
def _check_uids(self, uids: str) -> Sequence[ModuleUID]:
"""Check that the first request to rpc_inference is valid"""

@ -49,18 +49,22 @@ class Server:
block_indices: Optional[str] = None,
num_handlers: int = 8,
min_batch_size: int = 1,
max_batch_size: int = 4096,
inference_max_length: int = 4096,
max_batch_size: int = 2048,
inference_max_length: int = 2048,
torch_dtype: str = "auto",
revision: str = "main",
cache_dir: Optional[str] = None,
attn_cache_size: Optional[int] = None,
alloc_timeout: float = 60,
device: Optional[Union[str, torch.device]] = None,
compression=CompressionType.NONE,
stats_report_interval: Optional[int] = None,
custom_module_path=None,
update_period: float = 30,
expiration: Optional[float] = None,
request_timeout: float = 3 * 60,
session_timeout: float = 30 * 60,
step_timeout: float = 5 * 60,
prefetch_batches: int = 1,
sender_threads: int = 1,
balance_quality: float = 0.75,
@ -100,6 +104,9 @@ class Server:
expiration = max(2 * update_period, MAX_DHT_TIME_DISCREPANCY_SECONDS)
self.expiration = expiration
self.request_timeout = request_timeout
self.session_timeout, self.step_timeout = session_timeout, step_timeout
self.dht = DHT(initial_peers=initial_peers, start=True, **kwargs)
visible_maddrs_str = [str(a) for a in self.dht.get_visible_maddrs()]
if initial_peers == PUBLIC_INITIAL_PEERS:
@ -110,7 +117,7 @@ class Server:
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.memory_cache = MemoryCache(device, attn_cache_size)
self.memory_cache = MemoryCache(device, attn_cache_size, alloc_timeout)
assert isinstance(throughput, float) or throughput in ["auto", "eval"]
if throughput in ["auto", "eval"]:
@ -167,6 +174,9 @@ class Server:
stats_report_interval=self.stats_report_interval,
update_period=self.update_period,
expiration=self.expiration,
request_timeout=self.request_timeout,
session_timeout=self.session_timeout,
step_timeout=self.step_timeout,
prefetch_batches=self.prefetch_batches,
sender_threads=self.sender_threads,
use_auth_token=self.use_auth_token,
@ -238,22 +248,17 @@ class ModuleContainer(threading.Thread):
memory_cache: MemoryCache,
throughput: float,
block_indices: List[int],
num_handlers: Optional[int],
min_batch_size: int,
max_batch_size: int,
inference_max_length: int,
torch_dtype: torch.dtype,
cache_dir: Optional[str],
device: Union[str, torch.device],
compression: CompressionType,
stats_report_interval: Optional[int],
update_period: float,
expiration: Optional[float],
prefetch_batches: int,
sender_threads: int,
use_auth_token: Optional[str],
load_in_8bit: bool,
start: bool,
**kwargs,
) -> ModuleContainer:
module_uids = [f"{prefix}.{block_index}" for block_index in block_indices]
joining_announcer = ModuleAnnouncerThread(
@ -327,15 +332,10 @@ class ModuleContainer(threading.Thread):
dht,
blocks,
throughput=throughput,
num_connection_handlers=num_handlers,
inference_max_length=inference_max_length,
device=device,
stats_report_interval=stats_report_interval,
update_period=update_period,
expiration=expiration,
prefetch_batches=prefetch_batches,
sender_threads=sender_threads,
start=start,
**kwargs,
)
def __init__(
@ -344,10 +344,13 @@ class ModuleContainer(threading.Thread):
module_backends: Dict[str, TransformerBackend],
*,
inference_max_length: int,
num_connection_handlers: int,
num_handlers: int,
throughput: float,
update_period: float,
expiration: Optional[float] = None,
request_timeout: float,
session_timeout: float,
step_timeout: float,
start: bool,
**kwargs,
):
@ -356,8 +359,15 @@ class ModuleContainer(threading.Thread):
self.dht, self.module_backends = dht, module_backends
self.throughput, self.update_period, self.expiration = throughput, update_period, expiration
self.conn_handlers = [
TransformerConnectionHandler(dht, self.module_backends, inference_max_length)
for _ in range(num_connection_handlers)
TransformerConnectionHandler(
dht,
self.module_backends,
inference_max_length=inference_max_length,
request_timeout=request_timeout,
session_timeout=session_timeout,
step_timeout=step_timeout,
)
for _ in range(num_handlers)
]
self.runtime = Runtime(self.module_backends, **kwargs)
self.online_announcer = ModuleAnnouncerThread(

@ -3,7 +3,7 @@ import os
import bitsandbytes as bnb
import torch
PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 0)))
PETALS_8BIT_BACKWARD = bool(int(os.environ.get("PETALS_8BIT_BACKWARD", 1)))
def replace_8bit_linear(model, threshold=6.0):

Loading…
Cancel
Save