Remove hotfix for ExLlamaV2

Issue is resolved upstream
pull/39/head
Atinoda 8 months ago
parent e704efd2fa
commit 123618d6bf

@ -40,10 +40,8 @@ RUN git clone https://github.com/oobabooga/GPTQ-for-LLaMa.git -b cuda /app/repos
# Build and install default GPTQ ('quant_cuda')
ARG TORCH_CUDA_ARCH_LIST="6.1;7.0;7.5;8.0;8.6+PTX"
RUN cd /app/repositories/GPTQ-for-LLaMa/ && python3 setup_cuda.py install
# Install exllamav2 and flash attention
RUN pip install -U ninja exllamav2 && pip install flash-attn --no-build-isolation
# TEMPORARY HOTFIX FOR EXLLAMAV2:
RUN . /scripts/exllama_version_fix.sh
# Install flash attention for exllamav2
RUN pip install flash-attn --no-build-isolation
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 AS base
# Runtime pre-reqs

@ -1,23 +0,0 @@
From ec5164b8a8e282b91aedb2af94dfeb89887656b7 Mon Sep 17 00:00:00 2001
From: turboderp <11859846+turboderp@users.noreply.github.com>
Date: Tue, 19 Sep 2023 11:56:52 +0200
Subject: [PATCH] Fix for non-numerical version number in flash-attn
---
exllamav2/attn.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/exllamav2/attn.py b/exllamav2/attn.py
index b8b1752..09ef6d6 100644
--- a/exllamav2/attn.py
+++ b/exllamav2/attn.py
@@ -16,7 +16,7 @@
has_flash_attn = False
try:
import flash_attn
- flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".")]
+ flash_attn_ver = [int(t) for t in flash_attn.__version__.split(".") if t.isdigit()]
if flash_attn_ver >= [2, 2, 1]:
from flash_attn import flash_attn_func
has_flash_attn = True

@ -1,16 +0,0 @@
#!/bin/bash
set -x
# Get current directory
cur_dir=$(pwd)
exlv2_dir="/venv/lib/python3.10/site-packages/exllamav2/"
# Go to exl directory
cd $exlv2_dir
# Apply patch
patch -i /scripts/exllama_version_fix.patch
# Go back to the original directory
cd $cur_dir
Loading…
Cancel
Save