Skip to content

Commit

Permalink
Merge feature, structure and relaxation services into one service.
Browse files Browse the repository at this point in the history
  • Loading branch information
Zimiao1025 committed Mar 11, 2024
1 parent 4c62126 commit da873e7
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 364 deletions.
10 changes: 0 additions & 10 deletions api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,11 @@ async def selecttpl_task(requests: List[Dict[str, Any]]):
task = celery_client.send_task("selecttpl", args=[requests], queue="queue_selecttpl")
return {"task_id": task.id}

@app.post("/msafeature")
async def msafeature_task(requests: List[Dict[str, Any]]):
task = celery_client.send_task("msafeature", args=[requests], queue="queue_msafeature")
return {"task_id": task.id}

@app.post("/monostructure")
async def monostructure_task(requests: List[Dict[str, Any]]):
task = celery_client.send_task("monostructure", args=[requests], queue="queue_monostructure")
return {"task_id": task.id}

@app.post("/relaxation")
async def relaxation_task(requests: List[Dict[str, Any]]):
task = celery_client.send_task("relaxation", args=[requests], queue="queue_relaxation")
return {"task_id": task.id}

@app.post("/analysis")
async def analysis_task(requests: List[Dict[str, Any]]):
task = celery_client.send_task("analysis", args=[requests], queue="queue_analysis")
Expand Down
44 changes: 0 additions & 44 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -215,28 +215,6 @@ services:
- ./lib:/worker/lib
command: celery -A worker worker --loglevel=info -Q queue_tplfeature --concurrency=4

worker_msafeature:
build:
context: ./services/msafeature
dockerfile: Dockerfile
args:
- http_proxy=${http_proxy}
- https_proxy=${https_proxy}
- USER_ID=${UID}
- GROUP_ID=${GID}
- USER_NAME=${USER_NAME}
depends_on:
- rabbitmq
environment:
- CELERY_BROKER_URL=amqp://rabbitmq:5672
- CELERY_RESULT_BACKEND=redis://redis
volumes:
- /var/run/docker.sock:/var/run/docker.sock # run docker in docker
- ./services/msafeature:/worker
- /data:/data
- ./lib:/worker/lib
command: celery -A worker worker --loglevel=info -Q queue_msafeature --concurrency=4

worker_monostructure:
build:
context: ./services/monostructure
Expand All @@ -259,28 +237,6 @@ services:
- ./lib:/worker/lib
command: celery -A worker worker --loglevel=info -Q queue_monostructure --concurrency=4

worker_relaxation:
build:
context: ./services/relaxation
dockerfile: Dockerfile
args:
- http_proxy=${http_proxy}
- https_proxy=${https_proxy}
- USER_ID=${UID}
- GROUP_ID=${GID}
- USER_NAME=${USER_NAME}
depends_on:
- rabbitmq
environment:
- CELERY_BROKER_URL=amqp://rabbitmq:5672
- CELERY_RESULT_BACKEND=redis://redis
volumes:
- /var/run/docker.sock:/var/run/docker.sock # run docker in docker
- ./services/relaxation:/worker
- /data:/data
- ./lib:/worker/lib
command: celery -A worker worker --loglevel=info -Q queue_relaxation --concurrency=4

worker_analysis:
build:
context: ./services/analysis
Expand Down
2 changes: 1 addition & 1 deletion gui/stats.html

Large diffs are not rendered by default.

235 changes: 226 additions & 9 deletions services/monostructure/worker.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os
import shutil
from celery import Celery
from loguru import logger
from pathlib import Path
from typing import Any, Dict, List, Union
import matplotlib.pyplot as plt
from loguru import logger

from lib.base import BaseRunner
from lib.state import State
from lib.pathtree import get_pathtree
import lib.utils.datatool as dtool
from lib.monitor import info_report
from lib.func_from_docker import predict_structure
from lib.tool import plot
from lib.func_from_docker import monomer_msa2feature, predict_structure, run_relaxation
from lib.utils.systool import get_available_gpus
from lib.utils import misc

