langchain/libs/community/langchain_community/llms/rwkv.py
Eugene Yurtsev bf5193bb99
community[patch]: Upgrade pydantic extra (#25185)
Upgrade to using a literal for specifying the extra which is the
recommended approach in pydantic 2.

This works correctly also in pydantic v1.

```python
from pydantic.v1 import BaseModel

class Foo(BaseModel, extra="forbid"):
    x: int

Foo(x=5, y=1)
```

And 


```python
from pydantic.v1 import BaseModel

class Foo(BaseModel):
    x: int

    class Config:
      extra = "forbid"

Foo(x=5, y=1)
```


## Enum -> literal using grit pattern:

```
engine marzano(0.1)
language python
or {
    `extra=Extra.allow` => `extra="allow"`,
    `extra=Extra.forbid` => `extra="forbid"`,
    `extra=Extra.ignore` => `extra="ignore"`
}
```

Resorted attributes in config and removed doc-string in case we will
need to deal with going back and forth between pydantic v1 and v2 during
the 0.3 release. (This will reduce merge conflicts.)


## Sort attributes in Config:

```
engine marzano(0.1)
language python


function sort($values) js {
    return $values.text.split(',').sort().join("\n");
}


class_definition($name, $body) as $C where {
    $name <: `Config`,
    $body <: block($statements),
    $values = [],
    $statements <: some bubble($values) assignment() as $A where {
        $values += $A
    },
    $body => sort($values),
}

```
2024-08-08 17:20:39 +00:00

235 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""RWKV models.
Based on https://github.com/saharNooby/rwkv.cpp/blob/master/rwkv/chat_with_bot.py
https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
"""
from typing import Any, Dict, List, Mapping, Optional, Set
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils import pre_init
from langchain_community.llms.utils import enforce_stop_tokens
class RWKV(LLM, BaseModel):
"""RWKV language models.
To use, you should have the ``rwkv`` python package installed, the
pre-trained model file, and the model's config information.
Example:
.. code-block:: python
from langchain_community.llms import RWKV
model = RWKV(model="./models/rwkv-3b-fp16.bin", strategy="cpu fp32")
# Simplest invocation
response = model.invoke("Once upon a time, ")
"""
model: str
"""Path to the pre-trained RWKV model file."""
tokens_path: str
"""Path to the RWKV tokens file."""
strategy: str = "cpu fp32"
"""Token context window."""
rwkv_verbose: bool = True
"""Print debug information."""
temperature: float = 1.0
"""The temperature to use for sampling."""
top_p: float = 0.5
"""The top-p value to use for sampling."""
penalty_alpha_frequency: float = 0.4
"""Positive values penalize new tokens based on their existing frequency
in the text so far, decreasing the model's likelihood to repeat the same
line verbatim.."""
penalty_alpha_presence: float = 0.4
"""Positive values penalize new tokens based on whether they appear
in the text so far, increasing the model's likelihood to talk about
new topics.."""
CHUNK_LEN: int = 256
"""Batch size for prompt processing."""
max_tokens_per_generation: int = 256
"""Maximum number of tokens to generate."""
client: Any = None #: :meta private:
tokenizer: Any = None #: :meta private:
pipeline: Any = None #: :meta private:
model_tokens: Any = None #: :meta private:
model_state: Any = None #: :meta private:
class Config:
extra = "forbid"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return {
"verbose": self.verbose,
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_alpha_frequency": self.penalty_alpha_frequency,
"penalty_alpha_presence": self.penalty_alpha_presence,
"CHUNK_LEN": self.CHUNK_LEN,
"max_tokens_per_generation": self.max_tokens_per_generation,
}
@staticmethod
def _rwkv_param_names() -> Set[str]:
"""Get the identifying parameters."""
return {
"verbose",
}
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that the python package exists in the environment."""
try:
import tokenizers
except ImportError:
raise ImportError(
"Could not import tokenizers python package. "
"Please install it with `pip install tokenizers`."
)
try:
from rwkv.model import RWKV as RWKVMODEL
from rwkv.utils import PIPELINE
values["tokenizer"] = tokenizers.Tokenizer.from_file(values["tokens_path"])
rwkv_keys = cls._rwkv_param_names()
model_kwargs = {k: v for k, v in values.items() if k in rwkv_keys}
model_kwargs["verbose"] = values["rwkv_verbose"]
values["client"] = RWKVMODEL(
values["model"], strategy=values["strategy"], **model_kwargs
)
values["pipeline"] = PIPELINE(values["client"], values["tokens_path"])
except ImportError:
raise ImportError(
"Could not import rwkv python package. "
"Please install it with `pip install rwkv`."
)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"model": self.model,
**self._default_params,
**{k: v for k, v in self.__dict__.items() if k in RWKV._rwkv_param_names()},
}
@property
def _llm_type(self) -> str:
"""Return the type of llm."""
return "rwkv"
def run_rnn(self, _tokens: List[str], newline_adj: int = 0) -> Any:
AVOID_REPEAT_TOKENS = []
AVOID_REPEAT = ""
for i in AVOID_REPEAT:
dd = self.pipeline.encode(i)
assert len(dd) == 1
AVOID_REPEAT_TOKENS += dd
tokens = [int(x) for x in _tokens]
self.model_tokens += tokens
out: Any = None
while len(tokens) > 0:
out, self.model_state = self.client.forward(
tokens[: self.CHUNK_LEN], self.model_state
)
tokens = tokens[self.CHUNK_LEN :]
END_OF_LINE = 187
out[END_OF_LINE] += newline_adj # adjust \n probability
if self.model_tokens[-1] in AVOID_REPEAT_TOKENS:
out[self.model_tokens[-1]] = -999999999
return out
def rwkv_generate(self, prompt: str) -> str:
self.model_state = None
self.model_tokens = []
logits = self.run_rnn(self.tokenizer.encode(prompt).ids)
begin = len(self.model_tokens)
out_last = begin
occurrence: Dict = {}
decoded = ""
for i in range(self.max_tokens_per_generation):
for n in occurrence:
logits[n] -= (
self.penalty_alpha_presence
+ occurrence[n] * self.penalty_alpha_frequency
)
token = self.pipeline.sample_logits(
logits, temperature=self.temperature, top_p=self.top_p
)
END_OF_TEXT = 0
if token == END_OF_TEXT:
break
if token not in occurrence:
occurrence[token] = 1
else:
occurrence[token] += 1
logits = self.run_rnn([token])
xxx = self.tokenizer.decode(self.model_tokens[out_last:])
if "\ufffd" not in xxx: # avoid utf-8 display issues
decoded += xxx
out_last = begin + i + 1
if i >= self.max_tokens_per_generation - 100:
break
return decoded
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""RWKV generation
Args:
prompt: The prompt to pass into the model.
stop: A list of strings to stop generation when encountered.
Returns:
The string generated by the model.
Example:
.. code-block:: python
prompt = "Once upon a time, "
response = model.invoke(prompt, n_predict=55)
"""
text = self.rwkv_generate(prompt)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text