From b504cd739fc04f1169316203c972f553c58826f0 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 14 Nov 2022 22:05:41 -0800 Subject: [PATCH] Harrison/cleanup env check (#144) --- langchain/chains/serpapi.py | 15 ++++++--------- langchain/embeddings/cohere.py | 9 +-------- langchain/embeddings/openai.py | 9 +-------- langchain/llms/ai21.py | 9 ++------- langchain/llms/cohere.py | 10 ++-------- langchain/llms/huggingface_hub.py | 9 ++------- langchain/llms/nlpcloud.py | 9 +-------- langchain/llms/openai.py | 9 +-------- langchain/llms/utils.py | 10 +--------- langchain/utils.py | 17 +++++++++++++++++ langchain/vectorstores/elastic_vector_search.py | 15 ++++----------- 11 files changed, 38 insertions(+), 83 deletions(-) create mode 100644 langchain/utils.py diff --git a/langchain/chains/serpapi.py b/langchain/chains/serpapi.py index dbac148b..50e086e1 100644 --- a/langchain/chains/serpapi.py +++ b/langchain/chains/serpapi.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.chains.base import Chain +from langchain.utils import get_from_dict_or_env class HiddenPrints: @@ -43,7 +44,7 @@ class SerpAPIChain(Chain, BaseModel): input_key: str = "search_query" #: :meta private: output_key: str = "search_result" #: :meta private: - serpapi_api_key: Optional[str] = os.environ.get("SERPAPI_API_KEY") + serpapi_api_key: Optional[str] = None class Config: """Configuration for this pydantic object.""" @@ -69,14 +70,10 @@ class SerpAPIChain(Chain, BaseModel): @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" - serpapi_api_key = values.get("serpapi_api_key") - - if serpapi_api_key is None or serpapi_api_key == "": - raise ValueError( - "Did not find SerpAPI API key, please add an environment variable" - " `SERPAPI_API_KEY` which contains it, or pass `serpapi_api_key` " - "as a named parameter to the constructor." - ) + serpapi_api_key = get_from_dict_or_env( + values, "serpapi_api_key", "SERPAPI_API_KEY" + ) + values["serpapi_api_key"] = serpapi_api_key try: from serpapi import GoogleSearch diff --git a/langchain/embeddings/cohere.py b/langchain/embeddings/cohere.py index f9862249..9a4f2ffe 100644 --- a/langchain/embeddings/cohere.py +++ b/langchain/embeddings/cohere.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class CohereEmbeddings(BaseModel, Embeddings): @@ -38,13 +38,6 @@ class CohereEmbeddings(BaseModel, Embeddings): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) - - if cohere_api_key is None or cohere_api_key == "": - raise ValueError( - "Did not find Cohere API key, please add an environment variable" - " `COHERE_API_KEY` which contains it, or pass `cohere_api_key` as a" - " named parameter." - ) try: import cohere diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index a7366ab6..864e7758 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional from pydantic import BaseModel, Extra, root_validator from langchain.embeddings.base import Embeddings -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class OpenAIEmbeddings(BaseModel, Embeddings): @@ -38,13 +38,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings): openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) - - if openai_api_key is None or openai_api_key == "": - raise ValueError( - "Did not find OpenAI API key, please add an environment variable" - " `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a" - " named parameter." - ) try: import openai diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 3967ca65..a870d9e4 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -5,7 +5,7 @@ import requests from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class AI21PenaltyData(BaseModel): @@ -73,12 +73,7 @@ class AI21(BaseModel, LLM): def validate_environment(cls, values: Dict) -> Dict: """Validate that api key exists in environment.""" ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") - if ai21_api_key is None or ai21_api_key == "": - raise ValueError( - "Did not find AI21 API key, please add an environment variable" - " `AI21_API_KEY` which contains it, or pass `ai21_api_key`" - " as a named parameter." - ) + values["ai21_api_key"] = ai21_api_key return values @property diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index fed3b56e..e051ba47 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env +from langchain.llms.utils import enforce_stop_tokens +from langchain.utils import get_from_dict_or_env class Cohere(LLM, BaseModel): @@ -56,13 +57,6 @@ class Cohere(LLM, BaseModel): cohere_api_key = get_from_dict_or_env( values, "cohere_api_key", "COHERE_API_KEY" ) - - if cohere_api_key is None or cohere_api_key == "": - raise ValueError( - "Did not find Cohere API key, please add an environment variable" - " `COHERE_API_KEY` which contains it, or pass `cohere_api_key`" - " as a named parameter." - ) try: import cohere diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index 1cea3d55..c67c9720 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env +from langchain.llms.utils import enforce_stop_tokens +from langchain.utils import get_from_dict_or_env DEFAULT_REPO_ID = "gpt2" VALID_TASKS = ("text2text-generation", "text-generation") @@ -47,12 +48,6 @@ class HuggingFaceHub(LLM, BaseModel): huggingfacehub_api_token = get_from_dict_or_env( values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" ) - if huggingfacehub_api_token is None or huggingfacehub_api_token == "": - raise ValueError( - "Did not find HuggingFace API token, please add an environment variable" - " `HUGGINGFACEHUB_API_TOKEN` which contains it, or pass" - " `huggingfacehub_api_token` as a named parameter." - ) try: from huggingface_hub.inference_api import InferenceApi diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index e4c37ff7..d9e4c54e 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class NLPCloud(LLM, BaseModel): @@ -67,13 +67,6 @@ class NLPCloud(LLM, BaseModel): nlpcloud_api_key = get_from_dict_or_env( values, "nlpcloud_api_key", "NLPCLOUD_API_KEY" ) - - if nlpcloud_api_key is None or nlpcloud_api_key == "": - raise ValueError( - "Did not find NLPCloud API key, please add an environment variable" - " `NLPCLOUD_API_KEY` which contains it, or pass `nlpcloud_api_key`" - " as a named parameter." - ) try: import nlpcloud diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index f0127efb..2affb86d 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional from pydantic import BaseModel, Extra, root_validator from langchain.llms.base import LLM -from langchain.llms.utils import get_from_dict_or_env +from langchain.utils import get_from_dict_or_env class OpenAI(LLM, BaseModel): @@ -51,13 +51,6 @@ class OpenAI(LLM, BaseModel): openai_api_key = get_from_dict_or_env( values, "openai_api_key", "OPENAI_API_KEY" ) - - if openai_api_key is None or openai_api_key == "": - raise ValueError( - "Did not find OpenAI API key, please add an environment variable" - " `OPENAI_API_KEY` which contains it, or pass `openai_api_key`" - " as a named parameter." - ) try: import openai diff --git a/langchain/llms/utils.py b/langchain/llms/utils.py index 29c69d92..a42fd130 100644 --- a/langchain/llms/utils.py +++ b/langchain/llms/utils.py @@ -1,16 +1,8 @@ """Common utility functions for working with LLM APIs.""" -import os import re -from typing import Any, Dict, List +from typing import List def enforce_stop_tokens(text: str, stop: List[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text)[0] - - -def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> Any: - """Get a value from a dictionary or an environment variable.""" - if key in data and data[key]: - return data[key] - return os.environ.get(env_key, None) diff --git a/langchain/utils.py b/langchain/utils.py new file mode 100644 index 00000000..8588f4e9 --- /dev/null +++ b/langchain/utils.py @@ -0,0 +1,17 @@ +"""Generic utility functions.""" +import os +from typing import Any, Dict + + +def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> str: + """Get a value from a dictionary or an environment variable.""" + if key in data and data[key]: + return data[key] + elif env_key in os.environ and os.environ[env_key]: + return os.environ[env_key] + else: + raise ValueError( + f"Did not find {key}, please add an environment variable" + f" `{env_key}` which contains it, or pass" + f" `{key}` as a named parameter." + ) diff --git a/langchain/vectorstores/elastic_vector_search.py b/langchain/vectorstores/elastic_vector_search.py index 32dd4843..549277b3 100644 --- a/langchain/vectorstores/elastic_vector_search.py +++ b/langchain/vectorstores/elastic_vector_search.py @@ -1,10 +1,10 @@ """Wrapper around Elasticsearch vector database.""" -import os import uuid from typing import Any, Callable, Dict, List from langchain.docstore.document import Document from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env from langchain.vectorstores.base import VectorStore @@ -107,16 +107,9 @@ class ElasticVectorSearch(VectorStore): elasticsearch_url="http://localhost:9200" ) """ - elasticsearch_url = kwargs.get("elasticsearch_url") - if not elasticsearch_url: - elasticsearch_url = os.environ.get("ELASTICSEARCH_URL") - - if elasticsearch_url is None or elasticsearch_url == "": - raise ValueError( - "Did not find Elasticsearch URL, please add an environment variable" - " `ELASTICSEARCH_URL` which contains it, or pass" - " `elasticsearch_url` as a named parameter." - ) + elasticsearch_url = get_from_dict_or_env( + kwargs, "elasticsearch_url", "ELASTICSEARCH_URL" + ) try: import elasticsearch from elasticsearch.helpers import bulk