black & isort

Please enter the commit message for your changes. Lines starting
light-mode-fix
Cosmic Snow 1 year ago committed by cosmic-snow
parent 19d6460282
commit e285ce91da

@ -1,2 +1,2 @@
from .gpt4all import GPT4All, Embed4All # noqa from .gpt4all import Embed4All, GPT4All # noqa
from .pyllmodel import LLModel # noqa from .pyllmodel import LLModel # noqa

@ -5,7 +5,7 @@ import os
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Iterable, List, Union, Optional from typing import Any, Dict, Iterable, List, Optional, Union
import requests import requests
from tqdm import tqdm from tqdm import tqdm
@ -13,9 +13,7 @@ from tqdm import tqdm
from . import pyllmodel from . import pyllmodel
# TODO: move to config # TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace( DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
"\\", "\\\\"
)
DEFAULT_MODEL_CONFIG = { DEFAULT_MODEL_CONFIG = {
"systemPrompt": "", "systemPrompt": "",
@ -25,10 +23,12 @@ DEFAULT_MODEL_CONFIG = {
ConfigType = Dict[str, str] ConfigType = Dict[str, str]
MessageType = Dict[str, str] MessageType = Dict[str, str]
class Embed4All: class Embed4All:
""" """
Python class that handles embeddings for GPT4All. Python class that handles embeddings for GPT4All.
""" """
def __init__( def __init__(
self, self,
n_threads: Optional[int] = None, n_threads: Optional[int] = None,
@ -41,10 +41,7 @@ class Embed4All:
""" """
self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=n_threads) self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=n_threads)
def embed( def embed(self, text: str) -> List[float]:
self,
text: str
) -> List[float]:
""" """
Generate an embedding. Generate an embedding.
@ -56,6 +53,7 @@ class Embed4All:
""" """
return self.gpt4all.model.generate_embedding(text) return self.gpt4all.model.generate_embedding(text)
class GPT4All: class GPT4All:
""" """
Python class that handles instantiation, downloading, generation and chat with GPT4All models. Python class that handles instantiation, downloading, generation and chat with GPT4All models.
@ -84,9 +82,7 @@ class GPT4All:
self.model_type = model_type self.model_type = model_type
self.model = pyllmodel.LLModel() self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed # Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model( self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
model_name, model_path=model_path, allow_download=allow_download
)
self.model.load_model(self.config["path"]) self.model.load_model(self.config["path"])
# Set n_threads # Set n_threads
if n_threads is not None: if n_threads is not None:
@ -170,9 +166,7 @@ class GPT4All:
elif allow_download: elif allow_download:
url = config.pop("url", None) url = config.pop("url", None)
config["path"] = GPT4All.download_model( config["path"] = GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
model_filename, model_path, verbose=verbose, url=url
)
else: else:
raise ValueError("Failed to retrieve model") raise ValueError("Failed to retrieve model")
@ -281,19 +275,24 @@ class GPT4All:
) )
if self._is_chat_session_activated: if self._is_chat_session_activated:
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt # check if there is only one message, i.e. system prompt:
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1
self.current_chat_session.append({"role": "user", "content": prompt}) self.current_chat_session.append({"role": "user", "content": prompt})
prompt = self._format_chat_prompt_template( prompt = self._format_chat_prompt_template(
messages=self.current_chat_session[-1:], messages=self.current_chat_session[-1:],
default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "", default_prompt_header=self.current_chat_session[0]["content"]
if generate_kwargs["reset_context"]
else "",
) )
else: else:
generate_kwargs["reset_context"] = True generate_kwargs["reset_context"] = True
# Prepare the callback, process the model response # Prepare the callback, process the model response
output_collector: List[MessageType] output_collector: List[MessageType]
output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated output_collector = [
{"content": ""}
] # placeholder for the self.current_chat_session if chat session is not activated
if self._is_chat_session_activated: if self._is_chat_session_activated:
self.current_chat_session.append({"role": "assistant", "content": ""}) self.current_chat_session.append({"role": "assistant", "content": ""})

