Reduce vocabulary size in test model, fix bug in routing when overlapped (#45)

This PR reduces this vocabulary size to save memory during conversion, keeping only the first 50k tokens
As a result, 

* tests that load client-side embeddings need significantly less RAM
* we can now run CI tests with 4 servers instead of 2 - needed to test routing - see bugs uncovered
* some of the servers now use load balancing
* CI convert_model now takes 4-5 minutes (was 6-7)
pull/46/head
justheuristic 2 years ago committed by GitHub
parent 5745882c67
commit a2634001e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,7 +26,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Delete previous model, if exists
- name: Delete any test models older than 1 week
run: |
python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
- name: Delete previous version of this model, if exists
run: |
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \
@ -35,8 +38,8 @@ jobs:
run: |
export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))")
python -m cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \
--output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN
--output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \
--resize_token_embeddings 50000
run-tests:
runs-on: ubuntu-latest
@ -66,6 +69,7 @@ jobs:
run: |
git clone https://github.com/TimDettmers/bitsandbytes.git
cd bitsandbytes
git checkout 4cd7ea62b2f51c68aacde2f62e7141765e476111
make cpuonly
pip install .
cd -
@ -76,7 +80,8 @@ jobs:
export REF_NAME=bigscience/bloom-560m
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:12 \
--torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 --throughput 1 &
--torch_dtype float32 --identity tests/test.id --host_maddrs /ip4/127.0.0.1/tcp/31337 \
--throughput 1 &> server1.log &
SERVER1_PID=$!
sleep 5 # wait for the first server to initialize DHT
@ -84,13 +89,33 @@ jobs:
export INITIAL_PEERS=/ip4/127.0.0.1/tcp/31337/p2p/QmS9KwZptnVdB9FFV7uGgaTq4sEKBwcYeKZDfSpyKDUd1g
# ^-- server 1 multiaddr is determined by --identity and --host_maddrs
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:24 \
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 12:22 \
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server2.log &
SERVER2_PID=$!
sleep 60 # wait for server to download layers
sleep 10 # wait for initial servers to declare blocks, then let server decide which blocks to serve
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 0:6 \
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server3.log &
SERVER3_PID=$!
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --block_indices 4:16 \
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server4.log &
SERVER4_PID=$!
python -m cli.run_server --converted_model_name_or_path $MODEL_NAME --num_blocks 3 \
--torch_dtype float32 --initial_peers $INITIAL_PEERS --throughput 1 &> server5.log &
SERVER5_PID=$!
tail -n 100 -f server*.log &
LOGGER_PID=$!
sleep 30 # wait for servers to download layers
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived init
PYTHONPATH=. pytest tests --durations=0 --durations-min=1.0 -v
kill -s SIGINT $SERVER1_PID $SERVER2_PID
kill -0 $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID # ensure all servers survived tests
kill -s SIGINT $SERVER1_PID $SERVER2_PID $SERVER3_PID $SERVER4_PID $SERVER5_PID $LOGGER_PID
echo "Done!"

@ -35,6 +35,7 @@ if __name__ == "__main__":
"--commit_message", type=str, default="push-o-matic", help="Use this commit message for all parts"
)
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument("--resize_token_embeddings", type=int, default=None, help="change the vocabulary size")
args = parser.parse_args()
free_ram_gb = psutil.virtual_memory().available / 2**30
@ -56,6 +57,11 @@ if __name__ == "__main__":
model = BloomModel.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision, torch_dtype=DTYPE_MAP[args.torch_dtype]
)
if args.resize_token_embeddings:
logger.info(f"Resizing token embeddings, new size = {args.resize_token_embeddings}")
model.resize_token_embeddings(args.resize_token_embeddings)
config.vocab_size = args.resize_token_embeddings
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.model, use_auth_token=args.use_auth_token, revision=args.revision
)

@ -54,7 +54,7 @@ class RemoteSequenceManager:
chosen_span = random.choice(candidate_spans) # TODO this should be replaced with proper load balancing
assert chosen_span.start <= current_index < chosen_span.end
span_sequence.append(chosen_span)
span_sequence.append(RemoteSpanInfo(start=current_index, end=chosen_span.end, peer_id=chosen_span.peer_id))
current_index = chosen_span.end
return span_sequence

@ -0,0 +1,25 @@
import argparse
from datetime import datetime
from huggingface_hub import delete_repo, list_models
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Remove old testing models from HF hub")
parser.add_argument("--author", type=str, default="bloom-testing", help="auth token for from_pretrained")
parser.add_argument("--seconds_since_last_updated", type=int, default=7 * 24 * 60 * 60)
parser.add_argument("--use_auth_token", type=str, default=None, help="auth token for from_pretrained")
parser.add_argument("--dry_run", action="store_true")
args = parser.parse_args()
for model in list_models(author=args.author, full=True):
last_modified = datetime.strptime(model.lastModified, "%Y-%m-%dT%H:%M:%S.%fZ")
if model.modelId.endswith("-main") or "/test-" not in model.modelId:
continue # remove only test models
if (datetime.now() - last_modified).total_seconds() > args.seconds_since_last_updated:
if args.dry_run:
print(f"{model.modelId} can be deleted")
else:
delete_repo(token=args.use_auth_token, name=model.modelId, organization=args.author)

@ -17,6 +17,7 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
model = DistributedBloomForCausalLM.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
config = model.config
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.n_layer
@ -45,6 +46,10 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
ref_model = transformers.BloomForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
if config.vocab_size < ref_model.config.vocab_size:
ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837

Loading…
Cancel
Save