From 532481390e00846b6f360893d451b2496a835dd1 Mon Sep 17 00:00:00 2001 From: Laurel Orr <57237365+lorr1@users.noreply.github.com> Date: Thu, 16 Jun 2022 22:58:18 -0700 Subject: [PATCH] fix: response log probs and max memory (#19) --- manifest/api/models/huggingface.py | 4 ++++ manifest/response.py | 5 ++++- pyproject.toml | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/manifest/api/models/huggingface.py b/manifest/api/models/huggingface.py index bd42d59..f6ab338 100644 --- a/manifest/api/models/huggingface.py +++ b/manifest/api/models/huggingface.py @@ -169,6 +169,7 @@ class HuggingFaceModel(Model): cache_dir: cache directory for model. """ from accelerate import dispatch_model, infer_auto_device_map + from accelerate.utils.modeling import get_max_memory model.tie_weights() # type: ignore # Get the model where we can infer devices from @@ -180,8 +181,11 @@ class HuggingFaceModel(Model): # Eleuther Neo and J main_model = model model_getter = "" + # Decrease max mem + max_memory = {k: int(0.85 * v) for k, v in get_max_memory().items()} raw_device_map = infer_auto_device_map( main_model, + max_memory=max_memory, no_split_module_classes=[ "OPTDecoderLayer", "GPTNeoBlock", diff --git a/manifest/response.py b/manifest/response.py index 87faa88..25b9e15 100644 --- a/manifest/response.py +++ b/manifest/response.py @@ -25,7 +25,10 @@ class Response: "Response must be serialized to a dict with a " "list of choices with text field" ) - if "logprobs" in self._response["choices"][0]: + if ( + "logprobs" in self._response["choices"][0] + and self._response["choices"][0]["logprobs"] + ): if not isinstance(self._response["choices"][0]["logprobs"], list): raise ValueError( "Response must be serialized to a dict with a " diff --git a/pyproject.toml b/pyproject.toml index 5cd612e..7d1f870 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ module = [ "dill", "tqdm.auto", "accelerate", + "accelerate.utils.modeling", ] [tool.isort]