|
|
|
@ -86,6 +86,15 @@ class VoyageEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
show_progress_bar: bool = False
|
|
|
|
|
"""Whether to show a progress bar when embedding. Must have tqdm installed if set
|
|
|
|
|
to True."""
|
|
|
|
|
truncation: Optional[bool] = None
|
|
|
|
|
"""Whether to truncate the input texts to fit within the context length.
|
|
|
|
|
|
|
|
|
|
If True, over-length input texts will be truncated to fit within the context
|
|
|
|
|
length, before vectorized by the embedding model. If False, an error will be
|
|
|
|
|
raised if any given text exceeds the context length. If not specified
|
|
|
|
|
(defaults to None), we will truncate the input text before sending it to the
|
|
|
|
|
embedding model if it slightly exceeds the context window length. If it
|
|
|
|
|
significantly exceeds the context window length, an error will be raised."""
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
"""Configuration for this pydantic object."""
|
|
|
|
@ -104,12 +113,14 @@ class VoyageEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
self, input: List[str], input_type: Optional[str] = None
|
|
|
|
|
) -> Dict:
|
|
|
|
|
api_key = cast(SecretStr, self.voyage_api_key).get_secret_value()
|
|
|
|
|
params = {
|
|
|
|
|
params: Dict = {
|
|
|
|
|
"url": self.voyage_api_base,
|
|
|
|
|
"headers": {"Authorization": f"Bearer {api_key}"},
|
|
|
|
|
"json": {"model": self.model, "input": input, "input_type": input_type},
|
|
|
|
|
"timeout": self.request_timeout,
|
|
|
|
|
}
|
|
|
|
|
if self.truncation is not None:
|
|
|
|
|
params["json"]["truncation"] = self.truncation
|
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
def _get_embeddings(
|
|
|
|
|