Skip to content

Commit

Permalink
fixed the bugs of getting results from alphafold service.
Browse files Browse the repository at this point in the history
Zimiao1025 committed May 24, 2024
1 parent 7af52fa commit 651e7ef
Showing 10 changed files with 127 additions and 213 deletions.
2 changes: 1 addition & 1 deletion gui/stats.html

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions lib/pathtree.py
Original file line number Diff line number Diff line change
@@ -320,6 +320,10 @@ def template_feat(self):
@property
def selected_template_feat(self):
return self.root / "selected_template_feat.pkl"

@property
def processed_feat(self):
return self.root / "processed_feat.pkl"

@property
def relaxed_pdbs(self):
8 changes: 4 additions & 4 deletions lib/tool/alphafold/model/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -13,10 +13,10 @@
# limitations under the License.
"""Geometry Module."""

from alphafold.model.geometry import rigid_matrix_vector
from alphafold.model.geometry import rotation_matrix
from alphafold.model.geometry import struct_of_array
from alphafold.model.geometry import vector
from lib.tool.alphafold.model.geometry import rigid_matrix_vector
from lib.tool.alphafold.model.geometry import rotation_matrix
from lib.tool.alphafold.model.geometry import struct_of_array
from lib.tool.alphafold.model.geometry import vector

Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
2 changes: 1 addition & 1 deletion lib/utils/datatool.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from collections.abc import Iterable
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Tuple, Optional, Union
from typing import Any, Tuple, Optional, Union

import jsonlines
from loguru import logger
15 changes: 8 additions & 7 deletions services/alphafold/Dockerfile
Original file line number Diff line number Diff line change
@@ -48,9 +48,11 @@ RUN git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh
&& popd \
&& rm -rf /tmp/hh-suite

# Install Miniconda & use Python 3.9
ARG python=3.9
# Install Miniconda & use Python 3.8
ARG python=3.8
ENV PYTHON_VERSION=${python}
# https://repo.anaconda.com/miniconda/Miniconda3-py38_4.10.3-Linux-x86_64.sh
# https://repo.anaconda.com/miniconda/Miniconda3-py37_4.10.3-Linux-x86_64.sh
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/install-conda.sh \
&& chmod +x /tmp/install-conda.sh \
&& bash /tmp/install-conda.sh -b -f -p /usr/local \
@@ -64,8 +66,7 @@ RUN conda update -qy conda \
openmm=7.5.1 \
cudatoolkit==${CUDA_VERSION} \
pdbfixer \
pip \
python=3.9
pip

COPY . /app/alphafold
#RUN wget -q -P /app/alphafold/alphafold/common/ \
@@ -75,15 +76,15 @@ RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
# Install pip packages.
RUN pip3 install --upgrade pip \
&& pip3 install -r /app/alphafold/requirements.txt \
&& pip3 install --upgrade jax==0.2.14 jaxlib==0.1.69+cuda$(cut -f1,2 -d. <<< ${CUDA} | sed 's/\.//g') -f \
&& pip3 install --upgrade jax==0.2.25 jaxlib==0.1.69+cuda111 -f \
https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Apply OpenMM patch.
WORKDIR /usr/local/lib/python3.9/site-packages
WORKDIR /usr/local/lib/python3.8/site-packages
RUN patch -p0 < /app/alphafold/docker/openmm.patch

# OpenMM Backward Compatible
# WORKDIR /opt/conda/lib/python3.9/site-packages/simtk
# WORKDIR /opt/conda/lib/python3.8/site-packages/simtk
# RUN rm -rf openmm && ln -s ../openmm

# Add SETUID bit to the ldconfig binary so that non-root users can run it.
26 changes: 16 additions & 10 deletions services/alphafold/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
tensorflow-gpu==2.8.2
numpy==1.23
absl-py
biopython==1.81
chex
dm-haiku
dm-tree
immutabledict
ml-collections
pandas
scipy==1.12.0
absl-py==0.13.0
biopython==1.79
chex==0.0.7
dm-haiku==0.0.4
dm-tree==0.1.6
docker==5.0.0
immutabledict==2.0.0
jax==0.2.25
ml-collections==0.1.0
# numpy==1.19.5
pandas==1.3.4
scipy==1.7.0
ipykernel==6.13.0
jupyter==1.0.0
# tensorflow-cpu==2.5.0
celery[redis]
loguru
39 changes: 21 additions & 18 deletions services/alphafold/worker.py
Original file line number Diff line number Diff line change
@@ -2,8 +2,7 @@
from celery import Celery

