Python bindings: Custom callbacks, chat session improvement, refactoring (#1145)

* Added the following features: \n 1) Now prompt_model uses the positional argument callback to return the response tokens. \n 2) Due to the callback argument of prompt_model, prompt_model_streaming only manages the queue and threading now, which reduces duplication of the code. \n 3) Added optional verbose argument to prompt_model which prints out the prompt that is passed to the model. \n 4) Chat sessions can now have a header, i.e. an instruction before the transcript of the conversation. The header is set at the creation of the chat session context. \n 5) generate function now accepts an optional callback. \n 6) When streaming and using chat session, the user doesn't need to save assistant's messages by himself. This is done automatically.

* added _empty_response_callback so I don't have to check if callback is None

* added docs

* now if the callback stop generation, the last token is ignored

* fixed type hints, reimplemented chat session header as a system prompt, minor refactoring, docs: removed section about manual update of chat session for streaming

* forgot to add some type hints!

* keep the config of the model in GPT4All class which is taken from models.json if the download is allowed

* During chat sessions, the model-specific systemPrompt and promptTemplate are applied.

* implemented the changes

* Fixed typing. Now the user can set a prompt template that will be applied even outside of a chat session. The template can also have multiple placeholders that can be filled by passing a dictionary to the generate function

* reversed some changes concerning the prompt templates and their functionality

* fixed some type hints, changed list[float] to List[Float]

* fixed type hints, changed List[Float] to List[float]

* fix typo in the comment: Pepare => Prepare

---------

Signed-off-by: 385olt <385olt@gmail.com>
pull/1232/head
385olt 1 year ago committed by GitHub
parent 5f0aaf8bdb
commit b4dbbd1485
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -91,22 +91,4 @@ To interact with GPT4All responses as the model generates, use the `streaming =
[' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0']
```
#### Streaming and Chat Sessions
When streaming tokens in a chat session, you must manually handle collection and updating of the chat history.
```python
from gpt4all import GPT4All
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
with model.chat_session():
tokens = list(model.generate(prompt='hello', top_k=1, streaming=True))
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True))
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
print(model.current_chat_session)
```
### API documentation
::: gpt4all.gpt4all.GPT4All

