@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional
import requests
from langchain_core . embeddings import Embeddings
from langchain_core . pydantic_v1 import BaseModel , SecretStr, root_validator
from langchain_core . pydantic_v1 import BaseModel , Field, SecretStr, root_validator
from langchain_core . utils import convert_to_secret_str , get_from_dict_or_env
from requests import RequestException
@ -37,9 +37,16 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
"""
session : Any #: :meta private:
model_name : str = " Baichuan-Text-Embedding "
baichuan_api_key : Optional [ SecretStr ] = None
model_name : str = Field ( default = " Baichuan-Text-Embedding " , alias = " model " )
baichuan_api_key : Optional [ SecretStr ] = Field ( default = None , alias = " api_key " )
""" Automatically inferred from env var `BAICHUAN_API_KEY` if not provided. """
chunk_size : int = 16
""" Chunk size when multiple texts are input """
class Config :
""" Configuration for this pydantic object. """
allow_population_by_field_name = True
@root_validator ( allow_reuse = True )
def validate_environment ( cls , values : Dict ) - > Dict :
@ -78,26 +85,35 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
A list of list of floats representing the embeddings , or None if an
error occurs .
"""
response = self . session . post (
BAICHUAN_API_URL , json = { " input " : texts , " model " : self . model_name }
)
# Raise exception if response status code from 400 to 600
response . raise_for_status ( )
# Check if the response status code indicates success
if response . status_code == 200 :
resp = response . json ( )
embeddings = resp . get ( " data " , [ ] )
# Sort resulting embeddings by index
sorted_embeddings = sorted ( embeddings , key = lambda e : e . get ( " index " , 0 ) )
# Return just the embeddings
return [ result . get ( " embedding " , [ ] ) for result in sorted_embeddings ]
else :
# Log error or handle unsuccessful response appropriately
# Handle 100 <= status_code < 400, not include 200
raise RequestException (
f " Error: Received status code { response . status_code } from "
" `BaichuanEmbedding` API "
chunk_texts = [
texts [ i : i + self . chunk_size ]
for i in range ( 0 , len ( texts ) , self . chunk_size )
]
embed_results = [ ]
for chunk in chunk_texts :
response = self . session . post (
BAICHUAN_API_URL , json = { " input " : chunk , " model " : self . model_name }
)
# Raise exception if response status code from 400 to 600
response . raise_for_status ( )
# Check if the response status code indicates success
if response . status_code == 200 :
resp = response . json ( )
embeddings = resp . get ( " data " , [ ] )
# Sort resulting embeddings by index
sorted_embeddings = sorted ( embeddings , key = lambda e : e . get ( " index " , 0 ) )
# Return just the embeddings
embed_results . extend (
[ result . get ( " embedding " , [ ] ) for result in sorted_embeddings ]
)
else :
# Log error or handle unsuccessful response appropriately
# Handle 100 <= status_code < 400, not include 200
raise RequestException (
f " Error: Received status code { response . status_code } from "
" `BaichuanEmbedding` API "
)
return embed_results
def embed_documents ( self , texts : List [ str ] ) - > Optional [ List [ List [ float ] ] ] : # type: ignore[override]
""" Public method to get embeddings for a list of documents.