|
import time |
|
from datetime import datetime, timedelta, timezone |
|
from typing import Any, Dict, List, Optional |
|
|
|
import pandas as pd |
|
import requests |
|
|
|
from my_logger import setup_logger |
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
def get_pushshift_data(subreddit: str, before: Optional[int] = None, |
|
after: Optional[int] = None, aggs: Optional[str] = None) -> Optional[Dict[str, Any]]: |
|
""" |
|
Fetch data from the Pushshift API for the specified subreddit. |
|
|
|
:param subreddit: The name of the subreddit to scrape. |
|
:param before: The upper limit for the created_utc attribute of the submissions. |
|
:param after: The lower limit for the created_utc attribute of the submissions. |
|
:param aggs: The aggregation summary option to use. |
|
:return: A dictionary containing the fetched data and aggregations if available. |
|
""" |
|
url = "https://api.pushshift.io/reddit/search/submission/" |
|
params = { |
|
"subreddit": subreddit, |
|
"size": 1000, |
|
"sort": "created_utc", |
|
"sort_type": "desc", |
|
} |
|
if before is not None: |
|
params["before"] = before |
|
if after is not None: |
|
params["after"] = after |
|
if aggs is not None: |
|
params["aggs"] = aggs |
|
|
|
response = requests.get(url, params=params) |
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
logger.error(f"Error fetching data: {response.status_code}") |
|
return None |
|
|
|
|
|
def get_post_count_for_day(subreddit: str, day_to_scrape: str) -> int: |
|
""" |
|
Get the total number of posts for a specific day in the specified subreddit using the Pushshift API. |
|
|
|
:param subreddit: The name of the subreddit to get the post count for. |
|
:param day_to_scrape: The date for which to get the post count (format: "YYYY-MM-DD"). |
|
:return: The total number of posts for the specified day. |
|
""" |
|
date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d") |
|
after = int(date_obj.timestamp()) |
|
before = int((date_obj + timedelta(days=1)).timestamp()) |
|
|
|
response = get_pushshift_data(subreddit, before=before, after=after, aggs="created_utc") |
|
if response is not None: |
|
aggs = response.get("aggs", {}).get("created_utc", []) |
|
if aggs: |
|
return aggs[0]["doc_count"] |
|
return 0 |
|
|
|
|
|
def fetch_data(subreddit: str, before: int, after: int) -> Optional[Dict[str, Any]]: |
|
url = "https://api.pushshift.io/reddit/search/submission/" |
|
params = { |
|
"subreddit": subreddit, |
|
"size": 1000, |
|
"sort": "created_utc", |
|
"sort_type": "desc", |
|
"before": before, |
|
"after": after, |
|
} |
|
|
|
response = requests.get(url, params=params) |
|
if response.status_code == 200: |
|
return response.json() |
|
else: |
|
logger.error(f"Error fetching data: {response.status_code}") |
|
return None |
|
|
|
|
|
def convert_timestamp_to_datetime(timestamp: int) -> str: |
|
|
|
datetime_obj = datetime.utcfromtimestamp(timestamp) |
|
|
|
|
|
datetime_obj_utc = datetime_obj.replace(tzinfo=timezone.utc) |
|
|
|
|
|
datetime_str = datetime_obj_utc.strftime('%Y-%m-%d %H:%M:%S') |
|
|
|
return datetime_str |
|
|
|
|
|
def scrape_submissions_by_day(subreddit_to_scrape: str, day_to_scrape: str) -> List[Dict[str, Any]]: |
|
start_time = time.time() |
|
scraped_submissions = [] |
|
date_obj = datetime.strptime(day_to_scrape, "%Y-%m-%d") |
|
|
|
if date_obj > datetime.now() - timedelta(days=7): |
|
logger.error("The specified date might not be available in the Pushshift API yet. " |
|
"Please try an earlier date or wait for the API to be updated.") |
|
return scraped_submissions |
|
|
|
after = int(date_obj.timestamp()) |
|
before = int((date_obj + timedelta(days=1)).timestamp()) |
|
|
|
|
|
|
|
|
|
|
|
actual_requests = 0 |
|
while after < before: |
|
after_str, before_str = convert_timestamp_to_datetime(after), convert_timestamp_to_datetime(before) |
|
logger.info(f"Fetching data between timestamps {after_str} and {before_str}") |
|
data = get_pushshift_data(subreddit_to_scrape, before=before, after=after) |
|
if data is None or len(data["data"]) == 0: |
|
break |
|
|
|
scraped_submissions.extend(data["data"]) |
|
before = data["data"][-1]["created_utc"] |
|
|
|
actual_requests += 1 |
|
time.sleep(1) |
|
|
|
elapsed_time = time.time() - start_time |
|
if actual_requests: |
|
logger.info( |
|
f"{actual_requests}it [{elapsed_time // 60:02}:{elapsed_time % 60:.2f} {elapsed_time / actual_requests:.2f}s/it]") |
|
logger.info( |
|
f"Finished scraping {len(scraped_submissions)} submissions in {elapsed_time:.2f} seconds in {actual_requests} requests") |
|
return scraped_submissions |
|
|
|
|
|
def submissions_to_dataframe(submissions: List[Dict[str, Any]]) -> pd.DataFrame: |
|
""" |
|
Parse a list of submissions into a pandas DataFrame. |
|
|
|
:param submissions: A list of dictionaries containing the scraped submission data. |
|
:return: A pandas DataFrame containing the submission data. |
|
""" |
|
cols = ['score', 'num_comments', 'title', 'permalink', 'selftext', 'url', 'created_utc', 'author', 'id', |
|
'downs', 'ups'] |
|
df = pd.DataFrame(submissions) |
|
df = df.convert_dtypes() |
|
df = df[cols] |
|
|
|
df['created_utc'] = pd.to_datetime(df['created_utc'], unit='s').dt.tz_localize('UTC').dt.strftime( |
|
'%Y-%m-%d %H:%M:%S') |
|
return df |
|
|
|
|
|
if __name__ == '__main__': |
|
subreddit_to_scrape = "askreddit" |
|
day_to_scrape = "2013-03-01" |
|
submissions = scrape_submissions_by_day(subreddit_to_scrape, day_to_scrape) |
|
df = submissions_to_dataframe(submissions) |
|
print(df.head().to_string()) |
|
logger.info(f"Scraped {len(submissions)} submissions from r/{subreddit_to_scrape} on {day_to_scrape}") |
|
|