Harrison/cleanup env check (#144)

harrison/prompts_take_2
Harrison Chase 2 years ago committed by GitHub
parent a4b502d92f
commit b504cd739f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.utils import get_from_dict_or_env
class HiddenPrints:
@ -43,7 +44,7 @@ class SerpAPIChain(Chain, BaseModel):
input_key: str = "search_query" #: :meta private:
output_key: str = "search_result" #: :meta private:
serpapi_api_key: Optional[str] = os.environ.get("SERPAPI_API_KEY")
serpapi_api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
@ -69,14 +70,10 @@ class SerpAPIChain(Chain, BaseModel):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
serpapi_api_key = values.get("serpapi_api_key")
if serpapi_api_key is None or serpapi_api_key == "":
raise ValueError(
"Did not find SerpAPI API key, please add an environment variable"
" `SERPAPI_API_KEY` which contains it, or pass `serpapi_api_key` "
"as a named parameter to the constructor."
)
serpapi_api_key = get_from_dict_or_env(
values, "serpapi_api_key", "SERPAPI_API_KEY"
)
values["serpapi_api_key"] = serpapi_api_key
try:
from serpapi import GoogleSearch

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.llms.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env
class CohereEmbeddings(BaseModel, Embeddings):
@ -38,13 +38,6 @@ class CohereEmbeddings(BaseModel, Embeddings):
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
if cohere_api_key is None or cohere_api_key == "":
raise ValueError(
"Did not find Cohere API key, please add an environment variable"
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key` as a"
" named parameter."
)
try:
import cohere

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.embeddings.base import Embeddings
from langchain.llms.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env
class OpenAIEmbeddings(BaseModel, Embeddings):
@ -38,13 +38,6 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
if openai_api_key is None or openai_api_key == "":
raise ValueError(
"Did not find OpenAI API key, please add an environment variable"
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key` as a"
" named parameter."
)
try:
import openai

@ -5,7 +5,7 @@ import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env
class AI21PenaltyData(BaseModel):
@ -73,12 +73,7 @@ class AI21(BaseModel, LLM):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
ai21_api_key = get_from_dict_or_env(values, "ai21_api_key", "AI21_API_KEY")
if ai21_api_key is None or ai21_api_key == "":
raise ValueError(
"Did not find AI21 API key, please add an environment variable"
" `AI21_API_KEY` which contains it, or pass `ai21_api_key`"
" as a named parameter."
)
values["ai21_api_key"] = ai21_api_key
return values
@property

@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class Cohere(LLM, BaseModel):
@ -56,13 +57,6 @@ class Cohere(LLM, BaseModel):
cohere_api_key = get_from_dict_or_env(
values, "cohere_api_key", "COHERE_API_KEY"
)
if cohere_api_key is None or cohere_api_key == "":
raise ValueError(
"Did not find Cohere API key, please add an environment variable"
" `COHERE_API_KEY` which contains it, or pass `cohere_api_key`"
" as a named parameter."
)
try:
import cohere

@ -4,7 +4,8 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens, get_from_dict_or_env
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
DEFAULT_REPO_ID = "gpt2"
VALID_TASKS = ("text2text-generation", "text-generation")
@ -47,12 +48,6 @@ class HuggingFaceHub(LLM, BaseModel):
huggingfacehub_api_token = get_from_dict_or_env(
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
)
if huggingfacehub_api_token is None or huggingfacehub_api_token == "":
raise ValueError(
"Did not find HuggingFace API token, please add an environment variable"
" `HUGGINGFACEHUB_API_TOKEN` which contains it, or pass"
" `huggingfacehub_api_token` as a named parameter."
)
try:
from huggingface_hub.inference_api import InferenceApi

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env
class NLPCloud(LLM, BaseModel):
@ -67,13 +67,6 @@ class NLPCloud(LLM, BaseModel):
nlpcloud_api_key = get_from_dict_or_env(
values, "nlpcloud_api_key", "NLPCLOUD_API_KEY"
)
if nlpcloud_api_key is None or nlpcloud_api_key == "":
raise ValueError(
"Did not find NLPCloud API key, please add an environment variable"
" `NLPCLOUD_API_KEY` which contains it, or pass `nlpcloud_api_key`"
" as a named parameter."
)
try:
import nlpcloud

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import get_from_dict_or_env
from langchain.utils import get_from_dict_or_env
class OpenAI(LLM, BaseModel):
@ -51,13 +51,6 @@ class OpenAI(LLM, BaseModel):
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
if openai_api_key is None or openai_api_key == "":
raise ValueError(
"Did not find OpenAI API key, please add an environment variable"
" `OPENAI_API_KEY` which contains it, or pass `openai_api_key`"
" as a named parameter."
)
try:
import openai

@ -1,16 +1,8 @@
"""Common utility functions for working with LLM APIs."""
import os
import re
from typing import Any, Dict, List
from typing import List
def enforce_stop_tokens(text: str, stop: List[str]) -> str:
"""Cut off the text as soon as any stop words occur."""
return re.split("|".join(stop), text)[0]
def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> Any:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
return os.environ.get(env_key, None)

@ -0,0 +1,17 @@
"""Generic utility functions."""
import os
from typing import Any, Dict
def get_from_dict_or_env(data: Dict[str, Any], key: str, env_key: str) -> str:
"""Get a value from a dictionary or an environment variable."""
if key in data and data[key]:
return data[key]
elif env_key in os.environ and os.environ[env_key]:
return os.environ[env_key]
else:
raise ValueError(
f"Did not find {key}, please add an environment variable"
f" `{env_key}` which contains it, or pass"
f" `{key}` as a named parameter."
)

@ -1,10 +1,10 @@
"""Wrapper around Elasticsearch vector database."""
import os
import uuid
from typing import Any, Callable, Dict, List
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
from langchain.vectorstores.base import VectorStore
@ -107,16 +107,9 @@ class ElasticVectorSearch(VectorStore):
elasticsearch_url="http://localhost:9200"
)
"""
elasticsearch_url = kwargs.get("elasticsearch_url")
if not elasticsearch_url:
elasticsearch_url = os.environ.get("ELASTICSEARCH_URL")
if elasticsearch_url is None or elasticsearch_url == "":
raise ValueError(
"Did not find Elasticsearch URL, please add an environment variable"
" `ELASTICSEARCH_URL` which contains it, or pass"
" `elasticsearch_url` as a named parameter."
)
elasticsearch_url = get_from_dict_or_env(
kwargs, "elasticsearch_url", "ELASTICSEARCH_URL"
)
try:
import elasticsearch
from elasticsearch.helpers import bulk

Loading…
Cancel
Save