Skip to content

Commit

Permalink
Add code for training a FIRE on text8
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Oct 15, 2022
1 parent 216c3ca commit 9b03a75
Show file tree
Hide file tree
Showing 10 changed files with 837 additions and 0 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,26 @@ We selected a [subset](/data/wordnet-542.txt) of the ["Core"
WordNet](https://wordnet.princeton.edu/download/standoff-files) dataset and
constructed a list of 542 strongly polysemeous / strongly monosemeous words.
See `/data/wordnet-542.txt`


# Training FIRE

We provide scripts in `/scripts/text8/` to train a FIRE model on the *text8* dataset.
```bash
$ bash scripts/text8/1_download_text8.sh
# download the *text8* corpus

$ bash scripts/text8/2_tokenize.sh
# tokenize the corpus with the NLTK tokenizer

$ bash scripts/text8/3_build_vocab.sh
# build a vocabulary with the tokenized corpus

$ bash scripts/text8/4_train.sh
# training from scratch
```

The training is carried out with the SkipGram method.
For fast sampling from the tokenized corpus in the SkipGram way, we used
another python package [`corpusit`](https://github.com/kduxin/corpusit)
that is written in Rust (and binded with [PyO3](https://github.com/PyO3/pyo3)).
12 changes: 12 additions & 0 deletions scripts/build_vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import argparse
import corpusit

parser = argparse.ArgumentParser()
parser.add_argument('--path_to_corpus', type=str)
parser.add_argument('--min_count', type=int, default=5)
parser.add_argument('--infreq_replace', type=str, default="<unk>")
args = parser.parse_args()

vocab = corpusit.Vocab.build(args.path_to_corpus, min_count=args.min_count, unk=args.infreq_replace)
print('vocab size:', len(vocab))
print(vocab)
105 changes: 105 additions & 0 deletions scripts/multiproc_yield.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@

__all__ = ['yielder']

import time
import heapq
from multiprocessing import Process
from multiprocessing.queues import Empty
from faster_fifo import Queue

FOREVER = 1e10

class EndOfQueue:
pass

def wrapped_put(que, item, timeout=FOREVER):
for _ in range(max(1, int(timeout))):
try:
que.put(item, timeout=1)
except:
time.sleep(1)
continue
break

def wrapped_get(que, timeout=FOREVER):
for _ in range(max(1, int(timeout))):
try:
item = que.get(timeout=1)
except:
time.sleep(1)
continue
break
return item

def master_queue(generator, taskque: Queue, num_workers):
for i, task in enumerate(generator):
wrapped_put(taskque, (i, task))
for i in range(num_workers):
wrapped_put(taskque, EndOfQueue)

def slave_queue(func, taskque: Queue, resque: Queue, additional_kwds={}):
stop = False
while not stop:
task = wrapped_get(taskque)
if task is EndOfQueue:
stop = True
res = EndOfQueue
else:
jobid, task = task
res = (jobid, func(task, **additional_kwds))
wrapped_put(resque, res)

def yielder(generator, func, num_workers, additional_kwds={},
verbose=True, print_interval:int=10000,
max_size_bytes=1024*1024):
taskque = Queue(max_size_bytes=max_size_bytes)
resque = Queue(max_size_bytes=max_size_bytes)

buff.clear()
_master = Process(target=master_queue, args=(generator, taskque, num_workers))
_master.start()
for _ in range(num_workers):
_slave = Process(target=slave_queue, args=(func, taskque, resque, additional_kwds))
_slave.start()

for i, x in enumerate(ordered_results(resque, num_workers)):
yield x

if verbose and i % print_interval == 0:
print(f'{i}: taskque: {taskque.qsize()}, resque: {resque.qsize()}. heap buffsize: {len(buff)}')
return StopIteration

buff = [] # for results to be outputed later
def ordered_results(resque: Queue, num_workers):
''' Online sort of the outputs in queue, according to their job id '''
n = num_workers
pos = 0
while n > 0: # run until any active process exists
res = wrapped_get(resque)
if res is EndOfQueue:
n -= 1
else:
i, x = res
heapq.heappush(buff, (i, x))
while len(buff) and buff[0][0] == pos:
i, x = heapq.heappop(buff)
yield x
pos += 1
return StopIteration


if __name__ == '__main__':
from time import time
from tqdm import tqdm

def job(x, margin):
return max(x**3, margin)

t0 = time()
num_workers = 16
generator = range(1000000)
for x in tqdm(
yielder(generator, func=job, num_workers=num_workers,
additional_kwds={'margin': 3}, max_size_bytes=1024*1024)):
pass
print(f'Elapsed: {time() - t0}')
3 changes: 3 additions & 0 deletions scripts/text8/1_download_text8.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
wget https://data.deepai.org/text8.zip -O text8.zip
mkdir -p data/corpus/text8/
unzip text8.zip -d data/corpus/text8/
2 changes: 2 additions & 0 deletions scripts/text8/2_tokenize.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python scripts/tokenize.py \
--raw_path=data/corpus/text8/text8
4 changes: 4 additions & 0 deletions scripts/text8/3_build_vocab.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
python scripts/build_vocab.py \
--path_to_corpus='data/corpus/text8/text8.uncased.tokens' \
--min_count=5 \
--infreq_replace="<unk>"
18 changes: 18 additions & 0 deletions scripts/text8/4_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
python train.py \
--corpus_path=data/corpus/text8/text8.uncased.tokens \
--sz_batch=8192 \
--n_neg=1 \
--lr=0.005 \
--lr_scheduler=OneCycleLR \
--dim=2 \
--n_iters=1000 \
--eval_interval=100000 \
--savedir=results/ \
--optimizer=adamw \
--seed=0 \
--accum_steps=10 \
--func='MLPlanarDiv(args.dim, 4)' \
--measure='DiracMixture(args.dim, 10)' \
--weight_decay=1e-6 \
--use_wandb \
--amp
82 changes: 82 additions & 0 deletions scripts/tokenize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os

def parse_args():
from argparse import ArgumentParser

parser = ArgumentParser()
parser.add_argument("--raw_path", type=str, default="data/corpus/text8/text8")
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--cased", action="store_true")
parser.add_argument("--sos", type=str, default="<s>")
parser.add_argument("--eos", type=str, default="</s>")
parser.add_argument("--linesep", type=str, default="\n")
parser.add_argument("--tokensep", type=str, default=" ")

return parser.parse_args()


def run(args):
from nltk import word_tokenize
import tqdm
from scripts.multiproc_yield import yielder
import logging
import json

logger = logging.getLogger()
path = os.path.abspath(args.raw_path)
dirpath = os.path.dirname(path)
filename = os.path.basename(path)
savepath = (
f"{dirpath}/{filename}" + ("" if args.cased else ".uncased") + ".tokens"
)
logger.info(f"Tokenized corpus to be saved at {savepath}")

argsavepath = savepath + '.args'
with open(argsavepath, "wt") as f:
json.dump(args.__dict__, f, indent=2)
logger.info(f"Arguments saved at {argsavepath}")

def run_tokenize(corpus_path, save_path, args, max_size_bytes=1024 * 1024 * 1024):
logger.info("Tokenizing...")

def _tokenize(line, sos, eos):
if not args.cased:
line = line.lower()
line = line.strip()
if not line:
return []
else:
return [sos] + word_tokenize(line) + [eos]

def linecutter(lines, maxlen=10000):
for line in lines:
nseg = (len(line) + maxlen - 1) // maxlen
for i in range(nseg):
yield line[i * maxlen : (i + 1) * maxlen]

ntokens = 0
fout = open(save_path, "wt")
with open(corpus_path, "rt") as f:
for tokens in tqdm.tqdm(
yielder(
linecutter(f),
_tokenize,
num_workers=args.num_workers,
additional_kwds={"sos": args.sos, "eos": args.eos},
max_size_bytes=max_size_bytes,
)
):
if tokens:
ntokens += len(tokens)
fout.write(args.tokensep.join(tokens) + args.linesep)
logger.info(f"{ntokens} tokens in saved.")

fout.close()

run_tokenize(path, savepath, args)
logger.info("Finished.")


if __name__ == "__main__":
args = parse_args()
run(args)
Loading

0 comments on commit 9b03a75

Please sign in to comment.