Skip to content

Commit

Permalink
add cuda for mqe service.
Browse files Browse the repository at this point in the history
  • Loading branch information
Zimiao1025 committed Jul 15, 2024
1 parent 477834d commit 82f0666
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 59 deletions.
24 changes: 18 additions & 6 deletions batch_mqe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
import requests
from loguru import logger
import os


def load_fasta(file_path, dir_name, data_suffix):
Expand All @@ -21,10 +20,22 @@ def load_fasta(file_path, dir_name, data_suffix):

def MQEWorker(request_dicts):

mqe_url = f"http://10.0.0.12:8081/mqe"

TOKEN = "***"
HEADERS = {
"User-Agent": "Python API Sample",
"Authorization": "Bearer " + TOKEN,
"Content-Type": "application/json"
}
API_URL = f"http://10.0.0.12:8081/mqe"
try:
logger.info(f"------- Requests of mqe task: {request_dicts}")
requests.post(mqe_url , json={'requests': request_dicts})
# logger.info(f"------- Requests of mqe task: {request_dicts}")
# response = requests.post(url=API_URL , json={"requests": request_dicts})

data = {'requests': request_dicts}
json_data =json.dumps(data).encode('utf8')
response = requests.post(url=API_URL, headers=HEADERS, data=json_data)
print(json.dumps(json.loads(response.text), sort_keys=True, indent=4, separators=(",", ": ")))
except Exception as e:
logger.error(str(e))

Expand All @@ -40,7 +51,8 @@ def main():
"./tmp/temp_6000_64_1_seqentropy_mmseqs.json",
"./tmp/temp_6000_64_1_plmsim_mmseqs.json"]

dir_names = os.listdir(cameo_dir)
# dir_names = os.listdir(cameo_dir)
dir_names = ['8BL5_A']
for dir_name in dir_names:
request_dicts = []
for json_file in json_files:
Expand All @@ -51,7 +63,7 @@ def main():
request_dict["sequence"] = sequence
request_dict["name"] = seq_name + "_" + case_suffix
request_dict["target"] = seq_name

print(request_dict)
request_dicts.append(request_dict)

MQEWorker(request_dicts)
Expand Down
7 changes: 4 additions & 3 deletions batch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def pipelineWorker(request_dicts):

try:
logger.info(f"------- Requests of pipeline task: {request_dicts}")
exit()
requests.post(pipeline_url , json={'requests': request_dicts})
except Exception as e:
logger.error(str(e))
Expand All @@ -127,7 +128,7 @@ def main():
# with open("./tmp/temp_6000_64_1_seqentropy_mmseqs.json", 'r') as jf:
# with open("./tmp/temp_6000_64_1_plmsim_mmseqs.json", 'r') as jf:
# with open("./tmp/temp_6000_64_1_seqentropy.json", 'r') as jf:
with open("./tmp/temp_8000_64_1_plmsim_mmseqs.json", 'r') as jf:
with open("./tmp/temp_5000_64_1_seqentropy_mmseqs.json", 'r') as jf:
request_dict = json.load(jf)

# weeks = ['2024.02.17', '2024.02.24', '2024.03.02', '2024.03.09',
Expand All @@ -141,9 +142,9 @@ def main():

# for run dir or run bad case
# run dir
dir_names = os.listdir(cameo_dir)
# dir_names = os.listdir(cameo_dir)
# run bad case
# dir_names = ['8BL5_A']
dir_names = ['8BL5_A']
for dir_name in dir_names:
seq_file = cameo_dir + dir_name + "/" + "target.fasta"
seq_name, sequence = load_fasta(seq_file, dir_name, data_suffix)
Expand Down
35 changes: 35 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,41 @@ services:
- ./lib:/worker/lib
command: celery -A worker worker --loglevel=info -Q queue_alphafold --concurrency=16

worker_mqe:
build:
context: ./services/mqe
dockerfile: Dockerfile
args:
- http_proxy=${http_proxy}
- https_proxy=${https_proxy}
- USER_ID=${UID}
- GROUP_ID=${GID}
- USER_NAME=${USER_NAME}
depends_on:
- rabbitmq
environment:
- TZ=Asia/Shanghai
- CELERY_BROKER_URL=amqp://rabbitmq:5672
- CELERY_RESULT_BACKEND=redis://redis
- DOCKER_CLIENT_TIMEOUT=${DOCKER_CLIENT_TIMEOUT}
- COMPOSE_HTTP_TIMEOUT=${COMPOSE_HTTP_TIMEOUT}
- TF_FORCE_UNIFIED_MEMORY=${TF_FORCE_UNIFIED_MEMORY}
- XLA_PYTHON_CLIENT_MEM_FRACTION=${XLA_PYTHON_CLIENT_MEM_FRACTION}
- NVIDIA_VISIBLE_DEVICES=all
deploy:
resources:
reservations:
devices:
- driver: "nvidia"
count: "all"
capabilities: ["gpu"]
volumes:
- ./services/mqe:/worker
- /data:/data
- /tmp:/tmp
- ./lib:/worker/lib
command: celery -A worker worker --loglevel=info -Q queue_mqe --concurrency=16

