Do not ignore explicitly passed 4 threads

This commit is contained in:
Konstantin Gukov 2023-05-25 17:53:39 +02:00 committed by Richard Guo
parent dcbdd369ad
commit 100c809f1e

View File

@ -51,12 +51,12 @@ def repl(
n_threads: Annotated[ n_threads: Annotated[
int, int,
typer.Option("--n-threads", "-t", help="Number of threads to use for chatbot"), typer.Option("--n-threads", "-t", help="Number of threads to use for chatbot"),
] = 4, ] = None,
): ):
gpt4all_instance = GPT4All(model) gpt4all_instance = GPT4All(model)
# if threads are passed, set them # if threads are passed, set them
if n_threads != 4: if n_threads is not None:
num_threads = gpt4all_instance.model.thread_count() num_threads = gpt4all_instance.model.thread_count()
print(f"\nAdjusted: {num_threads}", end="") print(f"\nAdjusted: {num_threads}", end="")
@ -65,7 +65,8 @@ def repl(
num_threads = gpt4all_instance.model.thread_count() num_threads = gpt4all_instance.model.thread_count()
print(f" {num_threads} threads", end="", flush=True) print(f" {num_threads} threads", end="", flush=True)
else:
print(f"\nUsing {gpt4all_instance.model.thread_count()} threads", end="")
# overwrite _response_callback on model # overwrite _response_callback on model
gpt4all_instance.model._response_callback = _cli_override_response_callback gpt4all_instance.model._response_callback = _cli_override_response_callback