diff --git a/application/llm/huggingface.py b/application/llm/huggingface.py index e9cc47be..ef3b1fbc 100644 --- a/application/llm/huggingface.py +++ b/application/llm/huggingface.py @@ -2,13 +2,26 @@ from application.llm.base import BaseLLM class HuggingFaceLLM(BaseLLM): - def __init__(self, api_key, llm_name='Arc53/DocsGPT-7B'): + def __init__(self, api_key, llm_name='Arc53/DocsGPT-7B',q=False): global hf - + from langchain.llms import HuggingFacePipeline - from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline - tokenizer = AutoTokenizer.from_pretrained(llm_name) - model = AutoModelForCausalLM.from_pretrained(llm_name) + if q: + import torch + from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig + tokenizer = AutoTokenizer.from_pretrained(llm_name) + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16 + ) + model = AutoModelForCausalLM.from_pretrained(llm_name,quantization_config=bnb_config) + else: + from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline + tokenizer = AutoTokenizer.from_pretrained(llm_name) + model = AutoModelForCausalLM.from_pretrained(llm_name) + pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=2000,