RWKV: do not propagate model_state between calls (#2565)

RWKV is an RNN with a hidden state that is part of its inference.
However, the model state should not be carried across uses and it's a
bug to do so.

This resets the state for multiple invocations
pull/2589/head
Alex Rad 1 year ago committed by GitHub
parent 7a4e1b72a8
commit 7bf5b0ccd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -67,8 +67,6 @@ class RWKV(LLM, BaseModel):
pipeline: Any = None #: :meta private:
model_state: Any = None #: :meta private:
model_tokens: Any = None #: :meta private:
class Config:
@ -145,7 +143,7 @@ class RWKV(LLM, BaseModel):
tokens = self.tokenizer.encode(prompt).ids
logits = None
state = self.model_state
state = None
occurrence = {}
@ -178,8 +176,6 @@ class RWKV(LLM, BaseModel):
+ occurrence[n] * self.penalty_alpha_frequency
)
# Update state for future invocations
self.model_state = state
return decoded
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:

Loading…
Cancel
Save