Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dialvarezs committed Jan 28, 2022
2 parents ed9beb5 + 57a2455 commit b020cba
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 89 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ parameters are made available under the terms of the CC BY 4.0 license. Please
see the [Disclaimer](#license-and-disclaimer) below for more detail.

The AlphaFold parameters are available from
https://storage.googleapis.com/alphafold/alphafold_params_2021-10-27.tar, and
https://storage.googleapis.com/alphafold/alphafold_params_2022-01-19.tar, and
are downloaded as part of the `scripts/download_all_data.sh` script. This script
will download parameters for:

Expand Down
18 changes: 2 additions & 16 deletions alphafold/data/msa_pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import collections
import functools
import re
import string
from typing import Any, Dict, Iterable, List, Sequence

Expand Down Expand Up @@ -58,14 +57,6 @@
CHAIN_FEATURES = ('num_alignments', 'seq_length')


domain_name_pattern = re.compile(
r'''^(?P<pdb>[a-z\d]{4})
\{(?P<bioassembly>[\d+(\+\d+)?])\}
(?P<chain>[a-zA-Z\d]+)
\{(?P<transform_index>\d+)\}$
''', re.VERBOSE)


def create_paired_features(
chains: Iterable[pipeline.FeatureDict],
prokaryotic: bool,
Expand Down Expand Up @@ -618,6 +609,7 @@ def deduplicate_unpaired_sequences(
msa_features = MSA_FEATURES

for chain in np_chains:
# Convert the msa_all_seq numpy array to a tuple for hashing.
sequence_set = set(tuple(s) for s in chain['msa_all_seq'])
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
Expand All @@ -627,12 +619,6 @@ def deduplicate_unpaired_sequences(
keep_rows.append(row_num)
for feature_name in feature_names:
if feature_name in msa_features:
if keep_rows:
chain[feature_name] = chain[feature_name][keep_rows]
else:
new_shape = list(chain[feature_name].shape)
new_shape[0] = 0
chain[feature_name] = np.zeros(new_shape,
dtype=chain[feature_name].dtype)
chain[feature_name] = chain[feature_name][keep_rows]
chain['num_alignments'] = np.array(chain['msa'].shape[0], dtype=np.int32)
return np_chains
36 changes: 21 additions & 15 deletions alphafold/data/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
import string
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Set

# Internal import (7716).


DeletionMatrix = Sequence[Sequence[int]]


Expand Down Expand Up @@ -271,24 +274,27 @@ def _keep_line(line: str, seqnames: Set[str]) -> bool:
return seqname in seqnames


def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str:
"""Truncates a stockholm file to a maximum number of sequences."""
def truncate_stockholm_msa(stockholm_msa_path: str, max_sequences: int) -> str:
"""Reads + truncates a Stockholm file while preventing excessive RAM usage."""
seqnames = set()
filtered_lines = []
for line in stockholm_msa.splitlines():
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break

for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)

return '\n'.join(filtered_lines) + '\n'
with open(stockholm_msa_path) as f:
for line in f:
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break

f.seek(0)
for line in f:
if _keep_line(line, seqnames):
filtered_lines.append(line)

return ''.join(filtered_lines)


def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
Expand Down
52 changes: 35 additions & 17 deletions alphafold/data/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,25 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:

def run_msa_tool(msa_runner, input_fasta_path: str, msa_out_path: str,
msa_format: str, use_precomputed_msas: bool,
max_sto_sequences: Optional[int] = None
) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path):
result = msa_runner.query(input_fasta_path)[0]
if msa_format == 'sto' and max_sto_sequences is not None:
result = msa_runner.query(input_fasta_path, max_sto_sequences)[0] # pytype: disable=wrong-arg-count
else:
result = msa_runner.query(input_fasta_path)[0]
with open(msa_out_path, 'w') as f:
f.write(result[msa_format])
else:
logging.warning('Reading MSA from file %s', msa_out_path)
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
if msa_format == 'sto' and max_sto_sequences is not None:
precomputed_msa = parsers.truncate_stockholm_msa(
msa_out_path, max_sto_sequences)
result = {'sto': precomputed_msa}
else:
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
return result


Expand Down Expand Up @@ -157,18 +166,23 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:

uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
self.jackhmmer_uniref90_runner, input_fasta_path, uniref90_out_path,
'sto', self.use_precomputed_msas)
msa_runner=self.jackhmmer_uniref90_runner,
input_fasta_path=input_fasta_path,
msa_out_path=uniref90_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.uniref_max_hits)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool(
self.jackhmmer_mgnify_runner, input_fasta_path, mgnify_out_path, 'sto',
self.use_precomputed_msas)
msa_runner=self.jackhmmer_mgnify_runner,
input_fasta_path=input_fasta_path,
msa_out_path=mgnify_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas,
max_sto_sequences=self.mgnify_max_hits)

msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.truncate_stockholm_msa(
msa_for_templates, max_sequences=self.uniref_max_hits)
msa_for_templates = parsers.deduplicate_stockholm_msa(
msa_for_templates)
msa_for_templates = parsers.deduplicate_stockholm_msa(msa_for_templates)
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
msa_for_templates)

Expand All @@ -187,24 +201,28 @@ def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
f.write(pdb_templates_result)

uniref90_msa = parsers.parse_stockholm(jackhmmer_uniref90_result['sto'])
uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)

pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)

if self._use_small_bfd:
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
self.jackhmmer_small_bfd_runner, input_fasta_path, bfd_out_path,
'sto', self.use_precomputed_msas)
msa_runner=self.jackhmmer_small_bfd_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='sto',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_stockholm(jackhmmer_small_bfd_result['sto'])
else:
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
self.hhblits_bfd_uniclust_runner, input_fasta_path, bfd_out_path,
'a3m', self.use_precomputed_msas)
msa_runner=self.hhblits_bfd_uniclust_runner,
input_fasta_path=input_fasta_path,
msa_out_path=bfd_out_path,
msa_format='a3m',
use_precomputed_msas=self.use_precomputed_msas)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])

templates_result = self.template_featurizer.get_templates(
Expand Down
26 changes: 18 additions & 8 deletions alphafold/data/tools/jackhmmer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from absl import logging

from alphafold.data import parsers
from alphafold.data.tools import utils
# Internal import (7716).

Expand Down Expand Up @@ -86,8 +87,10 @@ def __init__(self,
self.get_tblout = get_tblout
self.streaming_callback = streaming_callback

def _query_chunk(self, input_fasta_path: str, database_path: str
) -> Mapping[str, Any]:
def _query_chunk(self,
input_fasta_path: str,
database_path: str,
max_sequences: Optional[int] = None) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto')
Expand Down Expand Up @@ -145,8 +148,11 @@ def _query_chunk(self, input_fasta_path: str, database_path: str
with open(tblout_path) as f:
tbl = f.read()

with open(sto_path) as f:
sto = f.read()
if max_sequences is None:
with open(sto_path) as f:
sto = f.read()
else:
sto = parsers.truncate_stockholm_msa(sto_path, max_sequences)

raw_output = dict(
sto=sto,
Expand All @@ -157,10 +163,14 @@ def _query_chunk(self, input_fasta_path: str, database_path: str

return raw_output

def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
def query(self,
input_fasta_path: str,
max_sequences: Optional[int] = None) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)]
single_chunk_result = self._query_chunk(
input_fasta_path, self.database_path, max_sequences)
return [single_chunk_result]

db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
Expand All @@ -187,8 +197,8 @@ def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:

# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i)))
chunked_output.append(self._query_chunk(
input_fasta_path, db_local_chunk(i), max_sequences))

# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
Expand Down
2 changes: 1 addition & 1 deletion alphafold/model/folding_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __call__(


class InvariantPointAttention(hk.Module):
"""Covariant attention module.
"""Invariant point attention module.
The high-level idea is that this attention module works over a set of points
and associated orientations in 3D space (e.g. protein residues).
Expand Down
15 changes: 11 additions & 4 deletions alphafold/relax/amber_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def _openmm_minimize(
tolerance: unit.Unit,
stiffness: unit.Unit,
restraint_set: str,
exclude_residues: Sequence[int]):
exclude_residues: Sequence[int],
use_gpu: bool):
"""Minimize energy via openmm."""

pdb_file = io.StringIO(pdb_str)
Expand All @@ -90,7 +91,7 @@ def _openmm_minimize(
_add_restraints(system, pdb, stiffness, restraint_set, exclude_residues)

integrator = openmm.LangevinIntegrator(0, 0.01, 0.0)
platform = openmm.Platform.getPlatformByName("CPU")
platform = openmm.Platform.getPlatformByName("CUDA" if use_gpu else "CPU")
simulation = openmm_app.Simulation(
pdb.topology, system, integrator, platform)
simulation.context.setPositions(pdb.positions)
Expand Down Expand Up @@ -371,6 +372,7 @@ def _run_one_iteration(
stiffness: float,
restraint_set: str,
max_attempts: int,
use_gpu: bool,
exclude_residues: Optional[Collection[int]] = None):
"""Runs the minimization pipeline.
Expand All @@ -383,6 +385,7 @@ def _run_one_iteration(
potential.
restraint_set: The set of atoms to restrain.
max_attempts: The maximum number of minimization attempts.
use_gpu: Whether to run on GPU.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Expand All @@ -407,7 +410,8 @@ def _run_one_iteration(
pdb_string, max_iterations=max_iterations,
tolerance=tolerance, stiffness=stiffness,
restraint_set=restraint_set,
exclude_residues=exclude_residues)
exclude_residues=exclude_residues,
use_gpu=use_gpu)
minimized = True
except Exception as e: # pylint: disable=broad-except
logging.info(e)
Expand All @@ -421,6 +425,7 @@ def _run_one_iteration(
def run_pipeline(
prot: protein.Protein,
stiffness: float,
use_gpu: bool,
max_outer_iterations: int = 1,
place_hydrogens_every_iteration: bool = True,
max_iterations: int = 0,
Expand All @@ -438,6 +443,7 @@ def run_pipeline(
Args:
prot: A protein to be relaxed.
stiffness: kcal/mol A**2, the restraint stiffness.
use_gpu: Whether to run on GPU.
max_outer_iterations: The maximum number of iterative minimization.
place_hydrogens_every_iteration: Whether hydrogens are re-initialized
prior to every minimization.
Expand Down Expand Up @@ -473,7 +479,8 @@ def run_pipeline(
tolerance=tolerance,
stiffness=stiffness,
restraint_set=restraint_set,
max_attempts=max_attempts)
max_attempts=max_attempts,
use_gpu=use_gpu)
prot = protein.from_pdb_string(ret["min_pdb"])
if place_hydrogens_every_iteration:
pdb_string = clean_protein(prot, checks=True)
Expand Down
9 changes: 6 additions & 3 deletions alphafold/relax/amber_minimize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import numpy as np
# Internal import (7716).

_USE_GPU = False


def _load_test_protein(data_path):
pdb_path = os.path.join(absltest.get_default_test_srcdir(), data_path)
Expand All @@ -35,7 +37,7 @@ def test_multiple_disulfides_target(self):
'alphafold/relax/testdata/multiple_disulfides_target.pdb'
)
ret = amber_minimize.run_pipeline(prot, max_iterations=10, max_attempts=1,
stiffness=10.)
stiffness=10., use_gpu=_USE_GPU)
self.assertIn('opt_time', ret)
self.assertIn('min_attempts', ret)

Expand All @@ -50,7 +52,8 @@ def test_raises_invalid_protein_assertion(self):
' residues. This protein contains at least one residue with no atoms.'):
amber_minimize.run_pipeline(prot, max_iterations=10,
stiffness=1.,
max_attempts=1)
max_attempts=1,
use_gpu=_USE_GPU)

def test_iterative_relax(self):
prot = _load_test_protein(
Expand All @@ -59,7 +62,7 @@ def test_iterative_relax(self):
violations = amber_minimize.get_violation_metrics(prot)
self.assertGreater(violations['num_residue_violations'], 0)
out = amber_minimize.run_pipeline(
prot=prot, max_outer_iterations=10, stiffness=10.)
prot=prot, max_outer_iterations=10, stiffness=10., use_gpu=_USE_GPU)
self.assertLess(out['efinal'], out['einit'])
self.assertEqual(0, out['num_residue_violations'])

Expand Down
Loading

0 comments on commit b020cba

Please sign in to comment.