diff --git a/gpt4all-bindings/python/gpt4all/__init__.py b/gpt4all-bindings/python/gpt4all/__init__.py index 71480e93..f4dfa4bf 100644 --- a/gpt4all-bindings/python/gpt4all/__init__.py +++ b/gpt4all-bindings/python/gpt4all/__init__.py @@ -1,2 +1,2 @@ -from .gpt4all import GPT4All, Embed4All # noqa +from .gpt4all import Embed4All, GPT4All # noqa from .pyllmodel import LLModel # noqa diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 62af9503..e82e6ae4 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -5,7 +5,7 @@ import os import time from contextlib import contextmanager from pathlib import Path -from typing import Any, Dict, Iterable, List, Union, Optional +from typing import Any, Dict, Iterable, List, Optional, Union import requests from tqdm import tqdm @@ -13,22 +13,22 @@ from tqdm import tqdm from . import pyllmodel # TODO: move to config -DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace( - "\\", "\\\\" -) +DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\") DEFAULT_MODEL_CONFIG = { "systemPrompt": "", "promptTemplate": "### Human: \n{0}\n### Assistant:\n", } -ConfigType = Dict[str,str] +ConfigType = Dict[str, str] MessageType = Dict[str, str] + class Embed4All: """ Python class that handles embeddings for GPT4All. """ + def __init__( self, n_threads: Optional[int] = None, @@ -41,10 +41,7 @@ class Embed4All: """ self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=n_threads) - def embed( - self, - text: str - ) -> List[float]: + def embed(self, text: str) -> List[float]: """ Generate an embedding. @@ -56,6 +53,7 @@ class Embed4All: """ return self.gpt4all.model.generate_embedding(text) + class GPT4All: """ Python class that handles instantiation, downloading, generation and chat with GPT4All models. @@ -84,9 +82,7 @@ class GPT4All: self.model_type = model_type self.model = pyllmodel.LLModel() # Retrieve model and download if allowed - self.config: ConfigType = self.retrieve_model( - model_name, model_path=model_path, allow_download=allow_download - ) + self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download) self.model.load_model(self.config["path"]) # Set n_threads if n_threads is not None: @@ -170,9 +166,7 @@ class GPT4All: elif allow_download: url = config.pop("url", None) - config["path"] = GPT4All.download_model( - model_filename, model_path, verbose=verbose, url=url - ) + config["path"] = GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url) else: raise ValueError("Failed to retrieve model") @@ -281,19 +275,24 @@ class GPT4All: ) if self._is_chat_session_activated: - generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt + # check if there is only one message, i.e. system prompt: + generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 self.current_chat_session.append({"role": "user", "content": prompt}) prompt = self._format_chat_prompt_template( - messages = self.current_chat_session[-1:], - default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "", + messages=self.current_chat_session[-1:], + default_prompt_header=self.current_chat_session[0]["content"] + if generate_kwargs["reset_context"] + else "", ) else: generate_kwargs["reset_context"] = True # Prepare the callback, process the model response output_collector: List[MessageType] - output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated + output_collector = [ + {"content": ""} + ] # placeholder for the self.current_chat_session if chat session is not activated if self._is_chat_session_activated: self.current_chat_session.append({"role": "assistant", "content": ""}) diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 7bdcb194..6326c493 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -1,4 +1,5 @@ import ctypes +import logging import os import platform from queue import Queue @@ -6,8 +7,7 @@ import re import subprocess import sys import threading -import logging -from typing import Iterable, Callable, List +from typing import Callable, Iterable, List import pkg_resources @@ -16,9 +16,7 @@ logger: logging.Logger = logging.getLogger(__name__) # TODO: provide a config file to make this more robust LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\") -MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace( - "\\", "\\\\" -) +MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\") def load_llmodel_library(): @@ -113,9 +111,7 @@ llmodel.llmodel_embedding.argtypes = [ llmodel.llmodel_embedding.restype = ctypes.POINTER(ctypes.c_float) -llmodel.llmodel_free_embedding.argtypes = [ - ctypes.POINTER(ctypes.c_float) -] +llmodel.llmodel_free_embedding.argtypes = [ctypes.POINTER(ctypes.c_float)] llmodel.llmodel_free_embedding.restype = None llmodel.llmodel_setThreadCount.argtypes = [ctypes.c_void_p, ctypes.c_int32] @@ -251,10 +247,7 @@ class LLModel: self.context.repeat_last_n = repeat_last_n self.context.context_erase = context_erase - def generate_embedding( - self, - text: str - ) -> List[float]: + def generate_embedding(self, text: str) -> List[float]: if not text: raise ValueError("Text must not be None or empty") @@ -330,10 +323,7 @@ class LLModel: def prompt_model_streaming( - self, - prompt: str, - callback: ResponseCallbackType = empty_response_callback, - **kwargs + self, prompt: str, callback: ResponseCallbackType = empty_response_callback, **kwargs ) -> Iterable[str]: # Symbol to terminate from generator TERMINATING_SYMBOL = object() @@ -361,10 +351,7 @@ class LLModel: # immediately thread = threading.Thread( target=run_llmodel_prompt, - args=( - prompt, - _generator_callback_wrapper(callback) - ), + args=(prompt, _generator_callback_wrapper(callback)), kwargs=kwargs, ) thread.start() diff --git a/gpt4all-bindings/python/gpt4all/tests/test_embed_timings.py b/gpt4all-bindings/python/gpt4all/tests/test_embed_timings.py index 01b3f666..9121d4be 100644 --- a/gpt4all-bindings/python/gpt4all/tests/test_embed_timings.py +++ b/gpt4all-bindings/python/gpt4all/tests/test_embed_timings.py @@ -1,8 +1,9 @@ import sys +import time from io import StringIO -from gpt4all import GPT4All, Embed4All -import time +from gpt4all import Embed4All, GPT4All + def time_embedding(i, embedder): text = 'foo bar ' * i @@ -12,6 +13,7 @@ def time_embedding(i, embedder): elapsed_time = end_time - start_time print(f"Time report: {2 * i / elapsed_time} tokens/second with {2 * i} tokens taking {elapsed_time} seconds") + if __name__ == "__main__": embedder = Embed4All(n_threads=8) for i in [2**n for n in range(6, 14)]: diff --git a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py index 89e81086..74a3214d 100644 --- a/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/tests/test_gpt4all.py @@ -6,6 +6,7 @@ from gpt4all import GPT4All, Embed4All import time import pytest + def test_inference(): model = GPT4All(model_name='orca-mini-3b.ggmlv3.q4_0.bin') output_1 = model.generate('hello', top_k=1) @@ -102,6 +103,7 @@ def test_inference_mpt(): assert isinstance(output, str) assert len(output) > 0 + def test_embedding(): text = 'The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox jumps over the lazy dog The quick brown fox' embedder = Embed4All() @@ -110,6 +112,7 @@ def test_embedding(): #print(f'Value at index {i}: {value}') assert len(output) == 384 + def test_empty_embedding(): text = '' embedder = Embed4All()