2024-04-19 20:56:24 +00:00
|
|
|
import os
|
|
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
|
2024-05-06 16:48:26 +00:00
|
|
|
from ibm_watsonx_ai import APIClient, Credentials # type: ignore
|
2024-04-19 20:56:24 +00:00
|
|
|
from ibm_watsonx_ai.foundation_models.embeddings import Embeddings # type: ignore
|
|
|
|
from langchain_core.embeddings import Embeddings as LangChainEmbeddings
|
|
|
|
from langchain_core.pydantic_v1 import (
|
|
|
|
BaseModel,
|
|
|
|
Extra,
|
|
|
|
Field,
|
|
|
|
SecretStr,
|
|
|
|
root_validator,
|
|
|
|
)
|
|
|
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
|
|
|
|
|
|
|
|
|
|
|
class WatsonxEmbeddings(BaseModel, LangChainEmbeddings):
|
2024-06-19 03:00:27 +00:00
|
|
|
"""IBM WatsonX.ai embedding models."""
|
|
|
|
|
2024-04-19 20:56:24 +00:00
|
|
|
model_id: str = ""
|
|
|
|
"""Type of model to use."""
|
|
|
|
|
|
|
|
project_id: str = ""
|
|
|
|
"""ID of the Watson Studio project."""
|
|
|
|
|
|
|
|
space_id: str = ""
|
|
|
|
"""ID of the Watson Studio space."""
|
|
|
|
|
|
|
|
url: Optional[SecretStr] = None
|
|
|
|
"""Url to Watson Machine Learning or CPD instance"""
|
|
|
|
|
|
|
|
apikey: Optional[SecretStr] = None
|
|
|
|
"""Apikey to Watson Machine Learning or CPD instance"""
|
|
|
|
|
|
|
|
token: Optional[SecretStr] = None
|
|
|
|
"""Token to CPD instance"""
|
|
|
|
|
|
|
|
password: Optional[SecretStr] = None
|
|
|
|
"""Password to CPD instance"""
|
|
|
|
|
|
|
|
username: Optional[SecretStr] = None
|
|
|
|
"""Username to CPD instance"""
|
|
|
|
|
|
|
|
instance_id: Optional[SecretStr] = None
|
|
|
|
"""Instance_id of CPD instance"""
|
|
|
|
|
|
|
|
version: Optional[SecretStr] = None
|
|
|
|
"""Version of CPD instance"""
|
|
|
|
|
|
|
|
params: Optional[dict] = None
|
|
|
|
"""Model parameters to use during generate requests."""
|
|
|
|
|
2024-05-08 17:23:14 +00:00
|
|
|
verify: Union[str, bool, None] = None
|
2024-04-19 20:56:24 +00:00
|
|
|
"""User can pass as verify one of following:
|
|
|
|
the path to a CA_BUNDLE file
|
|
|
|
the path of directory with certificates of trusted CAs
|
|
|
|
True - default path to truststore will be taken
|
|
|
|
False - no verification will be made"""
|
|
|
|
|
|
|
|
watsonx_embed: Embeddings = Field(default=None) #: :meta private:
|
|
|
|
|
|
|
|
watsonx_client: APIClient = Field(default=None) #: :meta private:
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
|
|
|
|
extra = Extra.forbid
|
|
|
|
arbitrary_types_allowed = True
|
|
|
|
|
|
|
|
@root_validator()
|
|
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
|
|
"""Validate that credentials and python package exists in environment."""
|
|
|
|
if isinstance(values.get("watsonx_client"), APIClient):
|
|
|
|
watsonx_embed = Embeddings(
|
|
|
|
model_id=values["model_id"],
|
|
|
|
params=values["params"],
|
|
|
|
api_client=values["watsonx_client"],
|
|
|
|
project_id=values["project_id"],
|
|
|
|
space_id=values["space_id"],
|
|
|
|
verify=values["verify"],
|
|
|
|
)
|
|
|
|
values["watsonx_embed"] = watsonx_embed
|
|
|
|
|
|
|
|
else:
|
|
|
|
values["url"] = convert_to_secret_str(
|
|
|
|
get_from_dict_or_env(values, "url", "WATSONX_URL")
|
|
|
|
)
|
|
|
|
if "cloud.ibm.com" in values.get("url", "").get_secret_value():
|
|
|
|
values["apikey"] = convert_to_secret_str(
|
|
|
|
get_from_dict_or_env(values, "apikey", "WATSONX_APIKEY")
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
if (
|
|
|
|
not values["token"]
|
|
|
|
and "WATSONX_TOKEN" not in os.environ
|
|
|
|
and not values["password"]
|
|
|
|
and "WATSONX_PASSWORD" not in os.environ
|
|
|
|
and not values["apikey"]
|
|
|
|
and "WATSONX_APIKEY" not in os.environ
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
"Did not find 'token', 'password' or 'apikey',"
|
|
|
|
" please add an environment variable"
|
|
|
|
" `WATSONX_TOKEN`, 'WATSONX_PASSWORD' or 'WATSONX_APIKEY' "
|
|
|
|
"which contains it,"
|
|
|
|
" or pass 'token', 'password' or 'apikey'"
|
|
|
|
" as a named parameter."
|
|
|
|
)
|
|
|
|
elif values["token"] or "WATSONX_TOKEN" in os.environ:
|
|
|
|
values["token"] = convert_to_secret_str(
|
|
|
|
get_from_dict_or_env(values, "token", "WATSONX_TOKEN")
|
|
|
|
)
|
|
|
|
elif values["password"] or "WATSONX_PASSWORD" in os.environ:
|
|
|
|
values["password"] = convert_to_secret_str(
|
|
|
|
get_from_dict_or_env(values, "password", "WATSONX_PASSWORD")
|
|
|
|
)
|
|
|
|
values["username"] = convert_to_secret_str(
|
|
|
|
get_from_dict_or_env(values, "username", "WATSONX_USERNAME")
|
|
|
|
)
|
|
|
|
elif values["apikey"] or "WATSONX_APIKEY" in os.environ:
|
|
|
|
values["apikey"] = convert_to_secret_str(
|
|
|
|
get_from_dict_or_env(values, "apikey", "WATSONX_APIKEY")
|
|
|
|
)
|
|
|
|
values["username"] = convert_to_secret_str(
|
|
|
|
get_from_dict_or_env(values, "username", "WATSONX_USERNAME")
|
|
|
|
)
|
|
|
|
if not values["instance_id"] or "WATSONX_INSTANCE_ID" not in os.environ:
|
|
|
|
values["instance_id"] = convert_to_secret_str(
|
|
|
|
get_from_dict_or_env(
|
|
|
|
values, "instance_id", "WATSONX_INSTANCE_ID"
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2024-05-06 16:48:26 +00:00
|
|
|
credentials = Credentials(
|
|
|
|
url=values["url"].get_secret_value() if values["url"] else None,
|
|
|
|
api_key=values["apikey"].get_secret_value()
|
2024-04-19 20:56:24 +00:00
|
|
|
if values["apikey"]
|
|
|
|
else None,
|
2024-05-06 16:48:26 +00:00
|
|
|
token=values["token"].get_secret_value() if values["token"] else None,
|
|
|
|
password=values["password"].get_secret_value()
|
2024-04-19 20:56:24 +00:00
|
|
|
if values["password"]
|
|
|
|
else None,
|
2024-05-06 16:48:26 +00:00
|
|
|
username=values["username"].get_secret_value()
|
2024-04-19 20:56:24 +00:00
|
|
|
if values["username"]
|
|
|
|
else None,
|
2024-05-06 16:48:26 +00:00
|
|
|
instance_id=values["instance_id"].get_secret_value()
|
2024-04-19 20:56:24 +00:00
|
|
|
if values["instance_id"]
|
|
|
|
else None,
|
2024-05-06 16:48:26 +00:00
|
|
|
version=values["version"].get_secret_value()
|
2024-04-19 20:56:24 +00:00
|
|
|
if values["version"]
|
|
|
|
else None,
|
2024-05-06 16:48:26 +00:00
|
|
|
verify=values["verify"],
|
|
|
|
)
|
2024-04-19 20:56:24 +00:00
|
|
|
|
|
|
|
watsonx_embed = Embeddings(
|
|
|
|
model_id=values["model_id"],
|
|
|
|
params=values["params"],
|
2024-05-06 16:48:26 +00:00
|
|
|
credentials=credentials,
|
2024-04-19 20:56:24 +00:00
|
|
|
project_id=values["project_id"],
|
|
|
|
space_id=values["space_id"],
|
|
|
|
)
|
|
|
|
|
|
|
|
values["watsonx_embed"] = watsonx_embed
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
"""Embed search docs."""
|
|
|
|
return self.watsonx_embed.embed_documents(texts=texts)
|
|
|
|
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
|
|
"""Embed query text."""
|
|
|
|
return self.embed_documents([text])[0]
|