from typing import Any, Dict
# from loguru import logger

import lib.utils.datatool as dtool
from lib.tool.run_af2_stage import (
search_template,
make_template_feature,
@@ -29,25 +28,29 @@


@celery.task(name="alphafold")
def alphafoldTask(run_stage: str, argument_dict: Dict[str, Any]):
# run_stage
# enum_values=[
# "search_template",
# "make_template_feature",
# "monomer_msa2feature",
# "predict_structure",
# "run_relaxation",
# ]
def alphafoldTask(run_stage: str, output_path: str, argument_dict: Dict[str, Any]):
if run_stage == "search_template":
results = search_template(**argument_dict)
pdb_template_hits = search_template(**argument_dict)
dtool.save_object_as_pickle(pdb_template_hits, output_path)
return output_path
elif run_stage == "make_template_feature":
results = make_template_feature(**argument_dict)
template_feature = make_template_feature(**argument_dict)
dtool.save_object_as_pickle(template_feature, output_path)
return output_path
elif run_stage == "monomer_msa2feature":
results = monomer_msa2feature(**argument_dict)
processed_feature, _ = monomer_msa2feature(**argument_dict)
dtool.save_object_as_pickle(processed_feature, output_path)
return output_path
elif run_stage == "predict_structure":
results = predict_structure(**argument_dict)
pkl_output = output_path + "_output_raw.pkl"
pdb_output = output_path + "_unrelaxed.pdb"
prediction_results, unrelaxed_pdb_str, _ = predict_structure(**argument_dict)
dtool.save_object_as_pickle(prediction_results, pkl_output)
dtool.write_text_file(plaintext=unrelaxed_pdb_str, path=pdb_output)
return pdb_output
elif run_stage == "run_relaxation":
results = run_relaxation(**argument_dict)
relaxed_pdb_str, _ = run_relaxation(**argument_dict)
dtool.write_text_file(relaxed_pdb_str, output_path)
return output_path
else:
results = None
return results
return None
114 changes: 48 additions & 66 deletions services/monostructure/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from copy import deepcopy
from celery import Celery
from celery.result import AsyncResult
from celery.result import AsyncResult, allow_join_result
from pathlib import Path
from typing import Any, Dict, List, Union
import matplotlib.pyplot as plt
@@ -83,18 +83,11 @@ def run(
):
if not isinstance(msa_paths, list):
msa_paths = [msa_paths]

# processed_feature = monomer_msa2feature(
# sequence=self.sequence,
# target_name=self.target_name,
# msa_paths=msa_paths,
# template_feature=template_feat,
# af2_config=af2_config,
# model_name=model_name,
# random_seed=random_seed # random.randint(0, 100000),
# )

run_stage = "monomer_msa2feature"
ptree = get_pathtree(request=self.requests[0])
out_path = ptree.alphafold.processed_feat

argument_dict = {
"sequence": self.sequence,
"target_name": self.target_name,
@@ -107,21 +100,27 @@ def run(
for k, v in af2_config.items():
if k not in argument_dict:
argument_dict[k] = v
task = celery_client.send_task("alphafold", args=[run_stage, argument_dict], queue="queue_alphafold")
task_result = AsyncResult(task.id, app=celery_client)
# if task_result.ready():
processed_feature, _ = task_result.get()

task = celery_client.send_task("alphafold", args=[run_stage, out_path, argument_dict], queue="queue_alphafold")
task_result = AsyncResult(task.id, app=celery_client)

ptree = get_pathtree(request=self.requests[0])
dtool.deduplicate_msa_a3m(msa_paths, str(ptree.alphafold.input_a3m))

self.save_msa_fig_from_a3m_files(
msa_paths=msa_paths,
save_path=ptree.alphafold.msa_coverage_image,
)
with allow_join_result():
try:
out_path = task_result.get()
processed_feature = dtool.read_pickle(out_path)

dtool.deduplicate_msa_a3m(msa_paths, str(ptree.alphafold.input_a3m))

self.save_msa_fig_from_a3m_files(
msa_paths=msa_paths,
save_path=ptree.alphafold.msa_coverage_image,
)
return processed_feature

except TimeoutError as exc:
print("--- Exception: %s\n Timeout!" %exc)
return

return processed_feature

def on_run_end(self):
if self.info_reportor is not None:
@@ -164,24 +163,6 @@ def run(
)
+ "_unrelaxed.pdb"
)

raw_output = (
os.path.join(
str(ptree.alphafold.root),
model_name,
)
+ "_output_raw.pkl"
)

# gpu_devices = "".join([f"{i}" for i in get_available_gpus(1)])
# (prediction_result, unrelaxed_pdb_str,) = predict_structure(
# af2_config=af2_config,
# target_name=self.target_name,
# processed_feature=processed_feat,
# model_name=model_name,
# random_seed=random_seed, # random.randint(0, 100000),
# gpu_devices=gpu_devices,
# )

run_stage = "predict_structure"
argument_dict = {
@@ -192,21 +173,25 @@ def run(
"random_seed": random_seed,
"return_representations": True,
}
# structure_config = run_config["structure_prediction"]["alphafold"]
argument_dict = deepcopy(argument_dict)
for k, v in af2_config.items():
if k not in argument_dict:
argument_dict[k] = v

task = celery_client.send_task("alphafold", args=[run_stage, argument_dict], queue="queue_alphafold")
out_path = str(os.path.join(str(ptree.alphafold.root), model_name))
task = celery_client.send_task("alphafold", args=[run_stage, out_path, argument_dict], queue="queue_alphafold")
task_result = AsyncResult(task.id, app=celery_client)
# if task_result.ready():
prediction_results, unrelaxed_pdb_str, _ = task_result.get()

dtool.save_object_as_pickle(prediction_results, raw_output)
dtool.write_text_file(plaintext=unrelaxed_pdb_str, path=self.output_path)

return unrelaxed_pdb_str
with allow_join_result():
try:
un_relaxed_pdb_path = task_result.get()
unrelaxed_pdb_str = dtool.read_text_file(path=un_relaxed_pdb_path)
return unrelaxed_pdb_str

except TimeoutError as exc:
print("--- Exception: %s\n Timeout!" %exc)
return


def on_run_end(self):
if self.info_reportor is not None:
@@ -243,24 +228,21 @@ def run(self, unrelaxed_pdb_str, model_name):
# relaxed_pdb_str = run_relaxation(
# unrelaxed_pdb_str=unrelaxed_pdb_str, gpu_devices=gpu_devices
# )
run_stage = ""
run_stage = "run_relaxation"
argument_dict = {"unrelaxed_pdb_str": unrelaxed_pdb_str}
task = celery_client.send_task("alphafold", args=[run_stage, argument_dict], queue="queue_alphafold")
out_path = str(os.path.join(str(ptree.alphafold.root), model_name)) + "_relaxed.pdb"
task = celery_client.send_task("alphafold", args=[run_stage, out_path, argument_dict], queue="queue_alphafold")
task_result = AsyncResult(task.id, app=celery_client)
# if task_result.ready():
relaxed_pdb_str, _ = task_result.get()

self.output_path = (
os.path.join(
str(ptree.alphafold.root),
model_name,
)
+ "_relaxed.pdb"
)

dtool.write_text_file(relaxed_pdb_str, self.output_path)

with allow_join_result():
try:
relaxed_pdb_path = task_result.get()
return relaxed_pdb_path

except TimeoutError as exc:
print("--- Exception: %s\n Timeout!" %exc)
return

return relaxed_pdb_str

def on_run_end(self):
if self.info_reportor is not None:
@@ -358,8 +340,8 @@ def run(self):
if not unrelaxed_pdb_str:
return

relaxed_pdb_str = self.amber_relax(
relaxed_pdb_path = self.amber_relax(
unrelaxed_pdb_str=unrelaxed_pdb_str, model_name=m_name
)
if not relaxed_pdb_str:
if not relaxed_pdb_path:
return
Loading
Oops, something went wrong.

0 comments on commit 651e7ef

Please sign in to comment.