Helm fixes (#71)

* handling helm response format

* Adding helm
This commit is contained in:
Simran Arora 2023-04-06 21:41:12 +05:30 committed by GitHub
parent 934a0bd5cd
commit 96648965ac
2 changed files with 9 additions and 3 deletions

View File

@ -13,7 +13,9 @@ from manifest.request import LMRequest, Request
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HELM_ENGINES = { HELM_ENGINES = {
"ai21/j1-jumbo" "ai21/j1-grande", "ai21/j1-jumbo",
"ai21/j2-jumbo",
"ai21/j1-grande",
"ai21/j1-grande-v2-beta", "ai21/j1-grande-v2-beta",
"ai21/j1-large", "ai21/j1-large",
"AlephAlpha/luminous-base", "AlephAlpha/luminous-base",
@ -22,6 +24,7 @@ HELM_ENGINES = {
"anthropic/stanford-online-all-v4-s3", "anthropic/stanford-online-all-v4-s3",
"together/bloom", "together/bloom",
"together/t0pp", "together/t0pp",
"cohere/command-xlarge-beta",
"cohere/xlarge-20220609", "cohere/xlarge-20220609",
"cohere/xlarge-20221108", "cohere/xlarge-20221108",
"cohere/large-20220720", "cohere/large-20220720",
@ -78,7 +81,7 @@ class HELMClient(Client):
"stop_sequences": ("stop_sequences", None), # HELM doesn't like empty lists "stop_sequences": ("stop_sequences", None), # HELM doesn't like empty lists
"presence_penalty": ("presence_penalty", 0.0), "presence_penalty": ("presence_penalty", 0.0),
"frequency_penalty": ("frequency_penalty", 0.0), "frequency_penalty": ("frequency_penalty", 0.0),
"client_timeout": ("client_timeout", 60), # seconds #"client_timeout": ("client_timeout", 60), # seconds
} }
REQUEST_CLS = LMRequest REQUEST_CLS = LMRequest
@ -168,6 +171,7 @@ class HELMClient(Client):
except Exception as e: except Exception as e:
logger.error(f"HELM error {e}.") logger.error(f"HELM error {e}.")
raise e raise e
return self.format_response(request_result.__dict__()) res_dict = {"choices": [{"text": com.text} for com in request_result.completions]}
return self.format_response(res_dict)
return _run_completion, request_params return _run_completion, request_params

View File

@ -15,6 +15,7 @@ from manifest.clients.dummy import DummyClient
from manifest.clients.huggingface import HuggingFaceClient from manifest.clients.huggingface import HuggingFaceClient
from manifest.clients.openai import OpenAIClient from manifest.clients.openai import OpenAIClient
from manifest.clients.toma import TOMAClient from manifest.clients.toma import TOMAClient
from manifest.clients.helm import HELMClient
from manifest.response import Response from manifest.response import Response
from manifest.session import Session from manifest.session import Session
@ -30,6 +31,7 @@ CLIENT_CONSTRUCTORS = {
"diffuser": DiffuserClient, "diffuser": DiffuserClient,
"dummy": DummyClient, "dummy": DummyClient,
"toma": TOMAClient, "toma": TOMAClient,
"helm": HELMClient,
} }
CACHE_CONSTRUCTORS = { CACHE_CONSTRUCTORS = {