multiple: implement ls_params (#22621)

implement ls_params for ai21, fireworks, groq.
pull/22087/head
ccurme 4 months ago committed by GitHub
parent f26ab93df8
commit b57aa89f34
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -6,7 +6,7 @@ from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import BaseChatModel, LangSmithParams
from langchain_core.messages import (
BaseMessage,
)
@ -113,6 +113,23 @@ class ChatAI21(BaseChatModel, AI21Base):
return base_params
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="ai21",
ls_model_name=self.model,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None) or self.stop:
ls_params["ls_stop"] = ls_stop
return ls_params
def _build_params_for_request(
self,
messages: List[BaseMessage],

@ -21,17 +21,6 @@ class TestAI21J2(ChatModelUnitTests):
"api_key": "test_api_key",
}
@pytest.mark.xfail(reason="Not implemented.")
def test_standard_params(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_standard_params(
chat_model_class,
chat_model_params,
)
class TestAI21Jamba(ChatModelUnitTests):
@pytest.fixture
@ -44,14 +33,3 @@ class TestAI21Jamba(ChatModelUnitTests):
"model": "jamba-instruct",
"api_key": "test_api_key",
}
@pytest.mark.xfail(reason="Not implemented.")
def test_standard_params(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_standard_params(
chat_model_class,
chat_model_params,
)

@ -31,6 +31,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
agenerate_from_stream,
generate_from_stream,
)
@ -363,6 +364,23 @@ class ChatFireworks(BaseChatModel):
params["max_tokens"] = self.max_tokens
return params
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="fireworks",
ls_model_name=self.model_name,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
return ls_params
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
system_fingerprint = None

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "aiohttp"
@ -572,7 +572,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.2.0rc1"
version = "0.2.4"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -581,7 +581,7 @@ develop = true
[package.dependencies]
jsonpatch = "^1.33"
langsmith = "^0.1.0"
langsmith = "^0.1.66"
packaging = "^23.2"
pydantic = ">=1,<3"
PyYAML = ">=5.3"
@ -613,13 +613,13 @@ url = "../../standard-tests"
[[package]]
name = "langsmith"
version = "0.1.59"
version = "0.1.73"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.59-py3-none-any.whl", hash = "sha256:445e3bc1d3baa1e5340cd979907a19483b9763a2ed37b863a01113d406f69345"},
{file = "langsmith-0.1.59.tar.gz", hash = "sha256:e748a89f4dd6aa441349143e49e546c03b5dfb43376a25bfef6a5ca792fe1437"},
{file = "langsmith-0.1.73-py3-none-any.whl", hash = "sha256:38bfcce2cfcf0b2da2e9628b903c9e768e1ce59d450e8a584514c1638c595e93"},
{file = "langsmith-0.1.73.tar.gz", hash = "sha256:0055471cb1fddb76ec65499716764ad0b0314affbdf33ff1f72ad5e2d6a3b224"},
]
[package.dependencies]
@ -1551,4 +1551,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "d1856a2ed3f8cd8e735458b0169d2144683b31ec48bef606f52276db54ad4fa0"
content-hash = "1bd993cb034f7eeb243d4c0861075008065d31c6c707aeb2e99c6214d72fb409"

@ -12,7 +12,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.1.52,<0.3"
langchain-core = ">=0.2.0,<0.3"
fireworks-ai = ">=0.13.0"
openai = "^1.10.0"
requests = "^2"

@ -19,14 +19,3 @@ class TestFireworksStandard(ChatModelUnitTests):
return {
"api_key": "test_api_key",
}
@pytest.mark.xfail(reason="Not implemented.")
def test_standard_params(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_standard_params(
chat_model_class,
chat_model_params,
)

@ -31,6 +31,7 @@ from langchain_core.callbacks import (
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
LangSmithParams,
agenerate_from_stream,
generate_from_stream,
)
@ -232,6 +233,23 @@ class ChatGroq(BaseChatModel):
"""Return type of model."""
return "groq-chat"
def _get_ls_params(
self, stop: Optional[List[str]] = None, **kwargs: Any
) -> LangSmithParams:
"""Get standard params for tracing."""
params = self._get_invocation_params(stop=stop, **kwargs)
ls_params = LangSmithParams(
ls_provider="groq",
ls_model_name=self.model_name,
ls_model_type="chat",
ls_temperature=params.get("temperature", self.temperature),
)
if ls_max_tokens := params.get("max_tokens", self.max_tokens):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None) or self.stop:
ls_params["ls_stop"] = ls_stop
return ls_params
def _generate(
self,
messages: List[BaseMessage],

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -323,7 +323,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.2.0rc1"
version = "0.2.4"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -332,7 +332,7 @@ develop = true
[package.dependencies]
jsonpatch = "^1.33"
langsmith = "^0.1.0"
langsmith = "^0.1.66"
packaging = "^23.2"
pydantic = ">=1,<3"
PyYAML = ">=5.3"
@ -364,13 +364,13 @@ url = "../../standard-tests"
[[package]]
name = "langsmith"
version = "0.1.58"
version = "0.1.73"
description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform."
optional = false
python-versions = "<4.0,>=3.8.1"
files = [
{file = "langsmith-0.1.58-py3-none-any.whl", hash = "sha256:1148cc836ec99d1b2f37cd2fa3014fcac213bb6bad798a2b21bb9111c18c9768"},
{file = "langsmith-0.1.58.tar.gz", hash = "sha256:a5060933c1fb3006b498ec849677993329d7e6138bdc2ec044068ab806e09c39"},
{file = "langsmith-0.1.73-py3-none-any.whl", hash = "sha256:38bfcce2cfcf0b2da2e9628b903c9e768e1ce59d450e8a584514c1638c595e93"},
{file = "langsmith-0.1.73.tar.gz", hash = "sha256:0055471cb1fddb76ec65499716764ad0b0314affbdf33ff1f72ad5e2d6a3b224"},
]
[package.dependencies]
@ -918,4 +918,4 @@ watchmedo = ["PyYAML (>=3.10)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "4b5bd20c32502ecdb953e16f06cb8c03b0a3e0755343bafd7e937d0b7a363343"
content-hash = "672ecb755a4d938d114d4ffa96455758ecc05943c06e49e9bad3dfe65ee3c810"

@ -12,7 +12,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.1.45,<0.3"
langchain-core = ">=0.2.0,<0.3"
groq = ">=0.4.1,<1"
[tool.poetry.group.test]

@ -13,14 +13,3 @@ class TestGroqStandard(ChatModelUnitTests):
@pytest.fixture
def chat_model_class(self) -> Type[BaseChatModel]:
return ChatGroq
@pytest.mark.xfail(reason="Not implemented.")
def test_standard_params(
self,
chat_model_class: Type[BaseChatModel],
chat_model_params: dict,
) -> None:
super().test_standard_params(
chat_model_class,
chat_model_params,
)

Loading…
Cancel
Save