@ -5,7 +5,7 @@ import os
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterable, List, Union, Optional
from typing import Any, Dict, Iterable, List, Union, Optional
import requests
from tqdm import tqdm
@ -13,7 +13,17 @@ from tqdm import tqdm
from . import pyllmodel
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace(
"\\", "\\\\"
)
DEFAULT_MODEL_CONFIG = {
"systemPrompt": "",
"promptTemplate": "### Human: \n{0}\n### Assistant:\n",
}
ConfigType = Dict[str,str]
MessageType = Dict[str, str]
class Embed4All:
"""
@ -34,7 +44,7 @@ class Embed4All:
def embed(
self,
text: str
) -> list[float]:
) -> List[float]:
"""
Generate an embedding.
@ -74,17 +84,20 @@ class GPT4All:
self.model_type = model_type
self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
self.model.load_model(model_dest)
self.config: ConfigType = self.retrieve_model(
model_name, model_path=model_path, allow_download=allow_download
)
self.model.load_model(self.config["path"])
# Set n_threads
if n_threads is not None:
self.model.set_thread_count(n_threads)
self._is_chat_session_activated = False
self.current_chat_session = []
self._is_chat_session_activated: bool = False
self.current_chat_session: List[MessageType] = empty_chat_session()
self._current_prompt_template: str = "{0}"
@staticmethod
def list_models() -> Dict:
def list_models() -> List[ConfigType]:
"""
Fetch model list from https://gpt4all.io/models/models.json.
@ -95,8 +108,11 @@ class GPT4All:
@staticmethod
def retrieve_model(
model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True
) -> str:
model_name: str,
model_path: Optional[str] = None,
allow_download: bool = True,
verbose: bool = True,
) -> ConfigType:
"""
Find model file, and if it doesn't exist, download the model.
@ -108,11 +124,25 @@ class GPT4All:
verbose: If True (default), print debug messages.
Returns:
Model file destination.
Model config.
"""
model_filename = append_bin_suffix_if_missing(model_name)
# get the config for the model
config: ConfigType = DEFAULT_MODEL_CONFIG
if allow_download:
available_models = GPT4All.list_models()
for m in available_models:
if model_filename == m["filename"]:
config.update(m)
config["systemPrompt"] = config["systemPrompt"].strip()
config["promptTemplate"] = config["promptTemplate"].replace(
"%1", "{0}", 1
) # change to Python-style formatting
break
# Validate download directory
if model_path is None:
try:
@ -131,31 +161,34 @@ class GPT4All:
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
if os.path.exists(model_dest):
config.pop("url", None)
config["path"] = model_dest
if verbose:
print("Found model file at ", model_dest)
return model_dest
# If model file does not exist, download
elif allow_download:
# Make sure valid model filename before attempting download
available_models = GPT4All.list_models()
selected_model = None
for m in available_models:
if model_filename == m['filename']:
selected_model = m
break
if selected_model is None:
if "url" not in config:
raise ValueError(f"Model filename not in model list: {model_filename}")
url = selected_model.pop('url', None)
url = config.pop("url", None)
return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
config["path"] = GPT4All.download_model(
model_filename, model_path, verbose=verbose, url=url
)
else:
raise ValueError("Failed to retrieve model")
return config
@staticmethod
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str:
def download_model(
model_filename: str,
model_path: str,
verbose: bool = True,
url: Optional[str] = None,
) -> str:
"""
Download model from https://gpt4all.io.
@ -191,7 +224,7 @@ class GPT4All:
except Exception:
if os.path.exists(download_path):
if verbose:
print('Cleaning up the interrupted download...')
print("Cleaning up the interrupted download...")
os.remove(download_path)
raise
@ -218,7 +251,8 @@ class GPT4All:
n_batch: int = 8,
n_predict: Optional[int] = None,
streaming: bool = False,
) -> Union[str, Iterable]:
callback: pyllmodel.ResponseCallbackType = pyllmodel.empty_response_callback,
) -> Union[str, Iterable[str]]:
"""
Generate outputs from any GPT4All model.
@ -233,12 +267,14 @@ class GPT4All:
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
n_predict: Equivalent to max_tokens, exists for backwards compatibility.
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False.
Returns:
Either the entire completion or a generator that yields the completion token by token.
"""
generate_kwargs = dict(
prompt=prompt,
# Preparing the model request
generate_kwargs: Dict[str, Any] = dict(
temp=temp,
top_k=top_k,
top_p=top_p,
@ -249,42 +285,87 @@ class GPT4All:
)
if self._is_chat_session_activated:
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt
self.current_chat_session.append({"role": "user", "content": prompt})
generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session[-1:])
generate_kwargs['reset_context'] = len(self.current_chat_session) == 1
else:
generate_kwargs['reset_context'] = True
if streaming:
return self.model.prompt_model_streaming(**generate_kwargs)
prompt = self._format_chat_prompt_template(
messages = self.current_chat_session[-1:],
default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "",
)
else:
generate_kwargs["reset_context"] = True
output = self.model.prompt_model(**generate_kwargs)
# Prepare the callback, process the model response
output_collector: List[MessageType]
output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated
if self._is_chat_session_activated:
self.current_chat_session.append({"role": "assistant", "content": output})
self.current_chat_session.append({"role": "assistant", "content": ""})
output_collector = self.current_chat_session
def _callback_wrapper(
callback: pyllmodel.ResponseCallbackType,
output_collector: List[MessageType],
) -> pyllmodel.ResponseCallbackType:
return output
def _callback(token_id: int, response: str) -> bool:
nonlocal callback, output_collector
output_collector[-1]["content"] += response
return callback(token_id, response)
return _callback
# Send the request to the model
if streaming:
return self.model.prompt_model_streaming(
prompt=prompt,
callback=_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
self.model.prompt_model(
prompt=prompt,
callback=_callback_wrapper(callback, output_collector),
**generate_kwargs,
)
return output_collector[-1]["content"]
@contextmanager
def chat_session(self):
'''
def chat_session(
self,
system_prompt: str = "",
prompt_template: str = "",
):
"""
Context manager to hold an inference optimized chat session with a GPT4All model.
'''
Args:
system_prompt: An initial instruction for the model.
prompt_template: Template for the prompts with {0} being replaced by the user message.
"""
# Code to acquire resource, e.g.:
self._is_chat_session_activated = True
self.current_chat_session = []
self.current_chat_session = empty_chat_session(system_prompt or self.config["systemPrompt"])
self._current_prompt_template = prompt_template or self.config["promptTemplate"]
try:
yield self
finally:
# Code to release resource, e.g.:
self._is_chat_session_activated = False
self.current_chat_session = []
self.current_chat_session = empty_chat_session()
self._current_prompt_template = "{0}"
def _format_chat_prompt_template(
self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True
self,
messages: List[MessageType],
default_prompt_header: str = "",
default_prompt_footer: str = "",
) -> str:
"""
Helper method for building a prompt using template from list of messages.
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
Args:
messages: List of dictionaries. Each dictionary should have a "role" key
@ -296,19 +377,44 @@ class GPT4All:
Returns:
Formatted prompt.
"""
full_prompt = ""
if isinstance(default_prompt_header, bool):
import warnings
warnings.warn(
"Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.",
DeprecationWarning,
)
default_prompt_header = ""
if isinstance(default_prompt_footer, bool):
import warnings
warnings.warn(
"Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.",
DeprecationWarning,
)
default_prompt_footer = ""
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
for message in messages:
if message["role"] == "user":
user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n"
user_message = self._current_prompt_template.format(message["content"])
full_prompt += user_message
if message["role"] == "assistant":
assistant_message = message["content"] + '\n'
assistant_message = message["content"] + "\n"
full_prompt += assistant_message
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
return full_prompt
def empty_chat_session(system_prompt: str = "") -> List[MessageType]:
return [{"role": "system", "content": system_prompt}]
def append_bin_suffix_if_missing(model_name):
if not model_name.endswith(".bin"):
model_name += ".bin"

@ -6,26 +6,19 @@ import re
import subprocess
import sys
import threading
from typing import Iterable
import logging
from typing import Iterable, Callable, List
import pkg_resources
class DualStreamProcessor:
def __init__(self, stream=None):
self.stream = stream
self.output = ""
def write(self, text):
if self.stream is not None:
self.stream.write(text)
self.stream.flush()
self.output += text
logger: logging.Logger = logging.getLogger(__name__)
# TODO: provide a config file to make this more robust
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\")
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace(
"\\", "\\\\"
)
def load_llmodel_library():
@ -43,9 +36,9 @@ def load_llmodel_library():
c_lib_ext = get_c_shared_lib_extension()
llmodel_file = "libllmodel" + '.' + c_lib_ext
llmodel_file = "libllmodel" + "." + c_lib_ext
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace(
llmodel_dir = str(pkg_resources.resource_filename("gpt4all", os.path.join(LLMODEL_PATH, llmodel_file))).replace(
"\\", "\\\\"
)
@ -134,7 +127,15 @@ llmodel.llmodel_set_implementation_search_path.restype = None
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
llmodel.llmodel_threadCount.restype = ctypes.c_int32
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode('utf-8'))
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode("utf-8"))
ResponseCallbackType = Callable[[int, str], bool]
RawResponseCallbackType = Callable[[int, bytes], bool]
def empty_response_callback(token_id: int, response: str) -> bool:
return True
class LLModel:
@ -250,9 +251,10 @@ class LLModel:
def generate_embedding(
self,
text: str
) -> list[float]:
) -> List[float]:
if not text:
raise ValueError("Text must not be None or empty")
embedding_size = ctypes.c_size_t()
c_text = ctypes.c_char_p(text.encode('utf-8'))
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
@ -263,6 +265,7 @@ class LLModel:
def prompt_model(
self,
prompt: str,
callback: ResponseCallbackType,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
@ -272,8 +275,7 @@ class LLModel:
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
streaming=False,
) -> str:
):
"""
Generate response from model from a prompt.
@ -281,25 +283,23 @@ class LLModel:
----------
prompt: str
Question, task, or conversation for model to respond to
streaming: bool
Stream response to stdout
callback(token_id:int, response:str): bool
The model sends response tokens to callback
Returns
-------
Model response str
None
"""
prompt_bytes = prompt.encode('utf-8')
prompt_ptr = ctypes.c_char_p(prompt_bytes)
old_stdout = sys.stdout
stream_processor = DualStreamProcessor()
if streaming:
stream_processor.stream = sys.stdout
logger.info(
"LLModel.prompt_model -- prompt:\n"
+ "%s\n"
+ "===/LLModel.prompt_model -- prompt/===",
prompt,
)
sys.stdout = stream_processor
prompt_bytes = prompt.encode("utf-8")
prompt_ptr = ctypes.c_char_p(prompt_bytes)
self._set_context(
n_predict=n_predict,
@ -317,56 +317,37 @@ class LLModel:
self.model,
prompt_ptr,
PromptCallback(self._prompt_callback),
ResponseCallback(self._response_callback),
ResponseCallback(self._callback_decoder(callback)),
RecalculateCallback(self._recalculate_callback),
self.context,
)
# Revert to old stdout
sys.stdout = old_stdout
# Force new line
return stream_processor.output
def prompt_model_streaming(
self,
prompt: str,
n_predict: int = 4096,
top_k: int = 40,
top_p: float = 0.9,
temp: float = 0.1,
n_batch: int = 8,
repeat_penalty: float = 1.2,
repeat_last_n: int = 10,
context_erase: float = 0.75,
reset_context: bool = False,
) -> Iterable:
callback: ResponseCallbackType = empty_response_callback,
**kwargs
) -> Iterable[str]:
# Symbol to terminate from generator
TERMINATING_SYMBOL = object()
output_queue = queue.Queue()
prompt_bytes = prompt.encode('utf-8')
prompt_ptr = ctypes.c_char_p(prompt_bytes)
# Put response tokens into an output queue
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
def _generator_callback(token_id: int, response: str):
nonlocal callback
self._set_context(
n_predict=n_predict,
top_k=top_k,
top_p=top_p,
temp=temp,
n_batch=n_batch,
repeat_penalty=repeat_penalty,
repeat_last_n=repeat_last_n,
context_erase=context_erase,
reset_context=reset_context,
)
if callback(token_id, response):
output_queue.put(response)
return True
# Put response tokens into an output queue
def _generator_response_callback(token_id, response):
output_queue.put(response.decode('utf-8', 'replace'))
return True
return False
return _generator_callback
def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context):
llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context)
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
self.prompt_model(prompt, callback, **kwargs)
output_queue.put(TERMINATING_SYMBOL)
# Kick off llmodel_prompt in separate thread so we can return generator
@ -374,13 +355,10 @@ class LLModel:
thread = threading.Thread(
target=run_llmodel_prompt,
args=(
self.model,
prompt_ptr,
PromptCallback(self._prompt_callback),
ResponseCallback(_generator_response_callback),
RecalculateCallback(self._recalculate_callback),
self.context,
prompt,
_generator_callback_wrapper(callback)
),
kwargs=kwargs,
)
thread.start()
@ -391,18 +369,19 @@ class LLModel:
break
yield response
# Empty prompt callback
@staticmethod
def _prompt_callback(token_id):
return True
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
def _raw_callback(token_id: int, response: bytes) -> bool:
nonlocal callback
return callback(token_id, response.decode("utf-8", "replace"))
# Empty response callback method that just prints response to be collected
return _raw_callback
# Empty prompt callback
@staticmethod
def _response_callback(token_id, response):
sys.stdout.write(response.decode('utf-8', 'replace'))
def _prompt_callback(token_id: int) -> bool:
return True
# Empty recalculate callback
@staticmethod
def _recalculate_callback(is_recalculating):
def _recalculate_callback(is_recalculating: bool) -> bool:
return is_recalculating

Loading…
Cancel
Save