langchain/libs/community/tests/integration_tests/embeddings/test_huggingface_hub.py
Raunak e26e1f8b37
community: Added functions to make async calls to HuggingFaceHub's embedding endpoint in HuggingFaceHubEmbeddings class (#15737)
**Description:**
Added aembed_documents() and aembed_query() async functions in
HuggingFaceHubEmbeddings class in
langchain_community\embeddings\huggingface_hub.py file. It will support
to make async calls to HuggingFaceHub's
embedding endpoint and generate embeddings asynchronously.

Test Cases: Added test_huggingfacehub_embedding_async_documents() and
test_huggingfacehub_embedding_async_query()
functions in test_huggingface_hub.py file to test the two async
functions created in HuggingFaceHubEmbeddings class.

Documentation: Updated huggingfacehub.ipynb with steps to install
huggingface_hub package and use
HuggingFaceHubEmbeddings.

**Dependencies:** None,
**Twitter handle:** I do not have a Twitter account

---------

Co-authored-by: H161961 <Raunak.Raunak@Honeywell.com>
2024-01-11 21:52:55 -08:00

46 lines
1.4 KiB
Python

"""Test HuggingFaceHub embeddings."""
import pytest
from langchain_community.embeddings import HuggingFaceHubEmbeddings
def test_huggingfacehub_embedding_documents() -> None:
"""Test huggingfacehub embeddings."""
documents = ["foo bar"]
embedding = HuggingFaceHubEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
async def test_huggingfacehub_embedding_async_documents() -> None:
"""Test huggingfacehub embeddings."""
documents = ["foo bar"]
embedding = HuggingFaceHubEmbeddings()
output = await embedding.aembed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
def test_huggingfacehub_embedding_query() -> None:
"""Test huggingfacehub embeddings."""
document = "foo bar"
embedding = HuggingFaceHubEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 768
async def test_huggingfacehub_embedding_async_query() -> None:
"""Test huggingfacehub embeddings."""
document = "foo bar"
embedding = HuggingFaceHubEmbeddings()
output = await embedding.aembed_query(document)
assert len(output) == 768
def test_huggingfacehub_embedding_invalid_repo() -> None:
"""Test huggingfacehub embedding repo id validation."""
# Only sentence-transformers models are currently supported.
with pytest.raises(ValueError):
HuggingFaceHubEmbeddings(repo_id="allenai/specter")