diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index f79da073ec..4927acf832 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -59,6 +59,43 @@ class DALMFilter(BaseModel): return values +class ArceeDocumentSource(BaseModel): + """Source of an Arcee document.""" + + document: str + name: str + id: str + + +class ArceeDocument(BaseModel): + """Arcee document.""" + + index: str + id: str + score: float + source: ArceeDocumentSource + + +class ArceeDocumentAdapter: + """Adapter for Arcee documents""" + + @classmethod + def adapt(cls, arcee_document: ArceeDocument) -> Document: + """Adapts an `ArceeDocument` to a langchain's `Document` object.""" + return Document( + page_content=arcee_document.source.document, + metadata={ + # arcee document; source metadata + "name": arcee_document.source.name, + "source_id": arcee_document.source.id, + # arcee document metadata + "index": arcee_document.index, + "id": arcee_document.id, + "score": arcee_document.score, + }, + ) + + class ArceeWrapper: """Wrapper for Arcee API.""" @@ -172,7 +209,7 @@ class ArceeWrapper: response = self._make_request( method="post", - route=ArceeRoute.generate, + route=ArceeRoute.generate.value, body=self._make_request_body_for_models( prompt=prompt, **kwargs, @@ -196,10 +233,13 @@ class ArceeWrapper: response = self._make_request( method="post", - route=ArceeRoute.retrieve, + route=ArceeRoute.retrieve.value, body=self._make_request_body_for_models( prompt=query, **kwargs, ), ) - return [Document(**doc) for doc in response["documents"]] + return [ + ArceeDocumentAdapter.adapt(ArceeDocument(**doc)) + for doc in response["results"] + ]