From 7bf5b0ccd36a72395ac16ebafdfb3453d57c6e9d Mon Sep 17 00:00:00 2001 From: Alex Rad <37549748+lts-rad@users.noreply.github.com> Date: Sat, 8 Apr 2023 08:36:16 -0700 Subject: [PATCH] RWKV: do not propagate model_state between calls (#2565) RWKV is an RNN with a hidden state that is part of its inference. However, the model state should not be carried across uses and it's a bug to do so. This resets the state for multiple invocations --- langchain/llms/rwkv.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/langchain/llms/rwkv.py b/langchain/llms/rwkv.py index ee31e31db8..f4294f1c53 100644 --- a/langchain/llms/rwkv.py +++ b/langchain/llms/rwkv.py @@ -67,8 +67,6 @@ class RWKV(LLM, BaseModel): pipeline: Any = None #: :meta private: - model_state: Any = None #: :meta private: - model_tokens: Any = None #: :meta private: class Config: @@ -145,7 +143,7 @@ class RWKV(LLM, BaseModel): tokens = self.tokenizer.encode(prompt).ids logits = None - state = self.model_state + state = None occurrence = {} @@ -178,8 +176,6 @@ class RWKV(LLM, BaseModel): + occurrence[n] * self.penalty_alpha_frequency ) - # Update state for future invocations - self.model_state = state return decoded def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: