langchain/libs/community/tests/integration_tests/llms/test_azureml_endpoint.py

177 lines
5.9 KiB
Python
Raw Normal View History

"""Test AzureML Endpoint wrapper."""
import json
import os
from pathlib import Path
from typing import Dict
from urllib.request import HTTPError
import pytest
from langchain_core.pydantic_v1 import ValidationError
community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463) Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
2023-12-11 21:53:30 +00:00
from langchain_community.llms.azureml_endpoint import (
AzureMLOnlineEndpoint,
ContentFormatterBase,
DollyContentFormatter,
HFContentFormatter,
OSSContentFormatter,
)
community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463) Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
2023-12-11 21:53:30 +00:00
from langchain_community.llms.loading import load_llm
def test_gpt2_call() -> None:
"""Test valid call to GPT2."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(),
)
output = llm.invoke("Foo")
assert isinstance(output, str)
def test_hf_call() -> None:
"""Test valid call to HuggingFace Foundation Model."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("HF_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("HF_ENDPOINT_URL"),
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
content_formatter=HFContentFormatter(),
)
output = llm.invoke("Foo")
assert isinstance(output, str)
def test_dolly_call() -> None:
"""Test valid call to dolly-v2."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"),
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
content_formatter=DollyContentFormatter(),
)
output = llm.invoke("Foo")
assert isinstance(output, str)
def test_custom_formatter() -> None:
"""Test ability to create a custom content formatter."""
class CustomFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps(
{
"inputs": [prompt],
"parameters": model_kwargs,
"options": {"use_cache": False, "wait_for_model": True},
}
)
return input_str.encode("utf-8")
def format_response_payload(self, output: bytes) -> str: # type: ignore[override]
response_json = json.loads(output)
return response_json[0]["summary_text"]
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("BART_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("BART_ENDPOINT_URL"),
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
content_formatter=CustomFormatter(),
)
output = llm.invoke("Foo")
assert isinstance(output, str)
def test_missing_content_formatter() -> None:
"""Test AzureML LLM without a content_formatter attribute"""
with pytest.raises(AttributeError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
)
llm.invoke("Foo")
def test_invalid_request_format() -> None:
"""Test invalid request format."""
class CustomContentFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps(
{
"incorrect_input": {"input_string": [prompt]},
"parameters": model_kwargs,
}
)
return str.encode(input_str)
def format_response_payload(self, output: bytes) -> str: # type: ignore[override]
response_json = json.loads(output)
return response_json[0]["0"]
with pytest.raises(HTTPError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=CustomContentFormatter(),
)
llm.invoke("Foo")
def test_incorrect_url() -> None:
"""Testing AzureML Endpoint for an incorrect URL"""
with pytest.raises(ValidationError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url="https://endpoint.inference.com",
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(),
)
llm.invoke("Foo")
def test_incorrect_api_type() -> None:
with pytest.raises(ValidationError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
endpoint_api_type="serverless",
content_formatter=OSSContentFormatter(),
)
llm.invoke("Foo")
def test_incorrect_key() -> None:
"""Testing AzureML Endpoint for incorrect key"""
with pytest.raises(HTTPError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key="incorrect-key",
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(),
)
llm.invoke("Foo")
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an AzureML Foundation Model LLM."""
save_llm = AzureMLOnlineEndpoint(
deployment_name="databricks-dolly-v2-12b-4",
model_kwargs={"temperature": 0.03, "top_p": 0.4, "max_tokens": 200},
)
save_llm.save(file_path=tmp_path / "azureml.yaml")
loaded_llm = load_llm(tmp_path / "azureml.yaml")
assert loaded_llm == save_llm