"""Test LiteLLM Router API wrapper."""
import asyncio
from copy import deepcopy
from typing import Any, AsyncGenerator, Coroutine, Dict, List, Tuple, Union, cast
import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
model_group = "gpt-4"
fake_model_prefix = "azure/fake-deployment-name-"
fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]]
fake_api_key = "fakekeyvalue"
fake_api_version = "XXXX-XX-XX"
fake_api_base = "https://faketesturl/"
fake_chunks = ["This is ", "a fake answer."]
fake_answer = "".join(fake_chunks)
token_usage_key_name = "token_usage"
model_list = [
"model_name": model_group,
"litellm_params": {
"model": fake_models_names[0],
"api_key": fake_api_key,
"api_version": fake_api_version,
"api_base": fake_api_base,
"model_name": model_group,
"litellm_params": {
"model": fake_models_names[1],
"api_key": fake_api_key,
"api_version": fake_api_version,
"api_base": fake_api_base,
class FakeCompletion:
def __init__(self) -> None:
self.seen_inputs: List[Any] = []
def _get_new_result_and_choices(
base_result: Dict[str, Any],
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
result = deepcopy(base_result)
choices = cast(List[Dict[str, Any]], result["choices"])
return result, choices
def _get_next_result(
agen: AsyncGenerator[Dict[str, Any], None],
) -> Dict[str, Any]:
coroutine = cast(Coroutine, agen.__anext__())
async def _get_fake_results_agenerator(
self, **kwargs: Any
) -> AsyncGenerator[Dict[str, Any], None]:
from litellm import Usage
base_result = {
"choices": [
"index": 0,
"created": 0,
"id": "",
"model": model_group,
"object": "chat.completion",
if kwargs["stream"]:
for chunk_index in range(0, len(fake_chunks)):
result, choices = self._get_new_result_and_choices(base_result)
choice = choices[0]
choice["delta"] = {
"role": "assistant",
"content": fake_chunks[chunk_index],
"function_call": None,
choice["finish_reason"] = None
# no usage here, since no usage from OpenAI API for streaming yet
yield result
result, choices = self._get_new_result_and_choices(base_result)
choice = choices[0]
choice["delta"] = {}
choice["finish_reason"] = "stop"
# no usage here, since no usage from OpenAI API for streaming yet
yield result
result, choices = self._get_new_result_and_choices(base_result)
choice = choices[0]
choice["message"] = {
"content": fake_answer,
"role": "assistant",
choice["finish_reason"] = "stop"
result["usage"] = Usage(
completion_tokens=1, prompt_tokens=2, total_tokens=3
yield result
def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]:
agen = self._get_fake_results_agenerator(**kwargs)
if kwargs["stream"]:
results: List[Dict[str, Any]] = []
while True:
except StopAsyncIteration:
return results
# there is only one result for non-streaming
return self._get_next_result(agen)
async def acompletion(
self, **kwargs: Any
) -> Union[AsyncGenerator[Dict[str, Any], None], Dict[str, Any]]:
agen = self._get_fake_results_agenerator(**kwargs)
if kwargs["stream"]:
return agen
# there is only one result for non-streaming
return await agen.__anext__()
def check_inputs(self, expected_num_calls: int) -> None:
assert len(self.seen_inputs) == expected_num_calls
for kwargs in self.seen_inputs:
metadata = kwargs["metadata"]
assert metadata["model_group"] == model_group
# LiteLLM router chooses one model name from the model_list
assert kwargs["model"] in fake_models_names
assert metadata["deployment"] in fake_models_names
assert kwargs["api_key"] == fake_api_key
assert kwargs["api_version"] == fake_api_version
assert kwargs["api_base"] == fake_api_base
def fake_completion() -> FakeCompletion:
"""Fake AI completion for testing."""
import litellm
fake_completion = FakeCompletion()
# Turn off LiteLLM's built-in telemetry
litellm.telemetry = False
litellm.completion = fake_completion.completion
litellm.acompletion = fake_completion.acompletion
return fake_completion
def litellm_router() -> Any:
"""LiteLLM router for testing."""
from litellm import Router
return Router(
def test_litellm_router_call(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test valid call to LiteLLM Router."""
chat = ChatLiteLLMRouter(router=litellm_router)
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == fake_answer
# no usage check here, since response is only an AIMessage
def test_litellm_router_generate(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test generate method of LiteLLM Router."""
from litellm import Usage
chat = ChatLiteLLMRouter(router=litellm_router)
chat_messages: List[List[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")]
messages_copy = [messages.copy() for messages in chat_messages]
result: LLMResult = chat.generate(chat_messages)
assert isinstance(result, LLMResult)
for generations in result.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.message.content == generation.text
assert generation.message.content == fake_answer
assert chat_messages == messages_copy
assert result.llm_output is not None
assert result.llm_output[token_usage_key_name] == Usage(
completion_tokens=1, prompt_tokens=2, total_tokens=3
def test_litellm_router_streaming(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test streaming tokens from LiteLLM Router."""
chat = ChatLiteLLMRouter(router=litellm_router, streaming=True)
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == fake_answer
# no usage check here, since response is only an AIMessage
def test_litellm_router_streaming_callback(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = ChatLiteLLMRouter(
message = HumanMessage(content="Write me a sentence with 10 words.")
response = chat.invoke([message])
assert callback_handler.llm_streams > 1
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
assert response.content == fake_answer
# no usage check here, since response is only an AIMessage
async def test_async_litellm_router(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test async generation."""
from litellm import Usage
chat = ChatLiteLLMRouter(router=litellm_router)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.message.content == generation.text
assert generation.message.content == fake_answer
assert response.llm_output is not None
assert response.llm_output[token_usage_key_name] == Usage(
completion_tokens=2, prompt_tokens=4, total_tokens=6
async def test_async_litellm_router_streaming(
fake_completion: FakeCompletion, litellm_router: Any
) -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
chat = ChatLiteLLMRouter(
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.message.content == generation.text
assert generation.message.content == fake_answer
# no usage check here, since no usage from OpenAI API for streaming yet