From 5fbeeb1cb46eec4fdea1fb77b1dc2624d84eacfc Mon Sep 17 00:00:00 2001 From: cebtenzzre Date: Thu, 19 Oct 2023 12:06:38 -0400 Subject: [PATCH] python: connection resume and MSVC support (#1535) --- gpt4all-bindings/python/gpt4all/gpt4all.py | 64 ++++++++++++++------ gpt4all-bindings/python/gpt4all/pyllmodel.py | 30 ++++----- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/gpt4all.py b/gpt4all-bindings/python/gpt4all/gpt4all.py index e02c20ae..5bd7541a 100644 --- a/gpt4all-bindings/python/gpt4all/gpt4all.py +++ b/gpt4all-bindings/python/gpt4all/gpt4all.py @@ -11,7 +11,9 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Union import requests +from requests.exceptions import ChunkedEncodingError from tqdm import tqdm +from urllib3.exceptions import IncompleteRead, ProtocolError from . import pyllmodel @@ -217,35 +219,61 @@ class GPT4All: download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\") download_url = get_download_url(model_filename) - response = requests.get(download_url, stream=True) - if response.status_code != 200: - raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}') + def make_request(offset=None): + headers = {} + if offset: + print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr) + headers['Range'] = f'bytes={offset}-' # resume incomplete response + response = requests.get(download_url, stream=True, headers=headers) + if response.status_code not in (200, 206): + raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}') + if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')): + raise ValueError('Connection was interrupted and server does not support range requests') + return response + + response = make_request() total_size_in_bytes = int(response.headers.get("content-length", 0)) block_size = 2**20 # 1 MB - with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar: + with open(download_path, "wb") as file, \ + tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar: try: - with open(download_path, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) + while True: + last_progress = progress_bar.n + try: + for data in response.iter_content(block_size): + file.write(data) + progress_bar.update(len(data)) + except ChunkedEncodingError as cee: + if cee.args and isinstance(pe := cee.args[0], ProtocolError): + if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead): + assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind + # the socket was closed during a read - retry + response = make_request(progress_bar.n) + continue + raise + if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes: + if progress_bar.n == last_progress: + raise RuntimeError('Download not making progress, aborting.') + # server closed connection prematurely - retry + response = make_request(progress_bar.n) + continue + break except Exception: - if os.path.exists(download_path): - if verbose: - print("Cleaning up the interrupted download...") + if verbose: + print("Cleaning up the interrupted download...", file=sys.stderr) + try: os.remove(download_path) + except OSError: + pass raise - # Validate download was successful - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - raise RuntimeError("An error occurred during download. Downloaded file may not work.") - - # Sleep for a little bit so OS can remove file lock - time.sleep(2) + if os.name == 'nt': + time.sleep(2) # Sleep for a little bit so Windows can remove file lock if verbose: - print("Model downloaded at: ", download_path) + print("Model downloaded at:", download_path, file=sys.stderr) return download_path def generate( diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 1018f65d..47cc5160 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -23,28 +23,20 @@ MODEL_LIB_PATH = file_manager.enter_context(importlib.resources.as_file( importlib.resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build", )) + def load_llmodel_library(): - system = platform.system() + ext = {"Darwin": "dylib", "Linux": "so", "Windows": "dll"}[platform.system()] - def get_c_shared_lib_extension(): - if system == "Darwin": - return "dylib" - elif system == "Linux": - return "so" - elif system == "Windows": - return "dll" - else: - raise Exception("Operating System not supported") + try: + # Linux, Windows, MinGW + lib = ctypes.CDLL(str(MODEL_LIB_PATH / f"libllmodel.{ext}")) + except FileNotFoundError: + if ext != 'dll': + raise + # MSVC + lib = ctypes.CDLL(str(MODEL_LIB_PATH / "llmodel.dll")) - c_lib_ext = get_c_shared_lib_extension() - - llmodel_file = "libllmodel" + "." + c_lib_ext - - llmodel_dir = str(MODEL_LIB_PATH / llmodel_file).replace("\\", r"\\") - - llmodel_lib = ctypes.CDLL(llmodel_dir) - - return llmodel_lib + return lib llmodel = load_llmodel_library()