mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
122 lines
4.4 KiB
Python
122 lines
4.4 KiB
Python
|
"""Wrapper for the Reddit API"""
|
||
|
|
||
|
from typing import Any, Dict, List, Optional
|
||
|
|
||
|
from langchain_core.pydantic_v1 import BaseModel, root_validator
|
||
|
from langchain_core.utils import get_from_dict_or_env
|
||
|
|
||
|
|
||
|
class RedditSearchAPIWrapper(BaseModel):
|
||
|
"""Wrapper for Reddit API
|
||
|
|
||
|
To use, set the environment variables ``REDDIT_CLIENT_ID``,
|
||
|
``REDDIT_CLIENT_SECRET``, ``REDDIT_USER_AGENT`` to set the client ID,
|
||
|
client secret, and user agent, respectively, as given by Reddit's API.
|
||
|
Alternatively, all three can be supplied as named parameters in the
|
||
|
constructor: ``reddit_client_id``, ``reddit_client_secret``, and
|
||
|
``reddit_user_agent``, respectively.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
from langchain_community.utilities import RedditSearchAPIWrapper
|
||
|
reddit_search = RedditSearchAPIWrapper()
|
||
|
"""
|
||
|
|
||
|
reddit_client: Any
|
||
|
|
||
|
# Values required to access Reddit API via praw
|
||
|
reddit_client_id: Optional[str]
|
||
|
reddit_client_secret: Optional[str]
|
||
|
reddit_user_agent: Optional[str]
|
||
|
|
||
|
@root_validator()
|
||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||
|
"""Validate that the API ID, secret and user agent exists in environment
|
||
|
and check that praw module is present.
|
||
|
"""
|
||
|
reddit_client_id = get_from_dict_or_env(
|
||
|
values, "reddit_client_id", "REDDIT_CLIENT_ID"
|
||
|
)
|
||
|
values["reddit_client_id"] = reddit_client_id
|
||
|
|
||
|
reddit_client_secret = get_from_dict_or_env(
|
||
|
values, "reddit_client_secret", "REDDIT_CLIENT_SECRET"
|
||
|
)
|
||
|
values["reddit_client_secret"] = reddit_client_secret
|
||
|
|
||
|
reddit_user_agent = get_from_dict_or_env(
|
||
|
values, "reddit_user_agent", "REDDIT_USER_AGENT"
|
||
|
)
|
||
|
values["reddit_user_agent"] = reddit_user_agent
|
||
|
|
||
|
try:
|
||
|
import praw
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"praw package not found, please install it with pip install praw"
|
||
|
)
|
||
|
|
||
|
reddit_client = praw.Reddit(
|
||
|
client_id=reddit_client_id,
|
||
|
client_secret=reddit_client_secret,
|
||
|
user_agent=reddit_user_agent,
|
||
|
)
|
||
|
values["reddit_client"] = reddit_client
|
||
|
|
||
|
return values
|
||
|
|
||
|
def run(
|
||
|
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
|
||
|
) -> str:
|
||
|
"""Search Reddit and return posts as a single string."""
|
||
|
results: List[Dict] = self.results(
|
||
|
query=query,
|
||
|
sort=sort,
|
||
|
time_filter=time_filter,
|
||
|
subreddit=subreddit,
|
||
|
limit=limit,
|
||
|
)
|
||
|
if len(results) > 0:
|
||
|
output: List[str] = [f"Searching r/{subreddit} found {len(results)} posts:"]
|
||
|
for r in results:
|
||
|
category = "N/A" if r["post_category"] is None else r["post_category"]
|
||
|
p = f"Post Title: '{r['post_title']}'\n\
|
||
|
User: {r['post_author']}\n\
|
||
|
Subreddit: {r['post_subreddit']}:\n\
|
||
|
Text body: {r['post_text']}\n\
|
||
|
Post URL: {r['post_url']}\n\
|
||
|
Post Category: {category}.\n\
|
||
|
Score: {r['post_score']}\n"
|
||
|
output.append(p)
|
||
|
return "\n".join(output)
|
||
|
else:
|
||
|
return f"Searching r/{subreddit} did not find any posts:"
|
||
|
|
||
|
def results(
|
||
|
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
|
||
|
) -> List[Dict]:
|
||
|
"""Use praw to search Reddit and return a list of dictionaries,
|
||
|
one for each post.
|
||
|
"""
|
||
|
subredditObject = self.reddit_client.subreddit(subreddit)
|
||
|
search_results = subredditObject.search(
|
||
|
query=query, sort=sort, time_filter=time_filter, limit=limit
|
||
|
)
|
||
|
search_results = [r for r in search_results]
|
||
|
results_object = []
|
||
|
for submission in search_results:
|
||
|
results_object.append(
|
||
|
{
|
||
|
"post_subreddit": submission.subreddit_name_prefixed,
|
||
|
"post_category": submission.category,
|
||
|
"post_title": submission.title,
|
||
|
"post_text": submission.selftext,
|
||
|
"post_score": submission.score,
|
||
|
"post_id": submission.id,
|
||
|
"post_url": submission.url,
|
||
|
"post_author": submission.author,
|
||
|
}
|
||
|
)
|
||
|
return results_object
|