From 07c26497530b88fbfe93edcedd65eda70e6c3952 Mon Sep 17 00:00:00 2001 From: Shwu Ku <65639964+EricLiclair@users.noreply.github.com> Date: Wed, 25 Oct 2023 23:25:13 +0530 Subject: [PATCH] response parser for ArceeRetriever (#12270) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - **Description:** Response parser for arcee retriever, - **Issue:** follow-up pr on #11578 and [discussion](https://github.com/arcee-ai/arcee-python/issues/15#issuecomment-1759874053), - **Dependencies:** NA This pr implements a parser for the response from ArceeRetreiver to convert to langchain `Document`. This closes the loop of generation and retrieval for Arcee DALMs in langchain. The reference for the response parser is [api-docs:retrieve](https://api.arcee.ai/docs#/v2/retrieve_model) Attaching screenshot of working implementation: Screenshot 2023-10-25 at 7 42 34 PM \*api key deleted --- Successful tests, lints, etc. ```shell Re-run pytest with --snapshot-update to delete unused snapshots. ==================================================================================================================== slowest 5 durations ===================================================================================================================== 1.56s call tests/unit_tests/schema/runnable/test_runnable.py::test_retrying 0.63s call tests/unit_tests/schema/runnable/test_runnable.py::test_map_astream 0.33s call tests/unit_tests/schema/runnable/test_runnable.py::test_map_stream_iterator_input 0.30s call tests/unit_tests/schema/runnable/test_runnable.py::test_map_astream_iterator_input 0.20s call tests/unit_tests/indexes/test_indexing.py::test_cleanup_with_different_batchsize ======================================================================================================= 1265 passed, 270 skipped, 32 warnings in 6.55s ======================================================================================================= [ "." = "" ] || poetry run black . All done! ✨ 🍰 ✨ 1871 files left unchanged. [ "." = "" ] || poetry run ruff --select I --fix . ./scripts/check_pydantic.sh . ./scripts/check_imports.sh poetry run ruff . [ "." = "" ] || poetry run black . --check All done! ✨ 🍰 ✨ 1871 files would be left unchanged. [ "." = "" ] || poetry run mypy . Success: no issues found in 1868 source files poetry run codespell --toml pyproject.toml poetry run codespell --toml pyproject.toml -w ``` Co-authored-by: Shubham Kushwaha --- libs/langchain/langchain/utilities/arcee.py | 46 +++++++++++++++++++-- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/utilities/arcee.py b/libs/langchain/langchain/utilities/arcee.py index f79da073ec..4927acf832 100644 --- a/libs/langchain/langchain/utilities/arcee.py +++ b/libs/langchain/langchain/utilities/arcee.py @@ -59,6 +59,43 @@ class DALMFilter(BaseModel): return values +class ArceeDocumentSource(BaseModel): + """Source of an Arcee document.""" + + document: str + name: str + id: str + + +class ArceeDocument(BaseModel): + """Arcee document.""" + + index: str + id: str + score: float + source: ArceeDocumentSource + + +class ArceeDocumentAdapter: + """Adapter for Arcee documents""" + + @classmethod + def adapt(cls, arcee_document: ArceeDocument) -> Document: + """Adapts an `ArceeDocument` to a langchain's `Document` object.""" + return Document( + page_content=arcee_document.source.document, + metadata={ + # arcee document; source metadata + "name": arcee_document.source.name, + "source_id": arcee_document.source.id, + # arcee document metadata + "index": arcee_document.index, + "id": arcee_document.id, + "score": arcee_document.score, + }, + ) + + class ArceeWrapper: """Wrapper for Arcee API.""" @@ -172,7 +209,7 @@ class ArceeWrapper: response = self._make_request( method="post", - route=ArceeRoute.generate, + route=ArceeRoute.generate.value, body=self._make_request_body_for_models( prompt=prompt, **kwargs, @@ -196,10 +233,13 @@ class ArceeWrapper: response = self._make_request( method="post", - route=ArceeRoute.retrieve, + route=ArceeRoute.retrieve.value, body=self._make_request_body_for_models( prompt=query, **kwargs, ), ) - return [Document(**doc) for doc in response["documents"]] + return [ + ArceeDocumentAdapter.adapt(ArceeDocument(**doc)) + for doc in response["results"] + ]