mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
159 lines
4.7 KiB
Python
159 lines
4.7 KiB
Python
|
import json
|
||
|
from collections import deque
|
||
|
from typing import Any, Dict
|
||
|
|
||
|
import pytest
|
||
|
import requests
|
||
|
from pytest import MonkeyPatch
|
||
|
|
||
|
from langchain_community.llms.llamafile import Llamafile
|
||
|
|
||
|
|
||
|
def default_generation_params() -> Dict[str, Any]:
|
||
|
return {
|
||
|
"temperature": 0.8,
|
||
|
"seed": -1,
|
||
|
"top_k": 40,
|
||
|
"top_p": 0.95,
|
||
|
"min_p": 0.05,
|
||
|
"n_predict": -1,
|
||
|
"n_keep": 0,
|
||
|
"tfs_z": 1.0,
|
||
|
"typical_p": 1.0,
|
||
|
"repeat_penalty": 1.1,
|
||
|
"repeat_last_n": 64,
|
||
|
"penalize_nl": True,
|
||
|
"presence_penalty": 0.0,
|
||
|
"frequency_penalty": 0.0,
|
||
|
"mirostat": 0,
|
||
|
"mirostat_tau": 5.0,
|
||
|
"mirostat_eta": 0.1,
|
||
|
}
|
||
|
|
||
|
|
||
|
def mock_response() -> requests.Response:
|
||
|
contents = json.dumps({"content": "the quick brown fox"})
|
||
|
response = requests.Response()
|
||
|
response.status_code = 200
|
||
|
response._content = str.encode(contents)
|
||
|
return response
|
||
|
|
||
|
|
||
|
def mock_response_stream(): # type: ignore[no-untyped-def]
|
||
|
mock_response = deque(
|
||
|
[
|
||
|
b'data: {"content":"the","multimodal":false,"slot_id":0,"stop":false}\n\n', # noqa
|
||
|
b'data: {"content":" quick","multimodal":false,"slot_id":0,"stop":false}\n\n', # noqa
|
||
|
]
|
||
|
)
|
||
|
|
||
|
class MockRaw:
|
||
|
def read(self, chunk_size): # type: ignore[no-untyped-def]
|
||
|
try:
|
||
|
return mock_response.popleft()
|
||
|
except IndexError:
|
||
|
return None
|
||
|
|
||
|
response = requests.Response()
|
||
|
response.status_code = 200
|
||
|
response.raw = MockRaw()
|
||
|
return response
|
||
|
|
||
|
|
||
|
def test_call(monkeypatch: MonkeyPatch) -> None:
|
||
|
"""
|
||
|
Test basic functionality of the `invoke` method
|
||
|
"""
|
||
|
llm = Llamafile(
|
||
|
base_url="http://llamafile-host:8080",
|
||
|
)
|
||
|
|
||
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||
|
assert url == "http://llamafile-host:8080/completion"
|
||
|
assert headers == {
|
||
|
"Content-Type": "application/json",
|
||
|
}
|
||
|
# 'unknown' kwarg should be ignored
|
||
|
assert json == {"prompt": "Test prompt", **default_generation_params()}
|
||
|
assert stream is False
|
||
|
assert timeout is None
|
||
|
return mock_response()
|
||
|
|
||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||
|
out = llm.invoke("Test prompt")
|
||
|
assert out == "the quick brown fox"
|
||
|
|
||
|
|
||
|
def test_call_with_kwargs(monkeypatch: MonkeyPatch) -> None:
|
||
|
"""
|
||
|
Test kwargs passed to `invoke` override the default values and are passed
|
||
|
to the endpoint correctly. Also test that any 'unknown' kwargs that are not
|
||
|
present in the LLM class attrs are ignored.
|
||
|
"""
|
||
|
llm = Llamafile(
|
||
|
base_url="http://llamafile-host:8080",
|
||
|
)
|
||
|
|
||
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||
|
assert url == "http://llamafile-host:8080/completion"
|
||
|
assert headers == {
|
||
|
"Content-Type": "application/json",
|
||
|
}
|
||
|
# 'unknown' kwarg should be ignored
|
||
|
expected = {"prompt": "Test prompt", **default_generation_params()}
|
||
|
expected["seed"] = 0
|
||
|
assert json == expected
|
||
|
assert stream is False
|
||
|
assert timeout is None
|
||
|
return mock_response()
|
||
|
|
||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||
|
out = llm.invoke(
|
||
|
"Test prompt",
|
||
|
unknown="unknown option", # should be ignored
|
||
|
seed=0, # should override the default
|
||
|
)
|
||
|
assert out == "the quick brown fox"
|
||
|
|
||
|
|
||
|
def test_call_raises_exception_on_missing_server(monkeypatch: MonkeyPatch) -> None:
|
||
|
"""
|
||
|
Test that the LLM raises a ConnectionError when no llamafile server is
|
||
|
listening at the base_url.
|
||
|
"""
|
||
|
llm = Llamafile(
|
||
|
# invalid url, nothing should actually be running here
|
||
|
base_url="http://llamafile-host:8080",
|
||
|
)
|
||
|
with pytest.raises(requests.exceptions.ConnectionError):
|
||
|
llm.invoke("Test prompt")
|
||
|
|
||
|
|
||
|
def test_streaming(monkeypatch: MonkeyPatch) -> None:
|
||
|
"""
|
||
|
Test basic functionality of `invoke` with streaming enabled.
|
||
|
"""
|
||
|
llm = Llamafile(
|
||
|
base_url="http://llamafile-hostname:8080",
|
||
|
streaming=True,
|
||
|
)
|
||
|
|
||
|
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
|
||
|
assert url == "http://llamafile-hostname:8080/completion"
|
||
|
assert headers == {
|
||
|
"Content-Type": "application/json",
|
||
|
}
|
||
|
# 'unknown' kwarg should be ignored
|
||
|
assert "unknown" not in json
|
||
|
expected = {"prompt": "Test prompt", **default_generation_params()}
|
||
|
expected["stream"] = True
|
||
|
assert json == expected
|
||
|
assert stream is True
|
||
|
assert timeout is None
|
||
|
|
||
|
return mock_response_stream()
|
||
|
|
||
|
monkeypatch.setattr(requests, "post", mock_post)
|
||
|
out = llm.invoke("Test prompt")
|
||
|
assert out == "the quick"
|