community[minor]: passthrough auth parameter on requests to Ollama-LLMs (#24068)

Thank you for contributing to LangChain!

**Description:**
This PR allows users of `langchain_community.llms.ollama.Ollama` to
specify the `auth` parameter, which is then forwarded to all internal
calls of `requests.request`. This works in the same way as the existing
`headers` parameters. The auth parameter enables the usage of the given
class with Ollama instances, which are secured by more complex
authentication mechanisms, that do not only rely on static headers. An
example are AWS API Gateways secured by the IAM authorizer, which
expects signatures dynamically calculated on the specific HTTP request.

**Issue:**

Integrating a remote LLM running through Ollama using
`langchain_community.llms.ollama.Ollama` only allows setting static HTTP
headers with the parameter `headers`. This does not work, if the given
instance of Ollama is secured with an authentication mechanism that
makes use of dynamically created HTTP headers which for example may
depend on the content of a given request.

**Dependencies:**

None

**Twitter handle:**

None

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
rick-SOPTIM 2024-07-25 17:48:35 +02:00 committed by GitHub
parent 256bad3251
commit cd563fb628
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 50 additions and 6 deletions

View File

@ -1,5 +1,18 @@
from __future__ import annotations
import json
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Union
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
Union,
)
import aiohttp
import requests
@ -132,6 +145,10 @@ class _OllamaCommon(BaseLanguageModel):
tokens for authentication.
"""
auth: Union[Callable, Tuple, None] = None
"""Additional auth tuple or callable to enable Basic/Digest/Custom HTTP Auth.
Expects the same format, type and values as requests.request auth parameter."""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
@ -237,6 +254,7 @@ class _OllamaCommon(BaseLanguageModel):
"Content-Type": "application/json",
**(self.headers if isinstance(self.headers, dict) else {}),
},
auth=self.auth,
json=request_payload,
stream=True,
timeout=self.timeout,
@ -300,6 +318,7 @@ class _OllamaCommon(BaseLanguageModel):
"Content-Type": "application/json",
**(self.headers if isinstance(self.headers, dict) else {}),
},
auth=self.auth,
json=request_payload,
timeout=self.timeout,
) as response:

View File

@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300,
)
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -49,10 +49,35 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
llm.invoke("Test prompt")
def test_pass_auth_if_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(
base_url="https://ollama-hostname:8000",
model="foo",
auth=("Test-User", "Test-Password"),
timeout=300,
)
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
}
assert json is not None
assert stream is True
assert timeout == 300
assert auth == ("Test-User", "Test-Password")
return mock_response_stream()
monkeypatch.setattr(requests, "post", mock_post)
llm.invoke("Test prompt")
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -72,7 +97,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"""Test that top level params are sent to the endpoint as top level params"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -120,7 +145,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
@ -169,7 +194,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
def mock_post(url, headers, json, stream, timeout, auth): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",