diff --git a/libs/langchain/langchain/embeddings/nlpcloud.py b/libs/langchain/langchain/embeddings/nlpcloud.py index 2e80d8d232..5a65768a94 100644 --- a/libs/langchain/langchain/embeddings/nlpcloud.py +++ b/libs/langchain/langchain/embeddings/nlpcloud.py @@ -20,12 +20,16 @@ class NLPCloudEmbeddings(BaseModel, Embeddings): """ model_name: str # Define model_name as a class attribute + gpu: bool # Define gpu as a class attribute client: Any #: :meta private: def __init__( - self, model_name: str = "paraphrase-multilingual-mpnet-base-v2", **kwargs: Any + self, + model_name: str = "paraphrase-multilingual-mpnet-base-v2", + gpu: bool = False, + **kwargs: Any ) -> None: - super().__init__(model_name=model_name, **kwargs) + super().__init__(model_name=model_name, gpu=gpu, **kwargs) @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -37,7 +41,7 @@ class NLPCloudEmbeddings(BaseModel, Embeddings): import nlpcloud values["client"] = nlpcloud.Client( - values["model_name"], nlpcloud_api_key, gpu=False, lang="en" + values["model_name"], nlpcloud_api_key, gpu=values["gpu"], lang="en" ) except ImportError: raise ImportError( diff --git a/libs/langchain/langchain/llms/nlpcloud.py b/libs/langchain/langchain/llms/nlpcloud.py index 7c1b0c0acb..9e5070acf7 100644 --- a/libs/langchain/langchain/llms/nlpcloud.py +++ b/libs/langchain/langchain/llms/nlpcloud.py @@ -17,12 +17,16 @@ class NLPCloud(LLM): .. code-block:: python from langchain.llms import NLPCloud - nlpcloud = NLPCloud(model="gpt-neox-20b") + nlpcloud = NLPCloud(model="finetuned-gpt-neox-20b") """ client: Any #: :meta private: model_name: str = "finetuned-gpt-neox-20b" """Model name to use.""" + gpu: bool = True + """Whether to use a GPU or not""" + lang: str = "en" + """Language to use (multilingual addon)""" temperature: float = 0.7 """What sampling temperature to use.""" min_length: int = 1 @@ -71,7 +75,10 @@ class NLPCloud(LLM): import nlpcloud values["client"] = nlpcloud.Client( - values["model_name"], nlpcloud_api_key, gpu=True, lang="en" + values["model_name"], + nlpcloud_api_key, + gpu=values["gpu"], + lang=values["lang"], ) except ImportError: raise ImportError( @@ -104,7 +111,12 @@ class NLPCloud(LLM): @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" - return {**{"model_name": self.model_name}, **self._default_params} + return { + **{"model_name": self.model_name}, + **{"gpu": self.gpu}, + **{"lang": self.lang}, + **self._default_params, + } @property def _llm_type(self) -> str: