From 9faf08b8984a07f0180be617a10bab4e767d1ac5 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Thu, 1 Dec 2022 04:47:05 +0100 Subject: [PATCH] Remove unused imports, add missing arguments to docstrings (#108) * Remove unused imports, add missing arguments to docstrings --- src/petals/bloom/model.py | 1 - src/petals/bloom/ops.py | 4 ---- src/petals/client/remote_generation.py | 1 + src/petals/client/sequential_autograd.py | 1 - src/petals/server/backend.py | 2 +- src/petals/server/handler.py | 2 +- src/petals/server/task_pool.py | 2 +- src/petals/server/task_prioritizer.py | 1 - 8 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/petals/bloom/model.py b/src/petals/bloom/model.py index 08f7713..687d765 100644 --- a/src/petals/bloom/model.py +++ b/src/petals/bloom/model.py @@ -21,7 +21,6 @@ from transformers.modeling_outputs import ( CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, ) -from transformers.modeling_utils import PreTrainedModel from transformers.models.bloom.configuration_bloom import BloomConfig from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel from transformers.utils import logging diff --git a/src/petals/bloom/ops.py b/src/petals/bloom/ops.py index b84c7c1..8e8f138 100644 --- a/src/petals/bloom/ops.py +++ b/src/petals/bloom/ops.py @@ -196,10 +196,6 @@ class BloomScaledSoftmax(nn.Module): fused operation: scaling + mask + softmax Args: - input_in_fp16 (`bool`, *required*): - flag to indicate if input in fp16 data format. - input_in_bf16 (`bool`, *required*): - flag to indicate if input in bf16 data format. scaled_masked_softmax_fusion (`bool`, *required*): flag to indicate user want to use softmax fusion mask_func (`function`, *required*): diff --git a/src/petals/client/remote_generation.py b/src/petals/client/remote_generation.py index c33ab0b..92c8a03 100644 --- a/src/petals/client/remote_generation.py +++ b/src/petals/client/remote_generation.py @@ -57,6 +57,7 @@ class RemoteGenerationMixin: :param bos_token_id: The id of the beginning of sentence token. :param eos_token_id: The id of the end of sentence token. :param pad_token_id: The id of the padding token. + :param max_length: The maximum number of tokens in the output (including input tokens). :param max_new_tokens: The maximum number of tokens to generate. :param decoding_algorithm: The decoding algorithm to use. :param provided_constraints: A list of constraints to use. diff --git a/src/petals/client/sequential_autograd.py b/src/petals/client/sequential_autograd.py index 7dc7116..49a090f 100644 --- a/src/petals/client/sequential_autograd.py +++ b/src/petals/client/sequential_autograd.py @@ -51,7 +51,6 @@ async def sequential_forward( sequences = deque() intermediate_inputs = [] done_sequences = [] - outputs = inputs block_idx = start_index while block_idx < end_index: diff --git a/src/petals/server/backend.py b/src/petals/server/backend.py index 2f7ace9..c29851c 100644 --- a/src/petals/server/backend.py +++ b/src/petals/server/backend.py @@ -1,5 +1,5 @@ """Code for serving bloom blocks via hivemind-server""" -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Dict, Sequence, Tuple import torch from hivemind import BatchTensorDescriptor, use_hivemind_log_handler diff --git a/src/petals/server/handler.py b/src/petals/server/handler.py index 3c57fff..1613c19 100644 --- a/src/petals/server/handler.py +++ b/src/petals/server/handler.py @@ -17,7 +17,7 @@ from hivemind import ( from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE from hivemind.proto import runtime_pb2 -from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter +from hivemind.utils.asyncio import amap_in_executor, anext from hivemind.utils.logging import get_logger from hivemind.utils.streaming import split_for_streaming diff --git a/src/petals/server/task_pool.py b/src/petals/server/task_pool.py index 672248f..41c9c15 100644 --- a/src/petals/server/task_pool.py +++ b/src/petals/server/task_pool.py @@ -4,7 +4,7 @@ import threading import time from dataclasses import dataclass, field from queue import PriorityQueue -from typing import Any, Generator, List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple import torch from hivemind import MPFuture, get_logger, use_hivemind_log_handler diff --git a/src/petals/server/task_prioritizer.py b/src/petals/server/task_prioritizer.py index 6e3b886..5aba88c 100644 --- a/src/petals/server/task_prioritizer.py +++ b/src/petals/server/task_prioritizer.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod import torch -from hivemind.moe.server.task_pool import Task class TaskPrioritizerBase(ABC):