You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
21 lines
731 B
Python
21 lines
731 B
Python
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
from hivemind.moe.server.task_pool import Task
|
|
|
|
|
|
class TaskPrioritizerBase(ABC):
|
|
"""Abstract class for TaskPrioritizer whose reponsibility is to evaluate task priority"""
|
|
|
|
@abstractmethod
|
|
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
|
|
"""Evaluates task value by the amout of points given, task input and additional kwargs. Lower priority is better"""
|
|
pass
|
|
|
|
|
|
class DummyTaskPrioritizer(TaskPrioritizerBase):
|
|
"""Simple implementation of TaskPrioritizer which gives constant zero priority for every task"""
|
|
|
|
def prioritize(self, *input: torch.Tensor, points: float = 0.0, **kwargs) -> float:
|
|
return 0.0
|