diff --git a/libs/community/langchain_community/embeddings/baichuan.py b/libs/community/langchain_community/embeddings/baichuan.py index d0f54fff0d..21175fb901 100644 --- a/libs/community/langchain_community/embeddings/baichuan.py +++ b/libs/community/langchain_community/embeddings/baichuan.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from requests import RequestException @@ -37,9 +37,16 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings): """ session: Any #: :meta private: - model_name: str = "Baichuan-Text-Embedding" - baichuan_api_key: Optional[SecretStr] = None + model_name: str = Field(default="Baichuan-Text-Embedding", alias="model") + baichuan_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") """Automatically inferred from env var `BAICHUAN_API_KEY` if not provided.""" + chunk_size: int = 16 + """Chunk size when multiple texts are input""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: @@ -78,26 +85,35 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings): A list of list of floats representing the embeddings, or None if an error occurs. """ - response = self.session.post( - BAICHUAN_API_URL, json={"input": texts, "model": self.model_name} - ) - # Raise exception if response status code from 400 to 600 - response.raise_for_status() - # Check if the response status code indicates success - if response.status_code == 200: - resp = response.json() - embeddings = resp.get("data", []) - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0)) - # Return just the embeddings - return [result.get("embedding", []) for result in sorted_embeddings] - else: - # Log error or handle unsuccessful response appropriately - # Handle 100 <= status_code < 400, not include 200 - raise RequestException( - f"Error: Received status code {response.status_code} from " - "`BaichuanEmbedding` API" + chunk_texts = [ + texts[i : i + self.chunk_size] + for i in range(0, len(texts), self.chunk_size) + ] + embed_results = [] + for chunk in chunk_texts: + response = self.session.post( + BAICHUAN_API_URL, json={"input": chunk, "model": self.model_name} ) + # Raise exception if response status code from 400 to 600 + response.raise_for_status() + # Check if the response status code indicates success + if response.status_code == 200: + resp = response.json() + embeddings = resp.get("data", []) + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e.get("index", 0)) + # Return just the embeddings + embed_results.extend( + [result.get("embedding", []) for result in sorted_embeddings] + ) + else: + # Log error or handle unsuccessful response appropriately + # Handle 100 <= status_code < 400, not include 200 + raise RequestException( + f"Error: Received status code {response.status_code} from " + "`BaichuanEmbedding` API" + ) + return embed_results def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override] """Public method to get embeddings for a list of documents. diff --git a/libs/community/tests/integration_tests/embeddings/test_baichuan.py b/libs/community/tests/integration_tests/embeddings/test_baichuan.py index b8f8e68bff..fd5921642f 100644 --- a/libs/community/tests/integration_tests/embeddings/test_baichuan.py +++ b/libs/community/tests/integration_tests/embeddings/test_baichuan.py @@ -17,3 +17,13 @@ def test_baichuan_embedding_query() -> None: embedding = BaichuanTextEmbeddings() # type: ignore[call-arg] output = embedding.embed_query(document) assert len(output) == 1024 # type: ignore[arg-type] + + +def test_baichuan_embeddings_multi_documents() -> None: + """Test Baichuan Text Embedding for documents with multi texts.""" + document = "午餐吃了螺蛳粉" + doc_amount = 35 + embeddings = BaichuanTextEmbeddings() # type: ignore[call-arg] + output = embeddings.embed_documents([document] * doc_amount) + assert len(output) == doc_amount # type: ignore[arg-type] + assert len(output[0]) == 1024 # type: ignore[index] diff --git a/libs/community/tests/unit_tests/embeddings/test_baichuan.py b/libs/community/tests/unit_tests/embeddings/test_baichuan.py new file mode 100644 index 0000000000..10513948f9 --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_baichuan.py @@ -0,0 +1,18 @@ +from typing import cast + +from langchain_core.pydantic_v1 import SecretStr + +from langchain_community.embeddings import BaichuanTextEmbeddings + + +def test_sparkllm_initialization_by_alias() -> None: + # Effective initialization + embeddings = BaichuanTextEmbeddings( # type: ignore[call-arg] + model="embedding_model", # type: ignore[arg-type] + api_key="your-api-key", # type: ignore[arg-type] + ) + assert embeddings.model_name == "embedding_model" + assert ( + cast(SecretStr, embeddings.baichuan_api_key).get_secret_value() + == "your-api-key" + )