langchain/libs/community/tests/unit_tests/embeddings/test_imports.py
Eli Lucherini 6b2a57161a
community[patch]: allow additional kwargs in MlflowEmbeddings for compatibility with Cohere API (#15242)
- **Description:** add support for kwargs in`MlflowEmbeddings`
`embed_document()` and `embed_query()` so that all the arguments
required by Cohere API (and others?) can be passed down to the server.
  - **Issue:** #15234 
- **Dependencies:** MLflow with MLflow Deployments (`pip install
mlflow[genai]`)

**Tests**
Now this code [adapted from the
docs](https://python.langchain.com/docs/integrations/providers/mlflow#embeddings-example)
for the Cohere API works locally.

```python
"""
Setup
-----
export COHERE_API_KEY=...
mlflow deployments start-server --config-path examples/deployments/cohere/config.yaml

Run
---
python /path/to/this/file.py
"""
embeddings = MlflowCohereEmbeddings(target_uri="http://127.0.0.1:5000", endpoint="embeddings")
print(embeddings.embed_query("hello")[:3])
print(embeddings.embed_documents(["hello", "world"])[0][:3])
```

Output
```
[0.060455322, 0.028793335, -0.025848389]
[0.031707764, 0.021057129, -0.009361267]
```
2024-01-22 11:38:11 -08:00

64 lines
1.7 KiB
Python

from langchain_community.embeddings import __all__
EXPECTED_ALL = [
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
"ClarifaiEmbeddings",
"CohereEmbeddings",
"DatabricksEmbeddings",
"ElasticsearchEmbeddings",
"FastEmbedEmbeddings",
"HuggingFaceEmbeddings",
"HuggingFaceInferenceAPIEmbeddings",
"InfinityEmbeddings",
"GradientEmbeddings",
"JinaEmbeddings",
"LlamaCppEmbeddings",
"LLMRailsEmbeddings",
"HuggingFaceHubEmbeddings",
"MlflowAIGatewayEmbeddings",
"MlflowEmbeddings",
"MlflowCohereEmbeddings",
"ModelScopeEmbeddings",
"TensorflowHubEmbeddings",
"SagemakerEndpointEmbeddings",
"HuggingFaceInstructEmbeddings",
"MosaicMLInstructorEmbeddings",
"SelfHostedEmbeddings",
"SelfHostedHuggingFaceEmbeddings",
"SelfHostedHuggingFaceInstructEmbeddings",
"FakeEmbeddings",
"DeterministicFakeEmbedding",
"AlephAlphaAsymmetricSemanticEmbedding",
"AlephAlphaSymmetricSemanticEmbedding",
"SentenceTransformerEmbeddings",
"GooglePalmEmbeddings",
"MiniMaxEmbeddings",
"VertexAIEmbeddings",
"BedrockEmbeddings",
"DeepInfraEmbeddings",
"EdenAiEmbeddings",
"DashScopeEmbeddings",
"EmbaasEmbeddings",
"OctoAIEmbeddings",
"SpacyEmbeddings",
"NLPCloudEmbeddings",
"GPT4AllEmbeddings",
"XinferenceEmbeddings",
"LocalAIEmbeddings",
"AwaEmbeddings",
"HuggingFaceBgeEmbeddings",
"ErnieEmbeddings",
"JavelinAIGatewayEmbeddings",
"OllamaEmbeddings",
"QianfanEmbeddingsEndpoint",
"JohnSnowLabsEmbeddings",
"VoyageEmbeddings",
"BookendEmbeddings",
"VolcanoEmbeddings",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)