diff --git a/src/petals/__init__.py b/src/petals/__init__.py index 4e4a9d0..27076ba 100644 --- a/src/petals/__init__.py +++ b/src/petals/__init__.py @@ -24,6 +24,9 @@ if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"): assert ( version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0") ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0" + assert version.parse("1.1.10") <= version.parse( + hivemind.__version__ + ), "Please install a proper hivemind version: pip install hivemind>=1.1.10" def _override_bfloat16_mode_default(): diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 9ae672b..db4faac 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -221,7 +221,7 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend]) first_pool = next(iter(backends.values())).inference_pool merged_inference_func = _MergedInferenceStep(backends) merged_pool = PrioritizedTaskPool( - lambda args, kwargs: merged_inference_func(*args, **kwargs), + merged_inference_func, max_batch_size=first_pool.max_batch_size, device=first_pool.device, name=f"merged_inference", diff --git a/src/petals/server/server.py b/src/petals/server/server.py index c85108a..e8e3d59 100644 --- a/src/petals/server/server.py +++ b/src/petals/server/server.py @@ -8,7 +8,7 @@ import random import sys import threading import time -from typing import Dict, List, Optional, Sequence, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import hivemind import psutil @@ -17,6 +17,7 @@ import torch.mps from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.runtime import Runtime +from hivemind.moe.server.task_pool import TaskPoolBase from hivemind.proto.runtime_pb2 import CompressionType from hivemind.utils.logging import get_logger from transformers import PretrainedConfig @@ -778,7 +779,7 @@ class RuntimeWithDeduplicatedPools(Runtime): outputs = pool.process_func(*args, **kwargs) batch_size = 1 for arg in args: - if isintance(arg, torch.Tensor) and arg.ndim > 2: + if isinstance(arg, torch.Tensor) and arg.ndim > 2: batch_size = arg.shape[0] * arg.shape[1] break return outputs, batch_size