python: connection resume and MSVC support (#1535)

This commit is contained in:
cebtenzzre 2023-10-19 12:06:38 -04:00 committed by GitHub
parent 017c3a9649
commit 5fbeeb1cb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 37 deletions

View File

@ -11,7 +11,9 @@ from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Union from typing import Any, Dict, Iterable, List, Optional, Union
import requests import requests
from requests.exceptions import ChunkedEncodingError
from tqdm import tqdm from tqdm import tqdm
from urllib3.exceptions import IncompleteRead, ProtocolError
from . import pyllmodel from . import pyllmodel
@ -217,35 +219,61 @@ class GPT4All:
download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\") download_path = os.path.join(model_path, model_filename).replace("\\", "\\\\")
download_url = get_download_url(model_filename) download_url = get_download_url(model_filename)
response = requests.get(download_url, stream=True) def make_request(offset=None):
if response.status_code != 200: headers = {}
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}') if offset:
print(f"\nDownload interrupted, resuming from byte position {offset}", file=sys.stderr)
headers['Range'] = f'bytes={offset}-' # resume incomplete response
response = requests.get(download_url, stream=True, headers=headers)
if response.status_code not in (200, 206):
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
if offset and (response.status_code != 206 or str(offset) not in response.headers.get('Content-Range', '')):
raise ValueError('Connection was interrupted and server does not support range requests')
return response
response = make_request()
total_size_in_bytes = int(response.headers.get("content-length", 0)) total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 2**20 # 1 MB block_size = 2**20 # 1 MB
with tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar: with open(download_path, "wb") as file, \
tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) as progress_bar:
try: try:
with open(download_path, "wb") as file: while True:
for data in response.iter_content(block_size): last_progress = progress_bar.n
progress_bar.update(len(data)) try:
file.write(data) for data in response.iter_content(block_size):
file.write(data)
progress_bar.update(len(data))
except ChunkedEncodingError as cee:
if cee.args and isinstance(pe := cee.args[0], ProtocolError):
if len(pe.args) >= 2 and isinstance(ir := pe.args[1], IncompleteRead):
assert progress_bar.n <= ir.partial # urllib3 may be ahead of us but never behind
# the socket was closed during a read - retry
response = make_request(progress_bar.n)
continue
raise
if total_size_in_bytes != 0 and progress_bar.n < total_size_in_bytes:
if progress_bar.n == last_progress:
raise RuntimeError('Download not making progress, aborting.')
# server closed connection prematurely - retry
response = make_request(progress_bar.n)
continue
break
except Exception: except Exception:
if os.path.exists(download_path): if verbose:
if verbose: print("Cleaning up the interrupted download...", file=sys.stderr)
print("Cleaning up the interrupted download...") try:
os.remove(download_path) os.remove(download_path)
except OSError:
pass
raise raise
# Validate download was successful if os.name == 'nt':
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: time.sleep(2) # Sleep for a little bit so Windows can remove file lock
raise RuntimeError("An error occurred during download. Downloaded file may not work.")
# Sleep for a little bit so OS can remove file lock
time.sleep(2)
if verbose: if verbose:
print("Model downloaded at: ", download_path) print("Model downloaded at:", download_path, file=sys.stderr)
return download_path return download_path
def generate( def generate(

View File

@ -23,28 +23,20 @@ MODEL_LIB_PATH = file_manager.enter_context(importlib.resources.as_file(
importlib.resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build", importlib.resources.files("gpt4all") / "llmodel_DO_NOT_MODIFY" / "build",
)) ))
def load_llmodel_library(): def load_llmodel_library():
system = platform.system() ext = {"Darwin": "dylib", "Linux": "so", "Windows": "dll"}[platform.system()]
def get_c_shared_lib_extension(): try:
if system == "Darwin": # Linux, Windows, MinGW
return "dylib" lib = ctypes.CDLL(str(MODEL_LIB_PATH / f"libllmodel.{ext}"))
elif system == "Linux": except FileNotFoundError:
return "so" if ext != 'dll':
elif system == "Windows": raise
return "dll" # MSVC
else: lib = ctypes.CDLL(str(MODEL_LIB_PATH / "llmodel.dll"))
raise Exception("Operating System not supported")
c_lib_ext = get_c_shared_lib_extension() return lib
llmodel_file = "libllmodel" + "." + c_lib_ext
llmodel_dir = str(MODEL_LIB_PATH / llmodel_file).replace("\\", r"\\")
llmodel_lib = ctypes.CDLL(llmodel_dir)
return llmodel_lib
llmodel = load_llmodel_library() llmodel = load_llmodel_library()