From 825f5dbf2d7109338755f8a31bf679cd185b4d10 Mon Sep 17 00:00:00 2001 From: Alexander Borzunov Date: Fri, 13 Jan 2023 19:53:57 +0400 Subject: [PATCH] CI: Convert model only when convert_model.py or setup.cfg change (#213) This reduces the test running time by 2 times, unless convert_model.py or setup.cfg are changed. --- .github/workflows/run-tests.yaml | 26 ++++++++++++++++++++------ src/petals/cli/convert_model.py | 6 +++++- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/.github/workflows/run-tests.yaml b/.github/workflows/run-tests.yaml index eb9c988..54614f3 100644 --- a/.github/workflows/run-tests.yaml +++ b/.github/workflows/run-tests.yaml @@ -12,32 +12,45 @@ jobs: BLOOM_TESTING_WRITE_TOKEN: ${{ secrets.BLOOM_TESTING_WRITE_TOKEN }} timeout-minutes: 15 steps: - - uses: actions/checkout@v2 + - name: Checkout + uses: actions/checkout@v2 + - name: Check if the model is cached + id: cache-model + uses: actions/cache@v2 + with: + path: ~/.dummy + key: model-v1-${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} - name: Set up Python + if: steps.cache-model.outputs.cache-hit != 'true' uses: actions/setup-python@v2 with: python-version: 3.9 - name: Cache dependencies + if: steps.cache-model.outputs.cache-hit != 'true' uses: actions/cache@v2 with: path: ~/.cache/pip key: Key-v1-3.9-${{ hashFiles('setup.cfg') }} - name: Install dependencies + if: steps.cache-model.outputs.cache-hit != 'true' run: | python -m pip install --upgrade pip - pip install .[dev] + pip install . - name: Delete any test models older than 1 week + if: steps.cache-model.outputs.cache-hit != 'true' run: | python tests/scripts/remove_old_models.py --author bloom-testing --use_auth_token $BLOOM_TESTING_WRITE_TOKEN - name: Delete previous version of this model, if exists + if: steps.cache-model.outputs.cache-hit != 'true' run: | export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") python -c "from huggingface_hub import delete_repo; delete_repo(token='$BLOOM_TESTING_WRITE_TOKEN', \ repo_id='bloom-testing/test-bloomd-560m-$HF_TAG')" || true - name: Convert model and push to hub + if: steps.cache-model.outputs.cache-hit != 'true' run: | - export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") - python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \ + export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} + python -m petals.cli.convert_model --model bigscience/bloom-560m --output_path ./converted_model \ --output_repo bloom-testing/test-bloomd-560m-$HF_TAG --use_auth_token $BLOOM_TESTING_WRITE_TOKEN \ --resize_token_embeddings 50000 @@ -50,7 +63,8 @@ jobs: fail-fast: false timeout-minutes: 15 steps: - - uses: actions/checkout@v2 + - name: Checkout + uses: actions/checkout@v2 - name: Set up Python uses: actions/setup-python@v2 with: @@ -66,7 +80,7 @@ jobs: pip install .[dev] - name: Test run: | - export HF_TAG=$(python -c "import os; print(os.environ.get('GITHUB_HEAD_REF') or os.environ.get('GITHUB_REF_NAME'))") + export HF_TAG=${{ hashFiles('setup.cfg', 'src/petals/cli/convert_model.py') }} export MODEL_NAME=bloom-testing/test-bloomd-560m-$HF_TAG export REF_NAME=bigscience/bloom-560m diff --git a/src/petals/cli/convert_model.py b/src/petals/cli/convert_model.py index c4746fd..289c764 100644 --- a/src/petals/cli/convert_model.py +++ b/src/petals/cli/convert_model.py @@ -18,7 +18,7 @@ logger = get_logger(__file__) DTYPE_MAP = dict(bfloat16=torch.bfloat16, float16=torch.float16, float32=torch.float32, auto="auto") -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser(description="Load bloom layers and convert to 8-bit using torch quantization.") parser.add_argument("--model", type=str, default="bigscience/bloom-6b3", help="Model name for from_pretrained") @@ -90,3 +90,7 @@ if __name__ == "__main__": config.save_pretrained(".") logger.info(f"Converted {args.model} and pushed to {args.output_repo}") + + +if __name__ == "__main__": + main()