Add 'download_dir' argument to VLLM (#9754)

- Description:
Add a 'download_dir' argument to VLLM model (to change the cache
download directotu when retrieving a model from HF hub)
- Issue:
On some remote machine, I want the cache dir to be in a volume where I
have space (models are heavy nowadays). Sometimes the default HF cache
dir might not be what we want.
- Dependencies:
None

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Louis 2023-09-04 18:53:48 +01:00 committed by GitHub
parent 8bba69ffd0
commit bb8c095127
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -62,6 +62,10 @@ class VLLM(BaseLLM):
dtype: str = "auto" dtype: str = "auto"
"""The data type for the model weights and activations.""" """The data type for the model weights and activations."""
download_dir: Optional[str] = None
"""Directory to download and load the weights. (Default to the default
cache dir of huggingface)"""
vllm_kwargs: Dict[str, Any] = Field(default_factory=dict) vllm_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `vllm.LLM` call not explicitly specified.""" """Holds any model parameters valid for `vllm.LLM` call not explicitly specified."""
@ -84,6 +88,7 @@ class VLLM(BaseLLM):
tensor_parallel_size=values["tensor_parallel_size"], tensor_parallel_size=values["tensor_parallel_size"],
trust_remote_code=values["trust_remote_code"], trust_remote_code=values["trust_remote_code"],
dtype=values["dtype"], dtype=values["dtype"],
download_dir=values["download_dir"],
**values["vllm_kwargs"], **values["vllm_kwargs"],
) )