worker_analysis:
build:
context: ./services/analysis
Expand Down
2 changes: 1 addition & 1 deletion gui/stats.html

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion lib/tool/enqa/data/process_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from scipy.special import softmax
from scipy.spatial.distance import pdist, squareform

from data.process_label import generate_lddt_score, parse_pdbfile, get_coords_ca
from lib.tool.enqa.data.process_label import generate_lddt_score, parse_pdbfile, get_coords_ca


def process_alphafold_model(input_model_path, alphafold_prediction_path, lddt_cmd, n_models=5,
Expand Down
4 changes: 2 additions & 2 deletions lib/tool/enqa/network/resEGNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
from torch.nn import functional as F

from network.EGNN import EGNN, EGNN_ne
from network.resnet import ResNet
from lib.tool.enqa.network.EGNN import EGNN, EGNN_ne
from lib.tool.enqa.network.resnet import ResNet


def task_loss(pred, target, use_mean=True):
Expand Down
4 changes: 2 additions & 2 deletions lib/tool/enqa/utils/SGCN/common/checks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

from utils.SGCN.common import utils
from utils.SGCN.common.format import create_atom_label
from lib.tool.enqa.utils.SGCN.common import utils
from lib.tool.enqa.utils.SGCN.common.format import create_atom_label


def is_hydrogen(atom_name):
Expand Down
2 changes: 1 addition & 1 deletion lib/tool/enqa/utils/SGCN/common/covalent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np

from Bio.PDB import NeighborSearch, PDBParser
from utils.SGCN.common import format
from lib.tool.enqa.utils.SGCN.common import format

SEARCH_RADIUS = 6
KMIN_DISTANCE_BETWEEN_ATOMS = 0.01
Expand Down
2 changes: 1 addition & 1 deletion lib/tool/enqa/utils/SGCN/common/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.sparse as sparse

from utils.SGCN.common import names, utils
from lib.tool.enqa.utils.SGCN.common import names, utils


CORRECT_MODEL_FILES_SET = {
Expand Down
2 changes: 1 addition & 1 deletion lib/tool/enqa/utils/SGCN/common/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from biopandas.pdb import PandasPdb
from functools import cmp_to_key

from utils.SGCN.common import checks, covalent, format, names, utils
from lib.tool.enqa.utils.SGCN.common import checks, covalent, format, names, utils

TARGET_SCORES_DF_NAMES = ['chain_id', 'residue_number', '#2', '#3', '#4', 'residue_name', 'atom_name', 'score']
CONTACTS_DF_NAMES = [
Expand Down
2 changes: 1 addition & 1 deletion lib/tool/enqa/utils/SGCN/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import shutil

from utils.SGCN.common import names
from lib.tool.enqa.utils.SGCN.common import names


def path(parts):
Expand Down
70 changes: 66 additions & 4 deletions services/mqe/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,51 @@
FROM python:3.9
ARG CUDA=11.1.1
FROM nvidia/cuda:${CUDA}-cudnn8-runtime-ubuntu18.04 AS base_af2_env
ARG CUDA

# Use bash to support string substitution.
SHELL ["/bin/bash", "-c"]

RUN rm /etc/apt/sources.list.d/cuda.list
#RUN rm /etc/apt/sources.list.d/nvidia-ml.list
RUN apt-key del 7fa2af80
ADD https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb .
RUN dpkg -i cuda-keyring_1.0-1_all.deb

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \
build-essential \
cmake \
cuda-command-line-tools-$(cut -f1,2 -d- <<< ${CUDA//./-}) \
git \
hmmer \
kalign \
tzdata \
wget \
&& rm -rf /var/lib/apt/lists/*

# Compile HHsuite from source.
RUN git clone --branch v3.3.0 https://github.com/soedinglab/hh-suite.git /tmp/hh-suite \
&& mkdir /tmp/hh-suite/build \
&& pushd /tmp/hh-suite/build \
&& cmake -DCMAKE_INSTALL_PREFIX=/opt/hhsuite .. \
&& make -j 4 && make install \
&& ln -s /opt/hhsuite/bin/* /usr/bin \
&& popd \
&& rm -rf /tmp/hh-suite

# Install Miniconda & use Python 3.8
ARG python=3.8
ENV PYTHON_VERSION=${python}
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O /tmp/install-conda.sh \
&& chmod +x /tmp/install-conda.sh \
&& bash /tmp/install-conda.sh -b -f -p /usr/local \
&& rm -f /tmp/install-conda.sh \
&& conda install -y python=${PYTHON_VERSION} \
&& conda clean -y --all

# Install conda packages.
RUN conda update -qy conda \
&& conda install -y -c conda-forge \
cudatoolkit==${CUDA_VERSION}

ENV CELERY_BROKER_URL pyamqp://guest:guest@localhost:5672/
ENV CELERY_RESULT_BACKEND rpc://
Expand All @@ -15,15 +62,30 @@ RUN pip install -U setuptools pip
RUN pip install --no-cache-dir -r /tmp/requirements.txt

# Install PyTorch
ENV PYTORCH_VERSION=1.8.1+cu111
ENV TORCHVISION_VERSION=0.9.1+cu111
# RUN pip install torch==1.10.0+rocm4.2 torchvision==0.11.0+rocm4.2 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html

# Install PyTorch
ENV PYTORCH_VERSION=1.10.0+cu111
ENV TORCHVISION_VERSION=0.11.0+cu111
RUN pip install --no-cache-dir \
torch==${PYTORCH_VERSION} \
torchvision==${TORCHVISION_VERSION} \
-f https://mirror.sjtu.edu.cn/pytorch-wheels/torch_stable.html

# Install torch-geometric
RUN pip install torch-geometric==2.2.0
RUN pip install https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_cluster-1.6.0-cp38-cp38-linux_x86_64.whl
RUN pip install https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_scatter-2.0.9-cp38-cp38-linux_x86_64.whl
RUN pip install https://data.pyg.org/whl/torch-1.10.0%2Bcu113/torch_sparse-0.6.12-cp38-cp38-linux_x86_64.whl

# Install pytorch3d
RUN pip install "git+https://github.com/facebookresearch/pytorch3d.git"

# Install ESM
RUN pip install git+https://github.com/facebookresearch/esm.git

# Install PyRosetta4
RUN pip install https://graylab.jhu.edu/download/PyRosetta4/archive/release/PyRosetta4.Release.python38.linux.wheel/pyrosetta-2024.15+release.d972b59-cp38-cp38-linux_x86_64.whl

# add user
RUN if [ $USER_NAME != "root" ] ; \
Expand All @@ -38,4 +100,4 @@ USER ${USER_NAME}
WORKDIR /worker

# Specify the command to run on container start
CMD ["celery", "-A", "mqe", "worker", "--loglevel=info", "-Q", "queue_mqe"]
CMD ["celery", "-A", "mqe", "worker", "--loglevel=info", "-Q", "queue_mqe"]
5 changes: 4 additions & 1 deletion services/mqe/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ tqdm
seaborn
logzero
scikit-learn
biotite
einops
numpy==1.23
biopandas==0.3.0dev0
pandas==1.3.4
pandas==1.3.4
setuptools==69.5.1
22 changes: 11 additions & 11 deletions services/mqe/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

from typing import Any, Dict, List

from lib.base import BaseRunner
# from lib.state import State
from lib.pathtree import get_pathtree
from lib.utils import misc
# from lib.monitor import info_report
import lib.utils.datatool as dtool
from lib.tool.enqa import enqa_msa
from lib.tool.gcpl import gcpl_qa
# from lib.tool.gcpl import gcpl_qa

CELERY_RESULT_BACKEND = os.environ.get("CELERY_RESULT_BACKEND", "rpc://")
CELERY_BROKER_URL = (
Expand All @@ -31,14 +30,16 @@

@celery.task(name="mqe")
def mqeTask(requests: List[Dict[str, Any]]):
MQERunner(requests=requests)()
MQERunner(requests=requests, method="enqa").run()


class MQERunner(BaseRunner):
class MQERunner():
def __init__(
self, requests: List[Dict[str, Any]]
self, requests: List[Dict[str, Any]], method: str
):
super().__init__(requests)
# super().__init__(requests)
self.requests = requests
self.mqe_method = method
# self.error_code = State.MQE_ERROR
# self.success_code = State.MQE_SUCCESS
# self.start_code = State.MQE_START
Expand All @@ -49,8 +50,7 @@ def __init__(

def run(self):
ptree_base = get_pathtree(self.requests[0])
mqe_method = misc.safe_get(self.requests[0], ["run_config", "mse"])
if mqe_method == "enqa":
if self.mqe_method == "enqa":
# EnQA
ptree_base.mqe.enqa_temp.parent.mkdir(exist_ok=True, parents=True)
mqe_tmp_dir = ptree_base.mqe.enqa_temp
Expand All @@ -76,10 +76,10 @@ def run(self):
for key, val in plddt_dict.items():
predicted_pdb = target_dir + key + "_relaxed.pdb"
predicted_result[predicted_pdb] = val
if mqe_method == "enqa":
if self.mqe_method == "enqa":
score = enqa_msa.evaluation(input_pdb=predicted_pdb, tmp_dir=mqe_tmp_dir)
else:
score = gcpl_qa.evaluation(fasta_file=ptree.seq.fasta, decoy_file=predicted_pdb, tmp_dir=mqe_tmp_dir)
# else:
# score = gcpl_qa.evaluation(fasta_file=ptree.seq.fasta, decoy_file=predicted_pdb, tmp_dir=mqe_tmp_dir)
predicted_result[ms_config+"_"+key] = {"predicted_pdb": predicted_pdb, "plddt": val, "score": score}

dtool.write_json(mqe_rank_file, data=predicted_result)
Expand Down
Loading

0 comments on commit 82f0666

Please sign in to comment.