From a10f3aea5e2a26ce35195093a712b1addced7d6f Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Thu, 5 Oct 2023 12:09:14 -0700 Subject: [PATCH] python/embed4all: use gguf model, allow passing kwargs/overriding model --- gpt4all-bindings/python/gpt4all/gpt4all.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index 9aa5794a..c6d5c9ba 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -30,17 +30,14 @@ class Embed4All: Python class that handles embeddings for GPT4All. """ - def __init__( - self, - n_threads: Optional[int] = None, - ): + def __init__(self, model_name: Optional[str] = None, n_threads: Optional[int] = None, **kwargs): """ Constructor Args: n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically. """ - self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=n_threads) + self.gpt4all = GPT4All(model_name or 'ggml-all-MiniLM-L6-v2-f16.gguf', n_threads=n_threads, **kwargs) def embed(self, text: str) -> List[float]: """ @@ -315,7 +312,6 @@ class GPT4All: callback: pyllmodel.ResponseCallbackType, output_collector: List[MessageType], ) -> pyllmodel.ResponseCallbackType: - def _callback(token_id: int, response: str) -> bool: nonlocal callback, output_collector