minimize diff

partial_rollback
Your Name 8 months ago
parent 6c7f762379
commit cc4fe17a99

@ -24,6 +24,9 @@ if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
assert version.parse("1.1.10") <= version.parse(
hivemind.__version__
), "Please install a proper hivemind version: pip install hivemind>=1.1.10"
def _override_bfloat16_mode_default():

@ -221,7 +221,7 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend])
first_pool = next(iter(backends.values())).inference_pool
merged_inference_func = _MergedInferenceStep(backends)
merged_pool = PrioritizedTaskPool(
lambda args, kwargs: merged_inference_func(*args, **kwargs),
merged_inference_func,
max_batch_size=first_pool.max_batch_size,
device=first_pool.device,
name=f"merged_inference",

@ -8,7 +8,7 @@ import random
import sys
import threading
import time
from typing import Dict, List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import hivemind
import psutil
@ -17,6 +17,7 @@ import torch.mps
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig
@ -778,7 +779,7 @@ class RuntimeWithDeduplicatedPools(Runtime):
outputs = pool.process_func(*args, **kwargs)
batch_size = 1
for arg in args:
if isintance(arg, torch.Tensor) and arg.ndim > 2:
if isinstance(arg, torch.Tensor) and arg.ndim > 2:
batch_size = arg.shape[0] * arg.shape[1]
break
return outputs, batch_size

Loading…
Cancel
Save