Allow to disable chunked forward (#176)

pull/177/head
Alexander Borzunov 1 year ago committed by GitHub
parent 356e099c3d
commit 6948a0c5ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -45,8 +45,11 @@ class LMHead(nn.Module):
def forward(self, hidden_states):
word_embeddings = self.word_embeddings.weight
# We use 'chunked_forward' only when embeddings are in half-precision on CPU.
if word_embeddings.dtype in [torch.float16, torch.bfloat16] and word_embeddings.device.type == "cpu":
if (
self.chunk_size is not None
and word_embeddings.dtype in [torch.float16, torch.bfloat16]
and word_embeddings.device.type == "cpu"
):
lm_logits = self.chunked_forward(hidden_states)
else:
# Switch dtype in case word_embeddings are fp16/bf16

@ -34,7 +34,8 @@ class DistributedBloomConfig(BloomConfig):
dht_prefix: str # a prefix for all dht keys that correspond to this model (usually equal to model name)
daemon_startup_timeout: int = 30
dht: Optional[hivemind.DHT] = None # a running DHT instance, e.g. when using the same DHT for multiple models
chunk_size_for_efficient_fp16_on_cpu: int = 10000 # a chunk size for a LM head for efficient half-precision on CPU
chunk_size_for_efficient_fp16_on_cpu: Optional[int] = 10000
# Chunk size for efficient half-precision on CPU in the LM head. Set to None if your CPU works fast with bfloat16.
pre_seq_len: int = 0 # a number of tokens for prompt tuning.
tuning_mode: Optional[str] = None # One of the finetune options: [None, 'shallow_ptune', 'deep_ptune', 'adapters']
request_timeout: int = 30 # a number of seconds for waiting result from each node

Loading…
Cancel
Save