Improvements to the Clarifai integration (#9290)

- Improved docs
- Improved performance in multiple ways through batching, threading,
etc.
 - fixed error message 
 - Added support for metadata filtering during similarity search.

@baskaryan PTAL
This commit is contained in:
Matthew Zeiler 2023-08-21 15:53:36 -04:00 committed by GitHub
parent 66a47d9a61
commit 949b2cf177
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 261 additions and 100 deletions

View File

@ -37,7 +37,7 @@ There is a Clarifai Embedding model in LangChain, which you can access with:
from langchain.embeddings import ClarifaiEmbeddings
embeddings = ClarifaiEmbeddings(pat=CLARIFAI_PAT, user_id=USER_ID, app_id=APP_ID, model_id=MODEL_ID)
```
For more details, the docs on the Clarifai Embeddings wrapper provide a [detailed walthrough](/docs/integrations/text_embedding/clarifai.html).
For more details, the docs on the Clarifai Embeddings wrapper provide a [detailed walkthrough](/docs/integrations/text_embedding/clarifai.html).
## Vectorstore
@ -49,4 +49,4 @@ You an also add data directly from LangChain as well, and the auto-indexing will
from langchain.vectorstores import Clarifai
clarifai_vector_db = Clarifai.from_texts(user_id=USER_ID, app_id=APP_ID, texts=texts, pat=CLARIFAI_PAT, number_of_docs=NUMBER_OF_DOCS, metadatas = metadatas)
```
For more details, the docs on the Clarifai vector store provide a [detailed walthrough](/docs/integrations/text_embedding/clarifai.html).
For more details, the docs on the Clarifai vector store provide a [detailed walkthrough](/docs/integrations/text_embedding/clarifai.html).

View File

@ -130,9 +130,9 @@
"metadata": {},
"outputs": [],
"source": [
"USER_ID = \"openai\"\n",
"APP_ID = \"embed\"\n",
"MODEL_ID = \"text-embedding-ada\"\n",
"USER_ID = \"salesforce\"\n",
"APP_ID = \"blip\"\n",
"MODEL_ID = \"multimodal-embedder-blip-2\"\n",
"\n",
"# You can provide a specific model version as the model_version_id arg.\n",
"# MODEL_VERSION_ID = \"MODEL_VERSION_ID\""

View File

@ -53,7 +53,15 @@
"execution_count": 1,
"id": "c1e38361-c1fe-4ac6-86e9-c90ebaf7ae87",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
" ········\n"
]
}
],
"source": [
"# Please login and get your API key from https://clarifai.com/settings/security\n",
"from getpass import getpass\n",
@ -61,18 +69,9 @@
"CLARIFAI_PAT = getpass()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "320af802-9271-46ee-948f-d2453933d44b",
"metadata": {},
"source": [
"We want to use `OpenAIEmbeddings` so we have to get the OpenAI API Key."
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 6,
"id": "aac9563e",
"metadata": {
"tags": []
@ -99,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "4d853395",
"metadata": {},
"outputs": [],
@ -134,7 +133,7 @@
" \"I love playing soccer with my friends\",\n",
"]\n",
"\n",
"metadatas = [{\"id\": i, \"text\": text} for i, text in enumerate(texts)]"
"metadatas = [{\"id\": i, \"text\": text, \"source\": \"book 1\", \"category\": [\"books\", \"modern\"]} for i, text in enumerate(texts)]"
]
},
{
@ -156,21 +155,17 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "e755cdce",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[Document(page_content='I really enjoy spending time with you', metadata={'text': 'I really enjoy spending time with you', 'id': 0.0}),\n",
" Document(page_content='I went to the movies yesterday', metadata={'text': 'I went to the movies yesterday', 'id': 3.0}),\n",
" Document(page_content='zab', metadata={'page': '2'}),\n",
" Document(page_content='zab', metadata={'page': '2'})]"
"[Document(page_content='I really enjoy spending time with you', metadata={'text': 'I really enjoy spending time with you', 'id': 0.0, 'source': 'book 1', 'category': ['books', 'modern']}),\n",
" Document(page_content='I went to the movies yesterday', metadata={'text': 'I went to the movies yesterday', 'id': 3.0, 'source': 'book 1', 'category': ['books', 'modern']})]"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
@ -179,6 +174,21 @@
"docs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "140103ec-0936-454a-9f4a-7d5beefc138f",
"metadata": {},
"outputs": [],
"source": [
"# There is lots powerful filtering you can do within an app by leveraging metadata filters. \n",
"# This one will limit the similarity query to only the texts that have key of \"source\" matching value of \"book 1\"\n",
"book1_similar_docs = clarifai_vector_db.similarity_search(\"I would love to see you\", filter={\"source\": \"book 1\"})\n",
"\n",
"# you can also use lists in the input's metadata and then select things that match an item in the list. This is useful for categories like below:\n",
"book_category_similar_docs = clarifai_vector_db.similarity_search(\"I would love to see you\", filter={\"category\": [\"books\"]})"
]
},
{
"attachments": {},
"cell_type": "markdown",
@ -249,7 +259,7 @@
" user_id=USER_ID,\n",
" app_id=APP_ID,\n",
" documents=docs,\n",
" pat=CLARIFAI_PAT_KEY,\n",
" pat=CLARIFAI_PAT,\n",
" number_of_docs=NUMBER_OF_DOCS,\n",
")"
]
@ -278,6 +288,55 @@
"docs = clarifai_vector_db.similarity_search(\"Texts related to criminals and violence\")\n",
"docs"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "7b332ca4-416b-4ea6-99da-b6949f399d72",
"metadata": {},
"source": [
"## From existing App\n",
"Within Clarifai we have great tools for adding data to applications (essentially projects) via API or UI. Most users will already have done that before interacting with LangChain so this example will use the data in an existing app to perform searches. Check out our [API docs](https://docs.clarifai.com/api-guide/data/create-get-update-delete) and [UI docs](https://docs.clarifai.com/portal-guide/data). The Clarifai Application can then be used for semantic search to find relevant documents."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "807c1141-591b-436d-abaa-f2c325e66d39",
"metadata": {},
"outputs": [],
"source": [
"USER_ID = \"USERNAME_ID\"\n",
"APP_ID = \"APPLICATION_ID\"\n",
"NUMBER_OF_DOCS = 4"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "762d74ef-f7df-43d6-b121-4980c4059fc0",
"metadata": {},
"outputs": [],
"source": [
"clarifai_vector_db = Clarifai(\n",
" user_id=USER_ID,\n",
" app_id=APP_ID,\n",
" documents=docs,\n",
" pat=CLARIFAI_PAT,\n",
" number_of_docs=NUMBER_OF_DOCS,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f7636b0f-68ab-4b8f-ba0f-3c27061e3631",
"metadata": {},
"outputs": [],
"source": [
"docs = clarifai_vector_db.similarity_search(\"Texts related to criminals and violence\")\n",
"docs"
]
}
],
"metadata": {

View File

@ -103,37 +103,44 @@ class ClarifaiEmbeddings(BaseModel, Embeddings):
"Please install it with `pip install clarifai`."
)
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
)
for t in texts
],
)
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
batch_size = 32
embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_output_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs[0])
else None
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=t))
)
for t in batch
],
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_output_failure}"
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_output_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs)
else None
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_output_failure}"
)
embeddings.extend(
[
list(o.data.embeddings[0].vector)
for o in post_model_outputs_response.outputs
]
)
embeddings = [
list(o.data.embeddings[0].vector)
for o in post_model_outputs_response.outputs
]
return embeddings
def embed_query(self, text: str) -> List[float]:

