mirror of
https://github.com/bigscience-workshop/petals
synced 2024-10-31 09:20:41 +00:00
minimize diff
This commit is contained in:
parent
6c7f762379
commit
cc4fe17a99
@ -24,6 +24,9 @@ if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
|
|||||||
assert (
|
assert (
|
||||||
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
|
||||||
), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
|
), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
|
||||||
|
assert version.parse("1.1.10") <= version.parse(
|
||||||
|
hivemind.__version__
|
||||||
|
), "Please install a proper hivemind version: pip install hivemind>=1.1.10"
|
||||||
|
|
||||||
|
|
||||||
def _override_bfloat16_mode_default():
|
def _override_bfloat16_mode_default():
|
||||||
|
@ -221,7 +221,7 @@ def merge_inference_pools_inplace(backends: Dict[ExpertUID, TransformerBackend])
|
|||||||
first_pool = next(iter(backends.values())).inference_pool
|
first_pool = next(iter(backends.values())).inference_pool
|
||||||
merged_inference_func = _MergedInferenceStep(backends)
|
merged_inference_func = _MergedInferenceStep(backends)
|
||||||
merged_pool = PrioritizedTaskPool(
|
merged_pool = PrioritizedTaskPool(
|
||||||
lambda args, kwargs: merged_inference_func(*args, **kwargs),
|
merged_inference_func,
|
||||||
max_batch_size=first_pool.max_batch_size,
|
max_batch_size=first_pool.max_batch_size,
|
||||||
device=first_pool.device,
|
device=first_pool.device,
|
||||||
name=f"merged_inference",
|
name=f"merged_inference",
|
||||||
|
@ -8,7 +8,7 @@ import random
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Dict, List, Optional, Sequence, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import hivemind
|
import hivemind
|
||||||
import psutil
|
import psutil
|
||||||
@ -17,6 +17,7 @@ import torch.mps
|
|||||||
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
||||||
from hivemind.moe.server.layers import add_custom_models_from_file
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
||||||
from hivemind.moe.server.runtime import Runtime
|
from hivemind.moe.server.runtime import Runtime
|
||||||
|
from hivemind.moe.server.task_pool import TaskPoolBase
|
||||||
from hivemind.proto.runtime_pb2 import CompressionType
|
from hivemind.proto.runtime_pb2 import CompressionType
|
||||||
from hivemind.utils.logging import get_logger
|
from hivemind.utils.logging import get_logger
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -778,7 +779,7 @@ class RuntimeWithDeduplicatedPools(Runtime):
|
|||||||
outputs = pool.process_func(*args, **kwargs)
|
outputs = pool.process_func(*args, **kwargs)
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isintance(arg, torch.Tensor) and arg.ndim > 2:
|
if isinstance(arg, torch.Tensor) and arg.ndim > 2:
|
||||||
batch_size = arg.shape[0] * arg.shape[1]
|
batch_size = arg.shape[0] * arg.shape[1]
|
||||||
break
|
break
|
||||||
return outputs, batch_size
|
return outputs, batch_size
|
||||||
|
Loading…
Reference in New Issue
Block a user