langchain/libs/experimental/langchain_experimental/rl_chain/model_repository.py

58 lines
1.8 KiB
Python
Raw Normal View History

import datetime
import glob
import logging
import os
import shutil
from pathlib import Path
2023-08-28 10:58:33 +00:00
from typing import TYPE_CHECKING, List, Union
if TYPE_CHECKING:
import vowpal_wabbit_next as vw
logger = logging.getLogger(__name__)
class ModelRepository:
def __init__(
self,
folder: Union[str, os.PathLike],
with_history: bool = True,
reset: bool = False,
):
self.folder = Path(folder)
self.model_path = self.folder / "latest.vw"
self.with_history = with_history
2023-08-28 10:58:33 +00:00
if reset and self.has_history():
logger.warning(
"There is non empty history which is recommended to be cleaned up"
)
if self.model_path.exists():
os.remove(self.model_path)
self.folder.mkdir(parents=True, exist_ok=True)
def get_tag(self) -> str:
return datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
def has_history(self) -> bool:
return len(glob.glob(str(self.folder / "model-????????-??????.vw"))) > 0
def save(self, workspace: "vw.Workspace") -> None:
with open(self.model_path, "wb") as f:
logger.info(f"storing rl_chain model in: {self.model_path}")
f.write(workspace.serialize())
if self.with_history: # write history
shutil.copyfile(self.model_path, self.folder / f"model-{self.get_tag()}.vw")
2023-08-28 10:58:33 +00:00
def load(self, commandline: List[str]) -> "vw.Workspace":
import vowpal_wabbit_next as vw
model_data = None
if self.model_path.exists():
with open(self.model_path, "rb") as f:
model_data = f.read()
if model_data:
logger.info(f"rl_chain model is loaded from: {self.model_path}")
return vw.Workspace(commandline, model_data=model_data)
return vw.Workspace(commandline)