langchain/libs/community/tests/integration_tests/embeddings/test_ipex_llm.py
Yuwen Hu ba0dca46d7
community[minor]: Add IPEX-LLM BGE embedding support on both Intel CPU and GPU (#22226)
**Description:** [IPEX-LLM](https://github.com/intel-analytics/ipex-llm)
is a PyTorch library for running LLM on Intel CPU and GPU (e.g., local
PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low
latency. This PR adds ipex-llm integrations to langchain for BGE
embedding support on both Intel CPU and GPU.
**Dependencies:** `ipex-llm`, `sentence-transformers`
**Contribution maintainer**: @Oscilloscope98 
**tests and docs**: 
- langchain/docs/docs/integrations/text_embedding/ipex_llm.ipynb
- langchain/docs/docs/integrations/text_embedding/ipex_llm_gpu.ipynb
-
langchain/libs/community/tests/integration_tests/embeddings/test_ipex_llm.py

---------

Co-authored-by: Shengsheng Huang <shannie.huang@gmail.com>
2024-06-03 12:37:10 -07:00

53 lines
1.6 KiB
Python

"""Test IPEX LLM"""
import os
import pytest
from langchain_community.embeddings import IpexLLMBgeEmbeddings
model_ids_to_test = os.getenv("TEST_IPEXLLM_BGE_EMBEDDING_MODEL_IDS") or ""
skip_if_no_model_ids = pytest.mark.skipif(
not model_ids_to_test,
reason="TEST_IPEXLLM_BGE_EMBEDDING_MODEL_IDS environment variable not set.",
)
model_ids_to_test = [model_id.strip() for model_id in model_ids_to_test.split(",")] # type: ignore
device = os.getenv("TEST_IPEXLLM_BGE_EMBEDDING_MODEL_DEVICE") or "cpu"
sentence = "IPEX-LLM is a PyTorch library for running LLM on Intel CPU and GPU (e.g., \
local PC with iGPU, discrete GPU such as Arc, Flex and Max) with very low latency."
query = "What is IPEX-LLM?"
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_embed_documents(model_id: str) -> None:
"""Test IpexLLMBgeEmbeddings embed_documents"""
embedding_model = IpexLLMBgeEmbeddings(
model_name=model_id,
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True},
)
output = embedding_model.embed_documents([sentence, query])
assert len(output) == 2
@skip_if_no_model_ids
@pytest.mark.parametrize(
"model_id",
model_ids_to_test,
)
def test_embed_query(model_id: str) -> None:
"""Test IpexLLMBgeEmbeddings embed_documents"""
embedding_model = IpexLLMBgeEmbeddings(
model_name=model_id,
model_kwargs={"device": device},
encode_kwargs={"normalize_embeddings": True},
)
output = embedding_model.embed_query(query)
assert isinstance(output, list)