|
|
|
@ -8,7 +8,7 @@ import random
|
|
|
|
|
import sys
|
|
|
|
|
import threading
|
|
|
|
|
import time
|
|
|
|
|
from typing import Dict, List, Optional, Sequence, Union
|
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import hivemind
|
|
|
|
|
import psutil
|
|
|
|
@ -17,6 +17,7 @@ import torch.mps
|
|
|
|
|
from hivemind import DHT, MAX_DHT_TIME_DISCREPANCY_SECONDS, BatchTensorDescriptor, get_dht_time
|
|
|
|
|
from hivemind.moe.server.layers import add_custom_models_from_file
|
|
|
|
|
from hivemind.moe.server.runtime import Runtime
|
|
|
|
|
from hivemind.moe.server.task_pool import TaskPoolBase
|
|
|
|
|
from hivemind.proto.runtime_pb2 import CompressionType
|
|
|
|
|
from hivemind.utils.logging import get_logger
|
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
@ -778,7 +779,7 @@ class RuntimeWithDeduplicatedPools(Runtime):
|
|
|
|
|
outputs = pool.process_func(*args, **kwargs)
|
|
|
|
|
batch_size = 1
|
|
|
|
|
for arg in args:
|
|
|
|
|
if isintance(arg, torch.Tensor) and arg.ndim > 2:
|
|
|
|
|
if isinstance(arg, torch.Tensor) and arg.ndim > 2:
|
|
|
|
|
batch_size = arg.shape[0] * arg.shape[1]
|
|
|
|
|
break
|
|
|
|
|
return outputs, batch_size
|
|
|
|
|