From b82644078efcdcfce99e66bb6de79cf646ac2e0b Mon Sep 17 00:00:00 2001 From: gonvee Date: Tue, 19 Mar 2024 12:29:01 +0800 Subject: [PATCH] =?UTF-8?q?community:=20Add=20`keep=5Falive`=20parameter?= =?UTF-8?q?=20to=20control=20how=20long=20the=20model=20w=E2=80=A6=20(#190?= =?UTF-8?q?05)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `keep_alive` parameter to control how long the model will stay loaded into memory with Ollama。 --------- Co-authored-by: Bagatur --- libs/community/langchain_community/llms/ollama.py | 15 ++++++++++++++- .../tests/unit_tests/llms/test_ollama.py | 3 +++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/libs/community/langchain_community/llms/ollama.py b/libs/community/langchain_community/llms/ollama.py index 2dc198fec1..c4747a4ebb 100644 --- a/libs/community/langchain_community/llms/ollama.py +++ b/libs/community/langchain_community/llms/ollama.py @@ -1,5 +1,5 @@ import json -from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional +from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, Union import aiohttp import requests @@ -111,6 +111,18 @@ class _OllamaCommon(BaseLanguageModel): timeout: Optional[int] = None """Timeout for the request stream""" + keep_alive: Optional[Union[int, str]] = None + """How long the model will stay loaded into memory. + + The parameter (Default: 5 minutes) can be set to: + 1. a duration string in Golang (such as "10m" or "24h"); + 2. a number in seconds (such as 3600); + 3. any negative number which will keep the model loaded \ + in memory (e.g. -1 or "-1m"); + 4. 0 which will unload the model immediately after generating a response; + + See the [Ollama documents](https://github.com/ollama/ollama/blob/main/docs/faq.md#how-do-i-keep-a-model-loaded-in-memory-or-make-it-unload-immediately)""" + headers: Optional[dict] = None """Additional headers to pass to endpoint (e.g. Authorization, Referer). This is useful when Ollama is hosted on cloud services that require @@ -141,6 +153,7 @@ class _OllamaCommon(BaseLanguageModel): }, "system": self.system, "template": self.template, + "keep_alive": self.keep_alive, } @property diff --git a/libs/community/tests/unit_tests/llms/test_ollama.py b/libs/community/tests/unit_tests/llms/test_ollama.py index 3b1798fd2e..8807aab826 100644 --- a/libs/community/tests/unit_tests/llms/test_ollama.py +++ b/libs/community/tests/unit_tests/llms/test_ollama.py @@ -100,6 +100,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None: "prompt": "Test prompt", "system": "Test system prompt", "template": None, + "keep_alive": None, } assert stream is True assert timeout == 300 @@ -147,6 +148,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None: "prompt": "Test prompt", "system": None, "template": None, + "keep_alive": None, } assert stream is True assert timeout == 300 @@ -178,6 +180,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None: "prompt": "Test prompt", "system": None, "template": None, + "keep_alive": None, } assert stream is True assert timeout == 300