You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
imaginAIry/imaginairy/vendored/refiners/fluxion/context.py

54 lines
1.7 KiB
Python

from typing import Any
from torch import Tensor
Context = dict[str, Any]
Contexts = dict[str, Context]
class ContextProvider:
def __init__(self) -> None:
self.contexts: Contexts = {}
def set_context(self, key: str, value: Context) -> None:
self.contexts[key] = value
def get_context(self, key: str) -> Any:
return self.contexts.get(key)
def update_contexts(self, new_contexts: Contexts) -> None:
for key, value in new_contexts.items():
if key not in self.contexts:
self.contexts[key] = value
else:
self.contexts[key].update(value)
@staticmethod
def create(contexts: Contexts) -> "ContextProvider":
provider = ContextProvider()
provider.update_contexts(contexts)
return provider
def __add__(self, other: "ContextProvider") -> "ContextProvider":
self.contexts.update(other.contexts)
return self
def __lshift__(self, other: "ContextProvider") -> "ContextProvider":
other.contexts.update(self.contexts)
return other
def __bool__(self) -> bool:
return bool(self.contexts)
def _get_repr_for_value(self, value: Any) -> str:
if isinstance(value, Tensor):
return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})"
return repr(value)
def _get_repr_for_dict(self, context_dict: Context) -> dict[str, str]:
return {key: self._get_repr_for_value(value) for key, value in context_dict.items()}
def __repr__(self) -> str:
contexts_repr = {key: self._get_repr_for_dict(value) for key, value in self.contexts.items()}
return f"{self.__class__.__name__}(contexts={contexts_repr})"