Skip to content

Commit

Permalink
Update run.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnweiwei authored Jan 8, 2024
1 parent c051e5c commit d3c1609
Showing 1 changed file with 16 additions and 63 deletions.
79 changes: 16 additions & 63 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,16 +706,15 @@ def train(config):
for batch in tk0:
step += 1
with accelerator.accumulate(model):
# losses = OurTrainer.train_step(model, batch, gathered=False)
# loss = sum([v * loss_w[k] for k, v in losses.items()])
# accelerator.backward(loss)
# accelerator.clip_grad_norm_(model.parameters(), 1.)
# optimizer.step()
# scheduler.step()
# optimizer.zero_grad()
#
# loss = accelerator.gather(loss).mean().item()
loss = 0.
losses = OurTrainer.train_step(model, batch, gathered=False)
loss = sum([v * loss_w[k] for k, v in losses.items()])
accelerator.backward(loss)
accelerator.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
scheduler.step()
optimizer.zero_grad()

loss = accelerator.gather(loss).mean().item()
loss_report.append(loss)
tk0.set_postfix(loss=sum(loss_report[-100:]) / len(loss_report[-100:]))

Expand Down Expand Up @@ -1007,22 +1006,21 @@ def do_epoch_encode(model: Model, data, corpus, ids, tokenizer, batch_size, save
data_loader = torch.utils.data.DataLoader(corpus_data, collate_fn=corpus_data.collate_fn, batch_size=batch_size,
shuffle=False, num_workers=16)

# collection, doc_code = our_encode(data_loader, model, 'doc')
collection = np.zeros((100, 768))
collection, doc_code = our_encode(data_loader, model, 'doc')
doc_code = [0] * len(corpus)

print(collection.shape)
# index = build_index(collection, gpu=False)
index = build_index(collection, gpu=False)

q_corpus = ['' for _ in range(len(corpus))]
corpus_data = BiDataset(data=data, corpus=q_corpus, tokenizer=tokenizer, max_doc_len=128, max_q_len=32, ids=ids)
data_loader = torch.utils.data.DataLoader(corpus_data, collate_fn=corpus_data.collate_fn, batch_size=batch_size,
shuffle=False, num_workers=4)
queries, query_code = our_encode(data_loader, model, 'query')

# rank, distance = do_retrieval(queries, index, k=100)
# rank = rank.tolist()
rank = None
rank, distance = do_retrieval(queries, index, k=100)
rank = rank.tolist()

json.dump(rank, open(f'{save_path}/{epoch}.pt.rank', 'w'))
all_doc_code = [prefix[1:] + [current] for prefix, current in zip(ids, doc_code)]
json.dump(all_doc_code, open(f'{save_path}/{epoch}.pt.code', 'w'))
Expand All @@ -1031,21 +1029,16 @@ def do_epoch_encode(model: Model, data, corpus, ids, tokenizer, batch_size, save
print('Doc_code balance', balance(doc_code, ids, ncentroids=n_code))
print('Doc_code conflict', conflict(doc_code, ids))

# normed_collection = norm_by_prefix(collection, ids)
normed_collection = norm_by_prefix(collection, ids)
nc = n_code
# centroids, code = constrained_km(normed_collection, nc)
code = [0] * len(corpus)
centroids = np.zeros((nc, 768))
centroids, code = constrained_km(normed_collection, nc)
print('Kmeans balance', balance(code, ids))
print('Kmeans conflict', conflict(code, ids))
write_pkl(centroids, f'{save_path}/{epoch}.pt.kmeans.{nc}')
json.dump(code, open(f'{save_path}/{epoch}.pt.kmeans_code.{nc}', 'w'))

query_ids = [x[1] for x in data]

from eval import eval_all
# print(eval_all(rank, query_ids))


def test_dr(config):
model_name = config.get('model_name', 't5-base')
Expand Down Expand Up @@ -1085,46 +1078,6 @@ def test_dr(config):
do_epoch_encode(model, data, corpus, ids, tokenizer, batch_size, save_path, epoch, n_code=code_num)


def test_case():
batch_size = 128
save_path = 'out/our-v12-512'
epoch = 300
data = json.load(open('dataset/nq320k/dev.json'))
corpus = json.load(open('dataset/nq320k/corpus_lite.json'))

print('DR evaluation', f'{save_path}')
t5 = AutoModelForSeq2SeqLM.from_pretrained('models/t5-base')
code_number = 512
model = Model(model=t5, use_constraint=False, code_length=1, zero_inp=False, code_number=code_number)
tokenizer = AutoTokenizer.from_pretrained('models/t5-base')
model = model.cuda()
model.eval()
ids = [[0] for i, j in
zip(json.load(open('out/our-v7-512/9.pt.code')), json.load(open('out/our-v9-512/19.pt.code')))]

safe_load(model, f'{save_path}/{epoch}.pt')
corpus_q = [['', i] for i in range(len(corpus))]
corpus_data = BiDataset(data=corpus_q, corpus=corpus, tokenizer=tokenizer, max_doc_len=128, max_q_len=32, ids=ids)
data_loader = torch.utils.data.DataLoader(corpus_data, collate_fn=corpus_data.collate_fn, batch_size=batch_size,
shuffle=False, num_workers=16)
keys = 'doc'
collection = []
code_collection = []
for batch in tqdm(data_loader):
batch = {k: v.cuda() for k, v in batch.items() if v is not None}
output: QuantizeOutput = model(input_ids=batch[keys], attention_mask=batch[keys].ne(0),
decoder_input_ids=batch['ids'],
aux_ids=None, return_code=False,
return_quantized_embedding=False, use_constraint=False)
sentence_embeddings = output.continuous_embeds.cpu().tolist()
code = output.probability.argmax(-1).cpu().tolist()
code_collection.extend(code)
collection.extend(sentence_embeddings)
collection = np.array(collection, dtype=np.float32)
print(collection.shape)
write_pkl(collection, f'case/l1.collection')


# centroids, code
def skl_kmeans(x, ncentroids=10, niter=300, n_init=10, mini=False, reassign=0.01):
from sklearn.cluster import KMeans, MiniBatchKMeans
Expand Down

0 comments on commit d3c1609

Please sign in to comment.