Harrison/cleanup env check (#144)

pull/142/head
Harrison Chase 2 years ago committed by GitHub
parent a4b502d92f
commit b504cd739f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.utils import get_from_dict_or_env
class HiddenPrints: class HiddenPrints:
@ -43,7 +44,7 @@ class SerpAPIChain(Chain, BaseModel):
input_key: str = "search_query" #: :meta private: input_key: str = "search_query" #: :meta private:
output_key: str = "search_result" #: :meta private: output_key: str = "search_result" #: :meta private:
serpapi_api_key: Optional[str] = os.environ.get("SERPAPI_API_KEY") serpapi_api_key: Optional[str] = None
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -69,14 +70,10 @@ class SerpAPIChain(Chain, BaseModel):
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment.""" """Validate that api key and python package exists in environment."""
serpapi_api_key = values.get("serpapi_api_key") serpapi_api_key = get_from_dict_or_env(
values, "serpapi_api_key", "SERPAPI_API_KEY"
if serpapi_api_key is None or serpapi_api_key == "": )
raise ValueError( values["serpapi_api_key"] = serpapi_api_key
"Did not find SerpAPI API key, please add an environment variable"
" `SERPAPI_API_KEY` which contains it, or pass `serpapi_api_key` "
"as a named parameter to the constructor."
)
try: try:
from serpapi import GoogleSearch from serpapi import GoogleSearch

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.llms.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class CohereEmbeddings(BaseModel, Embeddings): class CohereEmbeddings(BaseModel, Embeddings):
@ -38,13 +38,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key = get_from_dict_or_env( cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY" values, "cohere_api_key", "COHERE_API_KEY"
) )
if cohere_api_key is None or cohere_api_key == "":
raise ValueError(
"Did not find Cohere API key, please add an environment variable"
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key` as a"
" named parameter."
)
try: try:
import cohere import cohere

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.llms.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class OpenAIEmbeddings(BaseModel, Embeddings): class OpenAIEmbeddings(BaseModel, Embeddings):
@ -38,13 +38,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
openai_api_key = get_from_dict_or_env( openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY" values, "openai_api_key", "OPENAI_API_KEY"
) )
if openai_api_key is None or openai_api_key == "":
raise ValueError(
"Did not find OpenAI API key, please add an environment variable"
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a"
" named parameter."
)
try: try:
import openai import openai

@ -5,7 +5,7 @@ import requests
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class AI21PenaltyData(BaseModel): class AI21PenaltyData(BaseModel):
@ -73,12 +73,7 @@ class AI21(BaseModel, LLM):
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment.""" """Validate that api key exists in environment."""
ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY") ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
if ai21_api_key is None or ai21_api_key == "": values["ai21_api_key"] = ai21_api_key
raise ValueError(
"Did not find AI21 API key, please add an environment variable"
" `AI21_API_KEY` which contains it, or pass `ai21_api_key`"
" as a named parameter."
)
return values return values
@property @property

@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class Cohere(LLM, BaseModel): class Cohere(LLM, BaseModel):
@ -56,13 +57,6 @@ class Cohere(LLM, BaseModel):
cohere_api_key = get_from_dict_or_env( cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY" values, "cohere_api_key", "COHERE_API_KEY"
) )
if cohere_api_key is None or cohere_api_key == "":
raise ValueError(
"Did not find Cohere API key, please add an environment variable"
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key`"
" as a named parameter."
)
try: try:
import cohere import cohere

@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
DEFAULT_REPO_ID = "gpt2" DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation") VALID_TASKS = ("text2text-generation", "text-generation")
@ -47,12 +48,6 @@ class HuggingFaceHub(LLM, BaseModel):
huggingfacehub_api_token = get_from_dict_or_env( huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN" values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
) )
if huggingfacehub_api_token is None or huggingfacehub_api_token == "":
raise ValueError(
"Did not find HuggingFace API token, please add an environment variable"
" `HUGGINGFACEHUB_API_TOKEN` which contains it, or pass"
" `huggingfacehub_api_token` as a named parameter."
)
try: try:
from huggingface_hub.inference_api import InferenceApi from huggingface_hub.inference_api import InferenceApi

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class NLPCloud(LLM, BaseModel): class NLPCloud(LLM, BaseModel):
@ -67,13 +67,6 @@ class NLPCloud(LLM, BaseModel):
nlpcloud_api_key = get_from_dict_or_env( nlpcloud_api_key = get_from_dict_or_env(
values, "nlpcloud_api_key", "NLPCLOUD_API_KEY" values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
) )
if nlpcloud_api_key is None or nlpcloud_api_key == "":
raise ValueError(
"Did not find NLPCloud API key, please add an environment variable"
" `NLPCLOUD_API_KEY` which contains it, or pass `nlpcloud_api_key`"
" as a named parameter."
)
try: try:
import nlpcloud import nlpcloud

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM from langchain.llms.base import LLM
from langchain.llms.utils import get_from_dict_or_env from langchain.utils import get_from_dict_or_env
class OpenAI(LLM, BaseModel): class OpenAI(LLM, BaseModel):
@ -51,13 +51,6 @@ class OpenAI(LLM, BaseModel):
openai_api_key = get_from_dict_or_env( openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY" values, "openai_api_key", "OPENAI_API_KEY"
) )
if openai_api_key is None or openai_api_key == "":
raise ValueError(
"Did not find OpenAI API key, please add an environment variable"
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key`"
" as a named parameter."
)
try: try:
import openai import openai

@ -1,16 +1,8 @@
"""Common utility functions for working with LLM APIs.""" """Common utility functions for working with LLM APIs."""
import os
import re import re
from typing import Any, Dict, List from typing import List
def enforce_stop_tokens(text: str, stop: List[str]) -> str: def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur.""" """Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text)[0] return re.split("|".join(stop), text)[0]
def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> Any:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
return os.environ.get(env_key, None)

@ -0,0 +1,17 @@
"""Generic utility functions."""
import os
from typing import Any, Dict
def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> str:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
elif env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)

@ -1,10 +1,10 @@
"""Wrapper around Elasticsearch vector database.""" """Wrapper around Elasticsearch vector database."""
import os
import uuid import uuid
from typing import Any, Callable, Dict, List from typing import Any, Callable, Dict, List
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
@ -107,16 +107,9 @@ class ElasticVectorSearch(VectorStore):
elasticsearch_url="http://localhost:9200" elasticsearch_url="http://localhost:9200"
) )
""" """
elasticsearch_url = kwargs.get("elasticsearch_url") elasticsearch_url = get_from_dict_or_env(
if not elasticsearch_url: kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL") )
if elasticsearch_url is None or elasticsearch_url == "":
raise ValueError(
"Did not find Elasticsearch URL, please add an environment variable"
" `ELASTICSEARCH_URL` which contains it, or pass"
" `elasticsearch_url` as a named parameter."
)
try: try:
import elasticsearch import elasticsearch
from elasticsearch.helpers import bulk from elasticsearch.helpers import bulk

Loading…
Cancel
Save