diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index 3fa9f9f8..e1a2638f 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -123,7 +123,7 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) { } void llmodel_prompt(llmodel_model model, const char *prompt, - llmodel_response_callback prompt_callback, + llmodel_prompt_callback prompt_callback, llmodel_response_callback response_callback, llmodel_recalculate_callback recalculate_callback, llmodel_prompt_context *ctx) diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index 9a3c52a0..0b6972b7 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -162,7 +162,7 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src); * @param ctx A pointer to the llmodel_prompt_context structure. */ void llmodel_prompt(llmodel_model model, const char *prompt, - llmodel_response_callback prompt_callback, + llmodel_prompt_callback prompt_callback, llmodel_response_callback response_callback, llmodel_recalculate_callback recalculate_callback, llmodel_prompt_context *ctx); diff --git a/gpt4all-bindings/python/gpt4all/pyllmodel.py b/gpt4all-bindings/python/gpt4all/pyllmodel.py index b6737fd1..6117c9fa 100644 --- a/gpt4all-bindings/python/gpt4all/pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/pyllmodel.py @@ -84,12 +84,13 @@ class LLModelPromptContext(ctypes.Structure): ("repeat_last_n", ctypes.c_int32), ("context_erase", ctypes.c_float)] +PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32) ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool) llmodel.llmodel_prompt.argtypes = [ctypes.c_void_p, ctypes.c_char_p, - ResponseCallback, + PromptCallback, ResponseCallback, RecalculateCallback, ctypes.POINTER(LLModelPromptContext)] @@ -218,7 +219,7 @@ class LLModel: llmodel.llmodel_prompt(self.model, prompt, - ResponseCallback(self._prompt_callback), + PromptCallback(self._prompt_callback), ResponseCallback(self._response_callback), RecalculateCallback(self._recalculate_callback), context) @@ -232,7 +233,7 @@ class LLModel: # Empty prompt callback @staticmethod - def _prompt_callback(token_id, response): + def _prompt_callback(token_id): return True # Empty response callback method that just prints response to be collected