Remove unused imports, add missing arguments to docstrings (#108)

* Remove unused imports, add missing arguments to docstrings
pull/105/head^2
Max Ryabinin 1 year ago committed by GitHub
parent b3115dac58
commit 9faf08b898
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -21,7 +21,6 @@ from transformers.modeling_outputs import (
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig
from transformers.models.bloom.modeling_bloom import BloomPreTrainedModel
from transformers.utils import logging

@ -196,10 +196,6 @@ class BloomScaledSoftmax(nn.Module):
fused operation: scaling + mask + softmax
Args:
input_in_fp16 (`bool`, *required*):
flag to indicate if input in fp16 data format.
input_in_bf16 (`bool`, *required*):
flag to indicate if input in bf16 data format.
scaled_masked_softmax_fusion (`bool`, *required*):
flag to indicate user want to use softmax fusion
mask_func (`function`, *required*):

@ -57,6 +57,7 @@ class RemoteGenerationMixin:
:param bos_token_id: The id of the beginning of sentence token.
:param eos_token_id: The id of the end of sentence token.
:param pad_token_id: The id of the padding token.
:param max_length: The maximum number of tokens in the output (including input tokens).
:param max_new_tokens: The maximum number of tokens to generate.
:param decoding_algorithm: The decoding algorithm to use.
:param provided_constraints: A list of constraints to use.

@ -51,7 +51,6 @@ async def sequential_forward(
sequences = deque()
intermediate_inputs = []
done_sequences = []
outputs = inputs
block_idx = start_index
while block_idx < end_index:

@ -1,5 +1,5 @@
"""Code for serving bloom blocks via hivemind-server"""
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, Sequence, Tuple
import torch
from hivemind import BatchTensorDescriptor, use_hivemind_log_handler

@ -17,7 +17,7 @@ from hivemind import (
from hivemind.moe.server.connection_handler import ConnectionHandler
from hivemind.p2p.p2p_daemon import DEFAULT_MAX_MSG_SIZE
from hivemind.proto import runtime_pb2
from hivemind.utils.asyncio import amap_in_executor, anext, as_aiter
from hivemind.utils.asyncio import amap_in_executor, anext
from hivemind.utils.logging import get_logger
from hivemind.utils.streaming import split_for_streaming

@ -4,7 +4,7 @@ import threading
import time
from dataclasses import dataclass, field
from queue import PriorityQueue
from typing import Any, Generator, List, Optional, Sequence, Tuple
from typing import Any, List, Optional, Sequence, Tuple
import torch
from hivemind import MPFuture, get_logger, use_hivemind_log_handler

@ -1,7 +1,6 @@
from abc import ABC, abstractmethod
import torch
from hivemind.moe.server.task_pool import Task
class TaskPrioritizerBase(ABC):

Loading…
Cancel
Save