Merge pull request #425 from Archit-Kohli/patch-2

Update huggingface.py
This commit is contained in:
Alex 2023-10-05 11:03:27 +01:00 committed by GitHub
commit 4b629d20cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,13 +2,26 @@ from application.llm.base import BaseLLM
class HuggingFaceLLM(BaseLLM):
def __init__(self, api_key, llm_name='Arc53/DocsGPT-7B'):
def __init__(self, api_key, llm_name='Arc53/DocsGPT-7B',q=False):
global hf
from langchain.llms import HuggingFacePipeline
if q:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
tokenizer = AutoTokenizer.from_pretrained(llm_name)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(llm_name,quantization_config=bnb_config)
else:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
tokenizer = AutoTokenizer.from_pretrained(llm_name)
model = AutoModelForCausalLM.from_pretrained(llm_name)
pipe = pipeline(
"text-generation", model=model,
tokenizer=tokenizer, max_new_tokens=2000,