mirror of
https://github.com/bigscience-workshop/petals
synced 2024-11-16 06:12:50 +00:00
Miscellaneous fixes to automatic tests (#35)
1. __Reduce memory usage in in test_full_model__ - previously, loading the full model would consistently fail IF github is enforcing memory limit [example](https://github.com/bigscience-workshop/distributed-bloom/runs/7473920049?check_suite_focus=true) - the new version uses accelerate to save 2GB of peak memory, that was previously used when loading both reference model AND its state dict at the same time - only to load that state dict :) 2. __Safer delays when creating servers__ - run-tests will now wait for a few seconds after creating the first server - and before creating the second one, so as to make sure that the first server creates a DHT instance that subsequent servers can connect to. - also increased the wait time after creating servers by 30 seconds to make sure we load the model in time even when bumping into slow remotes on HF side 3. __Fix environment variables in CI to avoid build conflicts__ - the previous code was using a wrong environment variable that was always "main". The current one will correctly resolve branch name, both in main and on pull request. - For reference, below you can find sample environments when running CI in both cases: on pull request and on push to main. <details> <summary> Environment variables when building this branch (on pull request) </summary> SELENIUM_JAR_PATH=/usr/share/java/selenium-server.jar GOROOT_1_17_X64=/opt/hostedtoolcache/go/1.17.12/x64 CONDA=/usr/share/miniconda GITHUB_WORKSPACE=/home/runner/work/distributed-bloom/distributed-bloom JAVA_HOME_11_X64=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_PATH=/home/runner/work/_temp/_runner_file_commands/add_path_0aba811a-a04b-40a2-ba42-79efb2723e9e GITHUB_ACTION=__run_2 JAVA_HOME=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_RUN_NUMBER=98 RUNNER_NAME=GitHub Actions 3 GRADLE_HOME=/usr/share/gradle-7.5 XDG_CONFIG_HOME=/home/runner/.config DOTNET_SKIP_FIRST_TIME_EXPERIENCE=1 ANT_HOME=/usr/share/ant JAVA_HOME_8_X64=/usr/lib/jvm/temurin-8-jdk-amd64 HOMEBREW_PREFIX=/home/linuxbrew/.linuxbrew pythonLocation=/opt/hostedtoolcache/Python/3.9.13/x64 GITHUB_REF_TYPE=branch HOMEBREW_CLEANUP_PERIODIC_FULL_DAYS=3650 BOOTSTRAP_HASKELL_NONINTERACTIVE=1 *** PIPX_BIN_DIR=/opt/pipx_bin DEPLOYMENT_BASEPATH=/opt/runner GITHUB_ACTIONS=true ANDROID_NDK_LATEST_HOME=/usr/local/lib/android/sdk/ndk/24.0.8215888 GITHUB_SHA=3b457e8a14e5ecb0d65d6e4c0e9161f7756a8861 POWERSHELL_DISTRIBUTION_CHANNEL=GitHub-Actions-ubuntu20 DOTNET_MULTILEVEL_LOOKUP=0 GITHUB_REF=refs/pull/35/merge RUNNER_OS=Linux GITHUB_REF_PROTECTED=false HOME=/home/runner GITHUB_API_URL=https://api.github.com/ LANG=C.UTF-8 BLOOM_TESTING_WRITE_TOKEN=*** RUNNER_TRACKING_ID=github_cc9b46e4-56a1-40c5-ba08-5a91e21f0f95 STATS_KEEPALIVE=false RUNNER_ARCH=X64 RUNNER_TEMP=/home/runner/work/_temp EDGEWEBDRIVER=/usr/local/share/edge_driver GITHUB_ENV=/home/runner/work/_temp/_runner_file_commands/set_env_0aba811a-a04b-40a2-ba42-79efb2723e9e GITHUB_EVENT_PATH=/home/runner/work/_temp/_github_workflow/event.json INVOCATION_ID=8f0072e74f2847c0851e7ff9b5e4af7c GITHUB_EVENT_NAME=pull_request GITHUB_RUN_ID=2720198689 JAVA_HOME_17_X64=/usr/lib/jvm/temurin-17-jdk-amd64 ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle GITHUB_STEP_SUMMARY=/home/runner/work/_temp/_runner_file_commands/step_summary_0aba811a-a04b-40a2-ba42-79efb2723e9e HOMEBREW_NO_AUTO_UPDATE=1 GITHUB_ACTOR=justheuristic NVM_DIR=/home/runner/.nvm SGX_AESM_ADDR=1 GITHUB_RUN_ATTEMPT=1 ANDROID_HOME=/usr/local/lib/android/sdk GITHUB_GRAPHQL_URL=https://api.github.com/graphql ACCEPT_EULA=Y RUNNER_USER=runner USER=runner GITHUB_SERVER_URL=https://github.com/ HOMEBREW_CELLAR=/home/linuxbrew/.linuxbrew/Cellar PIPX_HOME=/opt/pipx GECKOWEBDRIVER=/usr/local/share/gecko_driver CHROMEWEBDRIVER=/usr/local/share/chrome_driver SHLVL=0 ANDROID_SDK_ROOT=/usr/local/lib/android/sdk VCPKG_INSTALLATION_ROOT=/usr/local/share/vcpkg HOMEBREW_REPOSITORY=/home/linuxbrew/.linuxbrew/Homebrew RUNNER_TOOL_CACHE=/opt/hostedtoolcache ImageVersion=20220717.1 DOTNET_NOLOGO=1 GITHUB_REF_NAME=35/merge STATS_PFS=true GRAALVM_11_ROOT=/usr/local/graalvm/graalvm-ce-java11-22.1.0 GITHUB_JOB=convert-model LD_LIBRARY_PATH=/opt/hostedtoolcache/Python/3.9.13/x64/lib XDG_RUNTIME_DIR=/run/user/1001 AZURE_EXTENSION_DIR=/opt/az/azcliextensions PERFLOG_LOCATION_SETTING=RUNNER_PERFLOG GITHUB_REPOSITORY=bigscience-workshop/distributed-bloom ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle CHROME_BIN=/usr/bin/google-chrome GOROOT_1_18_X64=/opt/hostedtoolcache/go/1.18.4/x64 GITHUB_RETENTION_DAYS=90 JOURNAL_STREAM=8:23653 RUNNER_WORKSPACE=/home/runner/work/distributed-bloom LEIN_HOME=/usr/local/lib/lein LEIN_JAR=/usr/local/lib/lein/self-installs/leiningen-2.9.8-standalone.jar GITHUB_ACTION_REPOSITORY= PATH=/opt/hostedtoolcache/Python/3.9.13/x64/bin:/opt/hostedtoolcache/Python/3.9.13/x64:/home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/runner/.local/bin:/opt/pipx_bin:/home/runner/.cargo/bin:/home/runner/.config/composer/vendor/bin:/usr/local/.ghcup/bin:/home/runner/.dotnet/tools:/snap/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin RUNNER_PERFLOG=/home/runner/perflog GITHUB_BASE_REF=main CI=true SWIFT_PATH=/usr/share/swift/usr/bin ImageOS=ubuntu20 GITHUB_REPOSITORY_OWNER=bigscience-workshop GITHUB_HEAD_REF=fix-branch-name GITHUB_ACTION_REF= GITHUB_WORKFLOW=Tests DEBIAN_FRONTEND=noninteractive AGENT_TOOLSDIRECTORY=/opt/hostedtoolcache GOROOT_1_16_X64=/opt/hostedtoolcache/go/1.16.15/x64 _=/usr/bin/env </details> <details> <summary> Environment variables when building in main (on push) </summary> SELENIUM_JAR_PATH=/usr/share/java/selenium-server.jar GOROOT_1_17_X64=/opt/hostedtoolcache/go/1.17.11/x64 CONDA=/usr/share/miniconda GITHUB_WORKSPACE=/home/runner/work/distributed-bloom/distributed-bloom JAVA_HOME_11_X64=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_PATH=/home/runner/work/_temp/_runner_file_commands/add_path_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 GITHUB_ACTION=__run_2 JAVA_HOME=/usr/lib/jvm/temurin-11-jdk-amd64 GITHUB_RUN_NUMBER=53 RUNNER_NAME=GitHub Actions 3 GRADLE_HOME=/usr/share/gradle-7.4.2 XDG_CONFIG_HOME=/home/runner/.config DOTNET_SKIP_FIRST_TIME_EXPERIENCE=1 ANT_HOME=/usr/share/ant JAVA_HOME_8_X64=/usr/lib/jvm/temurin-8-jdk-amd64 HOMEBREW_PREFIX=/home/linuxbrew/.linuxbrew pythonLocation=/opt/hostedtoolcache/Python/3.9.13/x64 GITHUB_REF_TYPE=branch HOMEBREW_CLEANUP_PERIODIC_FULL_DAYS=3650 BOOTSTRAP_HASKELL_NONINTERACTIVE=1 *** PIPX_BIN_DIR=/opt/pipx_bin DEPLOYMENT_BASEPATH=/opt/runner GITHUB_ACTIONS=true ANDROID_NDK_LATEST_HOME=/usr/local/lib/android/sdk/ndk/24.0.8215888 GITHUB_SHA=49242d81006454d687ff3293c49f6bf234793627 POWERSHELL_DISTRIBUTION_CHANNEL=GitHub-Actions-ubuntu20 DOTNET_MULTILEVEL_LOOKUP=0 GITHUB_REF=refs/heads/main RUNNER_OS=Linux GITHUB_REF_PROTECTED=true HOME=/home/runner GITHUB_API_URL=https://api.github.com/ LANG=C.UTF-8 BLOOM_TESTING_WRITE_TOKEN=*** RUNNER_TRACKING_ID=github_7668f06a-99e1-4ed1-81e9-46d75fab3f33 STATS_KEEPALIVE=false RUNNER_ARCH=X64 RUNNER_TEMP=/home/runner/work/_temp EDGEWEBDRIVER=/usr/local/share/edge_driver GITHUB_ENV=/home/runner/work/_temp/_runner_file_commands/set_env_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 GITHUB_EVENT_PATH=/home/runner/work/_temp/_github_workflow/event.json INVOCATION_ID=3dadac48981b4a679a33224db89be1ed GITHUB_EVENT_NAME=push GITHUB_RUN_ID=2680158280 JAVA_HOME_17_X64=/usr/lib/jvm/temurin-17-jdk-amd64 ANDROID_NDK_HOME=/usr/local/lib/android/sdk/ndk-bundle GITHUB_STEP_SUMMARY=/home/runner/work/_temp/_runner_file_commands/step_summary_cd6c1ed2-0d0f-496d-b7a6-ffa476dcc144 HOMEBREW_NO_AUTO_UPDATE=1 GITHUB_ACTOR=justheuristic NVM_DIR=/home/runner/.nvm SGX_AESM_ADDR=1 GITHUB_RUN_ATTEMPT=1 ANDROID_HOME=/usr/local/lib/android/sdk GITHUB_GRAPHQL_URL=https://api.github.com/graphql ACCEPT_EULA=Y RUNNER_USER=runner USER=runner GITHUB_SERVER_URL=https://github.com/ HOMEBREW_CELLAR=/home/linuxbrew/.linuxbrew/Cellar PIPX_HOME=/opt/pipx GECKOWEBDRIVER=/usr/local/share/gecko_driver CHROMEWEBDRIVER=/usr/local/share/chrome_driver SHLVL=0 ANDROID_SDK_ROOT=/usr/local/lib/android/sdk VCPKG_INSTALLATION_ROOT=/usr/local/share/vcpkg HOMEBREW_REPOSITORY=/home/linuxbrew/.linuxbrew/Homebrew RUNNER_TOOL_CACHE=/opt/hostedtoolcache ImageVersion=20220710.1 DOTNET_NOLOGO=1 GITHUB_REF_NAME=main STATS_PFS=true GRAALVM_11_ROOT=/usr/local/graalvm/graalvm-ce-java11-22.1.0 GITHUB_JOB=convert-model LD_LIBRARY_PATH=/opt/hostedtoolcache/Python/3.9.13/x64/lib XDG_RUNTIME_DIR=/run/user/1001 AZURE_EXTENSION_DIR=/opt/az/azcliextensions PERFLOG_LOCATION_SETTING=RUNNER_PERFLOG GITHUB_REPOSITORY=bigscience-workshop/distributed-bloom CHROME_BIN=/usr/bin/google-chrome ANDROID_NDK_ROOT=/usr/local/lib/android/sdk/ndk-bundle GOROOT_1_18_X64=/opt/hostedtoolcache/go/1.18.3/x64 GITHUB_RETENTION_DAYS=90 JOURNAL_STREAM=8:22000 RUNNER_WORKSPACE=/home/runner/work/distributed-bloom LEIN_HOME=/usr/local/lib/lein LEIN_JAR=/usr/local/lib/lein/self-installs/leiningen-2.9.8-standalone.jar GITHUB_ACTION_REPOSITORY= PATH=/opt/hostedtoolcache/Python/3.9.13/x64/bin:/opt/hostedtoolcache/Python/3.9.13/x64:/home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin:/home/runner/.local/bin:/opt/pipx_bin:/home/runner/.cargo/bin:/home/runner/.config/composer/vendor/bin:/usr/local/.ghcup/bin:/home/runner/.dotnet/tools:/snap/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin RUNNER_PERFLOG=/home/runner/perflog GITHUB_BASE_REF= CI=true SWIFT_PATH=/usr/share/swift/usr/bin ImageOS=ubuntu20 GITHUB_REPOSITORY_OWNER=bigscience-workshop GITHUB_HEAD_REF= GITHUB_ACTION_REF= GITHUB_WORKFLOW=Tests DEBIAN_FRONTEND=noninteractive AGENT_TOOLSDIRECTORY=/opt/hostedtoolcache GOROOT_1_16_X64=/opt/hostedtoolcache/go/1.16.15/x64 _=/usr/bin/env </details> Co-authored-by: Dmitry Baranchuk <dmitrybaranchuk@gmail.com>
This commit is contained in:
parent
7de3acf909
commit
f0cffbf67e
10
.github/workflows/run-tests.yaml
vendored
10
.github/workflows/run-tests.yaml
vendored
@ -28,12 +28,12 @@ jobs:
|
||||
pip install -r requirements.txt
|
||||
- name: Delete previous model, if exists
|
||||
run: |
|
||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
||||
python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
|
||||
name='test-bloomd-350m-$HF_TAG', organization='bloom-testing')" || true
|
||||
- name: Convert model and push to hub
|
||||
run: |
|
||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
||||
python -m cli.convert_model --model bigscience/bloom-350m --output_path ./converted_model \
|
||||
--output_repo bloom-testing/test-bloomd-350m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
|
||||
|
||||
@ -64,7 +64,7 @@ jobs:
|
||||
pip install -r requirements-dev.txt
|
||||
- name: Test
|
||||
run: |
|
||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_BASE_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
||||
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
|
||||
export MODEL_NAME=bloom-testing/test-bloomd-350m-$HF_TAG
|
||||
export REF_NAME=bigscience/bloom-350m
|
||||
|
||||
@ -72,6 +72,8 @@ jobs:
|
||||
--torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
|
||||
SERVER1_PID=$!
|
||||
|
||||
sleep 5 # wait for the first server to initialize DHT
|
||||
|
||||
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
|
||||
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
|
||||
|
||||
@ -79,7 +81,7 @@ jobs:
|
||||
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
|
||||
SERVER2_PID=$!
|
||||
|
||||
sleep 30 # wait for server to download layers
|
||||
sleep 60 # wait for server to download layers
|
||||
|
||||
PYTHONPATH=. pytest tests
|
||||
|
||||
|
@ -101,7 +101,7 @@ def pre_process_alibi_for_pad(alibi: torch.Tensor, attention_mask: torch.Tensor)
|
||||
attention_mask: ([`torch.tensor`], *required*):
|
||||
attention mask to pre-process
|
||||
"""
|
||||
assert attention_mask.shape.ndim == 2, "mask should be [batch_size, seq_length]"
|
||||
assert attention_mask.ndim == 2, "mask should be [batch_size, seq_length]"
|
||||
unpadded_indices = torch.relu(attention_mask.cumsum(dim=1) - 1)
|
||||
# ^-- [batch, max_len], values correspond to element indices after removing padding
|
||||
# We shift the alibi tensor + replace all the values where attention_mask==0.0 by 0
|
||||
|
@ -13,13 +13,15 @@ logger = get_logger(__file__)
|
||||
@pytest.mark.forked
|
||||
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
||||
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
|
||||
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
|
||||
model = DistributedBloomForCausalLM.from_pretrained(
|
||||
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
||||
)
|
||||
assert isinstance(model, DistributedBloomForCausalLM)
|
||||
assert len(model.transformer.h) == model.config.n_layer
|
||||
|
||||
test_inputs = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.inference_mode():
|
||||
parallel_outputs = model.forward(test_inputs).logits
|
||||
assert torch.all(torch.isfinite(parallel_outputs))
|
||||
logger.info("Forward outputs are finite")
|
||||
@ -32,21 +34,20 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
|
||||
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
|
||||
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
|
||||
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
|
||||
|
||||
dictionary = model.transformer.word_embeddings.weight.t()
|
||||
recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
|
||||
recurrent_outputs = (recurrent_outputs @ dictionary).float()
|
||||
recurrent_outputs = model.lm_head(recurrent_outputs)
|
||||
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
|
||||
logger.info("Inference is consistent with forward")
|
||||
|
||||
del model, recurrent_outputs
|
||||
del model, embs, recurrent_outputs
|
||||
|
||||
if REF_NAME:
|
||||
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
|
||||
ref_model = transformers.BloomForCausalLM.from_pretrained(
|
||||
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
|
||||
)
|
||||
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
|
||||
# note: this creates a dummy mask to make the test compatible with older transformer versions
|
||||
# prior to https://github.com/huggingface/transformers/pull/17837
|
||||
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits
|
||||
ref_outputs = ref_model.forward(test_inputs, attention_mask=dummy_mask).logits.float()
|
||||
assert torch.allclose(ref_outputs, parallel_outputs, rtol=0, atol=atol_forward)
|
||||
logger.warning(f"Distributed forward is consistent with {type(ref_model)}.forward")
|
||||
del ref_model, ref_outputs, dummy_mask
|
||||
|
Loading…
Reference in New Issue
Block a user