From a5d316e342c51ca338f5f5026b23110b0422d8f3 Mon Sep 17 00:00:00 2001
From: Tom Zehle
Date: Mon, 9 Jan 2023 14:19:59 +0100
Subject: [PATCH 1/2] added data_augment.py
---
scripts/data_augment/data_augment.py | 464 +++++++++++++++++++++++++++
1 file changed, 464 insertions(+)
create mode 100644 scripts/data_augment/data_augment.py
diff --git a/scripts/data_augment/data_augment.py b/scripts/data_augment/data_augment.py
new file mode 100644
index 0000000000..6b10a8510b
--- /dev/null
+++ b/scripts/data_augment/data_augment.py
@@ -0,0 +1,464 @@
+import torch
+import pandas as pd
+
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
+import nltk
+from nltk.corpus import wordnet
+import spacy
+from collections import Counter
+#from syntax.syntax_injector import SyntaxBug
+#from logic.logic_injector import LogicBug
+
+import random
+import string
+from bs4 import BeautifulSoup as bs
+
+import requests
+import json
+import html
+import os
+import subprocess
+import argparse
+
+
+
+class DataArgumenter:
+ def __init__(self):
+ raise NotImplementedError()
+
+ def parse(self, essays):
+ prompts = []
+ preds = []
+
+ for essay in essays:
+ essay_prompts, essay_preds = self.parse_single(essay)
+
+ prompts += essay_prompts
+ preds += essay_preds
+
+ return prompts, preds
+
+ def parse_single(self, essay):
+ pass
+
+
+class EssayInstructor(DataArgumenter):
+ def __init__(self, model_name=None):
+ if model_name is None:
+ model_name = "snrspeaks/t5-one-line-summary"
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ def parse_single(self, essay):
+ essay_paragraphs = essay.split("\n\n")
+ preds = []
+
+ for para in essay_paragraphs:
+ input_ids = self.tokenizer.encode(para, return_tensors="pt", add_special_tokens=True)
+ generated_ids = self.model.generate(
+ input_ids=input_ids,
+ num_beams=5,
+ max_length=35,
+ repetition_penalty=4.5,
+ length_penalty=1.5,
+ early_stopping=True,
+ num_return_sequences=1,
+ )
+ preds.append(
+ self.tokenizer.decode(
+ generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
+ )
+ )
+
+ prompts = (
+ ["Write an intro paragraph to an essay called"]
+ + ["Write a paragraph to an essay about"] * len(preds[1:-1])
+ + ["Write a concluding paragraph about"]
+ )
+
+ assert len(preds) == len(prompts)
+ prompts = [prompt + " " + pred for prompt, pred in zip(prompts, preds)]
+
+ return prompts, essay_paragraphs
+
+
+class EssayReviser(DataArgumenter):
+ def __init__(self):
+ nltk.download("wordnet")
+ nltk.download("omw-1.4")
+
+ def parse_single(self, essay):
+ instructions = []
+
+ # Make stucture error (shuffle one paragraph with another)
+ essay_paragraphs = essay.split("\n\n") # Splitting a String by newline character (\n)
+
+ rand1 = random.randint(0, len(essay_paragraphs) - 1)
+ rand2 = random.randint(0, len(essay_paragraphs) - 1)
+
+ temp = essay_paragraphs[rand1]
+ essay_paragraphs[rand1] = essay_paragraphs[rand2]
+ essay_paragraphs[rand2] = temp
+
+ corrupted_essay = "\n\n".join(essay_paragraphs)
+
+ instructions.append("Fix structure errors in this essay" + corrupted_essay)
+
+ essay_words = essay.split()
+ for i in range(len(essay_words)):
+ if random.randint(0, 100) < 30:
+ suggestion = []
+ for syn in wordnet.synsets(essay_words[i]):
+ for l in syn.lemmas():
+ suggestion.append(l.name())
+ if suggestion != []:
+ essay_words[i] = suggestion[random.randint(0, len(suggestion) - 1)]
+
+ corrupted_essay = " ".join(essay_words)
+
+ instructions.append("Fix grammar errors in this essay: " + corrupted_essay)
+
+ # you can change the number 60 to change how much corrupted this essay will be
+ for _ in range(len(essay) // 60):
+ rand = random.randint(0, len(essay))
+ corrupted_essay = essay[:rand] + random.choice(string.ascii_letters) + essay[rand+1:]
+
+ instructions.append("Fix typing errors in this essay" + corrupted_essay)
+
+ return instructions, [essay] * len(instructions)
+
+
+class StackExchangeBuilder(DataArgumenter):
+ def __init__(self, base_url=None, filter_opts=None):
+ self.base_url = base_url if base_url is not None else "https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/{0}&file=Posts.xml"
+ self.filter_opts = filter_opts if filter_opts is not None else ["accepted", "score", "convert_html", "clean_tags"]
+
+ def get_all_filenames(self):
+ response = requests.get("https://archive.org/download/stackexchange")
+ if response.ok:
+ soup = bs(response.content, "html.parser")
+ table = soup.find("table")
+ link_tags = table.find_all("a")
+ urls = {}
+ for link in link_tags:
+ url = link["href"]
+ name = url.split(".stackexchange")[0].replace(".", "_").replace("-", "_")
+ if url.endswith("7z"):
+ urls[name] = self.base_url.format(url)
+ return urls
+
+ def xml_to_df(self, response: str):
+ """
+ Collect and Manually import XML into Dataframe
+
+ pd.read_xml() errors when XML trees are too large, this is just a hack to
+ download a XML file and parse into a Dataframe. **Not Tested on huge XML files**
+
+ Parameters:
+ response (Requests.Response): Requests response object with the XML data
+
+ Returns:
+ df (DataFrame): A Dataframe from the XML file
+ """
+ xml_format_map = {
+ "Id": int,
+ "PostTypeId": int,
+ "CreationDate": str,
+ "Score": int,
+ "ViewCount": int,
+ "Body": str,
+ "AnswerCount": int,
+ "CommentCount": int,
+ "ContentLicense": str,
+ "AcceptedAnswerId": int,
+ "ParentId": int,
+ }
+ soup = bs(response.content, "xml")
+ posts = soup.find_all("row")
+
+ all_posts = [post.attrs for post in posts]
+
+ df = pd.DataFrame(all_posts)
+ df.AnswerCount.fillna(0, inplace=True)
+ df.ViewCount.fillna(0, inplace=True)
+ df.AcceptedAnswerId.fillna(0, inplace=True)
+ df.ParentId.fillna(0, inplace=True)
+ df["DataSource"] = response.url
+ df = df.astype(xml_format_map)
+ return df
+
+ def filter(self, df):
+ if "accepted" in self.filter_opts:
+ """**TODO**
+ Filter only to Questions with Accepted Answers
+
+ Filter dataframe by questions that have accepted answers, should also include
+ all rows of answers for those questions, even if not accepted."""
+
+ df = df[(df["AcceptedAnswerId"].notnull()) | (df["ParentId"] == df["Id"])]
+
+ if "score" in self.filter_opts:
+ """**TODO**
+ Filter Dataframe by minimum scores
+
+ Filter Question and Answer columns by score thresholds to trim lower scoring results"""
+ question_score_threshold = 0
+ answer_score_threshold = 5
+ df = df[
+ ((df["Score"] >= question_score_threshold) & (df.PostTypeId == 1))
+ | ((df["Score"] >= answer_score_threshold) & (df.PostTypeId == 2))
+ ]
+
+
+ if "clean_tags" in self.filter_opts:
+ """
+ Convert Tags into Comma separated
+ Converts Tag slugs into commas separated tags"""
+ df["TagsClean"] = df["Tags"].str.replace("-", " ").str.replace("><", ", ").str.replace("<", "").str.replace(">", "")
+
+ if "convert_html" in self.filter_opts:
+ """
+ Convert HTML tags to pure text
+
+ Feeds HTML text body into BeautifulSoup to parse it to only text. Set aside as
+ function to provide option to skip"""
+ column = "Body"
+ df.dropna(subset=[column], inplace=True)
+ df[f"{column}Clean"] = df[column].apply(lambda row: bs(row, "html.parser").text)
+
+
+ return df
+
+ def parse(self, _):
+ urls = self.get_all_filenames()
+ dataset_name = "ai"
+
+ xml_posts_path = urls.get(dataset_name)
+
+ response = requests.get(xml_posts_path)
+ df = self.xml_to_df(response)
+ df = self.filter(df)
+
+ questions = df[df.PostTypeId == 1]
+ answers = df[df.PostTypeId == 2]
+
+ df = pd.merge(
+ questions,
+ answers,
+ left_on="Id",
+ right_on="ParentId",
+ suffixes=("_q", "_a"),
+ how="left",
+ )
+ questions = df[["Title_q", "BodyClean_q"]]
+ # prepend title to question and make questions to list
+ questions = questions.apply(lambda x: x["Title_q"] + "\n" + x["BodyClean_q"], axis=1)
+ questions = questions.tolist()
+
+ answers = df[["BodyClean_a"]]
+ answers = answers.tolist()
+
+ return questions, answers
+
+
+class HierachicalSummarizer(DataArgumenter):
+ def __init__(self):
+ self.summarizer = pipeline(
+ "summarization",
+ "pszemraj/long-t5-tglobal-base-16384-book-summary",
+ device=0 if torch.cuda.is_available() else -1,
+ )
+
+ self.params = {
+ "max_length": 1024,
+ "min_length": 8,
+ "no_repeat_ngram_size": 3,
+ "early_stopping": False,
+ "repetition_penalty": 3.5,
+ "length_penalty": 0.3,
+ "encoder_no_repeat_ngram_size": 3,
+ "num_beams": 4,
+ } # parameters for text generation out of model
+
+ self.nlp = spacy.load("en_core_web_sm")
+
+ def cleanup_summary(self, out):
+ (
+ out
+ .replace("The novel begins with the description of", "")
+ .replace("the description of", "")
+ .replace("The novel begins", "")
+ .replace("This chapter introduces us to", "")
+ .replace("In this chapter, ", "")
+ .replace("This chapter", "")
+ .strip(" ,")
+ )
+ return out
+
+ def parse_single(self, essay):
+ final_summary = ""
+ new_summary = ""
+ level_2_summary = []
+ level_1_summary = []
+ entities = []
+ essay_parts = essay.split("##")
+ for section_text in essay_parts:
+ result = self.summarizer(section_text, **self.params)
+ out = self.cleanup_summary(result[0]['summary_text'])
+ level_2_summary.append(out)
+ result = self.summarizer(out, **self.params)
+ out = self.cleanup_summary(result[0]['summary_text'])
+ new_summary += "\n" + out
+ level_1_summary.append(out)
+
+ entity = recognize_entities(section_text, self.nlp, n=5, person="ignore")
+ entities.append(entity)
+
+
+
+ result = self.summarizer(new_summary, **self.params)
+ final_summary = self.cleanup_summary(result[0]['summary_text'])
+
+ first_instruction = "Write a story about the following:\n" + final_summary
+ first_answer = "\n".join(level_1_summary)
+ instructions = [first_instruction]
+ answers = [first_answer]
+
+ for entity, answer in zip(entities, level_2_summary):
+ instructions.append(f"Now expand on {entity}!")
+ answers.append(answer)
+
+ for entity, answer in zip(entities, level_1_summary):
+ instructions.append(f"Further expand on {entity}.")
+ answers.append(answer)
+
+
+ return instructions, answers
+
+
+class EntityRecognizedSummarizer(DataArgumenter):
+ def __init__(self):
+ self.nlp = spacy.load("en_core_web_sm") # run !python -m spacy download en_core_web_sm in order to download
+
+ def parse_single(self, essay):
+ ents = recognize_entities(essay, self.nlp)
+ characters = ents.most_common(4, person=True)
+ topic = recognize_entities(essay, self.nlp, n=2, person=False)
+
+ question = f"Please write a story titled {topic} with the characters {characters}."
+ answer = f"Sure. Here is a story titled {topic}\n" + essay
+
+ return [question], [answer]
+
+
+class CodeBugger(DataArgumenter):
+ """
+ https://github.com/LAION-AI/Open-Assistant/blob/main/notebooks/code-bugger/openbugger_example.md
+ Openbugger is a Python package that allows you to inject syntax and logic errors into your code.
+ This can be useful for testing the robustness of your code or for creating test cases for debugging exercises or for training an assistant to debug code.
+ To install:
+ cwd = os.getcwd()
+
+ # Next, we'll use Git to clone the repository.
+ subprocess.run(["git", "clone", "https://github.com/furlat/OpenBugger", cwd + "/OpenBugger"])
+
+ # Now, we'll use pip to install the package from the local repository.
+ subprocess.run(["python3", "-m", "pip", "install", "--editable", cwd + "/OpenBugger"])
+ """
+ def __init__(self):
+ self.syntax_bug = SyntaxBug()
+ self.logic_bug = LogicBug()
+
+ def parse_single(self, code):
+ code = self.syntax_bug(code, "medium", num_errors=2)
+ code = self.logic_bug(code, "medium", num_errors = 2)
+
+ question = "Can you fix the following code?\n" + code
+
+ answer = "The following code is correct:\n" + code + "\nI hope I could help you fixing your code. In case you need more help, feel free to ask me again."
+
+ return [question], [answer]
+
+def recognize_entities(text, model, n=4, person="ignore"):
+ """Given a text and a model for entity recognition, return the most occuring entites in the text as a string"""
+ doc = model(text)
+ if person == "ignore":
+ ents = Counter([ent.text.strip() for ent in list(doc.ents) if len(ent.text.strip()) >= 5])
+ elif person:
+ ents = Counter([ent.text.strip() for ent in list(doc.ents) if ent.label_ == "PERSON" and len(ent.text.strip()) >= 5])
+ else:
+ ents = Counter([ent.text.strip() for ent in list(doc.ents) if ent.label_ != "PERSON" and len(ent.text.strip()) >= 5])
+ ents = ents.most_common(n)
+ ents = ", ".join([a[0] for a in ents])
+
+ return ents
+
+
+def parse_arguments():
+ args = argparse.ArgumentParser()
+ args.add_argument("--dataset", type=str, required=True)
+ args.add_argument("--augmenter", type=str, required=True)
+ args.add_argument("--output", type=str, required=True)
+ args = args.parse_args()
+
+ assert args.dataset.endswith(".tsv"), "Dataset file must be a tsv file, containing a list of files to be augmented"
+ assert args.output.endswith(".json"), "Output file must be a json file"
+
+ return args
+
+
+def read_data(args):
+ files = pd.read_csv(args.dataset, sep="\t", header=None)
+ files = files[0].tolist()
+ data = []
+ for file in files:
+ with open(file, "r") as f:
+ text = f.read()
+ data.append(text)
+ return data
+
+
+def get_augmenter(args):
+ if args.augmenter == "essayinstruction":
+ augmenter = EssayInstructor()
+
+ elif args.augmenter == "essayrevision":
+ augmenter = EssayReviser()
+
+ elif args.augmenter == "stackexchange":
+ augmenter = StackExchangeBuilder()
+
+ elif args.augmenter == "hierarchicalsummarizer":
+ augmenter = HierachicalSummarizer()
+
+ elif args.augmenter == "entityrecognizedsummarizer":
+ augmenter = EntityRecognizedSummarizer()
+
+ elif args.augmenter == "codebugger":
+ augmenter = CodeBugger()
+
+ else:
+ raise ValueError("Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger")
+
+ return augmenter
+
+
+def main(args):
+ data = read_data(args)
+ augmenter = get_augmenter(args)
+
+ augmented_data = augmenter.parse(data)
+
+ # write augmented data as json file
+ with open(args.output, "w") as f:
+ json.dump(augmented_data, f)
+
+
+if __name__ == "__main__":
+ args = parse_arguments()
+ main(args)
+
+
+
+
+
From 48dd5047aa69741354f057fe22321b4d989c6f3b Mon Sep 17 00:00:00 2001
From: Tom Zehle
Date: Tue, 10 Jan 2023 16:48:06 +0100
Subject: [PATCH 2/2] .
---
scripts/data_augment/data_augment.py | 119 +++++++++++++++------------
1 file changed, 66 insertions(+), 53 deletions(-)
diff --git a/scripts/data_augment/data_augment.py b/scripts/data_augment/data_augment.py
index 6b10a8510b..b85d3bf5b6 100644
--- a/scripts/data_augment/data_augment.py
+++ b/scripts/data_augment/data_augment.py
@@ -1,25 +1,31 @@
-import torch
-import pandas as pd
+"""Script for a variety of data augmentation techniques for generating Question answer pairs.
+Depending on the class used it takes in the input files and generates summaries from essays (which then will result in a "write a story about [summary]"-> essay pair),#
+buggs code (in order to have bugged code + "please fix" -> code), ...
+example usage:
+ data_augment.py --dataset essays.tsv --augmenter hierarchicalsummarizer --output out.json
+args:
+ -- dataset: TSV file referencing txt files with essays/code
+ -- augmenter: the augmenter used: one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger"
+ -- output: where to save the output
+"""
-from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
-import nltk
-from nltk.corpus import wordnet
-import spacy
-from collections import Counter
-#from syntax.syntax_injector import SyntaxBug
-#from logic.logic_injector import LogicBug
+import argparse
+import json
import random
import string
-from bs4 import BeautifulSoup as bs
+from collections import Counter
+import nltk
+import pandas as pd
import requests
-import json
-import html
-import os
-import subprocess
-import argparse
-
+import spacy
+import torch
+from bs4 import BeautifulSoup as bs
+from logic.logic_injector import LogicBug
+from nltk.corpus import wordnet
+from syntax.syntax_injector import SyntaxBug
+from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
class DataArgumenter:
@@ -65,9 +71,7 @@ def parse_single(self, essay):
num_return_sequences=1,
)
preds.append(
- self.tokenizer.decode(
- generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
- )
+ self.tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
)
prompts = (
@@ -121,7 +125,7 @@ def parse_single(self, essay):
# you can change the number 60 to change how much corrupted this essay will be
for _ in range(len(essay) // 60):
rand = random.randint(0, len(essay))
- corrupted_essay = essay[:rand] + random.choice(string.ascii_letters) + essay[rand+1:]
+ corrupted_essay = essay[:rand] + random.choice(string.ascii_letters) + essay[rand + 1 :]
instructions.append("Fix typing errors in this essay" + corrupted_essay)
@@ -130,8 +134,14 @@ def parse_single(self, essay):
class StackExchangeBuilder(DataArgumenter):
def __init__(self, base_url=None, filter_opts=None):
- self.base_url = base_url if base_url is not None else "https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/{0}&file=Posts.xml"
- self.filter_opts = filter_opts if filter_opts is not None else ["accepted", "score", "convert_html", "clean_tags"]
+ self.base_url = (
+ base_url
+ if base_url is not None
+ else "https://ia600107.us.archive.org/view_archive.php?archive=/27/items/stackexchange/{0}&file=Posts.xml"
+ )
+ self.filter_opts = (
+ filter_opts if filter_opts is not None else ["accepted", "score", "convert_html", "clean_tags"]
+ )
def get_all_filenames(self):
response = requests.get("https://archive.org/download/stackexchange")
@@ -209,12 +219,13 @@ def filter(self, df):
| ((df["Score"] >= answer_score_threshold) & (df.PostTypeId == 2))
]
-
if "clean_tags" in self.filter_opts:
- """
+ """
Convert Tags into Comma separated
Converts Tag slugs into commas separated tags"""
- df["TagsClean"] = df["Tags"].str.replace("-", " ").str.replace("><", ", ").str.replace("<", "").str.replace(">", "")
+ df["TagsClean"] = (
+ df["Tags"].str.replace("-", " ").str.replace("><", ", ").str.replace("<", "").str.replace(">", "")
+ )
if "convert_html" in self.filter_opts:
"""
@@ -225,7 +236,6 @@ def filter(self, df):
column = "Body"
df.dropna(subset=[column], inplace=True)
df[f"{column}Clean"] = df[column].apply(lambda row: bs(row, "html.parser").text)
-
return df
@@ -269,7 +279,7 @@ def __init__(self):
device=0 if torch.cuda.is_available() else -1,
)
- self.params = {
+ self.params = {
"max_length": 1024,
"min_length": 8,
"no_repeat_ngram_size": 3,
@@ -278,14 +288,13 @@ def __init__(self):
"length_penalty": 0.3,
"encoder_no_repeat_ngram_size": 3,
"num_beams": 4,
- } # parameters for text generation out of model
+ } # parameters for text generation out of model
self.nlp = spacy.load("en_core_web_sm")
def cleanup_summary(self, out):
(
- out
- .replace("The novel begins with the description of", "")
+ out.replace("The novel begins with the description of", "")
.replace("the description of", "")
.replace("The novel begins", "")
.replace("This chapter introduces us to", "")
@@ -304,20 +313,18 @@ def parse_single(self, essay):
essay_parts = essay.split("##")
for section_text in essay_parts:
result = self.summarizer(section_text, **self.params)
- out = self.cleanup_summary(result[0]['summary_text'])
+ out = self.cleanup_summary(result[0]["summary_text"])
level_2_summary.append(out)
result = self.summarizer(out, **self.params)
- out = self.cleanup_summary(result[0]['summary_text'])
+ out = self.cleanup_summary(result[0]["summary_text"])
new_summary += "\n" + out
level_1_summary.append(out)
entity = recognize_entities(section_text, self.nlp, n=5, person="ignore")
entities.append(entity)
-
-
result = self.summarizer(new_summary, **self.params)
- final_summary = self.cleanup_summary(result[0]['summary_text'])
+ final_summary = self.cleanup_summary(result[0]["summary_text"])
first_instruction = "Write a story about the following:\n" + final_summary
first_answer = "\n".join(level_1_summary)
@@ -332,14 +339,13 @@ def parse_single(self, essay):
instructions.append(f"Further expand on {entity}.")
answers.append(answer)
-
return instructions, answers
class EntityRecognizedSummarizer(DataArgumenter):
def __init__(self):
- self.nlp = spacy.load("en_core_web_sm") # run !python -m spacy download en_core_web_sm in order to download
-
+ self.nlp = spacy.load("en_core_web_sm") # run !python -m spacy download en_core_web_sm in order to download
+
def parse_single(self, essay):
ents = recognize_entities(essay, self.nlp)
characters = ents.most_common(4, person=True)
@@ -365,29 +371,39 @@ class CodeBugger(DataArgumenter):
# Now, we'll use pip to install the package from the local repository.
subprocess.run(["python3", "-m", "pip", "install", "--editable", cwd + "/OpenBugger"])
"""
+
def __init__(self):
self.syntax_bug = SyntaxBug()
self.logic_bug = LogicBug()
def parse_single(self, code):
code = self.syntax_bug(code, "medium", num_errors=2)
- code = self.logic_bug(code, "medium", num_errors = 2)
+ code = self.logic_bug(code, "medium", num_errors=2)
question = "Can you fix the following code?\n" + code
- answer = "The following code is correct:\n" + code + "\nI hope I could help you fixing your code. In case you need more help, feel free to ask me again."
+ answer = (
+ "The following code is correct:\n"
+ + code
+ + "\nI hope I could help you fixing your code. In case you need more help, feel free to ask me again."
+ )
return [question], [answer]
+
def recognize_entities(text, model, n=4, person="ignore"):
"""Given a text and a model for entity recognition, return the most occuring entites in the text as a string"""
doc = model(text)
if person == "ignore":
- ents = Counter([ent.text.strip() for ent in list(doc.ents) if len(ent.text.strip()) >= 5])
+ ents = Counter([ent.text.strip() for ent in list(doc.ents) if len(ent.text.strip()) >= 5])
elif person:
- ents = Counter([ent.text.strip() for ent in list(doc.ents) if ent.label_ == "PERSON" and len(ent.text.strip()) >= 5])
+ ents = Counter(
+ [ent.text.strip() for ent in list(doc.ents) if ent.label_ == "PERSON" and len(ent.text.strip()) >= 5]
+ )
else:
- ents = Counter([ent.text.strip() for ent in list(doc.ents) if ent.label_ != "PERSON" and len(ent.text.strip()) >= 5])
+ ents = Counter(
+ [ent.text.strip() for ent in list(doc.ents) if ent.label_ != "PERSON" and len(ent.text.strip()) >= 5]
+ )
ents = ents.most_common(n)
ents = ", ".join([a[0] for a in ents])
@@ -400,7 +416,7 @@ def parse_arguments():
args.add_argument("--augmenter", type=str, required=True)
args.add_argument("--output", type=str, required=True)
args = args.parse_args()
-
+
assert args.dataset.endswith(".tsv"), "Dataset file must be a tsv file, containing a list of files to be augmented"
assert args.output.endswith(".json"), "Output file must be a json file"
@@ -421,10 +437,10 @@ def read_data(args):
def get_augmenter(args):
if args.augmenter == "essayinstruction":
augmenter = EssayInstructor()
-
+
elif args.augmenter == "essayrevision":
augmenter = EssayReviser()
-
+
elif args.augmenter == "stackexchange":
augmenter = StackExchangeBuilder()
@@ -433,12 +449,14 @@ def get_augmenter(args):
elif args.augmenter == "entityrecognizedsummarizer":
augmenter = EntityRecognizedSummarizer()
-
+
elif args.augmenter == "codebugger":
augmenter = CodeBugger()
else:
- raise ValueError("Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger")
+ raise ValueError(
+ "Augmenter must be one of 'essayinstruction', 'essayrevision', 'stackexchange', 'hierarchicalsummarizer', 'entityrecognizedsummarizer', 'codebugger"
+ )
return augmenter
@@ -457,8 +475,3 @@ def main(args):
if __name__ == "__main__":
args = parse_arguments()
main(args)
-
-
-
-
-