From 3ed6d176a5597bb41ec05ad2da2f4ad409f6533c Mon Sep 17 00:00:00 2001 From: 385olt <385olt@gmail.com> Date: Sun, 30 Jul 2023 20:29:51 +0200 Subject: [PATCH] Python bindings: unicode decoding (#1281) * rewrote the unicode decoding using the structure of multi-byte unicode symbols. --- gpt4all-bindings/python/gpt4all/pyllmodel.py | 49 ++++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 14f35626..7bdcb194 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -1,7 +1,7 @@ import ctypes import os import platform -import queue +from queue import Queue import re import subprocess import sys @@ -157,6 +157,9 @@ class LLModel: self.context = None self.llmodel_lib = llmodel + self.buffer = bytearray() + self.buff_expecting_cont_bytes: int = 0 + def __del__(self): if self.model is not None: self.llmodel_lib.llmodel_model_destroy(self.model) @@ -291,6 +294,9 @@ class LLModel: None """ + self.buffer.clear() + self.buff_expecting_cont_bytes = 0 + logger.info( "LLModel.prompt_model -- prompt:\n" + "%s\n" @@ -322,6 +328,7 @@ class LLModel: self.context, ) + def prompt_model_streaming( self, prompt: str, @@ -331,7 +338,7 @@ class LLModel: # Symbol to terminate from generator TERMINATING_SYMBOL = object() - output_queue = queue.Queue() + output_queue: Queue = Queue() # Put response tokens into an output queue def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType: @@ -371,8 +378,42 @@ class LLModel: def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType: def _raw_callback(token_id: int, response: bytes) -> bool: - nonlocal callback - return callback(token_id, response.decode("utf-8", "replace")) + nonlocal self, callback + + decoded = [] + + for byte in response: + + bits = "{:08b}".format(byte) + (high_ones, _, _) = bits.partition('0') + + if len(high_ones) == 1: + # continuation byte + self.buffer.append(byte) + self.buff_expecting_cont_bytes -= 1 + + else: + # beginning of a byte sequence + if len(self.buffer) > 0: + decoded.append(self.buffer.decode('utf-8', 'replace')) + + self.buffer.clear() + + self.buffer.append(byte) + self.buff_expecting_cont_bytes = max(0, len(high_ones) - 1) + + if self.buff_expecting_cont_bytes <= 0: + # received the whole sequence or an out of place continuation byte + decoded.append(self.buffer.decode('utf-8', 'replace')) + + self.buffer.clear() + self.buff_expecting_cont_bytes = 0 + + if len(decoded) == 0 and self.buff_expecting_cont_bytes > 0: + # wait for more continuation bytes + return True + + return callback(token_id, ''.join(decoded)) return _raw_callback