langchain/libs/community/langchain_community/tools/gmail/search.py
Eugene Yurtsev 25fbe356b4
community[patch]: upgrade to recent version of mypy (#21616)
This PR upgrades community to a recent version of mypy. It inserts type:
ignore on all existing failures.
2024-05-13 14:55:07 -04:00

146 lines
5.1 KiB
Python

import base64
import email
from enum import Enum
from typing import Any, Dict, List, Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.tools.gmail.base import GmailBaseTool
from langchain_community.tools.gmail.utils import clean_email_body
class Resource(str, Enum):
"""Enumerator of Resources to search."""
THREADS = "threads"
MESSAGES = "messages"
class SearchArgsSchema(BaseModel):
"""Input for SearchGmailTool."""
# From https://support.google.com/mail/answer/7190?hl=en
query: str = Field(
...,
description="The Gmail query. Example filters include from:sender,"
" to:recipient, subject:subject, -filtered_term,"
" in:folder, is:important|read|starred, after:year/mo/date, "
"before:year/mo/date, label:label_name"
' "exact phrase".'
" Search newer/older than using d (day), m (month), and y (year): "
"newer_than:2d, older_than:1y."
" Attachments with extension example: filename:pdf. Multiple term"
" matching example: from:amy OR from:david.",
)
resource: Resource = Field(
default=Resource.MESSAGES,
description="Whether to search for threads or messages.",
)
max_results: int = Field(
default=10,
description="The maximum number of results to return.",
)
class GmailSearch(GmailBaseTool):
"""Tool that searches for messages or threads in Gmail."""
name: str = "search_gmail"
description: str = (
"Use this tool to search for email messages or threads."
" The input must be a valid Gmail query."
" The output is a JSON list of the requested resource."
)
args_schema: Type[SearchArgsSchema] = SearchArgsSchema
def _parse_threads(self, threads: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Add the thread message snippets to the thread results
results = []
for thread in threads:
thread_id = thread["id"]
thread_data = (
self.api_resource.users()
.threads()
.get(userId="me", id=thread_id)
.execute()
)
messages = thread_data["messages"]
thread["messages"] = []
for message in messages:
snippet = message["snippet"]
thread["messages"].append({"snippet": snippet, "id": message["id"]})
results.append(thread)
return results
def _parse_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
results = []
for message in messages:
message_id = message["id"]
message_data = (
self.api_resource.users()
.messages()
.get(userId="me", format="raw", id=message_id)
.execute()
)
raw_message = base64.urlsafe_b64decode(message_data["raw"])
email_msg = email.message_from_bytes(raw_message)
subject = email_msg["Subject"]
sender = email_msg["From"]
message_body = ""
if email_msg.is_multipart():
for part in email_msg.walk():
ctype = part.get_content_type()
cdispo = str(part.get("Content-Disposition"))
if ctype == "text/plain" and "attachment" not in cdispo:
try:
message_body = part.get_payload(decode=True).decode("utf-8") # type: ignore[union-attr]
except UnicodeDecodeError:
message_body = part.get_payload(decode=True).decode( # type: ignore[union-attr]
"latin-1"
)
break
else:
message_body = email_msg.get_payload(decode=True).decode("utf-8") # type: ignore[union-attr]
body = clean_email_body(message_body)
results.append(
{
"id": message["id"],
"threadId": message_data["threadId"],
"snippet": message_data["snippet"],
"body": body,
"subject": subject,
"sender": sender,
}
)
return results
def _run(
self,
query: str,
resource: Resource = Resource.MESSAGES,
max_results: int = 10,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> List[Dict[str, Any]]:
"""Run the tool."""
results = (
self.api_resource.users()
.messages()
.list(userId="me", q=query, maxResults=max_results)
.execute()
.get(resource.value, [])
)
if resource == Resource.THREADS:
return self._parse_threads(results)
elif resource == Resource.MESSAGES:
return self._parse_messages(results)
else:
raise NotImplementedError(f"Resource of type {resource} not implemented.")