Skip to content

Commit

Permalink
tests, task data, start of fine-tuning, test prints
Browse files Browse the repository at this point in the history
  • Loading branch information
Xeadriel committed May 18, 2024
1 parent 46e30c6 commit 2b1306d
Show file tree
Hide file tree
Showing 29 changed files with 8,005 additions and 7,484 deletions.
3 changes: 3 additions & 0 deletions firelang/models/_fireword.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,9 @@ def loss_skipgram(
loss = Loss()

logits = x1 * x2
print("########################################")
print(type(logits))
print("########################################")
loss_sim = F.binary_cross_entropy_with_logits(
logits, labels.float(), reduction="none"
)
Expand Down
2 changes: 1 addition & 1 deletion firelang/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _sanity_check(self):
)

def __getitem__(self, index: IndexLike):
new_shape = tuple(torch.empty(self.shape)[index].shape)
new_shape = tuple(torch.empty(self.shape, device="cuda")[index].shape)
to: StackingSlicing = self.restack(new_shape)

# A parameter not listed in `unsliceable_params` should be
Expand Down
89 changes: 40 additions & 49 deletions scripts/additionalBenchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,39 @@ def main():
"--checkpointsMRPC",
nargs="+",
default=[
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/wacky_mlplanardiv_d2_l8_k20",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l8_k20",
],
)

parser.add_argument(
"--checkpointsSST",
nargs="+",
default=[
# "checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
# "checkpoints/wacky_mlplanardiv_d2_l4_k10",
# "checkpoints/wacky_mlplanardiv_d2_l8_k20",
# "checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy",
# "checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k10",
# "checkpoints/v1.1/wacky_mlplanardiv_d2_l8_k20",
],
)

parser.add_argument(
"--checkpointsSSTGlue",
nargs="+",
default=[
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/wacky_mlplanardiv_d2_l8_k20",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l8_k20",
],
)

parser.add_argument(
"--checkpointsRTE",
nargs="+",
default=[
"checkpoints/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/wacky_mlplanardiv_d2_l8_k20",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k10",
"checkpoints/v1.1/wacky_mlplanardiv_d2_l8_k20",
],
)

Expand All @@ -79,15 +79,15 @@ def main():
print("\t--------------------------------------------------------------------------------------------------------------------------------")
print("\tThe Stanford Sentiment Treebank")
print("\t--------------------------------------------------------------------------------------------------------------------------------")
trainPairsSST, trainLabelsSST, devPairsSST, devLabelsSST, testPairsSST, testLabelsSST = prepareSSTData('scripts/tasks/SST/datasetSplit.txt',
'scripts/tasks/SST/datasetSentences.txt', 'scripts/tasks/SST/dictionary.txt', 'scripts/tasks/SST/sentiment_labels.txt')
for checkpoint in args.checkpointsSST:
model = FireWord.from_pretrained(checkpoint).to(device)
print(f"checkpoint: {checkpoint}")
# trainPairsSST, trainLabelsSST, devPairsSST, devLabelsSST, testPairsSST, testLabelsSST = prepareSSTData('scripts/tasks/SST/datasetSplit.txt',
# 'scripts/tasks/SST/datasetSentences.txt', 'scripts/tasks/SST/dictionary.txt', 'scripts/tasks/SST/sentiment_labels.txt')
# for checkpoint in args.checkpointsSST:
# model = FireWord.from_pretrained(checkpoint).to(device)
# print(f"checkpoint: {checkpoint}")

accuracy = benchmarkSST(model, testPairsSST, testLabelsSST, devPairsSST, devLabelsSST, sifA)
# accuracy = benchmarkSST(model, testPairsSST, testLabelsSST, devPairsSST, devLabelsSST, sifA)

print(f"accuracy: {accuracy}")
# print(f"accuracy: {accuracy}")