View File

@ -5,6 +5,7 @@ from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Extra, root_validator
from langchain.schema import Generation, LLMResult
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
@ -163,7 +164,7 @@ class Clarifai(LLM):
logger.error(post_model_outputs_response.status)
first_model_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs[0])
if len(post_model_outputs_response.outputs)
else None
)
raise Exception(
@ -178,3 +179,67 @@ class Clarifai(LLM):
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
try:
from clarifai_grpc.grpc.api import (
resources_pb2,
service_pb2,
)
from clarifai_grpc.grpc.api.status import status_code_pb2
except ImportError:
raise ImportError(
"Could not import clarifai python package. "
"Please install it with `pip install clarifai`."
)
# TODO: add caching here.
generations = []
batch_size = 32
for i in range(0, len(prompts), batch_size):
batch = prompts[i : i + batch_size]
post_model_outputs_request = service_pb2.PostModelOutputsRequest(
user_app_id=self.userDataObject,
model_id=self.model_id,
version_id=self.model_version_id,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(text=resources_pb2.Text(raw=prompt))
)
for prompt in batch
],
)
post_model_outputs_response = self.stub.PostModelOutputs(
post_model_outputs_request
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
logger.error(post_model_outputs_response.status)
first_model_failure = (
post_model_outputs_response.outputs[0].status
if len(post_model_outputs_response.outputs)
else None
)
raise Exception(
f"Post model outputs failed, status: "
f"{post_model_outputs_response.status}, first output failure: "
f"{first_model_failure}"
)
for output in post_model_outputs_response.outputs:
if stop is not None:
text = enforce_stop_tokens(output.data.text.raw, stop)
else:
text = output.data.text.raw
generations.append([Generation(text=text)])
return LLMResult(generations=generations)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import logging
import os
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Iterable, List, Optional, Tuple
import requests
@ -84,7 +85,9 @@ class Clarifai(VectorStore):
self._userDataObject = self._auth.get_user_app_id_proto()
self._number_of_docs = number_of_docs
def _post_text_input(self, text: str, metadata: dict) -> str:
def _post_texts_as_inputs(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[str]:
"""Post text to Clarifai and return the ID of the input.
Args:
@ -104,20 +107,29 @@ class Clarifai(VectorStore):
"Please install it with `pip install clarifai`."
) from e
input_metadata = Struct()
input_metadata.update(metadata)
if metadatas is not None:
assert len(list(texts)) == len(
metadatas
), "Number of texts and metadatas should be the same."
inputs = []
for idx, text in enumerate(texts):
if metadatas is not None:
input_metadata = Struct()
input_metadata.update(metadatas[idx])
inputs.append(
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=text),
metadata=input_metadata,
)
)
)
post_inputs_response = self._stub.PostInputs(
service_pb2.PostInputsRequest(
user_app_id=self._userDataObject,
inputs=[
resources_pb2.Input(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=text),
metadata=input_metadata,
)
)
],
inputs=inputs,
)
)
@ -127,9 +139,11 @@ class Clarifai(VectorStore):
"Post inputs failed, status: " + post_inputs_response.status.description
)
input_id = post_inputs_response.inputs[0].id
input_ids = []
for input in post_inputs_response.inputs:
input_ids.append(input.id)
return input_id
return input_ids
def add_texts(
self,
@ -140,7 +154,7 @@ class Clarifai(VectorStore):
) -> List[str]:
"""Add texts to the Clarifai vectorstore. This will push the text
to a Clarifai application.
Application use base workflow that create and store embedding for each text.
Application use a base workflow that create and store embedding for each text.
Make sure you are using a base workflow that is compatible with text
(such as Language Understanding).
@ -153,20 +167,26 @@ class Clarifai(VectorStore):
List[str]: List of IDs of the added texts.
"""
assert len(list(texts)) > 0, "No texts provided to add to the vectorstore."
ltexts = list(texts)
length = len(ltexts)
assert length > 0, "No texts provided to add to the vectorstore."
if metadatas is not None:
assert len(list(texts)) == len(
assert length == len(
metadatas
), "Number of texts and metadatas should be the same."
batch_size = 32
input_ids = []
for idx, text in enumerate(texts):
for idx in range(0, length, batch_size):
try:
metadata = metadatas[idx] if metadatas else {}
input_id = self._post_text_input(text, metadata)
input_ids.append(input_id)
logger.debug(f"Input {input_id} posted successfully.")
batch_texts = ltexts[idx : idx + batch_size]
batch_metadatas = (
metadatas[idx : idx + batch_size] if metadatas else None
)
result_ids = self._post_texts_as_inputs(batch_texts, batch_metadatas)
input_ids.extend(result_ids)
logger.debug(f"Input {result_ids} posted successfully.")
except Exception as error:
logger.warning(f"Post inputs failed: {error}")
traceback.print_exc()
@ -196,6 +216,7 @@ class Clarifai(VectorStore):
from clarifai_grpc.grpc.api import resources_pb2, service_pb2
from clarifai_grpc.grpc.api.status import status_code_pb2
from google.protobuf import json_format # type: ignore
from google.protobuf.struct_pb2 import Struct # type: ignore
except ImportError as e:
raise ImportError(
"Could not import clarifai python package. "
@ -206,28 +227,35 @@ class Clarifai(VectorStore):
if self._number_of_docs is not None:
k = self._number_of_docs
post_annotations_searches_response = self._stub.PostAnnotationsSearches(
service_pb2.PostAnnotationsSearchesRequest(
user_app_id=self._userDataObject,
searches=[
resources_pb2.Search(
query=resources_pb2.Query(
ranks=[
resources_pb2.Rank(
annotation=resources_pb2.Annotation(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=query),
)
req = service_pb2.PostAnnotationsSearchesRequest(
user_app_id=self._userDataObject,
searches=[
resources_pb2.Search(
query=resources_pb2.Query(
ranks=[
resources_pb2.Rank(
annotation=resources_pb2.Annotation(
data=resources_pb2.Data(
text=resources_pb2.Text(raw=query),
)
)
]
)
)
]
)
],
pagination=service_pb2.Pagination(page=1, per_page=k),
)
)
],
pagination=service_pb2.Pagination(page=1, per_page=k),
)
# Add filter by metadata if provided.
if filter is not None:
search_metadata = Struct()
search_metadata.update(filter)
f = req.searches[0].query.filters.add()
f.annotation.data.metadata.update(search_metadata)
post_annotations_searches_response = self._stub.PostAnnotationsSearches(req)
# Check if search was successful
if post_annotations_searches_response.status.code != status_code_pb2.SUCCESS:
raise Exception(
@ -238,11 +266,12 @@ class Clarifai(VectorStore):
# Retrieve hits
hits = post_annotations_searches_response.hits
docs_and_scores = []
# Iterate over hits and retrieve metadata and text
for hit in hits:
executor = ThreadPoolExecutor(max_workers=10)
def hit_to_document(hit: resources_pb2.Hit) -> Tuple[Document, float]:
metadata = json_format.MessageToDict(hit.input.data.metadata)
request = requests.get(hit.input.data.text.url)
h = {"Authorization": f"Key {self._auth.pat}"}
request = requests.get(hit.input.data.text.url, headers=h)
# override encoding by real educated guess as provided by chardet
request.encoding = request.apparent_encoding
@ -252,10 +281,11 @@ class Clarifai(VectorStore):
f"\tScore {hit.score:.2f} for annotation: {hit.annotation.id}\
off input: {hit.input.id}, text: {requested_text[:125]}"
)
return (Document(page_content=requested_text, metadata=metadata), hit.score)
docs_and_scores.append(
(Document(page_content=requested_text, metadata=metadata), hit.score)
)
# Iterate over hits and retrieve metadata and text
futures = [executor.submit(hit_to_document, hit) for hit in hits]
docs_and_scores = [future.result() for future in futures]
return docs_and_scores