RWKV: do not propagate model_state between calls (#2565)

RWKV is an RNN with a hidden state that is part of its inference.
However, the model state should not be carried across uses and it's a
bug to do so.

This resets the state for multiple invocations
This commit is contained in:
Alex Rad 2023-04-08 08:36:16 -07:00 committed by GitHub
parent 7a4e1b72a8
commit 7bf5b0ccd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -67,8 +67,6 @@ class RWKV(LLM, BaseModel):
pipeline: Any = None #: :meta private: pipeline: Any = None #: :meta private:
model_state: Any = None #: :meta private:
model_tokens: Any = None #: :meta private: model_tokens: Any = None #: :meta private:
class Config: class Config:
@ -145,7 +143,7 @@ class RWKV(LLM, BaseModel):
tokens = self.tokenizer.encode(prompt).ids tokens = self.tokenizer.encode(prompt).ids
logits = None logits = None
state = self.model_state state = None
occurrence = {} occurrence = {}
@ -178,8 +176,6 @@ class RWKV(LLM, BaseModel):
+ occurrence[n] * self.penalty_alpha_frequency + occurrence[n] * self.penalty_alpha_frequency
) )
# Update state for future invocations
self.model_state = state
return decoded return decoded
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: