From b109cb031be390a029eec1657662f4b57f7e7656 Mon Sep 17 00:00:00 2001 From: Prabin Nepal <43682497+nepalprabin@users.noreply.github.com> Date: Mon, 30 Oct 2023 20:17:53 -0500 Subject: [PATCH] SecretStr for fireworks api (#12475) - **Description:** This pull request removes secrets present in raw format, - **Issue:** Fireworks api key was exposed when printing out the langchain object [#12165](https://github.com/langchain-ai/langchain/issues/12165) - **Maintainer:** @eyurtsev --------- Co-authored-by: Bagatur --- .../langchain/chat_models/fireworks.py | 11 ++++--- libs/langchain/langchain/llms/fireworks.py | 11 ++++--- libs/langchain/poetry.lock | 32 +++++++++++++++++-- libs/langchain/pyproject.toml | 2 ++ .../unit_tests/chat_models/test_fireworks.py | 28 ++++++++++++++++ .../tests/unit_tests/llms/test_fireworks.py | 28 ++++++++++++++++ 6 files changed, 100 insertions(+), 12 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/chat_models/test_fireworks.py create mode 100644 libs/langchain/tests/unit_tests/llms/test_fireworks.py diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index ce0ce83b59..1bc35ca42e 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -17,7 +17,7 @@ from langchain.callbacks.manager import ( ) from langchain.chat_models.base import BaseChatModel from langchain.llms.base import create_base_retry_decorator -from langchain.pydantic_v1 import Field, root_validator +from langchain.pydantic_v1 import Field, SecretStr, root_validator from langchain.schema.messages import ( AIMessage, AIMessageChunk, @@ -33,6 +33,7 @@ from langchain.schema.messages import ( SystemMessageChunk, ) from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain.utils import convert_to_secret_str from langchain.utils.env import get_from_dict_or_env @@ -87,7 +88,7 @@ class ChatFireworks(BaseChatModel): "top_p": 1, }.copy() ) - fireworks_api_key: Optional[str] = None + fireworks_api_key: Optional[SecretStr] = None max_retries: int = 20 use_retry: bool = True @@ -109,10 +110,10 @@ class ChatFireworks(BaseChatModel): "Could not import fireworks-ai python package. " "Please install it with `pip install fireworks-ai`." ) from e - fireworks_api_key = get_from_dict_or_env( - values, "fireworks_api_key", "FIREWORKS_API_KEY" + fireworks_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY") ) - fireworks.client.api_key = fireworks_api_key + fireworks.client.api_key = fireworks_api_key.get_secret_value() return values @property diff --git a/libs/langchain/langchain/llms/fireworks.py b/libs/langchain/langchain/llms/fireworks.py index 676c8813b0..72167d83d4 100644 --- a/libs/langchain/langchain/llms/fireworks.py +++ b/libs/langchain/langchain/llms/fireworks.py @@ -7,8 +7,9 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) from langchain.llms.base import BaseLLM, create_base_retry_decorator -from langchain.pydantic_v1 import Field, root_validator +from langchain.pydantic_v1 import Field, SecretStr, root_validator from langchain.schema.output import Generation, GenerationChunk, LLMResult +from langchain.utils import convert_to_secret_str from langchain.utils.env import get_from_dict_or_env @@ -36,7 +37,7 @@ class Fireworks(BaseLLM): "top_p": 1, }.copy() ) - fireworks_api_key: Optional[str] = None + fireworks_api_key: Optional[SecretStr] = None max_retries: int = 20 batch_size: int = 20 use_retry: bool = True @@ -59,10 +60,10 @@ class Fireworks(BaseLLM): "Could not import fireworks-ai python package. " "Please install it with `pip install fireworks-ai`." ) from e - fireworks_api_key = get_from_dict_or_env( - values, "fireworks_api_key", "FIREWORKS_API_KEY" + fireworks_api_key = convert_to_secret_str( + get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY") ) - fireworks.client.api_key = fireworks_api_key + fireworks.client.api_key = fireworks_api_key.get_secret_value() return values @property diff --git a/libs/langchain/poetry.lock b/libs/langchain/poetry.lock index 57a42cdf58..38f0807daf 100644 --- a/libs/langchain/poetry.lock +++ b/libs/langchain/poetry.lock @@ -2415,6 +2415,23 @@ calc = ["shapely"] s3 = ["boto3 (>=1.3.1)"] test = ["Fiona[s3]", "pytest (>=7)", "pytest-cov", "pytz"] +[[package]] +name = "fireworks-ai" +version = "0.6.0" +description = "Python client library for the Fireworks.ai Generative AI Platform" +optional = true +python-versions = ">=3.9" +files = [ + {file = "fireworks-ai-0.6.0.tar.gz", hash = "sha256:815b933f6236e3da9c85fea1c51a6b9dd3673110877f6c7f69ba978fdfb1f0f0"}, + {file = "fireworks_ai-0.6.0-py3-none-any.whl", hash = "sha256:f7f32ab10131a7897c5e78dd69531de30af21a6899c10f5a7e405f01b57b6432"}, +] + +[package.dependencies] +httpx = "*" +httpx-sse = "*" +Pillow = "*" +pydantic = "*" + [[package]] name = "flatbuffers" version = "23.5.26" @@ -3268,6 +3285,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] +[[package]] +name = "httpx-sse" +version = "0.3.1" +description = "Consume Server-Sent Event (SSE) messages with HTTPX." +optional = true +python-versions = ">=3.7" +files = [ + {file = "httpx-sse-0.3.1.tar.gz", hash = "sha256:3bb3289b2867f50cbdb2fee3eeeefecb1e86653122e164faac0023f1ffc88aea"}, + {file = "httpx_sse-0.3.1-py3-none-any.whl", hash = "sha256:7376dd88732892f9b6b549ac0ad05a8e2341172fe7dcf9f8f9c8050934297316"}, +] + [[package]] name = "huggingface-hub" version = "0.18.0" @@ -11042,7 +11070,7 @@ cli = ["typer"] cohere = ["cohere"] docarray = ["docarray"] embeddings = ["sentence-transformers"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "dashvector", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "html2text", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "openai", "openai", "openapi-pydantic", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict"] javascript = ["esprima"] llms = ["clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openlm", "torch", "transformers"] openai = ["openai", "tiktoken"] @@ -11052,4 +11080,4 @@ text-helpers = ["chardet"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "9ffdcad5f675571917ffb0f222acdb578f406939695977da3c19e55192cac513" +content-hash = "d9f38367a43153fda9eaea5e8fd03acfc0a06006dfecc3693dffb212249197bc" diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index a42a25525a..86e4e79da1 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -139,6 +139,7 @@ aiosqlite = {version = "^0.19.0", optional = true} rspace_client = {version = "^2.5.0", optional = true} upstash-redis = {version = "^0.15.0", optional = true} google-cloud-documentai = {version = "^2.20.1", optional = true} +fireworks-ai = {version = "^0.6.0", optional = true, python = ">=3.9,<4.0"} [tool.poetry.group.test.dependencies] @@ -372,6 +373,7 @@ extended_testing = [ "anthropic", "upstash-redis", "rspace_client", + "fireworks-ai", ] [tool.ruff] diff --git a/libs/langchain/tests/unit_tests/chat_models/test_fireworks.py b/libs/langchain/tests/unit_tests/chat_models/test_fireworks.py new file mode 100644 index 0000000000..8db4f9653a --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/test_fireworks.py @@ -0,0 +1,28 @@ +"""Test Fireworks chat model""" +import sys + +import pytest +from pytest import CaptureFixture + +from langchain.chat_models import ChatFireworks +from langchain.pydantic_v1 import SecretStr + +if sys.version_info < (3, 9): + pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) + + +@pytest.mark.requires("fireworks") +def test_api_key_is_string() -> None: + llm = ChatFireworks(fireworks_api_key="secret-api-key") + assert isinstance(llm.fireworks_api_key, SecretStr) + + +@pytest.mark.requires("fireworks") +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = ChatFireworks(fireworks_api_key="secret-api-key") + print(llm.fireworks_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" diff --git a/libs/langchain/tests/unit_tests/llms/test_fireworks.py b/libs/langchain/tests/unit_tests/llms/test_fireworks.py new file mode 100644 index 0000000000..cdfe6cfe64 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_fireworks.py @@ -0,0 +1,28 @@ +"""Test Fireworks chat model""" +import sys + +import pytest +from pytest import CaptureFixture + +from langchain.llms import Fireworks +from langchain.pydantic_v1 import SecretStr + +if sys.version_info < (3, 9): + pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) + + +@pytest.mark.requires("fireworks") +def test_api_key_is_string() -> None: + llm = Fireworks(fireworks_api_key="secret-api-key") + assert isinstance(llm.fireworks_api_key, SecretStr) + + +@pytest.mark.requires("fireworks") +def test_api_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + llm = Fireworks(fireworks_api_key="secret-api-key") + print(llm.fireworks_api_key, end="") + captured = capsys.readouterr() + + assert captured.out == "**********"