@ -1,4 +1,5 @@
import ctypes import ctypes
import logging
import os import os
import platform import platform
from queue import Queue from queue import Queue
@ -6,8 +7,7 @@ import re
import subprocess import subprocess
import sys import sys
import threading import threading
import logging from typing import Callable, Iterable, List
from typing import Iterable, Callable, List
import pkg_resources import pkg_resources
@ -16,9 +16,7 @@ logger: logging.Logger = logging.getLogger(__name__)
# TODO: provide a config file to make this more robust # TODO: provide a config file to make this more robust
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\") LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace( MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\")
"\\", "\\\\"
)
def load_llmodel_library(): def load_llmodel_library():
@ -113,9 +111,7 @@ llmodel.llmodel_embedding.argtypes = [
llmodel.llmodel_embedding.restype = ctypes.POINTER(ctypes.c_float) llmodel.llmodel_embedding.restype = ctypes.POINTER(ctypes.c_float)
llmodel.llmodel_free_embedding.argtypes = [ llmodel.llmodel_free_embedding.argtypes = [ctypes.POINTER(ctypes.c_float)]
ctypes.POINTER(ctypes.c_float)
]
llmodel.llmodel_free_embedding.restype = None llmodel.llmodel_free_embedding.restype = None
llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32] llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32]
@ -251,10 +247,7 @@ class LLModel:
self.context.repeat_last_n = repeat_last_n self.context.repeat_last_n = repeat_last_n
self.context.context_erase = context_erase self.context.context_erase = context_erase
def generate_embedding( def generate_embedding(self, text: str) -> List[float]:
self,
text: str
) -> List[float]:
if not text: if not text:
raise ValueError("Text must not be None or empty") raise ValueError("Text must not be None or empty")
@ -330,10 +323,7 @@ class LLModel:
def prompt_model_streaming( def prompt_model_streaming(
self, self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs
prompt: str,
callback: ResponseCallbackType = empty_response_callback,
**kwargs
) -> Iterable[str]: ) -> Iterable[str]:
# Symbol to terminate from generator # Symbol to terminate from generator
TERMINATING_SYMBOL = object() TERMINATING_SYMBOL = object()
@ -361,10 +351,7 @@ class LLModel:
# immediately # immediately
thread = threading.Thread( thread = threading.Thread(
target=run_llmodel_prompt, target=run_llmodel_prompt,
args=( args=(prompt, _generator_callback_wrapper(callback)),
prompt,
_generator_callback_wrapper(callback)
),
kwargs=kwargs, kwargs=kwargs,
) )
thread.start() thread.start()

@ -1,8 +1,9 @@
import sys import sys
import time
from io import StringIO from io import StringIO
from gpt4all import GPT4All, Embed4All from gpt4all import Embed4All, GPT4All
import time
def time_embedding(i, embedder): def time_embedding(i, embedder):
text = 'foo bar ' * i text = 'foo bar ' * i
@ -12,6 +13,7 @@ def time_embedding(i, embedder):
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
print(f"Time report: {2 * i / elapsed_time} tokens/second with {2 * i} tokens taking {elapsed_time} seconds") print(f"Time report: {2 * i / elapsed_time} tokens/second with {2 * i} tokens taking {elapsed_time} seconds")
if __name__ == "__main__": if __name__ == "__main__":
embedder = Embed4All(n_threads=8) embedder = Embed4All(n_threads=8)
for i in [2**n for n in range(6, 14)]: for i in [2**n for n in range(6, 14)]:

@ -6,6 +6,7 @@ from gpt4all import GPT4All, Embed4All
import time import time
import pytest import pytest
def test_inference(): def test_inference():
model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin') model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin')
output_1 = model.generate('hello', top_k=1) output_1 = model.generate('hello', top_k=1)
@ -102,6 +103,7 @@ def test_inference_mpt():
assert isinstance(output, str) assert isinstance(output, str)
assert len(output) > 0 assert len(output) > 0
def test_embedding(): def test_embedding():
text = 'The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox' text = 'The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox'
embedder = Embed4All() embedder = Embed4All()
@ -110,6 +112,7 @@ def test_embedding():
#print(f'Value at index {i}: {value}') #print(f'Value at index {i}: {value}')
assert len(output) == 384 assert len(output) == 384
def test_empty_embedding(): def test_empty_embedding():
text = '' text = ''
embedder = Embed4All() embedder = Embed4All()

Loading…
Cancel
Save