From 0164305e9dc3a22412e8b3018e82791791349498 Mon Sep 17 00:00:00 2001 From: Gustav von Zitzewitz Date: Wed, 24 May 2023 09:46:14 +0200 Subject: [PATCH] add function type --- datachad/models.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/datachad/models.py b/datachad/models.py index 31de4ec..9f3a883 100644 --- a/datachad/models.py +++ b/datachad/models.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Any, List import streamlit as st from langchain.base_language import BaseLanguageModel @@ -14,7 +15,7 @@ from datachad.utils import logger class Enum: @classmethod - def all(cls): + def all(cls) -> List[Any]: return [v for k, v in cls.__dict__.items() if not k.startswith("_")] @@ -25,7 +26,7 @@ class Model: embedding: str path: str = None # for local models only - def __str__(self): + def __str__(self) -> str: return self.name @@ -55,7 +56,7 @@ class MODELS(Enum): ) @classmethod - def for_mode(cls, mode): + def for_mode(cls, mode) -> List[Model]: return [m for m in cls.all() if isinstance(m, Model) and m.mode == mode]