-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4696111
Showing
63 changed files
with
1,460 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
env/ | ||
.vscode | ||
*pycache* | ||
*.bin | ||
*.pt | ||
*cached* | ||
*checkpoint* | ||
analysis/output/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# The r/Jokes Dataset: a Large Scale Humor Collection | ||
Code and Datasets from the paper, ["The r/Jokes Dataset: a Large Scale Humor Collection"](TODO) by Orion Weller and Kevin Seppi | ||
|
||
Dataset files are located in `data/{train/dev/test}.tsv` for the regression task, while the full unsplit data can be found in `data/preprocessed.tsv`. These files will need to be unzipped after cloning the repo. | ||
|
||
For related projects, see our work on [Humor Detection (separating the humorous jokes from the non-humorous)](https://github.com/orionw/RedditHumorDetection) or other cited work. | ||
|
||
** **We do not endorse these jokes. Please view at your own risk** ** | ||
|
||
## License | ||
The data is under the [Reddit License and Terms of Service](https://www.reddit.com/wiki/api-terms) and users must follow the Reddit User Agreement and Privacy Policy, as well as remove any posts if asked to by the original user. For more details on this, please see the link above. | ||
|
||
# Usage | ||
## Load the Required Packages | ||
0. Run `pip3 install -r requirements.txt` | ||
1. Gather the NLTK packages by running `bash download_nltk_packages.sh`. This downloads the packages `averaged_perceptron_tagger`, `words`, `stopwords`, `maxent_ne_chunker`, used for analysis/preprocessing. | ||
|
||
## Reproduce the current dataset (updated to Jan 1st 2020) | ||
### We chunk this process into three parts to avoid networking errors | ||
0. Run `python3 gather_reddit_pushshift.py` after `cd prepare_data` to gather the Reddit post ids. | ||
1. Run `python3 preprocess.py --update` to update the Reddit post IDs with the full post. | ||
2. Run `python3 preprocess.py --preprocess` to preprocess the Reddit posts into final datasets | ||
|
||
## Reproduce plots and analysis from the paper | ||
0. Run `cd analysis` | ||
1. Run `python3 time_statistics.py` to gather the statistics that display over time | ||
2. Run `python3 dataset_statistics.py` to gather the overall dataset statistics | ||
3. See plots in the `./plots` folder | ||
|
||
## Re-gather All Jokes and Extend With Newer Jokes | ||
0. Run the first two commands in the `Reproduce` section above | ||
1. Update the code in the `preprocess` function of the `preprocess.py` file to NOT remove all jokes after 2020 (line 89). Then run `python3 preprocess.py --preprocess` | ||
|
||
# Reference: | ||
If you found this repository helpful, please cite the following paper: | ||
``` | ||
@ARTICLE{rjokesData2020, | ||
title={The r/Jokes Dataset: a Large Scale Humor Collection}, | ||
author={Weller, Orion and Seppi, Kevin}, | ||
journal={"Proceedings of the 2020 Conference of Language Resources and Evaluation"}, | ||
month=May, | ||
year = "2020", | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import os | ||
import json | ||
import string | ||
import random | ||
|
||
import pandas as pd | ||
import numpy as np | ||
from wordcloud import WordCloud | ||
from sklearn.feature_extraction.text import CountVectorizer | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
import nltk | ||
from nltk.corpus import stopwords | ||
from nltk.tokenize import word_tokenize | ||
from nltk.stem import PorterStemmer | ||
from nltk.stem import LancasterStemmer | ||
from nltk import ne_chunk, pos_tag, word_tokenize | ||
from nltk.tree import Tree | ||
|
||
from sklearn.decomposition import LatentDirichletAllocation as LDA | ||
|
||
random.seed(42) | ||
np.random.seed(42) | ||
|
||
EXT_TYPES = ["pdf", "png"] | ||
REMOVE_TOKENS = set(stopwords.words('english') + ["n't", "ve"]).union(set(string.punctuation)) | ||
sns.set_style('whitegrid') | ||
|
||
def get_continuous_chunks(text): | ||
chunked = ne_chunk(pos_tag(word_tokenize(text))) | ||
continuous_chunk = [] | ||
current_chunk = [] | ||
for i in chunked: | ||
if type(i) == Tree: | ||
current_chunk.append(" ".join([token for token, pos in i.leaves()])) | ||
elif current_chunk: | ||
named_entity = " ".join(current_chunk) | ||
if named_entity not in continuous_chunk: | ||
continuous_chunk.append(named_entity) | ||
current_chunk = [] | ||
else: | ||
continue | ||
return continuous_chunk | ||
|
||
|
||
def plot_n_most_common_words(text: pd.Series, data_name: str, n: int = 10, specific_year: str = None): | ||
entities = [] | ||
for sentence in text: | ||
entities.extend(get_continuous_chunks(sentence)) | ||
stemmer = PorterStemmer() | ||
filtered_words = [stemmer.stem(w) for w in entities if not w in REMOVE_TOKENS] | ||
# Initialise the count vectorizer with the English stop words | ||
count_vectorizer = CountVectorizer(stop_words='english') | ||
# Fit and transform the processed titles | ||
count_data = count_vectorizer.fit_transform(filtered_words) | ||
# Visualise the 10 most common words | ||
words = count_vectorizer.get_feature_names() | ||
total_counts = np.zeros(len(words)) | ||
for t in count_data: | ||
total_counts+=t.toarray()[0] | ||
|
||
get_lda(count_data, count_vectorizer) | ||
|
||
count_dict = (zip(words, total_counts)) | ||
count_dict = sorted(count_dict, key=lambda x:x[1], reverse=True)[0:n] | ||
words = [w[0] for w in count_dict] | ||
counts = [w[1] for w in count_dict] | ||
x_pos = np.arange(len(words)) | ||
|
||
for ext_type in EXT_TYPES: | ||
plt.figure(2, figsize=(15, 15/1.6180)) | ||
plt.subplot(title='{} most common words'.format(n)) | ||
sns.set_context("notebook", font_scale=1.25, rc={"lines.linewidth": 2.5}) | ||
sns.barplot(x_pos, counts, palette='husl') | ||
plt.xticks(x_pos, words, rotation=90) | ||
plt.xlabel('words') | ||
plt.ylabel('counts') | ||
plt.savefig(os.path.join(os.path.realpath('..'), "plots", data_name, | ||
"{}top_wordcounts.{}".format("" if specific_year is None else str(specific_year) + "_", ext_type))) | ||
plt.close() | ||
|
||
|
||
def run_wordcloud(text: list, data_name: str): | ||
word_tokens = [] | ||
try: | ||
for index, sentence in enumerate(text): | ||
if pd.isnull(sentence): | ||
print("index is null", index) | ||
continue | ||
word_tokens.extend(word_tokenize(sentence)) | ||
except Exception as e: | ||
import pdb; pdb.set_trace() | ||
print(e, sentence) | ||
filtered_words = [w for w in word_tokens if not w in REMOVE_TOKENS] | ||
# Create a WordCloud object | ||
wordcloud = WordCloud(background_color="white", width=800, height=800, max_words=150, contour_width=3, contour_color='steelblue') | ||
# Generate a word cloud | ||
wordcloud.generate(" ".join(filtered_words)) | ||
# Visualize the word cloud | ||
for ext_type in EXT_TYPES: | ||
wordcloud.to_file(os.path.join(os.path.realpath('..'), "plots", data_name, 'wordCloud.{}'.format(ext_type))) | ||
plt.close() | ||
|
||
|
||
def print_topics(model, count_vectorizer, n_top_words): | ||
words = count_vectorizer.get_feature_names() | ||
for topic_idx, topic in enumerate(model.components_): | ||
print("\nTopic #%d:" % topic_idx) | ||
print(" ".join([words[i] | ||
for i in topic.argsort()[:-n_top_words - 1:-1]])) | ||
|
||
def get_lda(count_data, count_vectorizer): | ||
# Tweak the two parameters below | ||
number_topics = 5 | ||
number_words = 10 | ||
# Create and fit the LDA model | ||
lda = LDA(n_components=number_topics, n_jobs=-1) | ||
lda.fit(count_data) | ||
print("Topics found via LDA:") | ||
print_topics(lda, count_vectorizer, number_words) | ||
|
||
def get_statistics(df: pd.DataFrame, data_name: str): | ||
distrib = df.copy(deep=True) | ||
distrib = distrib[distrib["score"] != 0] | ||
for ext_type in EXT_TYPES: | ||
sns.distplot(distrib["score"]) | ||
plt.savefig(os.path.join(os.path.realpath('..'), "plots", data_name, 'total_score_distribution.{}'.format(ext_type))) | ||
plt.close() | ||
|
||
punch = np.array([len(sentence) for sentence in df["punchline"].values]) | ||
body = np.array([len(sentence) for sentence in df["body"].values]) | ||
joke = np.array([len(sentence) for sentence in df["joke"].values]) | ||
|
||
average_punch = np.mean(punch) | ||
average_body = np.mean(body) | ||
average_joke = np.mean(joke) - 1.0 # Joke Token, `AND` to join body and punchline together | ||
|
||
std_punch = np.nanstd(punch) | ||
std_body = np.nanstd(body) | ||
std_joke = np.nanstd(joke) | ||
|
||
ave_tokens_punch = np.nanmean(np.array([len(nltk.word_tokenize(sentence)) for sentence in df["punchline"].dropna().values])) | ||
ave_tokens_body = np.nanmean(np.array([len(nltk.word_tokenize(sentence)) for sentence in df["body"].dropna().values])) | ||
ave_tokens_joke = np.nanmean(np.array([len(nltk.word_tokenize(sentence)) for sentence in df["joke"].dropna().values])) | ||
|
||
tokens = [] | ||
[tokens.extend(nltk.word_tokenize(joke)) for joke in df["joke"].dropna().values] | ||
total_tokens = len(set(tokens)) | ||
|
||
stat_df = pd.DataFrame([{"ave_punchline_len": average_punch, "ave_body_len": average_body, "ave_joke_len": average_joke, "std_punch": std_punch, | ||
"std_body": std_body, "std_joke": std_joke, "total_tokens": total_tokens}]) | ||
stat_df.to_csv(os.path.join(os.path.realpath('..'), "plots", data_name, 'statistics.txt')) | ||
|
||
|
||
def plot_sentiment(df: pd.DataFrame, data_name: str): | ||
for ext_type in EXT_TYPES: | ||
ax = sns.lineplot(x="date", y="prop", hue="sentiment", data=df, ci=False) | ||
# Find the x,y coordinates for each point | ||
x_coords = [] | ||
y_coords = [] | ||
for point_pair in ax.collections: | ||
for x, y in point_pair.get_offsets(): | ||
x_coords.append(x) | ||
y_coords.append(y) | ||
# create the custom error bars | ||
colors = ['steelblue']*2 + ['coral']*2 | ||
ax.errorbar(x_coords, y_coords, yerr=df["std"], | ||
ecolor=colors, fmt=' ', zorder=-1) | ||
ax.savefig(os.path.join(os.path.realpath('..'), "data", data_name, "sentiment_plot.{}".format(ext_type))) | ||
|
||
|
||
|
||
def gather_data(df: pd.DataFrame, data_name: str): | ||
if not os.path.isdir(os.path.join(os.path.realpath('..'), "plots", data_name)): | ||
os.mkdir(os.path.join(os.path.realpath('..'), "plots", data_name)) | ||
run_wordcloud(df["joke"].tolist(), data_name) | ||
plot_n_most_common_words(df["joke"].tolist(), data_name) | ||
get_statistics(df, data_name) | ||
|
||
|
||
def percentiles_upvotes(df: pd.DataFrame, data_name: str) -> pd.DataFrame: | ||
list_of_percentiles = [] | ||
for percentile in [0, 10, 25, 50, 75, 90, 100]: | ||
cur_per = np.percentile(df["score"], percentile) | ||
list_of_percentiles.append({"percentile": percentile, "value": cur_per}) | ||
percent_df = pd.DataFrame(list_of_percentiles) | ||
percent_df.to_csv(os.path.join(os.path.realpath('..'), "plots", data_name, "percentiles.csv")) | ||
return percent_df | ||
|
||
|
||
if __name__ == "__main__": | ||
# NOTE: log-distribution plots are found in the preprocess.py script | ||
df = pd.read_csv(os.path.join(os.path.realpath('..'), "data", "preprocessed.csv"), index_col=None, encoding="UTF-8", keep_default_na=False) | ||
df["date"] = pd.to_numeric(df["date"]) | ||
df["score"] = pd.to_numeric(df["score"]) | ||
df = df[df["date"].isna() == False] | ||
assert df.shape == df.dropna().shape, "was nans that are unaccounted for" | ||
percentiles_upvotes(df, "all") | ||
gather_data(df, "all") | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import os | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import seaborn as sns | ||
import glob | ||
|
||
def gather_data_from_huggingface(root_dir: str, output_dir: str): | ||
all_models = {} | ||
for result_dir in os.listdir(root_dir): | ||
model_name = result_dir.split("/")[-1] | ||
model_results = [] | ||
checkpoint_folder = None | ||
for checkpoint_folder in glob.glob(os.path.join(root_dir, result_dir, "*checkpoint-*/")): | ||
if not os.path.isfile(os.path.join(checkpoint_folder, "eval_results.txt")): | ||
print("No results file in path:", checkpoint_folder) | ||
else: | ||
results_on_check = {} | ||
with open(os.path.join(checkpoint_folder, "eval_results.txt"), "r") as f: | ||
for index, line in enumerate(f): | ||
name, value = line.split("=") | ||
name, value = name.strip(), float(value.strip()) | ||
results_on_check[name] = value | ||
results_on_check["model"] = model_name | ||
results_on_check["checkpoint"] = checkpoint_folder.split("/")[-2] | ||
model_results.append(results_on_check) | ||
results_df = pd.DataFrame(model_results) | ||
if checkpoint_folder is not None: | ||
results_df.to_csv(os.path.join(checkpoint_folder, "model_results.csv")) | ||
ax = sns.lineplot(x="index", y="rmse", data=results_df.reset_index()) | ||
plt.savefig(os.path.join(root_dir, result_dir, "rmse_plot.png")) | ||
plt.close() | ||
all_models[model_name] = results_df | ||
print("Wrote model file to ", os.path.join(result_dir, "rmse_plot.png")) | ||
|
||
if all_models: | ||
# now that all the results have been gathered, let's combine them | ||
full_df = pd.concat(list(all_models.values())).reset_index() | ||
full_df.to_csv(output_dir) | ||
# focus on lowest RMSE (could focus on Pearson, Spearmanr etc.) | ||
min_rmse = full_df.groupby("model")["rmse"].idxmin() | ||
best_results = full_df.iloc[min_rmse, :] | ||
|
||
# plot line plot here | ||
ax = sns.barplot(x="model", y="rmse", data=best_results) | ||
plt.savefig(os.path.join(root_dir, "rmse_plot.png")) | ||
print("Wrote full file to ", os.path.join(root_dir, "rmse_plot.png")) | ||
plt.close() | ||
print(best_results) | ||
|
||
|
||
if __name__ == "__main__": | ||
gather_data_from_huggingface("output/large", "regression_results_large.csv") | ||
gather_data_from_huggingface("output/base/", "regression_results_base.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
export DATA_DIR=../data | ||
# regression task set up | ||
export TASK_NAME=STS-B | ||
declare -a arr=("roberta" "bert" "xlnet") | ||
declare -a arrtwo=("roberta-base" "bert-base-uncased" "xlnet-base-cased") | ||
|
||
|
||
for ((i=0; i<3; i++)); | ||
do | ||
python run_glue.py \ | ||
--model_type "${arr[i]}" \ | ||
--model_name_or_path "${arrtwo[i]}" \ | ||
--task_name $TASK_NAME \ | ||
--do_train \ | ||
--do_eval \ | ||
--do_lower_case \ | ||
--data_dir $DATA_DIR \ | ||
--eval_all_checkpoints \ | ||
--max_seq_length 128 \ | ||
--per_gpu_train_batch_size 96 \ | ||
--learning_rate 2e-5 \ | ||
--num_train_epochs 5.0 \ | ||
--overwrite_output_dir \ | ||
--output_dir output/base/"${arr[i]}" | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
export DATA_DIR=../data | ||
# regression task set up | ||
export TASK_NAME=STS-B | ||
declare -a arr=("roberta" "bert" "xlnet") | ||
# I used batch size of 96 with xlnet-large and 128 with others | ||
declare -a arrtwo=("roberta-large" "bert-large-uncased" "xlnet-large-cased") | ||
|
||
|
||
for ((i=0; i<3; i++)); | ||
do | ||
python run_glue.py \ | ||
--model_type "${arr[i]}" \ | ||
--model_name_or_path "${arrtwo[i]}" \ | ||
--task_name $TASK_NAME \ | ||
--do_train \ | ||
--do_eval \ | ||
--do_lower_case \ | ||
--data_dir $DATA_DIR \ | ||
--eval_all_checkpoints \ | ||
--max_seq_length 128 \ | ||
--per_gpu_train_batch_size 128 \ | ||
--learning_rate 2e-5 \ | ||
--num_train_epochs 5.0 \ | ||
--overwrite_output_dir \ | ||
--output_dir output/large/"${arr[i]}" | ||
done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
,index,corr,pearson,rmse,spearmanr,model,checkpoint | ||
0,0,0.4443397796905355,0.4651837567754322,1.7541763,0.42349580260563885,roberta,checkpoint-12000 | ||
1,1,0.4313989023673853,0.45019291276869466,1.6386057,0.41260489196607597,roberta,checkpoint-6000 | ||
2,2,0.4452852768941931,0.4658201273950414,1.8800247,0.4247504263933448,roberta,checkpoint-18000 | ||
3,0,0.4318083997703859,0.4531463200754483,1.8907644,0.4104704794653235,xlnet,checkpoint-12000 | ||
4,1,0.42182588036243795,0.4421904039165935,1.6474583,0.4014613568082824,xlnet,checkpoint-6000 | ||
5,2,0.43808309570939125,0.4580278612177484,1.9638016,0.4181383302010341,xlnet,checkpoint-18000 | ||
6,0,0.44518519943639,0.46596080774994675,1.7353058,0.4244095911228332,bert,checkpoint-12000 | ||
7,1,0.43999650071057195,0.4611161397762899,1.6263107,0.41887686164485405,bert,checkpoint-6000 | ||
8,2,0.44383916832300385,0.4650446676294345,1.8012,0.42263366901657323,bert,checkpoint-18000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
,index,corr,pearson,rmse,spearmanr,model,checkpoint | ||
0,0,0.4667788714460796,0.4883324584728209,1.7011992,0.4452252844193383,roberta,checkpoint-12000 | ||
1,1,0.4544410080159367,0.47388892823620876,1.6138555,0.43499308779566465,roberta,checkpoint-6000 | ||
2,2,0.4660406387792235,0.48559753527177185,1.8018064,0.4464837422866752,roberta,checkpoint-18000 | ||
3,0,0.43412252294017617,0.4567995734825274,1.7385621,0.411445472397825,xlnet,checkpoint-12000 | ||
4,1,0.37713699244381943,0.38511733976723006,1.9101037,0.3691566451204088,xlnet,checkpoint-6000 | ||
5,2,0.44718551599712086,0.4682513100855499,1.9024026,0.42611972190869185,xlnet,checkpoint-24000 | ||
6,3,0.44616641205790697,0.4687108119213232,1.772816,0.4236220121944907,xlnet,checkpoint-18000 | ||
7,0,0.44932231309459303,0.47116557636851286,1.7352395,0.4274790498206732,bert,checkpoint-12000 | ||
8,1,0.4506944418874647,0.4709748411712956,1.6188005,0.4304140426036338,bert,checkpoint-6000 | ||
9,2,0.4523183259704836,0.47327446616808244,1.8045843,0.43136218577288477,bert,checkpoint-18000 |
Oops, something went wrong.