community: add standard chat model params to Ollama (#22446)

pull/22401/head^2
ccurme 4 months ago committed by GitHub
parent 5119ab2fb9
commit afe89a1411
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -6,7 +6,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
@ -69,6 +69,23 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
"""Return whether this model can be serialized by Langchain."""
return False
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="ollama",
ls_model_name=self.model,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("num_predict", self.num_predict):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None) or self.stop:
ls_params["ls_stop"] = ls_stop
return ls_params
@deprecated("0.0.3", alternative="_convert_messages_to_ollama_messages")
def _format_message_as_text(self, message: BaseMessage) -> str:
if isinstance(message, ChatMessage):

@ -0,0 +1,35 @@
from typing import List, Literal, Optional
import pytest
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_community.chat_models import ChatOllama
def test_standard_params() -> None:
class ExpectedParams(BaseModel):
ls_provider: str
ls_model_name: str
ls_model_type: Literal["chat"]
ls_temperature: Optional[float]
ls_max_tokens: Optional[int]
ls_stop: Optional[List[str]]
model = ChatOllama(model="llama3")
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
pytest.fail(f"Validation error: {e}")
assert ls_params["ls_model_name"] == "llama3"
# Test optional params
model = ChatOllama(num_predict=10, stop=["test"], temperature=0.33)
ls_params = model._get_ls_params()
try:
ExpectedParams(**ls_params)
except ValidationError as e:
pytest.fail(f"Validation error: {e}")
assert ls_params["ls_max_tokens"] == 10
assert ls_params["ls_stop"] == ["test"]
assert ls_params["ls_temperature"] == 0.33
Loading…
Cancel
Save