langchain/libs/community/tests/integration_tests/chat_models/test_premai.py
Anindyadeep b2a11ce686
community[minor]: Prem AI langchain integration (#19113)
### Prem SDK integration in LangChain

This PR adds the integration with [PremAI's](https://www.premai.io/)
prem-sdk with langchain. User can now access to deployed models
(llms/embeddings) and use it with langchain's ecosystem. This PR adds
the following:

### This PR adds the following:

- [x]  Add chat support
- [X]  Adding embedding support
- [X]  writing integration tests
    - [X]  writing tests for chat 
    - [X]  writing tests for embedding
- [X]  writing unit tests
    - [X]  writing tests for chat 
    - [X]  writing tests for embedding
- [X]  Adding documentation
    - [X]  writing documentation for chat
    - [X]  writing documentation for embedding
- [X] run `make test`
- [X] run `make lint`, `make lint_diff` 
- [X]  Final checks (spell check, lint, format and overall testing)

---------

Co-authored-by: Anindyadeep Sannigrahi <anindyadeepsannigrahi@Anindyadeeps-MacBook-Pro.local>
Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
2024-03-26 01:37:19 +00:00

71 lines
2.3 KiB
Python

"""Test ChatPremAI from PremAI API wrapper.
Note: This test must be run with the PREMAI_API_KEY environment variable set to a valid
API key and a valid project_id.
For this we need to have a project setup in PremAI's platform: https://app.premai.io
"""
import pytest
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models import ChatPremAI
@pytest.fixture
def chat() -> ChatPremAI:
return ChatPremAI(project_id=8)
def test_chat_premai() -> None:
"""Test ChatPremAI wrapper."""
chat = ChatPremAI(project_id=8)
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_prem_system_message() -> None:
"""Test ChatPremAI wrapper for system message"""
chat = ChatPremAI(project_id=8)
system_message = SystemMessage(content="You are to chat with the user.")
human_message = HumanMessage(content="Hello")
response = chat([system_message, human_message])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_chat_prem_model() -> None:
"""Test ChatPremAI wrapper handles model_name."""
chat = ChatPremAI(model="foo", project_id=8)
assert chat.model == "foo"
def test_chat_prem_generate() -> None:
"""Test ChatPremAI wrapper with generate."""
chat = ChatPremAI(project_id=8)
message = HumanMessage(content="Hello")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content
async def test_prem_invoke(chat: ChatPremAI) -> None:
"""Tests chat completion with invoke"""
result = chat.invoke("How is the weather in New York today?")
assert isinstance(result.content, str)
def test_prem_streaming() -> None:
"""Test streaming tokens from Prem."""
chat = ChatPremAI(project_id=8, streaming=True)
for token in chat.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)