-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add code for training a FIRE on text8
- Loading branch information
Showing
10 changed files
with
837 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import argparse | ||
import corpusit | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--path_to_corpus', type=str) | ||
parser.add_argument('--min_count', type=int, default=5) | ||
parser.add_argument('--infreq_replace', type=str, default="<unk>") | ||
args = parser.parse_args() | ||
|
||
vocab = corpusit.Vocab.build(args.path_to_corpus, min_count=args.min_count, unk=args.infreq_replace) | ||
print('vocab size:', len(vocab)) | ||
print(vocab) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
|
||
__all__ = ['yielder'] | ||
|
||
import time | ||
import heapq | ||
from multiprocessing import Process | ||
from multiprocessing.queues import Empty | ||
from faster_fifo import Queue | ||
|
||
FOREVER = 1e10 | ||
|
||
class EndOfQueue: | ||
pass | ||
|
||
def wrapped_put(que, item, timeout=FOREVER): | ||
for _ in range(max(1, int(timeout))): | ||
try: | ||
que.put(item, timeout=1) | ||
except: | ||
time.sleep(1) | ||
continue | ||
break | ||
|
||
def wrapped_get(que, timeout=FOREVER): | ||
for _ in range(max(1, int(timeout))): | ||
try: | ||
item = que.get(timeout=1) | ||
except: | ||
time.sleep(1) | ||
continue | ||
break | ||
return item | ||
|
||
def master_queue(generator, taskque: Queue, num_workers): | ||
for i, task in enumerate(generator): | ||
wrapped_put(taskque, (i, task)) | ||
for i in range(num_workers): | ||
wrapped_put(taskque, EndOfQueue) | ||
|
||
def slave_queue(func, taskque: Queue, resque: Queue, additional_kwds={}): | ||
stop = False | ||
while not stop: | ||
task = wrapped_get(taskque) | ||
if task is EndOfQueue: | ||
stop = True | ||
res = EndOfQueue | ||
else: | ||
jobid, task = task | ||
res = (jobid, func(task, **additional_kwds)) | ||
wrapped_put(resque, res) | ||
|
||
def yielder(generator, func, num_workers, additional_kwds={}, | ||
verbose=True, print_interval:int=10000, | ||
max_size_bytes=1024*1024): | ||
taskque = Queue(max_size_bytes=max_size_bytes) | ||
resque = Queue(max_size_bytes=max_size_bytes) | ||
|
||
buff.clear() | ||
_master = Process(target=master_queue, args=(generator, taskque, num_workers)) | ||
_master.start() | ||
for _ in range(num_workers): | ||
_slave = Process(target=slave_queue, args=(func, taskque, resque, additional_kwds)) | ||
_slave.start() | ||
|
||
for i, x in enumerate(ordered_results(resque, num_workers)): | ||
yield x | ||
|
||
if verbose and i % print_interval == 0: | ||
print(f'{i}: taskque: {taskque.qsize()}, resque: {resque.qsize()}. heap buffsize: {len(buff)}') | ||
return StopIteration | ||
|
||
buff = [] # for results to be outputed later | ||
def ordered_results(resque: Queue, num_workers): | ||
''' Online sort of the outputs in queue, according to their job id ''' | ||
n = num_workers | ||
pos = 0 | ||
while n > 0: # run until any active process exists | ||
res = wrapped_get(resque) | ||
if res is EndOfQueue: | ||
n -= 1 | ||
else: | ||
i, x = res | ||
heapq.heappush(buff, (i, x)) | ||
while len(buff) and buff[0][0] == pos: | ||
i, x = heapq.heappop(buff) | ||
yield x | ||
pos += 1 | ||
return StopIteration | ||
|
||
|
||
if __name__ == '__main__': | ||
from time import time | ||
from tqdm import tqdm | ||
|
||
def job(x, margin): | ||
return max(x**3, margin) | ||
|
||
t0 = time() | ||
num_workers = 16 | ||
generator = range(1000000) | ||
for x in tqdm( | ||
yielder(generator, func=job, num_workers=num_workers, | ||
additional_kwds={'margin': 3}, max_size_bytes=1024*1024)): | ||
pass | ||
print(f'Elapsed: {time() - t0}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
wget https://data.deepai.org/text8.zip -O text8.zip | ||
mkdir -p data/corpus/text8/ | ||
unzip text8.zip -d data/corpus/text8/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
python scripts/tokenize.py \ | ||
--raw_path=data/corpus/text8/text8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
python scripts/build_vocab.py \ | ||
--path_to_corpus='data/corpus/text8/text8.uncased.tokens' \ | ||
--min_count=5 \ | ||
--infreq_replace="<unk>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
python train.py \ | ||
--corpus_path=data/corpus/text8/text8.uncased.tokens \ | ||
--sz_batch=8192 \ | ||
--n_neg=1 \ | ||
--lr=0.005 \ | ||
--lr_scheduler=OneCycleLR \ | ||
--dim=2 \ | ||
--n_iters=1000 \ | ||
--eval_interval=100000 \ | ||
--savedir=results/ \ | ||
--optimizer=adamw \ | ||
--seed=0 \ | ||
--accum_steps=10 \ | ||
--func='MLPlanarDiv(args.dim, 4)' \ | ||
--measure='DiracMixture(args.dim, 10)' \ | ||
--weight_decay=1e-6 \ | ||
--use_wandb \ | ||
--amp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
|
||
def parse_args(): | ||
from argparse import ArgumentParser | ||
|
||
parser = ArgumentParser() | ||
parser.add_argument("--raw_path", type=str, default="data/corpus/text8/text8") | ||
parser.add_argument("--num_workers", type=int, default=4) | ||
parser.add_argument("--cased", action="store_true") | ||
parser.add_argument("--sos", type=str, default="<s>") | ||
parser.add_argument("--eos", type=str, default="</s>") | ||
parser.add_argument("--linesep", type=str, default="\n") | ||
parser.add_argument("--tokensep", type=str, default=" ") | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def run(args): | ||
from nltk import word_tokenize | ||
import tqdm | ||
from scripts.multiproc_yield import yielder | ||
import logging | ||
import json | ||
|
||
logger = logging.getLogger() | ||
path = os.path.abspath(args.raw_path) | ||
dirpath = os.path.dirname(path) | ||
filename = os.path.basename(path) | ||
savepath = ( | ||
f"{dirpath}/{filename}" + ("" if args.cased else ".uncased") + ".tokens" | ||
) | ||
logger.info(f"Tokenized corpus to be saved at {savepath}") | ||
|
||
argsavepath = savepath + '.args' | ||
with open(argsavepath, "wt") as f: | ||
json.dump(args.__dict__, f, indent=2) | ||
logger.info(f"Arguments saved at {argsavepath}") | ||
|
||
def run_tokenize(corpus_path, save_path, args, max_size_bytes=1024 * 1024 * 1024): | ||
logger.info("Tokenizing...") | ||
|
||
def _tokenize(line, sos, eos): | ||
if not args.cased: | ||
line = line.lower() | ||
line = line.strip() | ||
if not line: | ||
return [] | ||
else: | ||
return [sos] + word_tokenize(line) + [eos] | ||
|
||
def linecutter(lines, maxlen=10000): | ||
for line in lines: | ||
nseg = (len(line) + maxlen - 1) // maxlen | ||
for i in range(nseg): | ||
yield line[i * maxlen : (i + 1) * maxlen] | ||
|
||
ntokens = 0 | ||
fout = open(save_path, "wt") | ||
with open(corpus_path, "rt") as f: | ||
for tokens in tqdm.tqdm( | ||
yielder( | ||
linecutter(f), | ||
_tokenize, | ||
num_workers=args.num_workers, | ||
additional_kwds={"sos": args.sos, "eos": args.eos}, | ||
max_size_bytes=max_size_bytes, | ||
) | ||
): | ||
if tokens: | ||
ntokens += len(tokens) | ||
fout.write(args.tokensep.join(tokens) + args.linesep) | ||
logger.info(f"{ntokens} tokens in saved.") | ||
|
||
fout.close() | ||
|
||
run_tokenize(path, savepath, args) | ||
logger.info("Finished.") | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
run(args) |
Oops, something went wrong.