mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
RWKV: do not propagate model_state between calls (#2565)
RWKV is an RNN with a hidden state that is part of its inference. However, the model state should not be carried across uses and it's a bug to do so. This resets the state for multiple invocations
This commit is contained in:
parent
7a4e1b72a8
commit
7bf5b0ccd3
@ -67,8 +67,6 @@ class RWKV(LLM, BaseModel):
|
|||||||
|
|
||||||
pipeline: Any = None #: :meta private:
|
pipeline: Any = None #: :meta private:
|
||||||
|
|
||||||
model_state: Any = None #: :meta private:
|
|
||||||
|
|
||||||
model_tokens: Any = None #: :meta private:
|
model_tokens: Any = None #: :meta private:
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -145,7 +143,7 @@ class RWKV(LLM, BaseModel):
|
|||||||
tokens = self.tokenizer.encode(prompt).ids
|
tokens = self.tokenizer.encode(prompt).ids
|
||||||
|
|
||||||
logits = None
|
logits = None
|
||||||
state = self.model_state
|
state = None
|
||||||
|
|
||||||
occurrence = {}
|
occurrence = {}
|
||||||
|
|
||||||
@ -178,8 +176,6 @@ class RWKV(LLM, BaseModel):
|
|||||||
+ occurrence[n] * self.penalty_alpha_frequency
|
+ occurrence[n] * self.penalty_alpha_frequency
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update state for future invocations
|
|
||||||
self.model_state = state
|
|
||||||
return decoded
|
return decoded
|
||||||
|
|
||||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user