reddit/
client.rs

1use crate::{
2    error::Error,
3    types::Thing,
4};
5
6// Guesses for good defaults for the user agent.
7
8// TODO: Extract from target
9const DEFAULT_PLATFORM: &str = "pc";
10
11const DEFAULT_APP_ID: &str = env!("CARGO_PKG_NAME");
12const DEFAULT_APP_VERSION: &str = env!("CARGO_PKG_VERSION");
13
14// TODO: Is there really a good default to choose here?
15const DEFAULT_REDDIT_USERNAME: &str = "deleted";
16
17/// A client to access reddit
18#[derive(Clone)]
19pub struct Client {
20    /// The inner http client.
21    ///
22    /// It probably shouldn't be used directly by you.
23    /// It also sets a strange user-agent as well in accordance with reddit's request.
24    pub client: reqwest::Client,
25}
26
27impl Client {
28    /// Create a new [`Client`].
29    pub fn new() -> Self {
30        Self::new_with_user_agent(
31            DEFAULT_PLATFORM,
32            DEFAULT_APP_ID,
33            DEFAULT_APP_VERSION,
34            DEFAULT_REDDIT_USERNAME,
35        )
36    }
37
38    /// Create a new [`Client`] with a user-agent.
39    ///
40    /// See https://github.com/reddit-archive/reddit/wiki/API#rules
41    pub fn new_with_user_agent(
42        platform: &str,
43        app_id: &str,
44        app_version: &str,
45        reddit_username: &str,
46    ) -> Self {
47        let user_agent = format!("{platform}:{app_id}:v{app_version} (by /u/{reddit_username})");
48
49        let mut client_builder = reqwest::Client::builder();
50        client_builder = client_builder.user_agent(user_agent);
51
52        let client = client_builder
53            .build()
54            .expect("failed to build reddit client");
55
56        Self { client }
57    }
58
59    /// Get the top posts of a subreddit where subreddit is the name and num_posts is the number of posts to retrieve.
60    pub async fn get_subreddit(&self, subreddit: &str, num_posts: usize) -> Result<Thing, Error> {
61        let url = format!("https://www.reddit.com/r/{subreddit}.json?limit={num_posts}");
62        let res = self.client.get(&url).send().await?.error_for_status()?;
63
64        // Reddit will redirect us here if the subreddit could not be found.
65        const SEARCH_URL: &str = "https://www.reddit.com/subreddits/search.json?";
66        if res.url().as_str().starts_with(SEARCH_URL) {
67            return Err(Error::SubredditNotFound);
68        }
69
70        let text = res.text().await?;
71        serde_json::from_str(&text).map_err(|error| Error::Json {
72            data: text.into(),
73            error,
74        })
75    }
76
77    /// Get the post data for a post from a given subreddit
78    pub async fn get_post(&self, subreddit: &str, post_id: &str) -> Result<Vec<Thing>, Error> {
79        let url = format!("https://www.reddit.com/r/{subreddit}/comments/{post_id}.json");
80        Ok(self
81            .client
82            .get(&url)
83            .send()
84            .await?
85            .error_for_status()?
86            .json()
87            .await?)
88    }
89}
90
91impl Default for Client {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97#[cfg(test)]
98mod test {
99    use super::*;
100
101    async fn get_subreddit(name: &str) -> Result<(), Error> {
102        let client = Client::new();
103        // 25 is the default
104        let subreddit = client.get_subreddit(name, 100).await?;
105        println!(
106            "# of children: {}",
107            subreddit.data.as_listing().unwrap().children.len()
108        );
109        Ok(())
110    }
111
112    #[tokio::test]
113    #[ignore]
114    async fn get_post_works() {
115        let post_data = [
116            ("dankmemes", "h966lq"),
117            // ("dankvideos", "h8p0py"), // Subreddit got privated, last tested 12/23/2022. Uncomment in the future to see if that is still the case.
118            ("oddlysatisfying", "ha7obv"),
119        ];
120        let client = Client::new();
121
122        for (subreddit, post_id) in post_data.iter() {
123            let post = client
124                .get_post(subreddit, post_id)
125                .await
126                .expect("failed to get post");
127            dbg!(&post);
128        }
129    }
130
131    #[tokio::test]
132    #[ignore]
133    async fn get_subreddit_works() {
134        let subreddits = [
135            "forbiddensnacks",
136            "dankmemes",
137            "cursedimages",
138            "MEOW_IRL",
139            "cuddleroll",
140            "cromch",
141            "cats",
142            "cursed_images",
143            "aww",
144        ];
145
146        for subreddit in subreddits.iter() {
147            match get_subreddit(subreddit).await {
148                Ok(()) => {}
149                Err(Error::Json { data, error }) => {
150                    let line = error.line();
151                    let column = error.column();
152
153                    // Try to get error in data
154                    let maybe_data = data.split('\n').nth(line.saturating_sub(1)).map(|line| {
155                        let start = column.saturating_sub(30);
156
157                        &line[start..]
158                    });
159
160                    let _ = tokio::fs::write("subreddit-error.json", data.as_bytes())
161                        .await
162                        .is_ok();
163
164                    panic!(
165                        "failed to get subreddit \"{subreddit}\": {error:#?}\ndata: {maybe_data:?}"
166                    );
167                }
168                Err(error) => {
169                    panic!("failed to get subreddit \"{subreddit}\": {error:#?}");
170                }
171            }
172        }
173    }
174
175    #[tokio::test]
176    #[ignore]
177    async fn invalid_subreddit() {
178        let client = Client::new();
179        let error = client.get_subreddit("gfdghfj", 25).await.unwrap_err();
180        assert!(error.is_subreddit_not_found(), "error = {error:#?}");
181    }
182}