From a99cc34efb0bbdaf1a1ecc299bad8106a2731d31 Mon Sep 17 00:00:00 2001 From: Richard Guo Date: Mon, 12 Jun 2023 22:38:50 -0400 Subject: [PATCH] fix prompt context so it's preserved in class --- gpt4all-bindings/python/gpt4all/pyllmodel.py | 64 +++++++++++--------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index 0b24ac86..820122c1 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -125,6 +125,7 @@ class LLModel: def __init__(self): self.model = None self.model_name = None + self.context = None def __del__(self): if self.model is not None: @@ -211,27 +212,29 @@ class LLModel: sys.stdout = stream_processor - context = LLModelPromptContext( - logits_size=logits_size, - tokens_size=tokens_size, - n_past=n_past, - n_ctx=n_ctx, - n_predict=n_predict, - top_k=top_k, - top_p=top_p, - temp=temp, - n_batch=n_batch, - repeat_penalty=repeat_penalty, - repeat_last_n=repeat_last_n, - context_erase=context_erase - ) + + if self.context is None: + self.context = LLModelPromptContext( + logits_size=logits_size, + tokens_size=tokens_size, + n_past=n_past, + n_ctx=n_ctx, + n_predict=n_predict, + top_k=top_k, + top_p=top_p, + temp=temp, + n_batch=n_batch, + repeat_penalty=repeat_penalty, + repeat_last_n=repeat_last_n, + context_erase=context_erase + ) llmodel.llmodel_prompt(self.model, prompt, PromptCallback(self._prompt_callback), ResponseCallback(self._response_callback), RecalculateCallback(self._recalculate_callback), - context) + self.context) # Revert to old stdout sys.stdout = old_stdout @@ -262,20 +265,21 @@ class LLModel: prompt = prompt.encode('utf-8') prompt = ctypes.c_char_p(prompt) - context = LLModelPromptContext( - logits_size=logits_size, - tokens_size=tokens_size, - n_past=n_past, - n_ctx=n_ctx, - n_predict=n_predict, - top_k=top_k, - top_p=top_p, - temp=temp, - n_batch=n_batch, - repeat_penalty=repeat_penalty, - repeat_last_n=repeat_last_n, - context_erase=context_erase - ) + if self.context is None: + self.context = LLModelPromptContext( + logits_size=logits_size, + tokens_size=tokens_size, + n_past=n_past, + n_ctx=n_ctx, + n_predict=n_predict, + top_k=top_k, + top_p=top_p, + temp=temp, + n_batch=n_batch, + repeat_penalty=repeat_penalty, + repeat_last_n=repeat_last_n, + context_erase=context_erase + ) # Put response tokens into an output queue def _generator_response_callback(token_id, response): @@ -305,7 +309,7 @@ class LLModel: PromptCallback(self._prompt_callback), ResponseCallback(_generator_response_callback), RecalculateCallback(self._recalculate_callback), - context)) + self.context)) thread.start() # Generator