Skip to content

Commit

Permalink
upmodified the msa selection procedure from merge-select to merge_par…
Browse files Browse the repository at this point in the history
…t-select-merge.
  • Loading branch information
Zimiao1025 committed Jun 25, 2024
1 parent d02e13b commit b9a2c4c
Showing 7 changed files with 392 additions and 146 deletions.
28 changes: 18 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
- Flower for monitoring the Celery tasks

## Introduction
AIRFold is
AIRFold is

## Quick Start

@@ -56,6 +56,23 @@ Please follow these steps:
* [PDB70](http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/hhsuite_dbs/),
* [PDB (MMCIF format)](https://www.rcsb.org/)

**Data structure**

```tree
├── model_params (models and parameters for AlphaFold2, RoseTTAFold2, ect.)
├── bfd
├── blast_dbs
├── JGIclust
├── metaclust
├── mgnify
├── pdb70
├── pdb_mmcif
├── small_bfd
├── uniclust30
├── uniref30
└── uniref90
```


### Third-party tools

@@ -108,12 +125,3 @@ Please follow these steps:
```


## Reference


- If you’re using **AlphaFold**, please also cite: <br />
Jumper et al. "Highly accurate protein structure prediction with AlphaFold." <br />
Nature (2021) doi: [10.1038/s41586-021-03819-2](https://doi.org/10.1038/s41586-021-03819-2)
- If you are using **RoseTTAFold**, please also cite: <br />
Minkyung et al. "Accurate prediction of protein structures and interactions using a three-track neural network." <br />
Science (2021) doi: [10.1126/science.abj8754](https://doi.org/10.1126/science.abj8754)
2 changes: 1 addition & 1 deletion gui/stats.html

Large diffs are not rendered by default.

38 changes: 36 additions & 2 deletions lib/pathtree.py
Original file line number Diff line number Diff line change
@@ -103,9 +103,42 @@ class SearchPathTree(BasePathTree):
@property
def integrated_search_a3m(self) -> Path:
return self.root / "intergrated_a3m" / f"{self.id}.a3m"
# integrate part
@property
def integrated_search_hj_a3m(self) -> Path:
return self.root / "intergrated_a3m" / f"{self.id}_hj.a3m"
@property
def integrated_search_bl_a3m(self) -> Path:
return self.root / "intergrated_a3m" / f"{self.id}_bl.a3m"
@property
def integrated_search_dq_a3m(self) -> Path:
return self.root / "intergrated_a3m" / f"{self.id}_dq.a3m"
@property
def integrated_search_dm_a3m(self) -> Path:
return self.root / "intergrated_a3m" / f"{self.id}_dm.a3m"
@property
def integrated_search_mm_a3m(self) -> Path:
return self.root / "intergrated_a3m" / f"{self.id}_mm.a3m"

@property
def integrated_search_a3m_dp(self) -> Path:
return self.root / "intergrated_a3m_dp" / f"{self.id}.a3m"
# integrate part duplicated
@property
def integrated_search_hj_a3m_dp(self) -> Path:
return self.root / "intergrated_a3m_dp" / f"{self.id}_hj.a3m"
@property
def integrated_search_bl_a3m_dp(self) -> Path:
return self.root / "intergrated_a3m_dp" / f"{self.id}_bl.a3m"
@property
def integrated_search_dq_a3m_dp(self) -> Path:
return self.root / "intergrated_a3m_dp" / f"{self.id}_dq.a3m"
@property
def integrated_search_dm_a3m_dp(self) -> Path:
return self.root / "intergrated_a3m_dp" / f"{self.id}_dm.a3m"
@property
def integrated_search_mm_a3m_dp(self) -> Path:
return self.root / "intergrated_a3m_dp" / f"{self.id}_mm.a3m"

@property
def integrated_search_fa(self) -> Path:
@@ -295,8 +328,9 @@ def parse_strgy(s_: Dict[str, Any]):
# print(p_)
for method_ in str_dict.keys():
# print(p_)
p_ = p_ / (method_[:5] + parse_strgy(str_dict[method_]))
path_l.append(p_ / f"{self.id}.a3m")
p_ = p_ / (method_[:5] + parse_strgy(str_dict[method_]["least_seqs"]))
# path_l.append(p_ / f"{self.id}.a3m")
path_l.append(p_ / f"{self.id}")
# return the list of
return path_l

2 changes: 1 addition & 1 deletion lib/strategy/seq_entropy.py
Original file line number Diff line number Diff line change
@@ -130,7 +130,7 @@ def _run(fasta_dir, strategy_dir, reduce_ratio, least_seqs):
# a3m_dir,strategy_dir,seq_id,cov_id,sample,rm_tmp_files=False
parser.add_argument("-i", "--input_a3m_path", required=True, type=str)
parser.add_argument("-o", "--output_a3m_path", required=True, type=str)
parser.add_argument("-r", "--reduce_ratio", required=True, type=float)
parser.add_argument("-r", "--reduce_ratio", default=0.1, type=float)
parser.add_argument("-l", "--least_seqs", required=True, type=int)
# parser.add_argument("--cid", default=0.8)
# parser.add_argument("--rm", default=False)
160 changes: 126 additions & 34 deletions services/alphafold/worker.py
Original file line number Diff line number Diff line change
@@ -50,8 +50,7 @@ def alphafoldTask(requests: List[Dict[str, Any]]):
TemplateSearchRunner(requests=requests)()
TemplateFeaturizationRunner(requests=requests)()
TPLTSelectRunner(requests=requests)()
MonoFeatureRunner(requests=requests)()
MonoStructureRunner(requests=requests)()
AlphaStrucRunner(requests=requests)()
AmberRelaxationRunner(requests=requests)()


@@ -284,9 +283,10 @@ def run(self):
# get msa_path
str_dict = misc.safe_get(self.requests[0], ["run_config", "msa_select"])
key_list = list(str_dict.keys())
for index in range(len(key_list)):
selected_msa_path = ptree.strategy.strategy_list[index]
msa_paths = [str(selected_msa_path)]
msa_paths = []
for idx in range(len(key_list)):
selected_msa_path = ptree.strategy.strategy_list[idx] + "_dp.a3m"
msa_paths.append(str(selected_msa_path))

# get selected_template_feat
selected_template_feat_path = str(ptree.alphafold.selected_template_feat)
@@ -363,15 +363,8 @@ def run(self):
for model_name in models:
fea_path = str(ptree.alphafold.processed_feat) + f"_{model_name}.pkl"
processed_feature = dtool.read_pickle(fea_path)
os.remove(fea_path)

self.output_path = (
os.path.join(
str(ptree.alphafold.root),
model_name,
)
+ "_unrelaxed.pdb"
)

argument_dict = {
"target_name": self.target_name,
"processed_feature": processed_feature,
@@ -386,15 +379,17 @@ def run(self):
argument_dict[k] = v

out_preffix = str(os.path.join(str(ptree.alphafold.root), model_name))
try:
pdb_output = alphafold_func(run_stage="predict_structure",
output_path=out_preffix,
argument_dict=argument_dict
)
self.output_paths.append(pdb_output)
except TimeoutError as exc:
logger.exception(exc)
return False
out_path = str(os.path.join(str(ptree.alphafold.root), model_name)) + "_unrelaxed.pdb"
if not os.path.exists(out_path):
try:
pdb_output = alphafold_func(run_stage="predict_structure",
output_path=out_preffix,
argument_dict=argument_dict
)
self.output_paths.append(pdb_output)
except TimeoutError as exc:
logger.exception(exc)
return False


def on_run_end(self):
@@ -412,6 +407,101 @@ def on_run_end(self):
)


class AlphaStrucRunner(BaseRunner):
def __init__(
self,
requests: List[Dict[str, Any]]
) -> None:
super().__init__(requests)
self.error_code = State.STRUCTURE_ERROR
self.success_code = State.STRUCTURE_SUCCESS
self.start_code = State.STRUCTURE_START
self.sequence = self.requests[0][SEQUENCE]
self.target_name = self.requests[0][TARGET]

@property
def start_stage(self) -> State:
return self.start_code

def run(self):
ptree = get_pathtree(request=self.requests[0])
# get msa_path
str_dict = misc.safe_get(self.requests[0], ["run_config", "msa_select"])
key_list = list(str_dict.keys())
for index in range(len(key_list)):
selected_msa_path = ptree.strategy.strategy_list[index]
msa_paths = [str(selected_msa_path)]

# get selected_template_feat
selected_template_feat_path = str(ptree.alphafold.selected_template_feat)

af2_config = self.requests[0]["run_config"]["structure_prediction"]["alphafold"]
models = af2_config["model_name"].split(",")
random_seed = af2_config.get("random_seed", 0)

self.output_paths = []
for model_name in models:
fea_output_path = str(ptree.alphafold.processed_feat) + f"_{model_name}.pkl"
template_feat = dtool.read_pickle(selected_template_feat_path)
argument_dict1 = {
"sequence": self.sequence,
"target_name": self.target_name,
"msa_paths": msa_paths,
"template_feature": template_feat,
"model_name": model_name,
"random_seed": random_seed,
}
argument_dict1 = deepcopy(argument_dict1)
for k, v in af2_config.items():
if k not in argument_dict1:
argument_dict1[k] = v

processed_feature = alphafold_func(run_stage="monomer_msa2feature",
output_path=fea_output_path,
argument_dict=argument_dict1
)

argument_dict2 = {
"target_name": self.target_name,
"processed_feature": processed_feature,
"model_name": model_name,
"data_dir": str(AF_PARAMS_ROOT),
"random_seed": random_seed,
"return_representations": True,
}
argument_dict2 = deepcopy(argument_dict2)
for k, v in af2_config.items():
if k not in argument_dict2:
argument_dict2[k] = v

out_preffix = str(os.path.join(str(ptree.alphafold.root), model_name))
out_path = str(os.path.join(str(ptree.alphafold.root), model_name)) + "_unrelaxed.pdb"
if not os.path.exists(out_path):
try:
pdb_output = alphafold_func(run_stage="predict_structure",
output_path=out_preffix,
argument_dict=argument_dict2
)
self.output_paths.append(pdb_output)
except TimeoutError as exc:
logger.exception(exc)
return False


def on_run_end(self):
if self.info_reportor is not None:
for request in self.requests:
if all([Path(p).exists() for p in self.output_paths]):
self.info_reportor.update_state(
hash_id=request[info_report.HASH_ID],
state=self.success_code,
)
else:
self.info_reportor.update_state(
hash_id=request[info_report.HASH_ID],
state=self.error_code,
)

class AmberRelaxationRunner(BaseRunner):
def __init__(
self,
@@ -433,17 +523,19 @@ def run(self):
self.output_paths = []
for model_name in models:
input_path = str(os.path.join(str(ptree.alphafold.root), model_name)) + "_unrelaxed.pdb"
unrelaxed_pdb_str = dtool.read_text_file(input_path)
output_path = str(os.path.join(str(ptree.alphafold.root), model_name)) + "_relaxed.pdb"
argument_dict = {"unrelaxed_pdb_str": input_path}
try:
relaxed_pdb_path = alphafold_func(run_stage="run_relaxation",
output_path=output_path,
argument_dict=argument_dict
)
self.output_paths.append(relaxed_pdb_path)
except TimeoutError as exc:
logger.exception(exc)
return False
argument_dict = {"unrelaxed_pdb_str": unrelaxed_pdb_str}
if not os.path.exists(output_path):
try:
relaxed_pdb_path = alphafold_func(run_stage="run_relaxation",
output_path=output_path,
argument_dict=argument_dict
)
self.output_paths.append(relaxed_pdb_path)
except TimeoutError as exc:
logger.exception(exc)
return False


def on_run_end(self):
@@ -483,8 +575,8 @@ def alphafold_func(run_stage: str, output_path: str, argument_dict: Dict[str, An
return output_path
elif run_stage == "monomer_msa2feature":
processed_feature, _ = monomer_msa2feature(**argument_dict)
dtool.save_object_as_pickle(processed_feature, output_path)
return output_path
# dtool.save_object_as_pickle(processed_feature, output_path)
return processed_feature
elif run_stage == "predict_structure":
pkl_output = output_path + "_output_raw.pkl"
pdb_output = output_path + "_unrelaxed.pdb"
Loading

0 comments on commit b9a2c4c

Please sign in to comment.