@ -1,6 +1,6 @@
from __future__ import annotations
from __future__ import annotations
from typing import TYPE_CHECKING , Any , Dict , List
from typing import TYPE_CHECKING , Any , Dict , List , Optional
from langchain_core . callbacks import (
from langchain_core . callbacks import (
AsyncCallbackManagerForRetrieverRun ,
AsyncCallbackManagerForRetrieverRun ,
@ -17,15 +17,16 @@ if TYPE_CHECKING:
def _get_docs ( response : Any ) - > List [ Document ] :
def _get_docs ( response : Any ) - > List [ Document ] :
docs = (
docs = [ ]
[ ]
if (
if " documents " not in response . generation_info
" documents " in response . generation_info
or len ( response . generation_info [ " documents " ] ) == 0
and len ( response . generation_info [ " documents " ] ) > 0
else [
) :
Document ( page_content = doc [ " snippet " ] , metadata = doc )
for doc in response . generation_info [ " documents " ] :
for doc in response . generation_info [ " documents " ]
content = doc . get ( " snippet " , None ) or doc . get ( " text " , None )
]
if content is not None :
)
docs . append ( Document ( page_content = content , metadata = doc ) )
docs . append (
docs . append (
Document (
Document (
page_content = response . message . content ,
page_content = response . message . content ,
@ -63,12 +64,18 @@ class CohereRagRetriever(BaseRetriever):
""" Allow arbitrary types. """
""" Allow arbitrary types. """
def _get_relevant_documents (
def _get_relevant_documents (
self , query : str , * , run_manager : CallbackManagerForRetrieverRun , * * kwargs : Any
self ,
query : str ,
* ,
run_manager : CallbackManagerForRetrieverRun ,
documents : Optional [ List [ Dict [ str , str ] ] ] = None ,
* * kwargs : Any ,
) - > List [ Document ] :
) - > List [ Document ] :
messages : List [ List [ BaseMessage ] ] = [ [ HumanMessage ( content = query ) ] ]
messages : List [ List [ BaseMessage ] ] = [ [ HumanMessage ( content = query ) ] ]
res = self . llm . generate (
res = self . llm . generate (
messages ,
messages ,
connectors = self . connectors ,
connectors = self . connectors if documents is None else None ,
documents = documents ,
callbacks = run_manager . get_child ( ) ,
callbacks = run_manager . get_child ( ) ,
* * kwargs ,
* * kwargs ,
) . generations [ 0 ] [ 0 ]
) . generations [ 0 ] [ 0 ]
@ -79,13 +86,15 @@ class CohereRagRetriever(BaseRetriever):
query : str ,
query : str ,
* ,
* ,
run_manager : AsyncCallbackManagerForRetrieverRun ,
run_manager : AsyncCallbackManagerForRetrieverRun ,
documents : Optional [ List [ Dict [ str , str ] ] ] = None ,
* * kwargs : Any ,
* * kwargs : Any ,
) - > List [ Document ] :
) - > List [ Document ] :
messages : List [ List [ BaseMessage ] ] = [ [ HumanMessage ( content = query ) ] ]
messages : List [ List [ BaseMessage ] ] = [ [ HumanMessage ( content = query ) ] ]
res = (
res = (
await self . llm . agenerate (
await self . llm . agenerate (
messages ,
messages ,
connectors = self . connectors ,
connectors = self . connectors if documents is None else None ,
documents = documents ,
callbacks = run_manager . get_child ( ) ,
callbacks = run_manager . get_child ( ) ,
* * kwargs ,
* * kwargs ,
)
)