From cd563fb628f274e1fe95408b03d613922a7148fb Mon Sep 17 00:00:00 2001 From: rick-SOPTIM <148056088+rick-SOPTIM@users.noreply.github.com> Date: Thu, 25 Jul 2024 17:48:35 +0200 Subject: [PATCH] community[minor]: passthrough auth parameter on requests to Ollama-LLMs (#24068) Thank you for contributing to LangChain! **Description:** This PR allows users of `langchain_community.llms.ollama.Ollama` to specify the `auth` parameter, which is then forwarded to all internal calls of `requests.request`. This works in the same way as the existing `headers` parameters. The auth parameter enables the usage of the given class with Ollama instances, which are secured by more complex authentication mechanisms, that do not only rely on static headers. An example are AWS API Gateways secured by the IAM authorizer, which expects signatures dynamically calculated on the specific HTTP request. **Issue:** Integrating a remote LLM running through Ollama using `langchain_community.llms.ollama.Ollama` only allows setting static HTTP headers with the parameter `headers`. This does not work, if the given instance of Ollama is secured with an authentication mechanism that makes use of dynamically created HTTP headers which for example may depend on the content of a given request. **Dependencies:** None **Twitter handle:** None --------- Co-authored-by: Eugene Yurtsev --- .../langchain_community/llms/ollama.py | 21 ++++++++++- .../tests/unit_tests/llms/test_ollama.py | 35 ++++++++++++++++--- 2 files changed, 50 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index 01b1ce37e1..73bf0d8fba 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -1,5 +1,18 @@ +from __future__ import annotations + import json -from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Union, +) import aiohttp import requests @@ -132,6 +145,10 @@ class _OllamaCommon(BaseLanguageModel): tokens for authentication. """ + auth: Union[Callable, Tuple, None] = None + """Additional auth tuple or callable to enable Basic/Digest/Custom HTTP Auth. + Expects the same format, type and values as requests.request auth parameter.""" + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -237,6 +254,7 @@ class _OllamaCommon(BaseLanguageModel): "Content-Type": "application/json", **(self.headers if isinstance(self.headers, dict) else {}), }, + auth=self.auth, json=request_payload, stream=True, timeout=self.timeout, @@ -300,6 +318,7 @@ class _OllamaCommon(BaseLanguageModel): "Content-Type": "application/json", **(self.headers if isinstance(self.headers, dict) else {}), }, + auth=self.auth, json=request_payload, timeout=self.timeout, ) as response: diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py index 6d6ce632a8..04da4d137e 100644 --- a/libs/community/tests/unit_tests/llms/test_ollama.py +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: timeout=300, ) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", @@ -49,10 +49,35 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: llm.invoke("Test prompt") +def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None: + llm = Ollama( + base_url="https://ollama-hostname:8000", + model="foo", + auth=("Test-User", "Test-Password"), + timeout=300, + ) + + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] + assert url == "https://ollama-hostname:8000/api/generate" + assert headers == { + "Content-Type": "application/json", + } + assert json is not None + assert stream is True + assert timeout == 300 + assert auth == ("Test-User", "Test-Password") + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mock_post) + + llm.invoke("Test prompt") + + def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", @@ -72,7 +97,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None: """Test that top level params are sent to the endpoint as top level params""" llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", @@ -120,7 +145,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None: """ llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", @@ -169,7 +194,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None: """ llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json",