You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
petals/src/petals/bloom/ops.py

247 lines
9.3 KiB
Python

"""
Utility operations used in the the BLOOM model
Based on https://github.com/huggingface/transformers/commit/ca2a55e9dfb245527b5e1c954fec6ffbb7aef07b
See commit history for authorship.
"""
import math
import torch
import torch.autograd
import torch.nn.functional as F
2 years ago
from torch import nn
def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=False):
"""Split a tensor along its last dimension.
Args:
tensor: ([`torch.tensor`], *required*):
input tensor to split
num_partitions ([`int`], *required*):
number of partitions to split the tensor
contiguous_split_chunks ([`bool`], *optional*, default=`False`)::
If True, make each chunk contiguous in memory.
"""
# Get the size and dimension.
last_dim = tensor.dim() - 1
numerator, denominator = tensor.size()[last_dim], num_partitions
if not (numerator % denominator == 0):
raise ValueError(f"{numerator} is not divisible by {denominator}")
last_dim_size = numerator // denominator
# Split.
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
# Note: torch.split does not create contiguous tensors by default.
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
def attention_mask_func(attention_scores, attention_mask, causal_mask):
if attention_mask.dtype == torch.bool:
attention_mask_bool = ~attention_mask
else:
attention_mask_bool = (1 - attention_mask).bool()
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
padded_causal_mask = (
attention_mask_bool[:, None, key_length - query_length : key_length, None]
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
).bool()
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
# Make use of floats
return (
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
padded_causal_mask,
)
def build_alibi_tensor(
max_seq_len: int, n_head: int, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cpu")
) -> torch.Tensor:
"""
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
`softmax(l+a) = softmax(l)`. Based on
https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742
Args:
Returns tensor shaped (n_head, 1, max_seq_len)
max_seq_len: (`int`, *required*):
max sequence length
n_head: (`int`, *required*):
number of heads
dtype: (`torch.dtype`, *optional*, default=`torch.bfloat16`):
dtype of the output tensor
device: (`torch.device`, *optional*, default=`torch.device('cpu')`):
device of the output alibi tensor
"""
closest_power_of_2 = 2 ** math.floor(math.log2(n_head))
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
slopes = torch.pow(base, powers)
if closest_power_of_2 != n_head:
extra_base = torch.tensor(
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32
)
num_remaining_heads = min(closest_power_of_2, n_head - closest_power_of_2)
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
lengths = torch.arange(max_seq_len, device=device, dtype=torch.int32)
return (slopes.view(-1, 1, 1) * lengths.view(1, 1, -1)).to(dtype)
def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor):
"""
Args:
Pre-process the alibi tensor for padding.
alibi: ([`torch.tensor`], *required*):
alibi tensor to pre-process
attention_mask: ([`torch.tensor`], *required*):
attention mask to pre-process
"""
Miscellaneous fixes to automatic tests (#35) 1. __Reduce memory usage in in test_full_model__ - previously, loading the full model would consistently fail IF github is enforcing memory limit [example](https://github.com/bigscience-workshop/distributed-bloom/runs/7473920049?check_suite_focus=true) - the new version uses accelerate to save 2GB of peak memory, that was previously used when loading both reference model AND its state dict at the same time - only to load that state dict :) 2. __Safer delays when creating servers__ - run-tests will now wait for a few seconds after creating the first server - and before creating the second one, so as to make sure that the first server creates a DHT instance that subsequent servers can connect to. - also increased the wait time after creating servers by 30 seconds to make sure we load the model in time even when bumping into slow remotes on HF side 3. __Fix environment variables in CI to avoid build conflicts__ - the previous code was using a wrong environment variable that was always "main". The current one will correctly resolve branch name, both in main and on pull request. - For reference, below you can find sample environments when running CI in both cases: on pull request and on push to main. <details> <summary> Environment variables when building this branch (on pull request) </summary> SELENIUM_JAR_PATH=/usr/share/java/selenium-server.jar GOROOT_1_17_X64=/opt/hostedtoolcache/go/1.17.12/x64 CONDA=/usr/share/miniconda GITHUB_WORKSPACE=/home/runner/work/distributed-bloom/distributed-bloom JAVA_HOME_11_X64=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_PATH=/home/runner/work/_temp/_runner_file_commands/add_path_0aba811a-a04b-40a2-ba42-79efb2723e9e GITHUB_ACTION=__run_2 JAVA_HOME=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_RUN_NUMBER=98 RUNNER_NAME=GitHub Actions 3 GRADLE_HOME=/usr/share/gradle-7.5 XDG_CONFIG_HOME=/home/runner/.config DOTNET_SKIP_FIRST_TIME_EXPERIENCE=1 ANT_HOME=/usr/share/ant JAVA_HOME_8_X64=/usr/lib/jvm/temurin-8-jdk-amd64 HOMEBREW_PREFIX=/home/linuxbrew/.linuxbrew pythonLocation=/opt/hostedtoolcache/Python/3.9.13/x64 GITHUB_REF_TYPE=branch HOMEBREW_CLEANUP_PERIODIC_FULL_DAYS=3650 BOOTSTRAP_HASKELL_NONINTERACTIVE=1 *** PIPX_BIN_DIR=/opt/pipx_bin DEPLOYMENT_BASEPATH=/opt/runner GITHUB_ACTIONS=true ANDROID_NDK_LATEST_HOME=/usr/local/lib/android/sdk/ndk/24.0.8215888 GITHUB_SHA=3b457e8a14e5ecb0d65d6e4c0e9161f7756a8861 POWERSHELL_DISTRIBUTION_CHANNEL=GitHub-Actions-ubuntu20 DOTNET_MULTILEVEL_LOOKUP=0 GITHUB_REF=refs/pull/35/merge RUNNER_OS=Linux GITHUB_REF_PROTECTED=false HOME=/home/runner GITHUB_API_URL=https://api.github.com/ LANG=C.UTF-8 BLOOM_TESTING_WRITE_TOKEN=*** RUNNER_TRACKING_ID=github_cc9b46e4-56a1-40c5-ba08-5a91e21f0f95 STATS_KEEPALIVE=false RUNNER_ARCH=X64 RUNNER_TEMP=/home/runner/work/_temp EDGEWEBDRIVER=/usr/local/share/edge_driver GITHUB_ENV=/home/runner/work/_temp/_runner_file_commands/set_env_0aba811a-a04b-40a2-ba42-79efb2723e9e GITHUB_EVENT_PATH=/home/runner/work/_temp/_github_workflow/event.json INVOCATION_ID=8f0072e74f2847c0851e7ff9b5e4af7c GITHUB_EVENT_NAME=pull_request GITHUB_RUN_ID=2720198689 JAVA_HOME_17_X64=/usr/lib/jvm/temurin-17-jdk-amd64 ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle GITHUB_STEP_SUMMARY=/home/runner/work/_temp/_runner_file_commands/step_summary_0aba811a-a04b-40a2-ba42-79efb2723e9e HOMEBREW_NO_AUTO_UPDATE=1 GITHUB_ACTOR=justheuristic NVM_DIR=/home/runner/.nvm SGX_AESM_ADDR=1 GITHUB_RUN_ATTEMPT=1 ANDROID_HOME=/usr/local/lib/android/sdk GITHUB_GRAPHQL_URL=https://api.github.com/graphql ACCEPT_EULA=Y RUNNER_USER=runner USER=runner GITHUB_SERVER_URL=https://github.com/ HOMEBREW_CELLAR=/home/linuxbrew/.linuxbrew/Cellar PIPX_HOME=/opt/pipx GECKOWEBDRIVER=/usr/local/share/gecko_driver CHROMEWEBDRIVER=/usr/local/share/chrome_driver SHLVL=0 ANDROID_SDK_ROOT=/usr/local/lib/android/sdk VCPKG_INSTALLATION_ROOT=/usr/local/share/vcpkg HOMEBREW_REPOSITORY=/home/linuxbrew/.linuxbrew/Homebrew RUNNER_TOOL_CACHE=/opt/hostedtoolcache ImageVersion=20220717.1 DOTNET_NOLOGO=1 GITHUB_REF_NAME=35/merge STATS_PFS=true GRAALVM_11_ROOT=/usr/local/graalvm/graalvm-ce-java11-22.1.0 GITHUB_JOB=convert-model LD_LIBRARY_PATH=/opt/hostedtoolcache/Python/3.9.13/x64/lib XDG_RUNTIME_DIR=/run/user/1001 AZURE_EXTENSION_DIR=/opt/az/azcliextensions PERFLOG_LOCATION_SETTING=RUNNER_PERFLOG GITHUB_REPOSITORY=bigscience-workshop/distributed-bloom ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle CHROME_BIN=/usr/bin/google-chrome GOROOT_1_18_X64=/opt/hostedtoolcache/go/1.18.4/x64 GITHUB_RETENTION_DAYS=90 JOURNAL_STREAM=8:23653 RUNNER_WORKSPACE=/home/runner/work/distributed-bloom LEIN_HOME=/usr/local/lib/lein LEIN_JAR=/usr/local/lib/lein/self-installs/leiningen-2.9.8-standalone.jar GITHUB_ACTION_REPOSITORY= PATH=/opt/hostedtoolcache/Python/3.9.13/x64/bin:/opt/hostedtoolcache/Python/3.9.13/x64:/home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/runner/.local/bin:/opt/pipx_bin:/home/runner/.cargo/bin:/home/runner/.config/composer/vendor/bin:/usr/local/.ghcup/bin:/home/runner/.dotnet/tools:/snap/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin RUNNER_PERFLOG=/home/runner/perflog GITHUB_BASE_REF=main CI=true SWIFT_PATH=/usr/share/swift/usr/bin ImageOS=ubuntu20 GITHUB_REPOSITORY_OWNER=bigscience-workshop GITHUB_HEAD_REF=fix-branch-name GITHUB_ACTION_REF= GITHUB_WORKFLOW=Tests DEBIAN_FRONTEND=noninteractive AGENT_TOOLSDIRECTORY=/opt/hostedtoolcache GOROOT_1_16_X64=/opt/hostedtoolcache/go/1.16.15/x64 _=/usr/bin/env </details> <details> <summary> Environment variables when building in main (on push) </summary> SELENIUM_JAR_PATH=/usr/share/java/selenium-server.jar GOROOT_1_17_X64=/opt/hostedtoolcache/go/1.17.11/x64 CONDA=/usr/share/miniconda GITHUB_WORKSPACE=/home/runner/work/distributed-bloom/distributed-bloom JAVA_HOME_11_X64=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_PATH=/home/runner/work/_temp/_runner_file_commands/add_path_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 GITHUB_ACTION=__run_2 JAVA_HOME=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_RUN_NUMBER=53 RUNNER_NAME=GitHub Actions 3 GRADLE_HOME=/usr/share/gradle-7.4.2 XDG_CONFIG_HOME=/home/runner/.config DOTNET_SKIP_FIRST_TIME_EXPERIENCE=1 ANT_HOME=/usr/share/ant JAVA_HOME_8_X64=/usr/lib/jvm/temurin-8-jdk-amd64 HOMEBREW_PREFIX=/home/linuxbrew/.linuxbrew pythonLocation=/opt/hostedtoolcache/Python/3.9.13/x64 GITHUB_REF_TYPE=branch HOMEBREW_CLEANUP_PERIODIC_FULL_DAYS=3650 BOOTSTRAP_HASKELL_NONINTERACTIVE=1 *** PIPX_BIN_DIR=/opt/pipx_bin DEPLOYMENT_BASEPATH=/opt/runner GITHUB_ACTIONS=true ANDROID_NDK_LATEST_HOME=/usr/local/lib/android/sdk/ndk/24.0.8215888 GITHUB_SHA=49242d81006454d687ff3293c49f6bf234793627 POWERSHELL_DISTRIBUTION_CHANNEL=GitHub-Actions-ubuntu20 DOTNET_MULTILEVEL_LOOKUP=0 GITHUB_REF=refs/heads/main RUNNER_OS=Linux GITHUB_REF_PROTECTED=true HOME=/home/runner GITHUB_API_URL=https://api.github.com/ LANG=C.UTF-8 BLOOM_TESTING_WRITE_TOKEN=*** RUNNER_TRACKING_ID=github_7668f06a-99e1-4ed1-81e9-46d75fab3f33 STATS_KEEPALIVE=false RUNNER_ARCH=X64 RUNNER_TEMP=/home/runner/work/_temp EDGEWEBDRIVER=/usr/local/share/edge_driver GITHUB_ENV=/home/runner/work/_temp/_runner_file_commands/set_env_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 GITHUB_EVENT_PATH=/home/runner/work/_temp/_github_workflow/event.json INVOCATION_ID=3dadac48981b4a679a33224db89be1ed GITHUB_EVENT_NAME=push GITHUB_RUN_ID=2680158280 JAVA_HOME_17_X64=/usr/lib/jvm/temurin-17-jdk-amd64 ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle GITHUB_STEP_SUMMARY=/home/runner/work/_temp/_runner_file_commands/step_summary_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 HOMEBREW_NO_AUTO_UPDATE=1 GITHUB_ACTOR=justheuristic NVM_DIR=/home/runner/.nvm SGX_AESM_ADDR=1 GITHUB_RUN_ATTEMPT=1 ANDROID_HOME=/usr/local/lib/android/sdk GITHUB_GRAPHQL_URL=https://api.github.com/graphql ACCEPT_EULA=Y RUNNER_USER=runner USER=runner GITHUB_SERVER_URL=https://github.com/ HOMEBREW_CELLAR=/home/linuxbrew/.linuxbrew/Cellar PIPX_HOME=/opt/pipx GECKOWEBDRIVER=/usr/local/share/gecko_driver CHROMEWEBDRIVER=/usr/local/share/chrome_driver SHLVL=0 ANDROID_SDK_ROOT=/usr/local/lib/android/sdk VCPKG_INSTALLATION_ROOT=/usr/local/share/vcpkg HOMEBREW_REPOSITORY=/home/linuxbrew/.linuxbrew/Homebrew RUNNER_TOOL_CACHE=/opt/hostedtoolcache ImageVersion=20220710.1 DOTNET_NOLOGO=1 GITHUB_REF_NAME=main STATS_PFS=true GRAALVM_11_ROOT=/usr/local/graalvm/graalvm-ce-java11-22.1.0 GITHUB_JOB=convert-model LD_LIBRARY_PATH=/opt/hostedtoolcache/Python/3.9.13/x64/lib XDG_RUNTIME_DIR=/run/user/1001 AZURE_EXTENSION_DIR=/opt/az/azcliextensions PERFLOG_LOCATION_SETTING=RUNNER_PERFLOG GITHUB_REPOSITORY=bigscience-workshop/distributed-bloom CHROME_BIN=/usr/bin/google-chrome ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle GOROOT_1_18_X64=/opt/hostedtoolcache/go/1.18.3/x64 GITHUB_RETENTION_DAYS=90 JOURNAL_STREAM=8:22000 RUNNER_WORKSPACE=/home/runner/work/distributed-bloom LEIN_HOME=/usr/local/lib/lein LEIN_JAR=/usr/local/lib/lein/self-installs/leiningen-2.9.8-standalone.jar GITHUB_ACTION_REPOSITORY= PATH=/opt/hostedtoolcache/Python/3.9.13/x64/bin:/opt/hostedtoolcache/Python/3.9.13/x64:/home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/runner/.local/bin:/opt/pipx_bin:/home/runner/.cargo/bin:/home/runner/.config/composer/vendor/bin:/usr/local/.ghcup/bin:/home/runner/.dotnet/tools:/snap/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin RUNNER_PERFLOG=/home/runner/perflog GITHUB_BASE_REF= CI=true SWIFT_PATH=/usr/share/swift/usr/bin ImageOS=ubuntu20 GITHUB_REPOSITORY_OWNER=bigscience-workshop GITHUB_HEAD_REF= GITHUB_ACTION_REF= GITHUB_WORKFLOW=Tests DEBIAN_FRONTEND=noninteractive AGENT_TOOLSDIRECTORY=/opt/hostedtoolcache GOROOT_1_16_X64=/opt/hostedtoolcache/go/1.16.15/x64 _=/usr/bin/env </details> Co-authored-by: Dmitry Baranchuk <dmitrybaranchuk@gmail.com>
2 years ago
assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
# ^-- [batch, max_len], values correspond to element indices after removing padding
# We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
alibi = alibi.take_along_dim(unpadded_indices.unsqueeze(0), -1) * attention_mask.unsqueeze(0)
return alibi.reshape(alibi.shape[0] * alibi.shape[1], 1, -1)
2 years ago
def dropout_add(x, residual, prob, training):
"""
Dropout add function
Args:
x (`torch.tensor`, *required*):
input tensor
residual (`torch.tensor`, *rquired*):
esidual tensor
prob (`float`, *required*):
dropout probability
training (`bool`, *required*):
training mode
"""
out = nn.functional.dropout(x, p=prob, training=training)
out = residual + out
return out
def bloom_gelu_forward(x):
"""
Custom bias GELU function. Adapted from Megatron-DeepSpeed code. Here we use a simple implementation (inference) to
make the model jitable.
Args:
x (`torch.tensor`, *required*):
input hidden states
"""
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
def bloom_gelu_back(g, x):
"""
gradient of tanh approximation of gelu gradient of actual gelu is: 0.5 * (1. + torch.erf(x * 0.70710678)) +
0.3989423 * x * torch.exp(-0.5 * x * x)
Args:
g (`torch.tensor`, *required*):
gradient output tensor
x (`torch.tensor`, *required*):
input tensor
"""
x = x[0] # x is a tuple of 1 element, needs to unpack it first
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)
return ff * g
class GeLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return bloom_gelu_forward(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.saved_tensors
tmp = bloom_gelu_back(grad_output, input)
return tmp
class BloomGelu(nn.Module):
"""
BloomBiasGelu wrapper function that make use of the simple function on inference mode to make the model
torchscriptable and use the autograd function in training mode to get the accurate results of the gradients Partly
copied from Megatron-DeepSpeed code and adapted for our needs
See here why autograd functions are not torchscriptable: https://github.com/pytorch/pytorch/issues/22329
"""
def __init__(self):
super().__init__()
def forward(self, x):
if self.training:
return GeLUFunction.apply(x)
else:
return bloom_gelu_forward(x)
class BloomScaledSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Args:
input_in_fp16 (`bool`, *required*):
flag to indicate if input in fp16 data format.
input_in_bf16 (`bool`, *required*):
flag to indicate if input in bf16 data format.
scaled_masked_softmax_fusion (`bool`, *required*):
flag to indicate user want to use softmax fusion
mask_func (`function`, *required*):
mask function to be applied.
softmax_in_fp32 (`bool`, *required*):
if true, softmax in performed at fp32 precision.
scale (`float`, *required*):
scaling factor used in input tensor scaling.
"""
def __init__(self, scaled_masked_softmax_fusion, mask_func, softmax_in_fp32, scale):
super().__init__()
self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion
self.mask_func = mask_func
self.softmax_in_fp32 = softmax_in_fp32
self.scale = scale
if not (self.scale is None or softmax_in_fp32):
raise ValueError("softmax should be in fp32 when scaled")
def forward(self, input, mask, max_positions):
input_dtype = input.dtype
input_in_16bit = input_dtype in [torch.float16, torch.bfloat16]
softmax_dtype = torch.float32 if self.softmax_in_fp32 else input_dtype
if self.scale is not None:
input = input * self.scale
if mask is None:
2 years ago
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = F.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
if input_in_16bit and self.softmax_in_fp32:
probs = probs.to(dtype=input_dtype)
return probs