black & isort

Please enter the commit message for your changes. Lines starting
light-mode-fix
Cosmic Snow 12 months ago committed by cosmic-snow
parent 19d6460282
commit e285ce91da

@ -1,2 +1,2 @@
from .gpt4all import GPT4All, Embed4All # noqa
from .gpt4all import Embed4All, GPT4All # noqa
from .pyllmodel import LLModel # noqa

@ -5,7 +5,7 @@ import os
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, Iterable, List, Union, Optional
from typing import Any, Dict, Iterable, List, Optional, Union
import requests
from tqdm import tqdm
@ -13,22 +13,22 @@ from tqdm import tqdm
from . import pyllmodel
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace(
"\\", "\\\\"
)
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
DEFAULT_MODEL_CONFIG = {
"systemPrompt": "",
"promptTemplate": "### Human: \n{0}\n### Assistant:\n",
}
ConfigType = Dict[str,str]
ConfigType = Dict[str, str]
MessageType = Dict[str, str]
class Embed4All:
"""
Python class that handles embeddings for GPT4All.
"""
def __init__(
self,
n_threads: Optional[int] = None,
@ -41,10 +41,7 @@ class Embed4All:
"""
self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=n_threads)
def embed(
self,
text: str
) -> List[float]:
def embed(self, text: str) -> List[float]:
"""
Generate an embedding.
@ -56,6 +53,7 @@ class Embed4All:
"""
return self.gpt4all.model.generate_embedding(text)
class GPT4All:
"""
Python class that handles instantiation, downloading, generation and chat with GPT4All models.
@ -84,9 +82,7 @@ class GPT4All:
self.model_type = model_type
self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model(
model_name, model_path=model_path, allow_download=allow_download
)
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
self.model.load_model(self.config["path"])
# Set n_threads
if n_threads is not None:
@ -170,9 +166,7 @@ class GPT4All:
elif allow_download:
url = config.pop("url", None)
config["path"] = GPT4All.download_model(
model_filename, model_path, verbose=verbose, url=url
)
config["path"] = GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
else:
raise ValueError("Failed to retrieve model")
@ -281,19 +275,24 @@ class GPT4All:
)
if self._is_chat_session_activated:
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt
# check if there is only one message, i.e. system prompt:
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1
self.current_chat_session.append({"role": "user", "content": prompt})
prompt = self._format_chat_prompt_template(
messages = self.current_chat_session[-1:],
default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "",
messages=self.current_chat_session[-1:],
default_prompt_header=self.current_chat_session[0]["content"]
if generate_kwargs["reset_context"]
else "",
)
else:
generate_kwargs["reset_context"] = True
# Prepare the callback, process the model response
output_collector: List[MessageType]
output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated
output_collector = [
{"content": ""}
] # placeholder for the self.current_chat_session if chat session is not activated
if self._is_chat_session_activated:
self.current_chat_session.append({"role": "assistant", "content": ""})

@ -1,4 +1,5 @@
import ctypes
import logging
import os
import platform
from queue import Queue
@ -6,8 +7,7 @@ import re
import subprocess
import sys
import threading
import logging
from typing import Iterable, Callable, List
from typing import Callable, Iterable, List
import pkg_resources
@ -16,9 +16,7 @@ logger: logging.Logger = logging.getLogger(__name__)
# TODO: provide a config file to make this more robust
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace(
"\\", "\\\\"
)
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\")
def load_llmodel_library():
@ -113,9 +111,7 @@ llmodel.llmodel_embedding.argtypes = [
llmodel.llmodel_embedding.restype = ctypes.POINTER(ctypes.c_float)
llmodel.llmodel_free_embedding.argtypes = [
ctypes.POINTER(ctypes.c_float)
]
llmodel.llmodel_free_embedding.argtypes = [ctypes.POINTER(ctypes.c_float)]
llmodel.llmodel_free_embedding.restype = None
llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32]
@ -251,10 +247,7 @@ class LLModel:
self.context.repeat_last_n = repeat_last_n
self.context.context_erase = context_erase
def generate_embedding(
self,
text: str
) -> List[float]:
def generate_embedding(self, text: str) -> List[float]:
if not text:
raise ValueError("Text must not be None or empty")
@ -330,10 +323,7 @@ class LLModel:
def prompt_model_streaming(
self,
prompt: str,
callback: ResponseCallbackType = empty_response_callback,
**kwargs
self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
) -> Iterable[str]:
# Symbol to terminate from generator
TERMINATING_SYMBOL = object()
@ -361,10 +351,7 @@ class LLModel:
# immediately
thread = threading.Thread(
target=run_llmodel_prompt,
args=(
prompt,
_generator_callback_wrapper(callback)
),
args=(prompt, _generator_callback_wrapper(callback)),
kwargs=kwargs,
)
thread.start()

@ -1,8 +1,9 @@
import sys
import time
from io import StringIO
from gpt4all import GPT4All, Embed4All
import time
from gpt4all import Embed4All, GPT4All
def time_embedding(i, embedder):
text = 'foo bar ' * i
@ -12,6 +13,7 @@ def time_embedding(i, embedder):
elapsed_time = end_time - start_time
print(f"Time report: {2 * i / elapsed_time} tokens/second with {2 * i} tokens taking {elapsed_time} seconds")
if __name__ == "__main__":
embedder = Embed4All(n_threads=8)
for i in [2**n for n in range(6, 14)]:

@ -6,6 +6,7 @@ from gpt4all import GPT4All, Embed4All
import time
import pytest
def test_inference():
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
output_1 = model.generate('hello', top_k=1)
@ -102,6 +103,7 @@ def test_inference_mpt():
assert isinstance(output, str)
assert len(output) > 0
def test_embedding():
text = 'The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox'
embedder = Embed4All()
@ -110,6 +112,7 @@ def test_embedding():
#print(f'Value at index {i}: {value}')
assert len(output) == 384
def test_empty_embedding():
text = ''
embedder = Embed4All()

Loading…
Cancel
Save