fix: response log probs and max memory (#19)

laurel/helm
Laurel Orr 2 years ago committed by GitHub
parent 99bcc76fde
commit 532481390e

@ -169,6 +169,7 @@ class HuggingFaceModel(Model):
cache_dir: cache directory for model.
"""
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils.modeling import get_max_memory
model.tie_weights() # type: ignore
# Get the model where we can infer devices from
@ -180,8 +181,11 @@ class HuggingFaceModel(Model):
# Eleuther Neo and J
main_model = model
model_getter = ""
# Decrease max mem
max_memory = {k: int(0.85 * v) for k, v in get_max_memory().items()}
raw_device_map = infer_auto_device_map(
main_model,
max_memory=max_memory,
no_split_module_classes=[
"OPTDecoderLayer",
"GPTNeoBlock",

@ -25,7 +25,10 @@ class Response:
"Response must be serialized to a dict with a "
"list of choices with text field"
)
if "logprobs" in self._response["choices"][0]:
if (
"logprobs" in self._response["choices"][0]
and self._response["choices"][0]["logprobs"]
):
if not isinstance(self._response["choices"][0]["logprobs"], list):
raise ValueError(
"Response must be serialized to a dict with a "

@ -67,6 +67,7 @@ module = [
"dill",
"tqdm.auto",
"accelerate",
"accelerate.utils.modeling",
]
[tool.isort]

Loading…
Cancel
Save