mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Allow user to modify the GPU and language settings when using NLP Cloud (#7985)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
483f6c2fe3
commit
73d5cba308
@ -20,12 +20,16 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
model_name: str # Define model_name as a class attribute
|
model_name: str # Define model_name as a class attribute
|
||||||
|
gpu: bool # Define gpu as a class attribute
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name: str = "paraphrase-multilingual-mpnet-base-v2", **kwargs: Any
|
self,
|
||||||
|
model_name: str = "paraphrase-multilingual-mpnet-base-v2",
|
||||||
|
gpu: bool = False,
|
||||||
|
**kwargs: Any
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(model_name=model_name, **kwargs)
|
super().__init__(model_name=model_name, gpu=gpu, **kwargs)
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
@ -37,7 +41,7 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
|
|||||||
import nlpcloud
|
import nlpcloud
|
||||||
|
|
||||||
values["client"] = nlpcloud.Client(
|
values["client"] = nlpcloud.Client(
|
||||||
values["model_name"], nlpcloud_api_key, gpu=False, lang="en"
|
values["model_name"], nlpcloud_api_key, gpu=values["gpu"], lang="en"
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
@ -17,12 +17,16 @@ class NLPCloud(LLM):
|
|||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
from langchain.llms import NLPCloud
|
from langchain.llms import NLPCloud
|
||||||
nlpcloud = NLPCloud(model="gpt-neox-20b")
|
nlpcloud = NLPCloud(model="finetuned-gpt-neox-20b")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
client: Any #: :meta private:
|
client: Any #: :meta private:
|
||||||
model_name: str = "finetuned-gpt-neox-20b"
|
model_name: str = "finetuned-gpt-neox-20b"
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
|
gpu: bool = True
|
||||||
|
"""Whether to use a GPU or not"""
|
||||||
|
lang: str = "en"
|
||||||
|
"""Language to use (multilingual addon)"""
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
min_length: int = 1
|
min_length: int = 1
|
||||||
@ -71,7 +75,10 @@ class NLPCloud(LLM):
|
|||||||
import nlpcloud
|
import nlpcloud
|
||||||
|
|
||||||
values["client"] = nlpcloud.Client(
|
values["client"] = nlpcloud.Client(
|
||||||
values["model_name"], nlpcloud_api_key, gpu=True, lang="en"
|
values["model_name"],
|
||||||
|
nlpcloud_api_key,
|
||||||
|
gpu=values["gpu"],
|
||||||
|
lang=values["lang"],
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
@ -104,7 +111,12 @@ class NLPCloud(LLM):
|
|||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {**{"model_name": self.model_name}, **self._default_params}
|
return {
|
||||||
|
**{"model_name": self.model_name},
|
||||||
|
**{"gpu": self.gpu},
|
||||||
|
**{"lang": self.lang},
|
||||||
|
**self._default_params,
|
||||||
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user