print("\t--------------------------------------------------------------------------------------------------------------------------------")
Expand All @@ -96,28 +96,28 @@ def main():
print("\t\t--------------------------------------------------------------------------------------------------------------------------------")
print("\t\tMicrosoft Research Paraphrase Corpus")
print("\t\t--------------------------------------------------------------------------------------------------------------------------------")
pairsMRPC, labelsMRPC = prepareMRPCData('scripts/tasks/MRPC/msr_paraphrase_train.csv')
testPairsMRPC, testLabelsMRPC = prepareMRPCData('scripts/tasks/MRPC/msr_paraphrase_test.txt')
for checkpoint in args.checkpointsMRPC:
model = FireWord.from_pretrained(checkpoint).to(device)
print(f"\t\tcheckpoint: {checkpoint}")

predsMedianMRPC, predsThresholdMRPC, predsF1ThresholdMRPC, accuracy, f1 = benchmarkMRPC(model, pairsMRPC, labelsMRPC, sifA)
predsMedianMRPC, predsThresholdMRPC, predsF1ThresholdMRPC, accuracy, f1 = benchmarkMRPC(model, testPairsMRPC, testLabelsMRPC, sifA)

print(f"\t\taccuracy: {accuracy}\n\t\tf1: {f1}\n")

with open(f'scripts/taskResults/MRPC/median/MRPC-{checkpoint[12:]}.tsv', 'w', newline='') as csvfile:
with open(f'scripts/taskResults/MRPC/median/MRPC-{checkpoint[17:]}.tsv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile, delimiter='\t', quotechar='ß')
writer.writerow(["index", "prediction"])
for index, pred in zip(range(len(predsMedianMRPC)), predsMedianMRPC):
writer.writerow([index, pred])

with open(f'scripts/taskResults/MRPC/threshold/MRPC-{checkpoint[12:]}.tsv', 'w', newline='') as csvfile:
with open(f'scripts/taskResults/MRPC/threshold/MRPC-{checkpoint[17:]}.tsv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile, delimiter='\t', quotechar='ß')
writer.writerow(["index", "prediction"])
for index, pred in zip(range(len(predsThresholdMRPC)), predsThresholdMRPC):
writer.writerow([index, pred])

with open(f'scripts/taskResults/MRPC/f1Threshold/MRPC-{checkpoint[12:]}.tsv', 'w', newline='') as csvfile:
with open(f'scripts/taskResults/MRPC/f1Threshold/MRPC-{checkpoint[17:]}.tsv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile, delimiter='\t', quotechar='ß')
writer.writerow(["index", "prediction"])
for index, pred in zip(range(len(predsF1ThresholdMRPC)), predsF1ThresholdMRPC):
Expand All @@ -133,13 +133,13 @@ def main():
model = FireWord.from_pretrained(checkpoint).to(device)
predsMedianSSTGlue, predsThresholdSSTGlue = predictSSTGlue(model, testPairsSSTGlue, devPairsSSTGlue, devLabelsSSTGlue, sifA)

with open(f'scripts/taskResults/SSTGLUE/median/SST-2-{checkpoint[12:]}.tsv', 'w', newline='') as csvfile:
with open(f'scripts/taskResults/SSTGLUE/median/SST-2-{checkpoint[17:]}.tsv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile, delimiter='\t', quotechar='ß')
writer.writerow(["index", "prediction"])
for index, pred in zip(range(len(predsMedianSSTGlue)), predsMedianSSTGlue):
writer.writerow([index, pred])

with open(f'scripts/taskResults/SSTGLUE/threshold/SST-2-{checkpoint[12:]}.tsv', 'w', newline='') as csvfile:
with open(f'scripts/taskResults/SSTGLUE/threshold/SST-2-{checkpoint[17:]}.tsv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile, delimiter='\t', quotechar='ß')
writer.writerow(["index", "prediction"])
for index, pred in zip(testIndicesSSTGlue, predsThresholdSSTGlue):
Expand All @@ -154,28 +154,19 @@ def main():
model = FireWord.from_pretrained(checkpoint).to(device)
predsMedianRTE, predsThresholdRTE = predictRTE(model, testPairsRTE, devPairsRTE, devLabelsRTE, sifA)

with open(f'scripts/taskResults/RTE/median/RTE-{checkpoint[12:]}.tsv', 'w', newline='') as csvfile:
with open(f'scripts/taskResults/RTE/median/RTE-{checkpoint[17:]}.tsv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile, delimiter='\t', quotechar='ß')
writer.writerow(["index", "prediction"])
for index, pred in zip(range(len(predsMedianRTE)), predsMedianRTE):
writer.writerow([index, pred])

with open(f'scripts/taskResults/RTE/threshold/RTE-{checkpoint[12:]}.tsv', 'w', newline='') as csvfile:
with open(f'scripts/taskResults/RTE/threshold/RTE-{checkpoint[17:]}.tsv', 'w', newline='') as csvfile:
writer = csv.writer(csvfile, delimiter='\t', quotechar='ß')
writer.writerow(["index", "prediction"])
for index, pred in zip(testIndicesRTE, predsThresholdRTE):
writer.writerow([index, pred])



""" rescaling (smoothing) and exponential """

def simmatRescale(simmat) -> np.ndarray:
scale = np.abs(simmat).mean(axis=1, keepdims=True)
simmat = simmat / (scale * scale.T) ** 0.5

return simmat

@torch.no_grad()
@Timer(elapsed, "sentsim")
def benchmarkMRPC(
Expand Down Expand Up @@ -210,7 +201,7 @@ def benchmarkMRPC(
diag = np.diag(simmat)
simmat = simmat - 0.5 * (diag.reshape(-1, 1) + diag.reshape(1, -1))

""" smoothing by standardization """
""" smoothing by standardization """
with Timer(elapsed, "smooth"):
mean1 = np.mean(simmat, axis=1, keepdims=True)
std1 = np.std(simmat, axis=1, keepdims=True)
Expand All @@ -230,7 +221,7 @@ def benchmarkMRPC(
falseNegCount = sum([int(preds[i] >= medianThreshhold) == 0 and labels[i] == 1 for i in range(len(preds))])

medianf1Score = truePosCount / (truePosCount + 0.5 * (falsePosCount + falseNegCount))
print(f"\t\tmedianThreshhold: {medianThreshhold} \n\t\tmedianAccuracy: {medianScore}\n\t\tmedianF1: {medianf1Score}")
# print(f"\t\tmedianThreshhold: {medianThreshhold} \n\t\tmedianAccuracy: {medianScore}\n\t\tmedianF1: {medianf1Score}")

bestThreshold = 0
bestF1Threshold = 0
Expand All @@ -256,8 +247,8 @@ def benchmarkMRPC(
bestThreshold = threshold
bestAccuracy = accuracy

print(f"\t\taccuracy threshold = {bestThreshold}")
print(f"\t\tF1 threshold = {bestF1Threshold}")
# print(f"\t\taccuracy threshold = {bestThreshold}")
# print(f"\t\tF1 threshold = {bestF1Threshold}")

predsMedian = [int(pred >= medianThreshhold) for pred in preds]
predsThreshold = [int(pred >= bestThreshold) for pred in preds]
Expand Down Expand Up @@ -433,12 +424,12 @@ def computeDevThresholds():
(((not predsVN[i] >= thresholdVN) and not(predsN[i] >= thresholdN) and not(predsNeut[i] >= thresholdNeut)
and not(predsP[i] >= thresholdP) and (predsVP[i] >= thresholdVP)) and (devLabels[i] > 0.8 and devLabels[i] <= 1)) #very positve
) for i in range(len(N))]) / len(N)
print(
f"threshhold very negative:\t{thresholdVN}\n"
f"threshhold negative:\t\t{thresholdN}\n"
f"threshhold neutral:\t\t {thresholdNeut}\n"
f"threshhold positive:\t\t {thresholdP}\n"
f"threshhold very positive:\t{thresholdVP}")
# print(
# f"threshhold very negative:\t{thresholdVN}\n"
# f"threshhold negative:\t\t{thresholdN}\n"
# f"threshhold neutral:\t\t {thresholdNeut}\n"
# f"threshhold positive:\t\t {thresholdP}\n"
# f"threshhold very positive:\t{thresholdVP}")

return accuracy

Expand Down Expand Up @@ -486,7 +477,7 @@ def computeThresholdFromDevData():
std0 = np.std(simmat, axis=0, keepdims=True)
simmat = (simmat - (mean1 + mean0) / 2) / (std0 * std1) ** 0.5

N = len(pairs[0])
N = len(devPairs[0])
preds = [simmat[i, i + N] for i in range(N)]
preds = np.exp(preds)
preds = np.array(preds)
Expand Down Expand Up @@ -593,7 +584,7 @@ def computeThresholdFromDevData():
std0 = np.std(simmat, axis=0, keepdims=True)
simmat = (simmat - (mean1 + mean0) / 2) / (std0 * std1) ** 0.5

N = len(pairs[0])
N = len(devPairs[0])
preds = [simmat[i, i + N] for i in range(N)]
preds = np.exp(preds)
preds = np.array(preds)
Expand Down
51 changes: 46 additions & 5 deletions scripts/corpusPreprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def prepareMRPCData(path):
"""
Makes the sentences lower case, strips leading and trailing whitespace and removes punctuation.
Returns the sentence pairs with their labels.
Returns the sentence pairs with their labels (0 or 1).
parameters:
String : path: path of MRPC data
Expand Down Expand Up @@ -42,7 +42,8 @@ def prepareMRPCData(path):
def prepareSSTData(splitPath, sentencesPath, dictionaryPath, labelsPath):
"""
Makes the sentences lower case, strips leading and trailing whitespace and removes punctuation.
Returns the sentence pairs with their labels.
The sentiment "very negative", "negative" etc. is used as the second sentence.
Returns the sentence pairs with their labels (0 or 1, to be interpreted as false or true).
parameters:
String : splitPath: file with train, test, dev split information,
Expand All @@ -52,11 +53,11 @@ def prepareSSTData(splitPath, sentencesPath, dictionaryPath, labelsPath):
returns:
List : trainPairs: 2 lists that contain sentence pairs
List : trainLabels: integer values 1 or 0 that indicate whether the sentences are similar
List : trainLabels: integer values 1 or 0 that indicate whether the sentences fits the supplied sentiment
List : devPairs: 2 lists that contain sentence pairs
List : devLabels: integer values 1 or 0 that indicate whether the sentences are similar
List : devLabels: integer values 1 or 0 that indicate whether the sentences fits the supplied sentiment
List : testPairs: 2 lists that contain sentence pairs
List : testLabels: integer values 1 or 0 that indicate whether the sentences are similar
List : testLabels: integer values 1 or 0 that indicate whether the sentence fits the supplied sentiment
"""
labels = {}
Expand Down Expand Up @@ -216,6 +217,26 @@ def prepareSSTData(splitPath, sentencesPath, dictionaryPath, labelsPath):
return trainPairs, trainLabels, devPairs, devLabels, testPairs, testLabels

def prepareSSTGlueData(trainPath, devPath, testPath):
"""
Makes the sentences lower case, strips leading and trailing whitespace and removes punctuation.
The sentiment "positive" is used as the second sentence on every sentence.
Returns the sentence pairs with their labels (0 or 1; interpret as: not positive or positive).
parameters:
String : splitPath: file with train, test, dev split information,
String : sentencesPath: file with all sentences,
String : dictionaryPath: file with phrases,
String : labelsPath: file with labels assigned to phrases,
returns:
List : trainPairs: 2 lists that contain sentence pairs
List : trainLabels: integer values 1 or 0 that indicate whether the sentence is positive
List : devPairs: 2 lists that contain sentence pairs
List : devLabels: integer values 1 or 0 that indicate whether the sentence is positive
List : testPairs: 2 lists that contain sentence pairs
List : testIndices: integer indices of pairs
"""
trainPairs = [[], []]
devPairs = [[], []]
testPairs = [[], []]
Expand Down Expand Up @@ -281,6 +302,26 @@ def prepareSSTGlueData(trainPath, devPath, testPath):
return trainPairs, trainLabels, devPairs, devLabels, testPairs, testIndices

def prepareRTEGlueData(trainPath, devPath, testPath):
"""
Makes the sentences lower case, strips leading and trailing whitespace and removes punctuation.
Returns the sentence pairs with their labels (0 == not_entailment, 1 = entailment).
parameters:
String : splitPath: file with train, test, dev split information,
String : sentencesPath: file with all sentences,
String : dictionaryPath: file with phrases,
String : labelsPath: file with labels assigned to phrases,
returns:
List : trainPairs: 2 lists that contain sentence pairs
List : trainLabels: integer values 1 or 0 that indicate whether the sentences are entailed
List : devPairs: 2 lists that contain sentence pairs
List : devLabels: integer values 1 or 0 that indicate whether the sentences are entailed
List : testPairs: 2 lists that contain sentence pairs
List : testIndices: integer indices of pairs
"""

trainPairs = [[], []]
devPairs = [[], []]
testPairs = [[], []]
Expand Down
5 changes: 5 additions & 0 deletions scripts/fine-tuning/1_fine_tune_MRPC_all_checkpoints.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
python -m scripts.fineTune --sz_batch=200 --lr=0.005 --lr_scheduler=OneCycleLR --n_iters=600000 --eval_interval=1000 --savedir=results/fineTuningResults/MRPC_v1.1_wacky_mlplanardiv_d2_l8_k20 --optimizer=adamw --seed=0 --accum_steps=10 --weight_decay=1e-6 --pretrainedModel=checkpoints/v1.1/wacky_mlplanardiv_d2_l8_k20 --task=MRPC

python -m scripts.fineTune --sz_batch=200 --lr=0.005 --lr_scheduler=OneCycleLR --n_iters=600000 --eval_interval=1000 --savedir=results/fineTuningResults/MRPC_v1.1_wacky_mlplanardiv_d2_l4_k10 --optimizer=adamw --seed=0 --accum_steps=10 --weight_decay=1e-6 --pretrainedModel=checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k10 --task=MRPC

python -m scripts.fineTune --sz_batch=200 --lr=0.005 --lr_scheduler=OneCycleLR --n_iters=600000 --eval_interval=1000 --savedir=results/fineTuningResults/MRPC_v1.1_wacky_mlplanardiv_d2_l4_k1_polysemy --optimizer=adamw --seed=0 --accum_steps=10 --weight_decay=1e-6 --pretrainedModel=checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy --task=MRPC
41 changes: 41 additions & 0 deletions scripts/fine-tuning/2_fine_tune_SST-2_all_checkpoints.bat
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
python -m scripts.train \
--sz_batch=32768 \
--lr=0.005 \
--lr_scheduler=OneCycleLR \
--n_iters=600000 \
--eval_interval=1000 \
--savedir=results/fineTuningResults/MRPC_v1.1_wacky_mlplanardiv_d2_l8_k20 \
--optimizer=adamw \
--seed=0 \
--accum_steps=10 \
--weight_decay=1e-6 \
--pretrainedModel=checkpoints/v1.1/wacky_mlplanardiv_d2_l8_k20
--task=MRPC

python -m scripts.train \
--sz_batch=32768 \
--lr=0.005 \
--lr_scheduler=OneCycleLR \
--n_iters=600000 \
--eval_interval=1000 \
--savedir=results/fineTuningResults/MRPC_v1.1_wacky_mlplanardiv_d2_l4_k10
--optimizer=adamw \
--seed=0 \
--accum_steps=10 \
--weight_decay=1e-6 \
--pretrainedModel=checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k10
--task=MRPC

python -m scripts.train \
--sz_batch=32768 \
--lr=0.005 \
--lr_scheduler=OneCycleLR \
--n_iters=600000 \
--eval_interval=1000 \
--savedir=results/fineTuningResults/MRPC_v1.1_wacky_mlplanardiv_d2_l4_k1_polysemy
--optimizer=adamw \
--seed=0 \
--accum_steps=10 \
--weight_decay=1e-6 \
--pretrainedModel=checkpoints/v1.1/wacky_mlplanardiv_d2_l4_k1_polysemy
--task=MRPC
Loading

0 comments on commit 2b1306d

Please sign in to comment.