#ifndef GPTJ_H #define GPTJ_H #include #include #include #include "llmodel.h" #include "tokenizer/bpe.h" class GPTJPrivate; class GPTJ : public LLModel { public: GPTJ(); ~GPTJ(); bool loadModel(const std::string &modelPath) override; bool isModelLoaded() const override; size_t stateSize() const override; size_t saveState(uint8_t *dest) const override; size_t restoreState(const uint8_t *src) override; void prompt(const std::string &prompt, std::function promptCallback, std::function responseCallback, std::function recalculateCallback, PromptContext &ctx) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; protected: void recalculateContext(PromptContext &promptCtx, std::function recalculate) override; private: GPTJPrivate *d_ptr; std::unique_ptr m_tokav; std::unique_ptr m_bpe; }; #endif // GPTJ_H