python/embed4all: use gguf model, allow passing kwargs/overriding model

This commit is contained in:
Aaron Miller 2023-10-05 12:09:14 -07:00 committed by Adam Treat
parent 8bb6a6c201
commit a10f3aea5e

View File

@ -30,17 +30,14 @@ class Embed4All:
Python class that handles embeddings for GPT4All. Python class that handles embeddings for GPT4All.
""" """
def __init__( def __init__(self, model_name: Optional[str] = None, n_threads: Optional[int] = None, **kwargs):
self,
n_threads: Optional[int] = None,
):
""" """
Constructor Constructor
Args: Args:
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically. n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
""" """
self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=n_threads) self.gpt4all = GPT4All(model_name or 'ggml-all-MiniLM-L6-v2-f16.gguf', n_threads=n_threads, **kwargs)
def embed(self, text: str) -> List[float]: def embed(self, text: str) -> List[float]:
""" """
@ -315,7 +312,6 @@ class GPT4All:
callback: pyllmodel.ResponseCallbackType, callback: pyllmodel.ResponseCallbackType,
output_collector: List[MessageType], output_collector: List[MessageType],
) -> pyllmodel.ResponseCallbackType: ) -> pyllmodel.ResponseCallbackType:
def _callback(token_id: int, response: str) -> bool: def _callback(token_id: int, response: str) -> bool:
nonlocal callback, output_collector nonlocal callback, output_collector