diff --git a/.gitignore b/.gitignore index b713501..cd2b6cb 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,12 @@ openai_api_key.json prompts_backup/ prompt_data/mol_cluster.csv prompt_data/mole_graph_property.csv +create_dataset.py +create_pretraining_dataset.py +create_hug_repo.py +huggingface_dataset.py +download_huggingface_dataset.py +downstream_test_huggingface.py +download_huggingface_model.py +debug.py + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..08465f5 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 haiteng zhao + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 925cda5..f9127e7 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # GIMLET -This is the code for paper [GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning](https://www.biorxiv.org/content/10.1101/2023.05.30.542904). +This is the code for paper [GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning](https://arxiv.org/pdf/2306.13089.pdf) published at NeurIPS 2023. GIMLET is a unified transformer model for both graph and text data and is pretrained on large scale molecule tasks with instructions, towards instruction-based molecule zero-shot learning. The framework and pretraining & downstream tasks are as follows: @@ -13,6 +13,34 @@ GIMLET is a unified transformer model for both graph and text data and is pretra We also benchmark baselines including KVPLM, MoMu, and Galactica on our downstream tasks for instruction-based zero-shot learning. +## Updates + +### 2023.9.24 + +Out work has been accepted at NeurIPS 2023! The camera ready paper is at [https://proceedings.neurips.cc/paper_files/paper/2023/file/129033c7c08be683059559e8d6bfd460-Paper-Conference.pdf](https://proceedings.neurips.cc/paper_files/paper/2023/file/129033c7c08be683059559e8d6bfd460-Paper-Conference.pdf). + +### 2023.7.10 + +**1.** Now the datasets and the GIMLET model can be download directly from HuggingFace 🤗 : [https://huggingface.co/datasets/haitengzhao/molecule_property_instruction](https://huggingface.co/datasets/haitengzhao/molecule_property_instruction) and [https://huggingface.co/haitengzhao/gimlet](https://huggingface.co/haitengzhao/gimlet). + +The GIMLET model can be downloaded and used as follows: + +``` +from model import GraphT5TransformerForConditionalGeneration +model = GraphT5TransformerForConditionalGeneration.from_pretrained("haitengzhao/gimlet") +``` + +Our datasets can be downloaded and used as follows: + +``` +from datasets import load_dataset +dataset = load_dataset("haitengzhao/molecule_property_instruction") +``` + +We have made updates to the pipeline and scripts to accommodate the new loading methods. Try out the new implementation in your projects and enjoy the improved experience! + +**2.** A few bugs in KVPLM testing have been fixed. + ## Installation To run GIMLET, please clone the repository to your local machine and install the required dependencies using the script provided. @@ -36,7 +64,7 @@ pip install torch_spline_conv-1.2.1-cp37-cp37m-linux_x86_64.whl pip install torch_geometric==1.7.2 -git clone https://github.com/huggingface/transformers +git clone -b v4.28.1 https://github.com/huggingface/transformers cd transformers pip install --editable ./ @@ -74,7 +102,19 @@ pip install openai ### Checkpoint Download -Please download pytorch_model.bin from [https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing](https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing) and move it to .\ckpts\gimlet. You can do this by the following scripts: + +#### Method 1: HuggingFace + +Our model can now be downloaded from HuggingFace 🤗 . To download the model parameters, you can simply specify **--model_name_or_path** as **haitengzhao/gimlet**. Here's an example: + +``` +from model import GraphT5TransformerForConditionalGeneration +model = GraphT5TransformerForConditionalGeneration.from_pretrained("haitengzhao/gimlet") +``` + +#### Method 2: Manual Download + +You can also download pytorch_model.bin from [https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing](https://drive.google.com/file/d/1ROU4SLW2NF9EtT70JC_SHC1OZIPB90id/view?usp=sharing) and move it to **.\ckpts\gimlet**. You can do this by the following scripts: ``` mkdir ckpts @@ -90,9 +130,16 @@ cd .. cd .. ``` +In this case, the **--model_name_or_path** refers to the path of the checkpoint directory, which is **ckpts/gimlet**. + ### Dataset Download +#### Method 1: HuggingFace +Our datasets is available for download on HuggingFace 🤗 . You can automatically download the datasets and use the huggingface dataset pipeline by augment **--use_huggingface_pipeline**. + +#### Method 2: Manual Download +Alternatively, you can run experiments from the original molecule datasets. In this pipeline, we will incorporate instruction text to the molecule data during the experimentation process. The MoleculeNet datasets, which comprise pcba, bace, hiv, muv, tox21, toxcast, bbbp, esol, lipo, and freesolv, can be conveniently downloaded automatically upon the first run. Alternatively, you can manually download them by following the script below: ``` @@ -113,9 +160,12 @@ Besides MoleculeNet, we also includes CYP450 which can be downloaded from [https The script to run one downstream task is ``` -CUDA_VISIBLE_DEVICES=0 python downstream_test.py --zero_shot --transformer_backbone gimlet --model_name_or_path ckpts/gimlet --tokenizer_name t5-small --dataset bace --runseed 5 --batch_size 40 --grad_accum_step 1 --transform_in_collator +CUDA_VISIBLE_DEVICES=0 python downstream_test.py --zero_shot --transformer_backbone gimlet --model_name_or_path haitengzhao/gimlet --tokenizer_name t5-small --dataset bace --runseed 5 --batch_size 40 --grad_accum_step 1 --transform_in_collator --only_test --use_huggingface_pipeline ``` +You have the option to include the **--use_huggingface_pipeline** flag to utilize the HuggingFace dataset pipeline. This feature is applicable for both GIMLET and baseline models in downstream scenarios involving zero-shot and few-shot settings. + + To execute all the downstream tasks, you can utilize the script downstream_test.sh. Running this script will generate results that will be written into the file "./cache/testing_$modelname.csv". ``` bash downstream_test.sh $device $backbone $modelname_or_path ($few_shot_number) ($augment_type) @@ -149,7 +199,6 @@ bash downstream_test.sh 0 kvplm_aug ckpt_KV.pt 0 rewrite bash downstream_test.sh 0 momu_aug littlegin=graphclinit_bert=scibert_epoch=299-step=18300.pt 0 rewrite ``` - ## Run Few-Shot Learning You can run few-shot learning for all the downstream tasks by specify the few-shot number: @@ -170,9 +219,25 @@ bash downstream_test.sh 0 momu_fewshot littlegin=graphclinit_bert=scibert_epoch= ## Run the Pretraining -### Pretraining Data +### Run the Pretraining + +To reproduce the pretraining on Chembl and Chembl property datasets, you can run the following command: +``` +CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path t5-small --tokenizer_name t5-small --transformer_backbone gimlet --do_train --train_file haitengzhao/molecule_property_instruction --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new +``` + +You can validate the pretrained model on the splitted Chembl dataset (Chembl Zero Shot): + +``` +CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path ckpts/gimlet_new --tokenizer_name t5-small --transformer_backbone gimlet --do_eval --validation_file haitengzhao/molecule_property_instruction --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new +``` + +You can run your own pretraining by specifying --train_file as your pretraining file, or imply your model into the pipeline. -You can download the pretraining dataset if you want to reproduce the pretraining or train your own model. The Chembl dataset can be downloaded and processed by the following steps: + +### Reproducing the Pretraining Data Generation + +You can reproduce the pretraining dataset generation if you want to imply your own instruction methods. The Chembl dataset can be downloaded and processed by the following steps: ``` cd prompt_data/ @@ -208,6 +273,7 @@ cd .. Produce the pretraining dataset by the following script: ``` +cd prompts python generate_pretrain_dataset.py --generate_assay_text --generate_mole_text --split_non_overlap --add_negation --use_augmented_prompt ``` @@ -217,35 +283,18 @@ And merge the generated dataset together: python generate_pretrain_dataset_merge.py --merge_file_list assay_graph_text_train_non_overlap_split_0.csv assay_graph_text_detail_train_non_overlap_split_0.csv assay_graph_text_expand_train_non_overlap_split_0.csv assay_graph_text_rewrite_train_non_overlap_split_0.csv assay_graph_text_shorten_train_non_overlap_split_0.csv property_graph_text_negative05_train_non_overlap_split_0.csv property_graph_text_negative05_detail_train_non_overlap_split_0.csv property_graph_text_negative05_expand_train_non_overlap_split_0.csv property_graph_text_negative05_rewrite_train_non_overlap_split_0.csv property_graph_text_negative05_shorten_train_non_overlap_split_0.csv --merge_file_policy custom --merge_file_ratio 1.0 1.0 1.0 1.0 1.0 1.0 0.25 0.25 0.25 0.25 --final_file_name merge_split0.csv ``` -### Run the Pretraining - -After creating the pretraining datasets, you can reproduce the pretraining by yourself: - -``` -CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path t5-small --tokenizer_name t5-small --transformer_backbone gimlet --do_train --train_file pretrain_datasets/merge_split0.csv --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new -``` - -You can validate the pretrained model on the splitted Chembl dataset (Chembl Zero Shot): - -``` -CUDA_VISIBLE_DEVICES=0 python pretraining_gimlet.py --model_name_or_path ckpts/gimlet_new --tokenizer_name t5-small --transformer_backbone gimlet --do_eval --validation_file pretrain_datasets/assay_graph_text_valid_non_overlap_split_0.csv --transform_in_collator --per_device_train_batch_size 64 --gradient_accumulation_steps 1 --per_device_eval_batch_size 200 --line_by_line --loss_reduction_method sentence --save_steps 10000 --output_dir ckpts/gimlet_new -``` - - -You can run your own pretraining by specifying --train_file as your pretraining file, or imply your model into the pipeline. - +In this scenario, the pretraining data is the file "pretrain_datasets/merge_split0.csv". To validate the pretrained model, you can use the data file "pretrain_datasets/assay_graph_text_valid_non_overlap_split_0.csv". To specify these files as the training and validation data, use the arguments **--train_file** and **--validation_file** with their respective file paths. ## Citation -Please cite our paper if you find it helpful. +Please cite our paper if you find it helpful or use our datasets. ``` -@article{zhao2023gimlet, - title={GIMLET: A Unified Graph-Text Model for Instruction-Based Molecule Zero-Shot Learning}, - author={Zhao, Haiteng and Liu, Shengchao and Ma, Chang and Xu, Hannan and Fu, Jie and Deng, Zhi-Hong and Kong, Lingpeng and Liu, Qi}, - journal={bioRxiv}, - pages={2023--05}, - year={2023}, - publisher={Cold Spring Harbor Laboratory} +@article{zhao2024gimlet, + title={Gimlet: A unified graph-text model for instruction-based molecule zero-shot learning}, + author={Zhao, Haiteng and Liu, Shengchao and Chang, Ma and Xu, Hannan and Fu, Jie and Deng, Zhihong and Kong, Lingpeng and Liu, Qi}, + journal={Advances in Neural Information Processing Systems}, + volume={36}, + year={2024} } ``` diff --git a/basic_pipeline.py b/basic_pipeline.py index 77749d1..7b3fcec 100644 --- a/basic_pipeline.py +++ b/basic_pipeline.py @@ -66,7 +66,7 @@ def eval_result(model, loader,label_dict,tokenizer,task_type,transformer_backbon batch[key] = batch[key].to(model.device) with torch.no_grad(): labels=batch["labels"] - if labels.shape[1]>1: # Yes + if labels.shape[1]>1 and not transformer_backbone in ['kvplm']: # Yes assert all((labels[:,1]==tokenizer.eos_token_id) + (labels[:,1]==id_invalid)) labels=labels[:,0].unsqueeze(1) del batch["labels"] diff --git a/dataloaders/galatica_smiles_collator.py b/dataloaders/galatica_smiles_collator.py index 9656420..e6a7e91 100644 --- a/dataloaders/galatica_smiles_collator.py +++ b/dataloaders/galatica_smiles_collator.py @@ -67,9 +67,10 @@ def torch_call(self, examples): def galactica_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs): data_new = {} + text = examples[text_column_name] if isinstance(examples[text_column_name], str) else examples[text_column_name][0] tokenized_input = tokenizer( # examples[text_column_name]+ ' ', - '[START_I_SMILES]' + examples['graph'] + '[END_I_SMILES]\n\n##Question: ' + examples[text_column_name] + '\n\nAnswer:', + '[START_I_SMILES]' + examples['graph'] + '[END_I_SMILES]\n\n##Question: ' + text + '\n\nAnswer:', padding=padding, truncation=True, max_length=max_seq_length, diff --git a/dataloaders/gpt3_smiles_collator.py b/dataloaders/gpt3_smiles_collator.py index 8facb9a..55975fd 100644 --- a/dataloaders/gpt3_smiles_collator.py +++ b/dataloaders/gpt3_smiles_collator.py @@ -24,8 +24,9 @@ def gpt3_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs): data_new = {} + text = examples[text_column_name] if isinstance(examples[text_column_name], str) else examples[text_column_name][0] tokenized_input = tokenizer( - 'Please answer questions on this molecule. The SMILES of this molecule is:' + examples['graph'] + '\n\n##Question: ' + examples[text_column_name] + '\n\nAnswer:', + 'Please answer questions on this molecule. The SMILES of this molecule is:' + examples['graph'] + '\n\n##Question: ' + text + '\n\nAnswer:', padding=padding, truncation=True, max_length=max_seq_length, diff --git a/dataloaders/graph_text_transform.py b/dataloaders/graph_text_transform.py index 5f9a9f1..0f9c02c 100644 --- a/dataloaders/graph_text_transform.py +++ b/dataloaders/graph_text_transform.py @@ -71,7 +71,7 @@ def tokenize_function_gin_T5(examples,tokenizer,text_column_name,padding,max_seq # Remove empty lines # examples[text_column_name] = [line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()] text = tokenizer( - examples[text_column_name], + examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0], padding=padding, truncation=True, max_length=max_seq_length, @@ -107,7 +107,7 @@ def tokenize_function_gimlet(examples, tokenizer, text_column_name, padding, max # Remove empty lines # examples[text_column_name] = [line for line in examples[text_column_name] if len(line) > 0 and not line.isspace()] text = tokenizer( - examples[text_column_name], + examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0], # if examples[text_column_name] is list padding=padding, truncation=True, max_length=max_seq_length, diff --git a/dataloaders/kvplm_smiles_collator.py b/dataloaders/kvplm_smiles_collator.py index 15fab00..299558d 100644 --- a/dataloaders/kvplm_smiles_collator.py +++ b/dataloaders/kvplm_smiles_collator.py @@ -113,9 +113,10 @@ def torch_call(self, examples): def kvplm_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,**kwargs): data_new = {} + text=examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0] tokenized_input = tokenizer( examples['graph'] + ' '+ - examples[text_column_name]+ ' ', + text+ ' ', padding=padding, truncation=True, max_length=max_seq_length, diff --git a/dataloaders/momu_collator.py b/dataloaders/momu_collator.py index 4101e3c..8e08e67 100644 --- a/dataloaders/momu_collator.py +++ b/dataloaders/momu_collator.py @@ -23,8 +23,9 @@ def contrastive_conditional_generation_tokenizer(examples,tokenizer,text_column_name,padding,max_seq_length,rich_features,**kwargs): label_dict={'Yes':[1],'No':[0]} data_new = {} - tokenized_input_pos=tokenizer(examples[text_column_name]+' '+'Yes',truncation=True,max_length=512) - tokenized_input_neg=tokenizer(examples[text_column_name]+' '+'No',truncation=True,max_length=512) + text=examples[text_column_name] if isinstance(examples[text_column_name],str) else examples[text_column_name][0] + tokenized_input_pos=tokenizer(text+' '+'Yes',truncation=True,max_length=512) + tokenized_input_neg=tokenizer(text+' '+'No',truncation=True,max_length=512) # if not transform_in_collator: # examples['graph'] = smiles2graph(examples['graph']) data_new['graph']=examples['graph'] diff --git a/downstream_test.py b/downstream_test.py index 9d7de9d..ccd59b4 100644 --- a/downstream_test.py +++ b/downstream_test.py @@ -17,7 +17,7 @@ from model import get_model from dataloaders import add_prompt_transform_dict,\ graph_text_collator_dict, \ - MoleculeDatasetSplitLabel + MoleculeDatasetSplitLabel,graph_text_tokenizer_dict from transformers import ( AutoTokenizer, @@ -26,6 +26,7 @@ from tqdm import tqdm import os import re +from datasets import load_dataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -39,10 +40,12 @@ parser.add_argument('--disable_tqdm',action='store_true') # about dataset and dataloader +parser.add_argument('--use_huggingface_pipeline',action='store_true') parser.add_argument('--dataset', type=str, default='bace') parser.add_argument('--num_workers', type=int, default=0) parser.add_argument('--rich_features',action='store_true') parser.add_argument('--transform_in_collator',action='store_true') +parser.add_argument('--overwrite_data_cache',action='store_true') # about multitask strategies parser.add_argument('--task_policy',type=str,default='traversal', choices=['single','traversal','multi_mixture','multi_label']) @@ -174,7 +177,7 @@ def train(model, loader, optimizer): -def downstream_task_by_transform(transform,model,train_loader,val_loader,test_loader,prompt=''): +def downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt=''): #reload the model parameter if args.few_shot: model = get_model(args, graph_args,tokenizer) @@ -324,65 +327,79 @@ def downstream_task_by_transform(transform,model,train_loader,val_loader,test_lo } tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, **tokenizer_kwargs) - if args.few_shot and args.few_shot_prompt_fashion!='traversal': - - def modify_name(name): - name = name.replace('.ckpt', '.pt') - name=name.replace('ckpts/','') - if name[-1]=='/': - name=name[:-1] - return name - - file_name=os.path.join('cache','result_'+args.few_shot_prompt_fashion+'_prompt_table.csv') - prompts_pd = pd.read_csv(file_name,index_col='unique_task_id') - rename_keys={} - for name in prompts_pd.columns: - rename_keys[name]=modify_name(name) - prompts_pd=prompts_pd.rename(columns=rename_keys) - prompt={} - model_name=modify_name(args.model_name_or_path) - for ind in range(get_num_task(args.dataset)): - if args.dataset + '@' + str(ind) in prompts_pd.index.values: - res=prompts_pd.loc[args.dataset+'@'+str(ind),model_name] - if pd.isna(res): - continue - prompt[str(ind)]=[res] + model=get_model(args,graph_args,tokenizer) - else: - if args.prompt_augmentation=='': - with open(os.path.join("prompts",args.prompt_file), 'r') as load_f: - prompts = commentjson.load(load_f) - prompt=prompts[args.dataset] - else: - with open(os.path.join("prompts",args.prompt_file), 'r') as load_f: - prompts = commentjson.load(load_f) - prompt_all=prompts[args.dataset] + def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + if args.return_model_size: + print('Model size: {}'.format(count_parameters(model))) + + + if not args.use_huggingface_pipeline: + #Load instruction files, and add them for molecule data. + if args.few_shot and args.few_shot_prompt_fashion!='traversal': + + def modify_name(name): + name = name.replace('.ckpt', '.pt') + name=name.replace('ckpts/','') + if name[-1]=='/': + name=name[:-1] + return name + + file_name=os.path.join('cache','result_'+args.few_shot_prompt_fashion+'_prompt_table.csv') + prompts_pd = pd.read_csv(file_name,index_col='unique_task_id') + rename_keys={} + for name in prompts_pd.columns: + rename_keys[name]=modify_name(name) + prompts_pd=prompts_pd.rename(columns=rename_keys) prompt={} - for key in prompt_all: - if args.prompt_augmentation in prompt_all[key]: - prompt[key]=prompt_all[key][args.prompt_augmentation] - else: - print('label split {} has no augmentation {}'.format(key, args.prompt_augmentation)) - - if isinstance(prompt,list): - prompt_token=tokenizer(prompt,return_special_tokens_mask=True) - input_ids = [item for item in prompt_token.data['input_ids']] - attention_mask = [item for item in prompt_token.data['attention_mask']] - if args.prompt_id is None: - args.prompt_id = list(range(len(prompt))) - elif isinstance(prompt,dict): - prompt_token={} - input_ids={} - attention_mask={} - args.prompt_id={} - for key in prompt.keys(): - if len(prompt[key])>0: - prompt_token[key]=tokenizer(prompt[key],return_special_tokens_mask=True) - input_ids[key] = [item for item in prompt_token[key].data['input_ids']] - attention_mask[key] = [item for item in prompt_token[key].data['attention_mask']] - args.prompt_id[key] = list(range(len(prompt[key]))) + model_name=modify_name(args.model_name_or_path) + for ind in range(get_num_task(args.dataset)): + if args.dataset + '@' + str(ind) in prompts_pd.index.values: + res=prompts_pd.loc[args.dataset+'@'+str(ind),model_name] + if pd.isna(res): + continue + prompt[str(ind)]=[res] + + else: + if args.prompt_augmentation=='': + with open(os.path.join("prompts",args.prompt_file), 'r') as load_f: + prompts = commentjson.load(load_f) + prompt=prompts[args.dataset] + else: + with open(os.path.join("prompts",args.prompt_file), 'r') as load_f: + prompts = commentjson.load(load_f) + prompt_all=prompts[args.dataset] + prompt={} + for key in prompt_all: + if args.prompt_augmentation in prompt_all[key]: + prompt[key]=prompt_all[key][args.prompt_augmentation] + else: + print('label split {} has no augmentation {}'.format(key, args.prompt_augmentation)) + + if isinstance(prompt,list): + prompt_token=tokenizer(prompt,return_special_tokens_mask=True) + input_ids = [item for item in prompt_token.data['input_ids']] + attention_mask = [item for item in prompt_token.data['attention_mask']] + if args.prompt_id is None: + args.prompt_id = list(range(len(prompt))) + elif isinstance(prompt,dict): + prompt_token={} + input_ids={} + attention_mask={} + args.prompt_id={} + for key in prompt.keys(): + if len(prompt[key])>0: + prompt_token[key]=tokenizer(prompt[key],return_special_tokens_mask=True) + input_ids[key] = [item for item in prompt_token[key].data['input_ids']] + attention_mask[key] = [item for item in prompt_token[key].data['attention_mask']] + args.prompt_id[key] = list(range(len(prompt[key]))) + else: + raise ValueError('Prompt type not supported. Only list or dict of (list of) prompts are supported.') + else: - raise ValueError('Prompt type not supported. Only list or dict of (list of) prompts are supported.') + print('Using huggingface pipeline. Prompt file not loaded.') label_ignore = [-100] raw_label = {1: 'Yes', 0: 'No', 'invalid': label_ignore} @@ -393,168 +410,311 @@ def modify_name(name): # Bunch of classification tasks num_tasks = get_num_task(args.dataset) - dataset_folder = 'property_data/' - if args.transformer_backbone in ['kvplm', 'galactica','gpt3']: - dataset = MoleculeDatasetSplitLabel(root=dataset_folder, name=args.dataset,return_smiles=True,split_label=args.split_label,single_split=args.single_split,rich_features=args.rich_features) - else: - dataset = MoleculeDatasetSplitLabel(root=dataset_folder, name=args.dataset,split_label=args.split_label,single_split=args.single_split,rich_features=args.rich_features) - - print(dataset) - print(dataset[0]) - - - if args.split == 'scaffold': - # if args.single_split is not None: - smiles_list = pd.read_csv(dataset_folder + args.dataset + '/processed/smiles.csv', - header=None)[0].tolist() - train_index, valid_index, test_index = scaffold_split( - torch.arange(len(smiles_list)), smiles_list, null_value=0, frac_train=0.8, - frac_valid=0.1, frac_test=0.1) - - train_index_total=[] - valid_index_total=[] - test_index_total=[] - for times in range(dataset.label_number): - train_index_times=train_index+times*dataset.len_oridata() - valid_index_times = valid_index + times * dataset.len_oridata() - test_index_times = test_index + times * dataset.len_oridata() - - train_index_total.append(train_index_times) - valid_index_total.append(valid_index_times) - test_index_total.append(test_index_times) - train_index_total=torch.cat(train_index_total,0) - valid_index_total=torch.cat(valid_index_total,0) - test_index_total=torch.cat(test_index_total,0) - - train_dataset = dataset[train_index_total] - valid_dataset = dataset[valid_index_total] - test_dataset = dataset[test_index_total] - - print('split via scaffold') - elif args.split == 'random': - train_dataset, valid_dataset, test_dataset = random_split( - dataset, null_value=0, frac_train=0.8, frac_valid=0.1, - frac_test=0.1, seed=args.seed) - print('randomly split') - elif args.split == 'random_scaffold': - smiles_list = pd.read_csv(dataset_folder + args.dataset + '/processed/smiles.csv', - header=None)[0].tolist() - train_dataset, valid_dataset, test_dataset = random_scaffold_split( - dataset, smiles_list, null_value=0, frac_train=0.8, - frac_valid=0.1, frac_test=0.1, seed=args.seed) - print('random scaffold') - else: - raise ValueError('Invalid split option.') - print(train_dataset[0]) + if not args.use_huggingface_pipeline: + # Loading Molecule Dataset + dataset_folder = 'property_data/' - data_collator = graph_text_collator_dict[args.transformer_backbone]( - tokenizer=tokenizer, - transform_in_collator=args.transform_in_collator, - rich_features=args.rich_features) + if args.transformer_backbone in ['kvplm', 'galactica','gpt3']: + dataset = MoleculeDatasetSplitLabel(root=dataset_folder, name=args.dataset,return_smiles=True,split_label=args.split_label,single_split=args.single_split,rich_features=args.rich_features) + else: + dataset = MoleculeDatasetSplitLabel(root=dataset_folder, name=args.dataset,split_label=args.split_label,single_split=args.single_split,rich_features=args.rich_features) - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, - shuffle=True, num_workers=args.num_workers,collate_fn=data_collator) - val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, - shuffle=False, num_workers=args.num_workers,collate_fn=data_collator) - test_loader = DataLoader(test_dataset, batch_size=args.batch_size, - shuffle=False, num_workers=args.num_workers,collate_fn=data_collator) + print(dataset) + print(dataset[0]) - model=get_model(args,graph_args,tokenizer) + if args.split == 'scaffold': + # if args.single_split is not None: + smiles_list = pd.read_csv(dataset_folder + args.dataset + '/processed/smiles.csv', + header=None)[0].tolist() + train_index, valid_index, test_index = scaffold_split( + torch.arange(len(smiles_list)), smiles_list, null_value=0, frac_train=0.8, + frac_valid=0.1, frac_test=0.1) - def count_parameters(model): - return sum(p.numel() for p in model.parameters() if p.requires_grad) + train_index_total=[] + valid_index_total=[] + test_index_total=[] + for times in range(dataset.label_number): + train_index_times=train_index+times*dataset.len_oridata() + valid_index_times = valid_index + times * dataset.len_oridata() + test_index_times = test_index + times * dataset.len_oridata() + + train_index_total.append(train_index_times) + valid_index_total.append(valid_index_times) + test_index_total.append(test_index_times) + train_index_total=torch.cat(train_index_total,0) + valid_index_total=torch.cat(valid_index_total,0) + test_index_total=torch.cat(test_index_total,0) + + train_dataset = dataset[train_index_total] + valid_dataset = dataset[valid_index_total] + test_dataset = dataset[test_index_total] + + print('split via scaffold') + elif args.split == 'random': + train_dataset, valid_dataset, test_dataset = random_split( + dataset, null_value=0, frac_train=0.8, frac_valid=0.1, + frac_test=0.1, seed=args.seed) + print('randomly split') + elif args.split == 'random_scaffold': + smiles_list = pd.read_csv(dataset_folder + args.dataset + '/processed/smiles.csv', + header=None)[0].tolist() + train_dataset, valid_dataset, test_dataset = random_scaffold_split( + dataset, smiles_list, null_value=0, frac_train=0.8, + frac_valid=0.1, frac_test=0.1, seed=args.seed) + print('random scaffold') + else: + raise ValueError('Invalid split option.') + print(train_dataset[0]) - if args.return_model_size: - print('Model size: {}'.format(count_parameters(model))) + data_collator = graph_text_collator_dict[args.transformer_backbone]( + tokenizer=tokenizer, + transform_in_collator=args.transform_in_collator, + rich_features=args.rich_features) + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers,collate_fn=data_collator) + val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers,collate_fn=data_collator) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers,collate_fn=data_collator) - if args.task_policy =='traversal': - recurrent_range=range(num_tasks) - elif args.task_policy =='single': - recurrent_range = [args.single_split] else: - raise ValueError('prompt_policy not implemented yet') + # Loading Huggingface Dataset + dataset = load_dataset("haitengzhao/molecule_property_instruction", + # download_mode = "force_redownload" + )[args.dataset] + - if args.not_retest_tasks_in_result_file: - if os.path.exists(args.output_result_to_file): - result_file=pd.read_csv(args.output_result_to_file,header=0,index_col=0) + print(dataset) + print(dataset[0]) + + if args.split == 'scaffold': + train_dataset_total = dataset.filter(lambda example: (example["split"] == 'train')) + valid_dataset_total = dataset.filter(lambda example: (example["split"] == 'valid')) + test_dataset_total = dataset.filter(lambda example: (example["split"] == 'test')) else: - result_file=None - - for single_split_label in recurrent_range: - if args.task_policy in ['traversal','single']: - print('label split: ',single_split_label) - if not str(single_split_label) in prompt: - print('No prompt for label split {}'.format(single_split_label)) - continue - if args.not_retest_tasks_in_result_file and result_file is not None: - if len(result_file[(result_file['dataset']==args.dataset) & (result_file['split']==single_split_label)])>0: - print(args.dataset,' ',single_split_label,'has been tested') - continue + raise ValueError('Not implied split option for huggingface pipeline.') - train_loader.dataset.set_single_split(single_split_label) - val_loader.dataset.set_single_split(single_split_label) - test_loader.dataset.set_single_split(single_split_label) - - dataset.set_single_split(single_split_label) - if args.few_shot is not None: - ind_each_class = {} - for ind in train_index_total: - label=int(dataset[ind].y) - if label not in ind_each_class: - ind_each_class[label]=[ind] - else: - ind_each_class[label].append(ind) + def select_single_prompt(example, prompt_id): + example["text"] = example["text"][prompt_id] + return example - for key in ind_each_class.keys(): - ind_each_class[key]=np.random.choice(ind_each_class[key], size=min(len(ind_each_class[key]),args.few_shot),replace=False).tolist() - train_index_total=[] - for key in ind_each_class.keys(): - train_index_total+=ind_each_class[key] - train_dataset = dataset[train_index_total] - train_loader = DataLoader(train_dataset, batch_size=args.batch_size, - shuffle=True, num_workers=args.num_workers, collate_fn=data_collator) - train_loader.dataset.set_single_split(single_split_label) + tokenize_function = lambda x: graph_text_tokenizer_dict[args.transformer_backbone](examples=x, + tokenizer=tokenizer, + text_column_name='text', + padding=False, + max_seq_length=None, + rich_features=args.rich_features, + transform_in_collator=( + args.transform_in_collator)) - if args.prompt_policy == 'single': - print(prompt[args.prompt_id[0]]) + data_collator = graph_text_collator_dict[args.transformer_backbone]( + tokenizer=tokenizer, + transform_in_collator=args.transform_in_collator, + rich_features=args.rich_features) - #add prompt to graph data by data transform - transform=lambda x: add_prompt_transform_dict[args.transformer_backbone]( - data=x,data_label=x.y,input_ids=input_ids[args.prompt_id[0]], - attention_mask=attention_mask[args.prompt_id[0]],label_dict=label_dict, - rich_features=args.rich_features,transform_in_collator=args.transform_in_collator, - raw_prompts=prompt[args.prompt_id[0]],raw_label=raw_label,tokenizer=tokenizer, - generaltive_label=(task_type(args.dataset)=='reg')) - train_loader.dataset.transform = transform - val_loader.dataset.transform = transform - test_loader.dataset.transform = transform + if not args.use_huggingface_pipeline: #Different pre-processing for the two types of pipelines. - downstream_task_by_transform(transform,model,train_loader,val_loader,test_loader,prompt[args.prompt_id[0]]) + if args.task_policy =='traversal': + recurrent_range=range(num_tasks) + elif args.task_policy =='single': + recurrent_range = [args.single_split] + else: + raise ValueError('prompt_policy not implemented yet') - elif args.prompt_policy == 'traversal': - for prompt_id in args.prompt_id[str(single_split_label)]: - print(prompt[str(single_split_label)][prompt_id]) + if args.not_retest_tasks_in_result_file: + if os.path.exists(args.output_result_to_file): + result_file=pd.read_csv(args.output_result_to_file,header=0,index_col=0) + else: + result_file=None + + for single_split_label in recurrent_range: + if args.task_policy in ['traversal','single']: + print('label split: ',single_split_label) + if not str(single_split_label) in prompt: + print('No prompt for label split {}'.format(single_split_label)) + continue + if args.not_retest_tasks_in_result_file and result_file is not None: + if len(result_file[(result_file['dataset']==args.dataset) & (result_file['split']==single_split_label)])>0: + print(args.dataset,' ',single_split_label,'has been tested') + continue + train_loader.dataset.set_single_split(single_split_label) + val_loader.dataset.set_single_split(single_split_label) + test_loader.dataset.set_single_split(single_split_label) + + dataset.set_single_split(single_split_label) + if args.few_shot is not None: + ind_each_class = {} + for ind in train_index_total: + label=int(dataset[ind].y) + if label not in ind_each_class: + ind_each_class[label]=[ind] + else: + ind_each_class[label].append(ind) + + for key in ind_each_class.keys(): + ind_each_class[key]=np.random.choice(ind_each_class[key], size=min(len(ind_each_class[key]),args.few_shot),replace=False).tolist() + train_index_total=[] + for key in ind_each_class.keys(): + train_index_total+=ind_each_class[key] + + train_dataset = dataset[train_index_total] + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers, collate_fn=data_collator) + train_loader.dataset.set_single_split(single_split_label) + + if args.prompt_policy == 'single': + print(prompt[args.prompt_id[0]]) + + #add prompt to graph data by data transform transform=lambda x: add_prompt_transform_dict[args.transformer_backbone]( - data=x,data_label=x.y,input_ids=input_ids[str(single_split_label)][prompt_id], - attention_mask=attention_mask[str(single_split_label)][prompt_id],label_dict=label_dict, - rich_features=args.rich_features,transform_in_collator=args.transform_in_collator, - raw_prompts=prompt[str(single_split_label)][prompt_id],raw_label=raw_label,tokenizer=tokenizer, + data=x,data_label=x.y,input_ids=input_ids[args.prompt_id[0]], + attention_mask=attention_mask[args.prompt_id[0]],label_dict=label_dict, + rich_features=args.rich_features,transform_in_collator=args.transform_in_collator, + raw_prompts=prompt[args.prompt_id[0]],raw_label=raw_label,tokenizer=tokenizer, generaltive_label=(task_type(args.dataset)=='reg')) train_loader.dataset.transform = transform val_loader.dataset.transform = transform test_loader.dataset.transform = transform - downstream_task_by_transform(transform,model,train_loader,val_loader,test_loader,prompt[str(single_split_label)][prompt_id]) + downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt[args.prompt_id[0]]) + elif args.prompt_policy == 'traversal': + for prompt_id in args.prompt_id[str(single_split_label)]: + print(prompt[str(single_split_label)][prompt_id]) + + transform=lambda x: add_prompt_transform_dict[args.transformer_backbone]( + data=x,data_label=x.y,input_ids=input_ids[str(single_split_label)][prompt_id], + attention_mask=attention_mask[str(single_split_label)][prompt_id],label_dict=label_dict, + rich_features=args.rich_features,transform_in_collator=args.transform_in_collator, + raw_prompts=prompt[str(single_split_label)][prompt_id],raw_label=raw_label,tokenizer=tokenizer, + generaltive_label=(task_type(args.dataset)=='reg')) + + train_loader.dataset.transform = transform + val_loader.dataset.transform = transform + test_loader.dataset.transform = transform + + downstream_task_by_transform(model,train_loader,val_loader,test_loader,prompt[str(single_split_label)][prompt_id]) + + else: + raise ValueError('prompt_policy not implemented yet') + + else: #Huggingface pipelie + + if args.task_policy == 'traversal': + recurrent_range = range(num_tasks) + elif args.task_policy == 'single': + recurrent_range = [args.single_split] else: raise ValueError('prompt_policy not implemented yet') + if args.not_retest_tasks_in_result_file: + if os.path.exists(args.output_result_to_file): + result_file = pd.read_csv(args.output_result_to_file, header=0, index_col=0) + else: + result_file = None + + for single_split_label in recurrent_range: + if args.task_policy in ['traversal', 'single']: + print('label split: ', single_split_label) + + if args.not_retest_tasks_in_result_file and result_file is not None: + if len(result_file[ + (result_file['dataset'] == args.dataset) & ( + result_file['split'] == single_split_label)]) > 0: + print(args.dataset, ' ', single_split_label, 'has been tested') + continue + + train_dataset_task = train_dataset_total.filter( + lambda example: (example["task_index"] == str(single_split_label))) + valid_dataset_task = valid_dataset_total.filter( + lambda example: (example["task_index"] == str(single_split_label))) + test_dataset_task = test_dataset_total.filter(lambda example: (example["task_index"] == str(single_split_label))) + + if len(test_dataset_task) == 0: + print('No label or prompt for label split {}'.format(single_split_label)) + continue + + if args.prompt_policy == 'single': + # print() + prompt_id_range = [args.prompt_id[0]] + elif args.prompt_policy == 'traversal': + if args.prompt_id is None: + prompt_id_range = range(len(train_dataset_task[0]['text'])) + else: + prompt_id_range = args.prompt_id[str(single_split_label)] + else: + raise ValueError('prompt_policy not implemented yet') + + for prompt_id in prompt_id_range: + + train_dataset = train_dataset_task.map(lambda example: select_single_prompt(example, prompt_id)) + valid_dataset = valid_dataset_task.map(lambda example: select_single_prompt(example, prompt_id)) + test_dataset = test_dataset_task.map(lambda example: select_single_prompt(example, prompt_id)) + + prompt = train_dataset[0]['text'] + print(prompt) + + train_dataset = train_dataset.map( + tokenize_function, + batched=False, + num_proc=None, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_data_cache, + desc="Running tokenizer on dataset line_by_line", + ) + valid_dataset = valid_dataset.map( + tokenize_function, + batched=False, + num_proc=None, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_data_cache, + desc="Running tokenizer on dataset line_by_line", + ) + test_dataset = test_dataset.map( + tokenize_function, + batched=False, + num_proc=None, + remove_columns=['text'], + load_from_cache_file=not args.overwrite_data_cache, + desc="Running tokenizer on dataset line_by_line", + ) + + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers, collate_fn=data_collator) + val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers, collate_fn=data_collator) + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, + shuffle=False, num_workers=args.num_workers, collate_fn=data_collator) + + if args.few_shot is not None: + ind_each_class = {} + for ind, data in enumerate(train_dataset): + label = data['label'] + if label not in ind_each_class: + ind_each_class[label] = [ind] + else: + ind_each_class[label].append(ind) + + for key in ind_each_class.keys(): + ind_each_class[key] = np.random.choice(ind_each_class[key], + size=min(len(ind_each_class[key]), args.few_shot), + replace=False).tolist() + train_index_total = [] + for key in ind_each_class.keys(): + train_index_total += ind_each_class[key] + + train_dataset = train_dataset.select(train_index_total) + train_loader = DataLoader(train_dataset, batch_size=args.batch_size, + shuffle=True, num_workers=args.num_workers, collate_fn=data_collator) + + downstream_task_by_transform(model, train_loader, val_loader, test_loader, + prompt) diff --git a/model/GIMLET/GIMLETTransformerForConditionalGeneration.py b/model/GIMLET/GIMLETTransformerForConditionalGeneration.py index 2c436b2..7c96a81 100644 --- a/model/GIMLET/GIMLETTransformerForConditionalGeneration.py +++ b/model/GIMLET/GIMLETTransformerForConditionalGeneration.py @@ -7,6 +7,7 @@ BaseModelOutput, Seq2SeqLMOutput, ) +from transformers import AutoConfig,PretrainedConfig from model.GIMLET.GIMLETEncoderStack import GraphT5EncoderStack_dict import copy logger = logging.get_logger(__name__) @@ -33,7 +34,7 @@ class GraphT5TransformerForConditionalGeneration(T5ForConditionalGeneration): r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", ] - def __init__(self, config,graph_args): + def __init__(self, config,graph_args=None): #for debug # config.dropout_rate=0.0 @@ -44,6 +45,13 @@ def __init__(self, config,graph_args): encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False + + if graph_args is None: + assert hasattr(config,'graph_args') + graph_args= PretrainedConfig.from_dict(config.graph_args) + else: + config.graph_args = vars(graph_args) + self.config.loss_reduction_method = getattr(graph_args,'loss_reduction_method') self.encoder = GraphT5EncoderStack_dict[graph_args.transformer_backbone]\ (encoder_config,graph_args, self.shared) diff --git a/model/__init__.py b/model/__init__.py index 211e7d5..3da6dff 100644 --- a/model/__init__.py +++ b/model/__init__.py @@ -21,26 +21,36 @@ def get_model(args,graph_args,tokenizer): if not (args.transformer_backbone in ['kvplm','momu','galactica','gpt3']): - config_kwargs = { - "cache_dir": None, - "revision": 'main', - "use_auth_token": None, - } - config = AutoConfig.from_pretrained(args.tokenizer_name, **config_kwargs) - config.vocab_size=len(tokenizer) - graph_args.transformer_backbone = args.transformer_backbone - model = GraphTransformer_dict[args.transformer_backbone].from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - graph_args=graph_args, - cache_dir=None, - revision='main', - use_auth_token=None, - ignore_mismatched_sizes=True, - ) - model.resize_token_embeddings(len(tokenizer)) + if args.model_name_or_path=='haitengzhao/gimlet': + model = GraphTransformer_dict[args.transformer_backbone].from_pretrained( + args.model_name_or_path, + ) + + else: #load from local file: + config_kwargs = { + "cache_dir": None, + "revision": 'main', + "use_auth_token": None, + } + config = AutoConfig.from_pretrained(args.tokenizer_name, **config_kwargs) + config.vocab_size=len(tokenizer) + graph_args.transformer_backbone = args.transformer_backbone + config.graph_args = vars(graph_args) #use the user-provided graph args + + model = GraphTransformer_dict[args.transformer_backbone].from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + # graph_args=graph_args, + cache_dir=None, + revision='main', + use_auth_token=None, + ignore_mismatched_sizes=True, + ) + model.resize_token_embeddings(len(tokenizer)) + + elif args.transformer_backbone == 'kvplm': model = GraphTransformer_dict[args.transformer_backbone](graph_args) elif args.transformer_backbone == 'momu': diff --git a/pretraining_gimlet.py b/pretraining_gimlet.py index cc0d0d4..0dfa576 100644 --- a/pretraining_gimlet.py +++ b/pretraining_gimlet.py @@ -5,7 +5,7 @@ from itertools import chain from typing import Optional import datasets -from datasets import load_dataset +from datasets import load_dataset,DatasetDict import transformers from transformers import ( CONFIG_MAPPING, @@ -185,18 +185,18 @@ class DataTrainingArguments: transform_in_collator: Optional[bool] = field(default=False) wrap_dataset: Optional[bool] = field(default=False) - def __post_init__(self): - if self.dataset_name is None and self.train_file is None and self.validation_file is None: - raise ValueError("Need either a dataset name or a training/validation file.") - else: - if self.train_file is not None: - extension = self.train_file.split(".")[-1] - if extension not in ["csv", "json", "txt"]: - raise ValueError("`train_file` should be a csv, a json or a txt file.") - if self.validation_file is not None: - extension = self.validation_file.split(".")[-1] - if extension not in ["csv", "json", "txt"]: - raise ValueError("`validation_file` should be a csv, a json or a txt file.") + # def __post_init__(self): + # if self.dataset_name is None and self.train_file is None and self.validation_file is None: + # raise ValueError("Need either a dataset name or a training/validation file.") + # else: + # if self.train_file is not None: + # extension = self.train_file.split(".")[-1] + # if extension not in ["csv", "json", "txt"]: + # raise ValueError("`train_file` should be a csv, a json or a txt file.") + # if self.validation_file is not None: + # extension = self.validation_file.split(".")[-1] + # if extension not in ["csv", "json", "txt"]: + # raise ValueError("`validation_file` should be a csv, a json or a txt file.") # def eval_result(trainer,task_type='cla'): # @@ -416,29 +416,39 @@ def main(): # Set seed before initializing model. set_seed(training_args.seed) - if data_args.dataset_name is not None: - # Downloading and loading a dataset from the hub. - raw_datasets = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - if "validation" not in raw_datasets.keys(): - raw_datasets["validation"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[:{data_args.validation_split_percentage}%]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) - raw_datasets["train"] = load_dataset( - data_args.dataset_name, - data_args.dataset_config_name, - split=f"train[{data_args.validation_split_percentage}%:]", - cache_dir=model_args.cache_dir, - use_auth_token=True if model_args.use_auth_token else None, - ) + # if data_args.dataset_name is not None: + # # Downloading and loading a dataset from the hub. + # raw_datasets = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # cache_dir=model_args.cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + # if "validation" not in raw_datasets.keys(): + # raw_datasets["validation"] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=f"train[:{data_args.validation_split_percentage}%]", + # cache_dir=model_args.cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + # raw_datasets["train"] = load_dataset( + # data_args.dataset_name, + # data_args.dataset_config_name, + # split=f"train[{data_args.validation_split_percentage}%:]", + # cache_dir=model_args.cache_dir, + # use_auth_token=True if model_args.use_auth_token else None, + # ) + if data_args.train_file=='haitengzhao/molecule_property_instruction' or data_args.validation_file=='haitengzhao/molecule_property_instruction': + raw_datasets={} + dataset_full=load_dataset("haitengzhao/molecule_property_instruction", + # download_mode = "force_redownload" + ) + if training_args.do_train: + raw_datasets['train']=dataset_full['chembl_pretraining'] + if training_args.do_eval: + raw_datasets['validation']=dataset_full['chembl_zero_shot'] + raw_datasets=DatasetDict(raw_datasets) else: data_files = {} if data_args.train_file is not None: