From 96648965ac5e7ae425b88e7e02da6db325f0844e Mon Sep 17 00:00:00 2001 From: Simran Arora Date: Thu, 6 Apr 2023 21:41:12 +0530 Subject: [PATCH] Helm fixes (#71) * handling helm response format * Adding helm --- manifest/clients/helm.py | 10 +++++++--- manifest/manifest.py | 2 ++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/manifest/clients/helm.py b/manifest/clients/helm.py index 4356b6a..095b32d 100644 --- a/manifest/clients/helm.py +++ b/manifest/clients/helm.py @@ -13,7 +13,9 @@ from manifest.request import LMRequest, Request logger = logging.getLogger(__name__) HELM_ENGINES = { - "ai21/j1-jumbo" "ai21/j1-grande", + "ai21/j1-jumbo", + "ai21/j2-jumbo", + "ai21/j1-grande", "ai21/j1-grande-v2-beta", "ai21/j1-large", "AlephAlpha/luminous-base", @@ -22,6 +24,7 @@ HELM_ENGINES = { "anthropic/stanford-online-all-v4-s3", "together/bloom", "together/t0pp", + "cohere/command-xlarge-beta", "cohere/xlarge-20220609", "cohere/xlarge-20221108", "cohere/large-20220720", @@ -78,7 +81,7 @@ class HELMClient(Client): "stop_sequences": ("stop_sequences", None), # HELM doesn't like empty lists "presence_penalty": ("presence_penalty", 0.0), "frequency_penalty": ("frequency_penalty", 0.0), - "client_timeout": ("client_timeout", 60), # seconds + #"client_timeout": ("client_timeout", 60), # seconds } REQUEST_CLS = LMRequest @@ -168,6 +171,7 @@ class HELMClient(Client): except Exception as e: logger.error(f"HELM error {e}.") raise e - return self.format_response(request_result.__dict__()) + res_dict = {"choices": [{"text": com.text} for com in request_result.completions]} + return self.format_response(res_dict) return _run_completion, request_params diff --git a/manifest/manifest.py b/manifest/manifest.py index a740d35..4a86da3 100644 --- a/manifest/manifest.py +++ b/manifest/manifest.py @@ -15,6 +15,7 @@ from manifest.clients.dummy import DummyClient from manifest.clients.huggingface import HuggingFaceClient from manifest.clients.openai import OpenAIClient from manifest.clients.toma import TOMAClient +from manifest.clients.helm import HELMClient from manifest.response import Response from manifest.session import Session @@ -30,6 +31,7 @@ CLIENT_CONSTRUCTORS = { "diffuser": DiffuserClient, "dummy": DummyClient, "toma": TOMAClient, + "helm": HELMClient, } CACHE_CONSTRUCTORS = {