[partner]: minor change to embeddings for Ollama (#24521)

This commit is contained in:
Isaac Francisco 2024-07-23 17:00:13 -07:00 committed by GitHub
parent 0f45ac4088
commit 464a525a5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
from langchain_ollama import OllamaEmbeddings from langchain_ollama import OllamaEmbeddings
model = OllamaEmbeddings(model="llama3") embedder = OllamaEmbeddings(model="llama3")
embedder.embed_query("what is the place that jonathan worked at?") embedder.embed_query("what is the place that jonathan worked at?")
""" """
@ -28,9 +28,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
def embed_documents(self, texts: List[str]) -> List[List[float]]: def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.""" """Embed search docs."""
embedded_docs = [] embedded_docs = ollama.embed(self.model, texts)["embeddings"]
for doc in texts:
embedded_docs.append(list(ollama.embeddings(self.model, doc)["embedding"]))
return embedded_docs return embedded_docs
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
@ -39,11 +37,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
async def aembed_documents(self, texts: List[str]) -> List[List[float]]: async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.""" """Embed search docs."""
embedded_docs = [] embedded_docs = (await AsyncClient().embed(self.model, texts))["embeddings"]
for doc in texts:
embedded_docs.append(
list((await AsyncClient().embeddings(self.model, doc))["embedding"])
)
return embedded_docs return embedded_docs
async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> List[float]: