community[patch]: OllamaEmbeddings - Pass headers to post request (#16880)

## Feature
- Set additional headers in constructor
- Headers will be sent in post request

This feature is useful if deploying Ollama on a cloud service such as
hugging face, which requires authentication tokens to be passed in the
request header.

## Tests
- Test if header is passed
- Test if header is not passed

Similar to https://github.com/langchain-ai/langchain/pull/15881

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
shahrin014 2024-03-30 03:44:52 +09:00 committed by GitHub
parent e0f137dbe0
commit f51e6a35ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 70 additions and 2 deletions

View File

@ -105,6 +105,12 @@ class OllamaEmbeddings(BaseModel, Embeddings):
show_progress: bool = False
"""Whether to show a tqdm progress bar. Must have `tqdm` installed."""
headers: Optional[dict] = None
"""Additional headers to pass to endpoint (e.g. Authorization, Referer).
This is useful when Ollama is hosted on cloud services that require
tokens for authentication.
"""
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Ollama."""
@ -151,6 +157,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
"""
headers = {
"Content-Type": "application/json",
**(self.headers or {}),
}
try:

View File

@ -0,0 +1,61 @@
import requests
from pytest import MonkeyPatch
from langchain_community.embeddings.ollama import OllamaEmbeddings
class MockResponse:
status_code = 200
def json(self) -> dict:
return {"embedding": [1, 2, 3]}
def mock_response() -> MockResponse:
return MockResponse()
def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
embedder = OllamaEmbeddings(
base_url="https://ollama-hostname:8000",
model="foo",
headers={
"Authorization": "Bearer TEST-TOKEN-VALUE",
"Referer": "https://application-host",
},
)
def mock_post(url: str, headers: dict, json: str) -> MockResponse:
assert url == "https://ollama-hostname:8000/api/embeddings"
assert headers == {
"Content-Type": "application/json",
"Authorization": "Bearer TEST-TOKEN-VALUE",
"Referer": "https://application-host",
}
assert json is not None
return mock_response()
monkeypatch.setattr(requests, "post", mock_post)
embedder.embed_query("Test prompt")
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
embedder = OllamaEmbeddings(
base_url="https://ollama-hostname:8000",
model="foo",
)
def mock_post(url: str, headers: dict, json: str) -> MockResponse:
assert url == "https://ollama-hostname:8000/api/embeddings"
assert headers == {
"Content-Type": "application/json",
}
assert json is not None
return mock_response()
monkeypatch.setattr(requests, "post", mock_post)
embedder.embed_query("Test prompt")

View File

@ -25,7 +25,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
base_url="https://ollama-hostname:8000",
model="foo",
headers={
"Authentication": "Bearer TEST-TOKEN-VALUE",
"Authorization": "Bearer TEST-TOKEN-VALUE",
"Referer": "https://application-host",
},
timeout=300,
@ -35,7 +35,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
assert url == "https://ollama-hostname:8000/api/generate"
assert headers == {
"Content-Type": "application/json",
"Authentication": "Bearer TEST-TOKEN-VALUE",
"Authorization": "Bearer TEST-TOKEN-VALUE",
"Referer": "https://application-host",
}
assert json is not None