Remove unused imports, add missing arguments to docstrings (#108)

* Remove unused imports, add missing arguments to docstrings
pull/105/head^2
Max Ryabinin 2 years ago committed by GitHub
parent b3115dac58
commit 9faf08b898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,7 +21,6 @@ from transformers.modeling_outputs import (
CausalLMOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
from transformers.utils import logging from transformers.utils import logging

@ -196,10 +196,6 @@ class BloomScaledSoftmax(nn.Module):
fused operation: scaling + mask + softmax fused operation: scaling + mask + softmax
Args: Args:
input_in_fp16 (`bool`, *required*):
flag to indicate if input in fp16 data format.
input_in_bf16 (`bool`, *required*):
flag to indicate if input in bf16 data format.
scaled_masked_softmax_fusion (`bool`, *required*): scaled_masked_softmax_fusion (`bool`, *required*):
flag to indicate user want to use softmax fusion flag to indicate user want to use softmax fusion
mask_func (`function`, *required*): mask_func (`function`, *required*):

@ -57,6 +57,7 @@ class RemoteGenerationMixin:
:param bos_token_id: The id of the beginning of sentence token. :param bos_token_id: The id of the beginning of sentence token.
:param eos_token_id: The id of the end of sentence token. :param eos_token_id: The id of the end of sentence token.
:param pad_token_id: The id of the padding token. :param pad_token_id: The id of the padding token.
:param max_length: The maximum number of tokens in the output (including input tokens).
:param max_new_tokens: The maximum number of tokens to generate. :param max_new_tokens: The maximum number of tokens to generate.
:param decoding_algorithm: The decoding algorithm to use. :param decoding_algorithm: The decoding algorithm to use.
:param provided_constraints: A list of constraints to use. :param provided_constraints: A list of constraints to use.

@ -51,7 +51,6 @@ async def sequential_forward(
sequences = deque() sequences = deque()
intermediate_inputs = [] intermediate_inputs = []
done_sequences = [] done_sequences = []
outputs = inputs
block_idx = start_index block_idx = start_index
while block_idx < end_index: while block_idx < end_index:

@ -1,5 +1,5 @@
"""Code for serving bloom blocks via hivemind-server""" """Code for serving bloom blocks via hivemind-server"""
from typing import Any, Dict, Optional, Sequence, Tuple from typing import Any, Dict, Sequence, Tuple
import torch import torch
from hivemind import BatchTensorDescriptor, use_hivemind_log_handler from hivemind import BatchTensorDescriptor, use_hivemind_log_handler

@ -17,7 +17,7 @@ from hivemind import (
from hivemind.moe.server.connection_handler import ConnectionHandler from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.proto import runtime_pb2 from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter from hivemind.utils.asyncio import amap_in_executor, anext
from hivemind.utils.logging import get_logger from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming from hivemind.utils.streaming import split_for_streaming

@ -4,7 +4,7 @@ import threading
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from queue import PriorityQueue from queue import PriorityQueue
from typing import Any, Generator, List, Optional, Sequence, Tuple from typing import Any, List, Optional, Sequence, Tuple
import torch import torch
from hivemind import MPFuture, get_logger, use_hivemind_log_handler from hivemind import MPFuture, get_logger, use_hivemind_log_handler

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch import torch
from hivemind.moe.server.task_pool import Task
class TaskPrioritizerBase(ABC): class TaskPrioritizerBase(ABC):

Loading…
Cancel
Save