mirror of
https://github.com/brycedrennan/imaginAIry
synced 2024-10-31 03:20:40 +00:00
55e27160f5
so we can still work in conda envs
54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
from typing import Any
|
|
|
|
from torch import Tensor
|
|
|
|
Context = dict[str, Any]
|
|
Contexts = dict[str, Context]
|
|
|
|
|
|
class ContextProvider:
|
|
def __init__(self) -> None:
|
|
self.contexts: Contexts = {}
|
|
|
|
def set_context(self, key: str, value: Context) -> None:
|
|
self.contexts[key] = value
|
|
|
|
def get_context(self, key: str) -> Any:
|
|
return self.contexts.get(key)
|
|
|
|
def update_contexts(self, new_contexts: Contexts) -> None:
|
|
for key, value in new_contexts.items():
|
|
if key not in self.contexts:
|
|
self.contexts[key] = value
|
|
else:
|
|
self.contexts[key].update(value)
|
|
|
|
@staticmethod
|
|
def create(contexts: Contexts) -> "ContextProvider":
|
|
provider = ContextProvider()
|
|
provider.update_contexts(contexts)
|
|
return provider
|
|
|
|
def __add__(self, other: "ContextProvider") -> "ContextProvider":
|
|
self.contexts.update(other.contexts)
|
|
return self
|
|
|
|
def __lshift__(self, other: "ContextProvider") -> "ContextProvider":
|
|
other.contexts.update(self.contexts)
|
|
return other
|
|
|
|
def __bool__(self) -> bool:
|
|
return bool(self.contexts)
|
|
|
|
def _get_repr_for_value(self, value: Any) -> str:
|
|
if isinstance(value, Tensor):
|
|
return f"Tensor(shape={value.shape}, dtype={value.dtype}, device={value.device})"
|
|
return repr(value)
|
|
|
|
def _get_repr_for_dict(self, context_dict: Context) -> dict[str, str]:
|
|
return {key: self._get_repr_for_value(value) for key, value in context_dict.items()}
|
|
|
|
def __repr__(self) -> str:
|
|
contexts_repr = {key: self._get_repr_for_dict(value) for key, value in self.contexts.items()}
|
|
return f"{self.__class__.__name__}(contexts={contexts_repr})"
|