2023-12-11 21:53:30 +00:00
|
|
|
|
"""RWKV models.
|
|
|
|
|
|
|
|
|
|
Based on https://github.com/saharNooby/rwkv.cpp/blob/master/rwkv/chat_with_bot.py
|
|
|
|
|
https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
|
|
|
|
|
"""
|
infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path
import toml
import subprocess
import re
ROOT_DIR = Path(__file__).parents[1]
def main():
for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
print(path)
with open(path, "rb") as f:
pyproject = tomllib.load(f)
try:
pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
"^1.10"
)
pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
"^0.5"
)
except KeyError:
continue
with open(path, "w") as f:
toml.dump(pyproject, f)
cwd = "/".join(path.split("/")[:-1])
completed = subprocess.run(
"poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
cwd=cwd,
shell=True,
capture_output=True,
text=True,
)
logs = completed.stdout.split("\n")
to_ignore = {}
for l in logs:
if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
path, line_no, error_type = re.match(
"^(.*)\:(\d+)\: error:.*\[(.*)\]", l
).groups()
if (path, line_no) in to_ignore:
to_ignore[(path, line_no)].append(error_type)
else:
to_ignore[(path, line_no)] = [error_type]
print(len(to_ignore))
for (error_path, line_no), error_types in to_ignore.items():
all_errors = ", ".join(error_types)
full_path = f"{cwd}/{error_path}"
try:
with open(full_path, "r") as f:
file_lines = f.readlines()
except FileNotFoundError:
continue
file_lines[int(line_no) - 1] = (
file_lines[int(line_no) - 1][:-1] + f" # type: ignore[{all_errors}]\n"
)
with open(full_path, "w") as f:
f.write("".join(file_lines))
subprocess.run(
"poetry run ruff format .; poetry run ruff --select I --fix .",
cwd=cwd,
shell=True,
capture_output=True,
text=True,
)
if __name__ == "__main__":
main()
```
2024-07-03 17:33:27 +00:00
|
|
|
|
|
2023-12-11 21:53:30 +00:00
|
|
|
|
from typing import Any, Dict, List, Mapping, Optional, Set
|
|
|
|
|
|
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
|
|
|
from langchain_core.language_models.llms import LLM
|
community[patch]: Upgrade pydantic extra (#25185)
Upgrade to using a literal for specifying the extra which is the
recommended approach in pydantic 2.
This works correctly also in pydantic v1.
```python
from pydantic.v1 import BaseModel
class Foo(BaseModel, extra="forbid"):
x: int
Foo(x=5, y=1)
```
And
```python
from pydantic.v1 import BaseModel
class Foo(BaseModel):
x: int
class Config:
extra = "forbid"
Foo(x=5, y=1)
```
## Enum -> literal using grit pattern:
```
engine marzano(0.1)
language python
or {
`extra=Extra.allow` => `extra="allow"`,
`extra=Extra.forbid` => `extra="forbid"`,
`extra=Extra.ignore` => `extra="ignore"`
}
```
Resorted attributes in config and removed doc-string in case we will
need to deal with going back and forth between pydantic v1 and v2 during
the 0.3 release. (This will reduce merge conflicts.)
## Sort attributes in Config:
```
engine marzano(0.1)
language python
function sort($values) js {
return $values.text.split(',').sort().join("\n");
}
class_definition($name, $body) as $C where {
$name <: `Config`,
$body <: block($statements),
$values = [],
$statements <: some bubble($values) assignment() as $A where {
$values += $A
},
$body => sort($values),
}
```
2024-08-08 17:20:39 +00:00
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel
|
2024-07-08 20:09:29 +00:00
|
|
|
|
from langchain_core.utils import pre_init
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
|
|
from langchain_community.llms.utils import enforce_stop_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RWKV(LLM, BaseModel):
|
|
|
|
|
"""RWKV language models.
|
|
|
|
|
|
|
|
|
|
To use, you should have the ``rwkv`` python package installed, the
|
|
|
|
|
pre-trained model file, and the model's config information.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
from langchain_community.llms import RWKV
|
|
|
|
|
model = RWKV(model="./models/rwkv-3b-fp16.bin", strategy="cpu fp32")
|
|
|
|
|
|
|
|
|
|
# Simplest invocation
|
2024-04-24 23:39:23 +00:00
|
|
|
|
response = model.invoke("Once upon a time, ")
|
2023-12-11 21:53:30 +00:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
model: str
|
|
|
|
|
"""Path to the pre-trained RWKV model file."""
|
|
|
|
|
|
|
|
|
|
tokens_path: str
|
|
|
|
|
"""Path to the RWKV tokens file."""
|
|
|
|
|
|
|
|
|
|
strategy: str = "cpu fp32"
|
|
|
|
|
"""Token context window."""
|
|
|
|
|
|
|
|
|
|
rwkv_verbose: bool = True
|
|
|
|
|
"""Print debug information."""
|
|
|
|
|
|
|
|
|
|
temperature: float = 1.0
|
|
|
|
|
"""The temperature to use for sampling."""
|
|
|
|
|
|
|
|
|
|
top_p: float = 0.5
|
|
|
|
|
"""The top-p value to use for sampling."""
|
|
|
|
|
|
|
|
|
|
penalty_alpha_frequency: float = 0.4
|
|
|
|
|
"""Positive values penalize new tokens based on their existing frequency
|
|
|
|
|
in the text so far, decreasing the model's likelihood to repeat the same
|
|
|
|
|
line verbatim.."""
|
|
|
|
|
|
|
|
|
|
penalty_alpha_presence: float = 0.4
|
|
|
|
|
"""Positive values penalize new tokens based on whether they appear
|
|
|
|
|
in the text so far, increasing the model's likelihood to talk about
|
|
|
|
|
new topics.."""
|
|
|
|
|
|
|
|
|
|
CHUNK_LEN: int = 256
|
|
|
|
|
"""Batch size for prompt processing."""
|
|
|
|
|
|
|
|
|
|
max_tokens_per_generation: int = 256
|
|
|
|
|
"""Maximum number of tokens to generate."""
|
|
|
|
|
|
|
|
|
|
client: Any = None #: :meta private:
|
|
|
|
|
|
|
|
|
|
tokenizer: Any = None #: :meta private:
|
|
|
|
|
|
|
|
|
|
pipeline: Any = None #: :meta private:
|
|
|
|
|
|
|
|
|
|
model_tokens: Any = None #: :meta private:
|
|
|
|
|
|
|
|
|
|
model_state: Any = None #: :meta private:
|
|
|
|
|
|
|
|
|
|
class Config:
|
community[patch]: Upgrade pydantic extra (#25185)
Upgrade to using a literal for specifying the extra which is the
recommended approach in pydantic 2.
This works correctly also in pydantic v1.
```python
from pydantic.v1 import BaseModel
class Foo(BaseModel, extra="forbid"):
x: int
Foo(x=5, y=1)
```
And
```python
from pydantic.v1 import BaseModel
class Foo(BaseModel):
x: int
class Config:
extra = "forbid"
Foo(x=5, y=1)
```
## Enum -> literal using grit pattern:
```
engine marzano(0.1)
language python
or {
`extra=Extra.allow` => `extra="allow"`,
`extra=Extra.forbid` => `extra="forbid"`,
`extra=Extra.ignore` => `extra="ignore"`
}
```
Resorted attributes in config and removed doc-string in case we will
need to deal with going back and forth between pydantic v1 and v2 during
the 0.3 release. (This will reduce merge conflicts.)
## Sort attributes in Config:
```
engine marzano(0.1)
language python
function sort($values) js {
return $values.text.split(',').sort().join("\n");
}
class_definition($name, $body) as $C where {
$name <: `Config`,
$body <: block($statements),
$values = [],
$statements <: some bubble($values) assignment() as $A where {
$values += $A
},
$body => sort($values),
}
```
2024-08-08 17:20:39 +00:00
|
|
|
|
extra = "forbid"
|
2023-12-11 21:53:30 +00:00
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
|
|
|
"""Get the identifying parameters."""
|
|
|
|
|
return {
|
|
|
|
|
"verbose": self.verbose,
|
|
|
|
|
"top_p": self.top_p,
|
|
|
|
|
"temperature": self.temperature,
|
|
|
|
|
"penalty_alpha_frequency": self.penalty_alpha_frequency,
|
|
|
|
|
"penalty_alpha_presence": self.penalty_alpha_presence,
|
|
|
|
|
"CHUNK_LEN": self.CHUNK_LEN,
|
|
|
|
|
"max_tokens_per_generation": self.max_tokens_per_generation,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _rwkv_param_names() -> Set[str]:
|
|
|
|
|
"""Get the identifying parameters."""
|
|
|
|
|
return {
|
|
|
|
|
"verbose",
|
|
|
|
|
}
|
|
|
|
|
|
2024-07-08 20:09:29 +00:00
|
|
|
|
@pre_init
|
2023-12-11 21:53:30 +00:00
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
|
"""Validate that the python package exists in the environment."""
|
|
|
|
|
try:
|
|
|
|
|
import tokenizers
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import tokenizers python package. "
|
|
|
|
|
"Please install it with `pip install tokenizers`."
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
from rwkv.model import RWKV as RWKVMODEL
|
|
|
|
|
from rwkv.utils import PIPELINE
|
|
|
|
|
|
|
|
|
|
values["tokenizer"] = tokenizers.Tokenizer.from_file(values["tokens_path"])
|
|
|
|
|
|
|
|
|
|
rwkv_keys = cls._rwkv_param_names()
|
|
|
|
|
model_kwargs = {k: v for k, v in values.items() if k in rwkv_keys}
|
|
|
|
|
model_kwargs["verbose"] = values["rwkv_verbose"]
|
|
|
|
|
values["client"] = RWKVMODEL(
|
|
|
|
|
values["model"], strategy=values["strategy"], **model_kwargs
|
|
|
|
|
)
|
|
|
|
|
values["pipeline"] = PIPELINE(values["client"], values["tokens_path"])
|
|
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import rwkv python package. "
|
|
|
|
|
"Please install it with `pip install rwkv`."
|
|
|
|
|
)
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
|
"""Get the identifying parameters."""
|
|
|
|
|
return {
|
|
|
|
|
"model": self.model,
|
|
|
|
|
**self._default_params,
|
|
|
|
|
**{k: v for k, v in self.__dict__.items() if k in RWKV._rwkv_param_names()},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
|
"""Return the type of llm."""
|
|
|
|
|
return "rwkv"
|
|
|
|
|
|
|
|
|
|
def run_rnn(self, _tokens: List[str], newline_adj: int = 0) -> Any:
|
|
|
|
|
AVOID_REPEAT_TOKENS = []
|
|
|
|
|
AVOID_REPEAT = ",:?!"
|
|
|
|
|
for i in AVOID_REPEAT:
|
|
|
|
|
dd = self.pipeline.encode(i)
|
|
|
|
|
assert len(dd) == 1
|
|
|
|
|
AVOID_REPEAT_TOKENS += dd
|
|
|
|
|
|
|
|
|
|
tokens = [int(x) for x in _tokens]
|
|
|
|
|
self.model_tokens += tokens
|
|
|
|
|
|
|
|
|
|
out: Any = None
|
|
|
|
|
|
|
|
|
|
while len(tokens) > 0:
|
|
|
|
|
out, self.model_state = self.client.forward(
|
|
|
|
|
tokens[: self.CHUNK_LEN], self.model_state
|
|
|
|
|
)
|
|
|
|
|
tokens = tokens[self.CHUNK_LEN :]
|
|
|
|
|
END_OF_LINE = 187
|
|
|
|
|
out[END_OF_LINE] += newline_adj # adjust \n probability
|
|
|
|
|
|
|
|
|
|
if self.model_tokens[-1] in AVOID_REPEAT_TOKENS:
|
|
|
|
|
out[self.model_tokens[-1]] = -999999999
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
def rwkv_generate(self, prompt: str) -> str:
|
|
|
|
|
self.model_state = None
|
|
|
|
|
self.model_tokens = []
|
|
|
|
|
logits = self.run_rnn(self.tokenizer.encode(prompt).ids)
|
|
|
|
|
begin = len(self.model_tokens)
|
|
|
|
|
out_last = begin
|
|
|
|
|
|
|
|
|
|
occurrence: Dict = {}
|
|
|
|
|
|
|
|
|
|
decoded = ""
|
|
|
|
|
for i in range(self.max_tokens_per_generation):
|
|
|
|
|
for n in occurrence:
|
|
|
|
|
logits[n] -= (
|
|
|
|
|
self.penalty_alpha_presence
|
|
|
|
|
+ occurrence[n] * self.penalty_alpha_frequency
|
|
|
|
|
)
|
|
|
|
|
token = self.pipeline.sample_logits(
|
|
|
|
|
logits, temperature=self.temperature, top_p=self.top_p
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
END_OF_TEXT = 0
|
|
|
|
|
if token == END_OF_TEXT:
|
|
|
|
|
break
|
|
|
|
|
if token not in occurrence:
|
|
|
|
|
occurrence[token] = 1
|
|
|
|
|
else:
|
|
|
|
|
occurrence[token] += 1
|
|
|
|
|
|
|
|
|
|
logits = self.run_rnn([token])
|
|
|
|
|
xxx = self.tokenizer.decode(self.model_tokens[out_last:])
|
|
|
|
|
if "\ufffd" not in xxx: # avoid utf-8 display issues
|
|
|
|
|
decoded += xxx
|
|
|
|
|
out_last = begin + i + 1
|
|
|
|
|
if i >= self.max_tokens_per_generation - 100:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
return decoded
|
|
|
|
|
|
|
|
|
|
def _call(
|
|
|
|
|
self,
|
|
|
|
|
prompt: str,
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> str:
|
|
|
|
|
r"""RWKV generation
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
prompt: The prompt to pass into the model.
|
|
|
|
|
stop: A list of strings to stop generation when encountered.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
The string generated by the model.
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
prompt = "Once upon a time, "
|
2024-04-24 23:39:23 +00:00
|
|
|
|
response = model.invoke(prompt, n_predict=55)
|
2023-12-11 21:53:30 +00:00
|
|
|
|
"""
|
|
|
|
|
text = self.rwkv_generate(prompt)
|
|
|
|
|
|
|
|
|
|
if stop is not None:
|
|
|
|
|
text = enforce_stop_tokens(text, stop)
|
|
|
|
|
return text
|