diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index 01b1ce37e1..73bf0d8fba 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -1,5 +1,18 @@ +from __future__ import annotations + import json -from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Union +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Mapping, + Optional, + Tuple, + Union, +) import aiohttp import requests @@ -132,6 +145,10 @@ class _OllamaCommon(BaseLanguageModel): tokens for authentication. """ + auth: Union[Callable, Tuple, None] = None + """Additional auth tuple or callable to enable Basic/Digest/Custom HTTP Auth. + Expects the same format, type and values as requests.request auth parameter.""" + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling Ollama.""" @@ -237,6 +254,7 @@ class _OllamaCommon(BaseLanguageModel): "Content-Type": "application/json", **(self.headers if isinstance(self.headers, dict) else {}), }, + auth=self.auth, json=request_payload, stream=True, timeout=self.timeout, @@ -300,6 +318,7 @@ class _OllamaCommon(BaseLanguageModel): "Content-Type": "application/json", **(self.headers if isinstance(self.headers, dict) else {}), }, + auth=self.auth, json=request_payload, timeout=self.timeout, ) as response: diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py index 6d6ce632a8..04da4d137e 100644 --- a/libs/community/tests/unit_tests/llms/test_ollama.py +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: timeout=300, ) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", @@ -49,10 +49,35 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None: llm.invoke("Test prompt") +def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None: + llm = Ollama( + base_url="https://ollama-hostname:8000", + model="foo", + auth=("Test-User", "Test-Password"), + timeout=300, + ) + + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] + assert url == "https://ollama-hostname:8000/api/generate" + assert headers == { + "Content-Type": "application/json", + } + assert json is not None + assert stream is True + assert timeout == 300 + assert auth == ("Test-User", "Test-Password") + + return mock_response_stream() + + monkeypatch.setattr(requests, "post", mock_post) + + llm.invoke("Test prompt") + + def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None: llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", @@ -72,7 +97,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None: """Test that top level params are sent to the endpoint as top level params""" llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", @@ -120,7 +145,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None: """ llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json", @@ -169,7 +194,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None: """ llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300) - def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def] + def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def] assert url == "https://ollama-hostname:8000/api/generate" assert headers == { "Content-Type": "application/json",