minimize diff

This commit is contained in:
Your Name 2023-09-05 14:05:41 +03:00
parent 6c7f762379
commit cc4fe17a99
3 changed files with 7 additions and 3 deletions

View File

@ -24,6 +24,9 @@ if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert ( assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0") version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0" ), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
assert version.parse("1.1.10") <= version.parse(
hivemind.__version__
), "Please install a proper hivemind version: pip install hivemind>=1.1.10"
def _override_bfloat16_mode_default(): def _override_bfloat16_mode_default():

View File

@ -221,7 +221,7 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend])
first_pool = next(iter(backends.values())).inference_pool first_pool = next(iter(backends.values())).inference_pool
merged_inference_func = _MergedInferenceStep(backends) merged_inference_func = _MergedInferenceStep(backends)
merged_pool = PrioritizedTaskPool( merged_pool = PrioritizedTaskPool(
lambda args, kwargs: merged_inference_func(*args, **kwargs), merged_inference_func,
max_batch_size=first_pool.max_batch_size, max_batch_size=first_pool.max_batch_size,
device=first_pool.device, device=first_pool.device,
name=f"merged_inference", name=f"merged_inference",

View File

@ -8,7 +8,7 @@ import random
import sys import sys
import threading import threading
import time import time
from typing import Dict, List, Optional, Sequence, Union from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import hivemind import hivemind
import psutil import psutil
@ -17,6 +17,7 @@ import torch.mps
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
from hivemind.moe.server.layers import add_custom_models_from_file from hivemind.moe.server.layers import add_custom_models_from_file
from hivemind.moe.server.runtime import Runtime from hivemind.moe.server.runtime import Runtime
from hivemind.moe.server.task_pool import TaskPoolBase
from hivemind.proto.runtime_pb2 import CompressionType from hivemind.proto.runtime_pb2 import CompressionType
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -778,7 +779,7 @@ class RuntimeWithDeduplicatedPools(Runtime):
outputs = pool.process_func(*args, **kwargs) outputs = pool.process_func(*args, **kwargs)
batch_size = 1 batch_size = 1
for arg in args: for arg in args:
if isintance(arg, torch.Tensor) and arg.ndim > 2: if isinstance(arg, torch.Tensor) and arg.ndim > 2:
batch_size = arg.shape[0] * arg.shape[1] batch_size = arg.shape[0] * arg.shape[1]
break break
return outputs, batch_size return outputs, batch_size