langchain/libs/community/tests/integration_tests/llms/test_azureml_endpoint.py
Bagatur ed58eeb9c5
community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463)
Moved the following modules to new package langchain-community in a backwards compatible fashion:

```
mv langchain/langchain/adapters community/langchain_community
mv langchain/langchain/callbacks community/langchain_community/callbacks
mv langchain/langchain/chat_loaders community/langchain_community
mv langchain/langchain/chat_models community/langchain_community
mv langchain/langchain/document_loaders community/langchain_community
mv langchain/langchain/docstore community/langchain_community
mv langchain/langchain/document_transformers community/langchain_community
mv langchain/langchain/embeddings community/langchain_community
mv langchain/langchain/graphs community/langchain_community
mv langchain/langchain/llms community/langchain_community
mv langchain/langchain/memory/chat_message_histories community/langchain_community
mv langchain/langchain/retrievers community/langchain_community
mv langchain/langchain/storage community/langchain_community
mv langchain/langchain/tools community/langchain_community
mv langchain/langchain/utilities community/langchain_community
mv langchain/langchain/vectorstores community/langchain_community
mv langchain/langchain/agents/agent_toolkits community/langchain_community
mv langchain/langchain/cache.py community/langchain_community
mv langchain/langchain/adapters community/langchain_community
mv langchain/langchain/callbacks community/langchain_community/callbacks
mv langchain/langchain/chat_loaders community/langchain_community
mv langchain/langchain/chat_models community/langchain_community
mv langchain/langchain/document_loaders community/langchain_community
mv langchain/langchain/docstore community/langchain_community
mv langchain/langchain/document_transformers community/langchain_community
mv langchain/langchain/embeddings community/langchain_community
mv langchain/langchain/graphs community/langchain_community
mv langchain/langchain/llms community/langchain_community
mv langchain/langchain/memory/chat_message_histories community/langchain_community
mv langchain/langchain/retrievers community/langchain_community
mv langchain/langchain/storage community/langchain_community
mv langchain/langchain/tools community/langchain_community
mv langchain/langchain/utilities community/langchain_community
mv langchain/langchain/vectorstores community/langchain_community
mv langchain/langchain/agents/agent_toolkits community/langchain_community
mv langchain/langchain/cache.py community/langchain_community
```

Moved the following to core
```
mv langchain/langchain/utils/json_schema.py core/langchain_core/utils
mv langchain/langchain/utils/html.py core/langchain_core/utils
mv langchain/langchain/utils/strings.py core/langchain_core/utils
cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py
rm langchain/langchain/utils/env.py
```

See .scripts/community_split/script_integrations.sh for all changes
2023-12-11 13:53:30 -08:00

152 lines
4.8 KiB
Python

"""Test AzureML Endpoint wrapper."""
import json
import os
from pathlib import Path
from typing import Dict
from urllib.request import HTTPError
import pytest
from langchain_community.llms.azureml_endpoint import (
AzureMLOnlineEndpoint,
ContentFormatterBase,
DollyContentFormatter,
HFContentFormatter,
OSSContentFormatter,
)
from langchain_community.llms.loading import load_llm
def test_gpt2_call() -> None:
"""Test valid call to GPT2."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(),
)
output = llm("Foo")
assert isinstance(output, str)
def test_hf_call() -> None:
"""Test valid call to HuggingFace Foundation Model."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("HF_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("HF_ENDPOINT_URL"),
deployment_name=os.getenv("HF_DEPLOYMENT_NAME"),
content_formatter=HFContentFormatter(),
)
output = llm("Foo")
assert isinstance(output, str)
def test_dolly_call() -> None:
"""Test valid call to dolly-v2."""
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("DOLLY_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("DOLLY_ENDPOINT_URL"),
deployment_name=os.getenv("DOLLY_DEPLOYMENT_NAME"),
content_formatter=DollyContentFormatter(),
)
output = llm("Foo")
assert isinstance(output, str)
def test_custom_formatter() -> None:
"""Test ability to create a custom content formatter."""
class CustomFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{
"inputs": [prompt],
"parameters": model_kwargs,
"options": {"use_cache": False, "wait_for_model": True},
}
)
return input_str.encode("utf-8")
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0]["summary_text"]
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("BART_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("BART_ENDPOINT_URL"),
deployment_name=os.getenv("BART_DEPLOYMENT_NAME"),
content_formatter=CustomFormatter(),
)
output = llm("Foo")
assert isinstance(output, str)
def test_missing_content_formatter() -> None:
"""Test AzureML LLM without a content_formatter attribute"""
with pytest.raises(AttributeError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
)
llm("Foo")
def test_invalid_request_format() -> None:
"""Test invalid request format."""
class CustomContentFormatter(ContentFormatterBase):
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps(
{
"incorrect_input": {"input_string": [prompt]},
"parameters": model_kwargs,
}
)
return str.encode(input_str)
def format_response_payload(self, output: bytes) -> str:
response_json = json.loads(output)
return response_json[0]["0"]
with pytest.raises(HTTPError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key=os.getenv("OSS_ENDPOINT_API_KEY"),
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=CustomContentFormatter(),
)
llm("Foo")
def test_incorrect_key() -> None:
"""Testing AzureML Endpoint for incorrect key"""
with pytest.raises(HTTPError):
llm = AzureMLOnlineEndpoint(
endpoint_api_key="incorrect-key",
endpoint_url=os.getenv("OSS_ENDPOINT_URL"),
deployment_name=os.getenv("OSS_DEPLOYMENT_NAME"),
content_formatter=OSSContentFormatter(),
)
llm("Foo")
def test_saving_loading_llm(tmp_path: Path) -> None:
"""Test saving/loading an AzureML Foundation Model LLM."""
save_llm = AzureMLOnlineEndpoint(
deployment_name="databricks-dolly-v2-12b-4",
model_kwargs={"temperature": 0.03, "top_p": 0.4, "max_tokens": 200},
)
save_llm.save(file_path=tmp_path / "azureml.yaml")
loaded_llm = load_llm(tmp_path / "azureml.yaml")
assert loaded_llm == save_llm