CELERY_RESULT_BACKEND = os.environ.get("CELERY_RESULT_BACKEND", "rpc://")
CELERY_BROKER_URL = (
Expand All @@ -34,8 +36,84 @@

@celery.task(name="monostructure")
def monostructureTask(requests: List[Dict[str, Any]]):
MonoStructureRunner(requests=requests, db_path=DB_PATH).run()
af2_config = requests[0]["run_config"]["structure_prediction"]["alphafold"]
random_seed = af2_config.get("random_seed", 0)
MonoStructureRunner(requests=requests, db_path=DB_PATH).run(af2_config=af2_config, random_seed=random_seed)


class MonoFeatureRunner(BaseRunner):
def __init__(
self,
requests: List[Dict[str, Any]],
db_path: Union[str, Path] = None,
) -> None:
super().__init__(requests, db_path)
self.error_code = State.MSA2FEATURE_ERROR
self.success_code = State.MSA2FEATURE_SUCCESS
self.start_code = State.MSA2FEATURE_START
self.sequence = self.requests[0][SEQUENCE]
self.target_name = self.requests[0][TARGET]

@property
def start_stage(self) -> State:
return self.start_code

@staticmethod
def save_msa_fig_from_a3m_files(msa_paths, save_path):

delete_lowercase = lambda line: "".join(
[t for t in list(line) if not t.islower()]
)
msa_collection = []
for p in msa_paths:
with open(p) as fd:
_lines = fd.read().strip().split("\n")
_lines = [
delete_lowercase(l) for l in _lines if not l.startswith(">") and l
]
msa_collection.extend(_lines)
plot.plot_msas([msa_collection])
plt.savefig(save_path, bbox_inches="tight", dpi=200)

def run(
self,
msa_paths: List[Union[str, Path]],
template_feat: Dict[str, Any],
af2_config: Dict[str, Any],
model_name: str = "model_1",
*args,
dry=False,
**kwargs,
):
if not isinstance(msa_paths, list):
msa_paths = [msa_paths]
processed_feature = monomer_msa2feature(
sequence=self.sequence,
target_name=self.target_name,
msa_paths=msa_paths,
template_feature=template_feat,
af2_config=af2_config,
model_name=model_name,
random_seed=kwargs.get("random_seed", 0), # random.randint(0, 100000),
)

ptree = get_pathtree(request=self.requests[0])
dtool.deduplicate_msa_a3m(msa_paths, str(ptree.alphafold.input_a3m))

self.save_msa_fig_from_a3m_files(
msa_paths=msa_paths,
save_path=ptree.alphafold.msa_coverage_image,
)

return processed_feature

def on_run_end(self):
if self.info_reportor is not None:
for request in self.requests:
self.info_reportor.update_state(
hash_id=request[info_report.HASH_ID],
state=self.success_code,
)

class MonoStructureRunner(BaseRunner):
def __init__(
Expand All @@ -57,11 +135,10 @@ def run(
self,
processed_feat: Dict,
af2_config: Dict,
model_name: str,
*args,
dry=False,
**kwargs,
random_seed: int,
model_name: str
):


ptree = get_pathtree(request=self.requests[0])

Expand All @@ -87,7 +164,7 @@ def run(
target_name=self.target_name,
processed_feature=processed_feat,
model_name=model_name,
random_seed=kwargs.get("random_seed", 0), # random.randint(0, 100000),
random_seed=random_seed, # random.randint(0, 100000),
gpu_devices=gpu_devices,
)
dtool.save_object_as_pickle(prediction_result, raw_output)
Expand All @@ -108,3 +185,143 @@ def on_run_end(self):
hash_id=request[info_report.HASH_ID],
state=self.error_code,
)


class AmberRelaxationRunner(BaseRunner):
def __init__(
self,
requests: List[Dict[str, Any]],
db_path: Union[str, Path] = None,
) -> None:
super().__init__(requests, db_path)
self.error_code = State.RELAX_ERROR
self.success_code = State.RELAX_SUCCESS
self.start_code = State.RELAX_START

@property
def start_stage(self) -> State:
return self.start_code

def run(self, unrelaxed_pdb_str, model_name, *args, dry=False, **kwargs):
ptree = get_pathtree(request=self.requests[0])
gpu_devices = "".join([f"{i}" for i in get_available_gpus(1)])
relaxed_pdb_str = run_relaxation(
unrelaxed_pdb_str=unrelaxed_pdb_str, gpu_devices=gpu_devices
)

self.output_path = (
os.path.join(
str(ptree.alphafold.root),
model_name,
)
+ "_relaxed.pdb"
)

dtool.write_text_file(relaxed_pdb_str, self.output_path)

return relaxed_pdb_str

def on_run_end(self):
if self.info_reportor is not None:
for request in self.requests:
if os.path.exists(self.output_path):
self.info_reportor.update_state(
hash_id=request[info_report.HASH_ID],
state=self.success_code,
)
else:
self.info_reportor.update_state(
hash_id=request[info_report.HASH_ID],
state=self.error_code,
)

class AirFoldRunner(BaseRunner):
def __init__(
self, requests: List[Dict[str, Any]], db_path: Union[str, Path] = None
) -> None:
"""_summary_
Parameters
----------
request : List[Dict], optional
Request for the pipeline. Each item in the list includes the basic
information about a protein sequence, e.g. name, time stamp, etc.,
as well as the strategy for structure prediction.
See `sample_request.jsonl` as an example.
"""
super().__init__(requests, db_path)
"""
Here we make a request, cut the request according to the segment part.
"""
# self.info_reportor.update_reserved(
# hash_id=requests[0]["hash_id"], update_dict={"pid": os.getpid()}
# )
# logger.info(f"#### the process id is {os.getpid()}")

self.mono_msa2feature = MonoFeatureRunner(
requests=self.requests, db_path=db_path
)

self.mono_structure = MonoStructureRunner(
requests=self.requests, db_path=db_path
)
self.amber_relax = AmberRelaxationRunner(
requests=self.requests, db_path=db_path
)

@property
def start_stage(self) -> int:
return State.AIRFOLD_START

def run(self, dry=False):

af2_config = self.requests[0]["run_config"]["structure_prediction"]["alphafold"]
models = af2_config["model_name"].split(",")
random_seed = af2_config.get("random_seed", 0)
af2_config = {
k: v
for k, v in af2_config.items()
if v != "model_name" and v != "random_seed"
}

# get msa_path
ptree = get_pathtree(request=self.requests[0])
integrated_search_a3m = str(ptree.search.integrated_search_a3m)
str_dict = misc.safe_get(self.requests[0], ["run_config", "msa_select"])
key_list = list(str_dict.keys())
selected_msa_path = integrated_search_a3m
for index in range(len(key_list)):
selected_msa_path = ptree.strategy.strategy_list[index]
# selected_msa_path = ptree.strategy.strategy_list[index]
selected_template_feat = ptree.alphafold.selected_template_feat
for m_name in models:
processed_feature = self.mono_msa2feature(
msa_paths=selected_msa_path,
template_feat=selected_template_feat,
af2_config=af2_config,
model_name=m_name,
random_seed=random_seed,
)
logger.info(
f"the shape of msa_feat is: {processed_feature['msa_feat'].shape}"
)
if not processed_feature:
return
unrelaxed_pdb_str = self.mono_structure(
processed_feat=processed_feature,
af2_config=af2_config,
model_name=m_name,
random_seed=random_seed,
)
if not unrelaxed_pdb_str:
return

relaxed_pdb_str = self.amber_relax(
unrelaxed_pdb_str=unrelaxed_pdb_str, model_name=m_name
)
if not relaxed_pdb_str:
return
success = self.gen_analysis()
if not success:
return
self.submit(dry=dry)
39 changes: 0 additions & 39 deletions services/msafeature/Dockerfile

This file was deleted.

Loading

0 comments on commit da873e7

Please sign in to comment.