diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index c6d5c9ba..6821bcf4 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -1,6 +1,8 @@ """ Python only API for running all GPT4All models. """ +from __future__ import annotations + import os import sys import time @@ -60,7 +62,7 @@ class GPT4All: def __init__( self, model_name: str, - model_path: Optional[str] = None, + model_path: Optional[Union[str, os.PathLike[str]]] = None, model_type: Optional[str] = None, allow_download: bool = True, n_threads: Optional[int] = None, @@ -115,7 +117,7 @@ class GPT4All: @staticmethod def retrieve_model( model_name: str, - model_path: Optional[str] = None, + model_path: Optional[Union[str, os.PathLike[str]]] = None, allow_download: bool = True, verbose: bool = True, ) -> ConfigType: @@ -160,7 +162,7 @@ class GPT4All: ) model_path = DEFAULT_MODEL_DIRECTORY else: - model_path = model_path.replace("\\", "\\\\") + model_path = str(model_path).replace("\\", "\\\\") if not os.path.exists(model_path): raise ValueError(f"Invalid model directory: {model_path}") @@ -185,7 +187,7 @@ class GPT4All: @staticmethod def download_model( model_filename: str, - model_path: str, + model_path: Union[str, os.PathLike[str]], verbose: bool = True, url: Optional[str] = None, ) -> str: