Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10481410master
@@ -99,6 +99,10 @@ class Models(object): | |||
team = 'team-multi-modal-similarity' | |||
video_clip = 'video-clip-multi-modal-embedding' | |||
# science models | |||
unifold = 'unifold' | |||
unifold_symmetry = 'unifold-symmetry' | |||
class TaskModels(object): | |||
# nlp task | |||
@@ -266,6 +270,9 @@ class Pipelines(object): | |||
image_text_retrieval = 'image-text-retrieval' | |||
ofa_ocr_recognition = 'ofa-ocr-recognition' | |||
# science tasks | |||
protein_structure = 'unifold-protein-structure' | |||
class Trainers(object): | |||
""" Names for different trainer. | |||
@@ -368,6 +375,9 @@ class Preprocessors(object): | |||
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | |||
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | |||
# science preprocessor | |||
unifold_preprocessor = 'unifold-preprocessor' | |||
class Metrics(object): | |||
""" Names for different metrics. | |||
@@ -0,0 +1,21 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .unifold import UnifoldForProteinStructrue | |||
else: | |||
_import_structure = {'unifold': ['UnifoldForProteinStructrue']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1 @@ | |||
from .model import UnifoldForProteinStructrue |
@@ -0,0 +1,636 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import copy | |||
from typing import Any | |||
import ml_collections as mlc | |||
N_RES = 'number of residues' | |||
N_MSA = 'number of MSA sequences' | |||
N_EXTRA_MSA = 'number of extra MSA sequences' | |||
N_TPL = 'number of templates' | |||
d_pair = mlc.FieldReference(128, field_type=int) | |||
d_msa = mlc.FieldReference(256, field_type=int) | |||
d_template = mlc.FieldReference(64, field_type=int) | |||
d_extra_msa = mlc.FieldReference(64, field_type=int) | |||
d_single = mlc.FieldReference(384, field_type=int) | |||
max_recycling_iters = mlc.FieldReference(3, field_type=int) | |||
chunk_size = mlc.FieldReference(4, field_type=int) | |||
aux_distogram_bins = mlc.FieldReference(64, field_type=int) | |||
eps = mlc.FieldReference(1e-8, field_type=float) | |||
inf = mlc.FieldReference(3e4, field_type=float) | |||
use_templates = mlc.FieldReference(True, field_type=bool) | |||
is_multimer = mlc.FieldReference(False, field_type=bool) | |||
def base_config(): | |||
return mlc.ConfigDict({ | |||
'data': { | |||
'common': { | |||
'features': { | |||
'aatype': [N_RES], | |||
'all_atom_mask': [N_RES, None], | |||
'all_atom_positions': [N_RES, None, None], | |||
'alt_chi_angles': [N_RES, None], | |||
'atom14_alt_gt_exists': [N_RES, None], | |||
'atom14_alt_gt_positions': [N_RES, None, None], | |||
'atom14_atom_exists': [N_RES, None], | |||
'atom14_atom_is_ambiguous': [N_RES, None], | |||
'atom14_gt_exists': [N_RES, None], | |||
'atom14_gt_positions': [N_RES, None, None], | |||
'atom37_atom_exists': [N_RES, None], | |||
'frame_mask': [N_RES], | |||
'true_frame_tensor': [N_RES, None, None], | |||
'bert_mask': [N_MSA, N_RES], | |||
'chi_angles_sin_cos': [N_RES, None, None], | |||
'chi_mask': [N_RES, None], | |||
'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES], | |||
'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES], | |||
'extra_msa': [N_EXTRA_MSA, N_RES], | |||
'extra_msa_mask': [N_EXTRA_MSA, N_RES], | |||
'extra_msa_row_mask': [N_EXTRA_MSA], | |||
'is_distillation': [], | |||
'msa_feat': [N_MSA, N_RES, None], | |||
'msa_mask': [N_MSA, N_RES], | |||
'msa_chains': [N_MSA, None], | |||
'msa_row_mask': [N_MSA], | |||
'num_recycling_iters': [], | |||
'pseudo_beta': [N_RES, None], | |||
'pseudo_beta_mask': [N_RES], | |||
'residue_index': [N_RES], | |||
'residx_atom14_to_atom37': [N_RES, None], | |||
'residx_atom37_to_atom14': [N_RES, None], | |||
'resolution': [], | |||
'rigidgroups_alt_gt_frames': [N_RES, None, None, None], | |||
'rigidgroups_group_exists': [N_RES, None], | |||
'rigidgroups_group_is_ambiguous': [N_RES, None], | |||
'rigidgroups_gt_exists': [N_RES, None], | |||
'rigidgroups_gt_frames': [N_RES, None, None, None], | |||
'seq_length': [], | |||
'seq_mask': [N_RES], | |||
'target_feat': [N_RES, None], | |||
'template_aatype': [N_TPL, N_RES], | |||
'template_all_atom_mask': [N_TPL, N_RES, None], | |||
'template_all_atom_positions': [N_TPL, N_RES, None, None], | |||
'template_alt_torsion_angles_sin_cos': [ | |||
N_TPL, | |||
N_RES, | |||
None, | |||
None, | |||
], | |||
'template_frame_mask': [N_TPL, N_RES], | |||
'template_frame_tensor': [N_TPL, N_RES, None, None], | |||
'template_mask': [N_TPL], | |||
'template_pseudo_beta': [N_TPL, N_RES, None], | |||
'template_pseudo_beta_mask': [N_TPL, N_RES], | |||
'template_sum_probs': [N_TPL, None], | |||
'template_torsion_angles_mask': [N_TPL, N_RES, None], | |||
'template_torsion_angles_sin_cos': | |||
[N_TPL, N_RES, None, None], | |||
'true_msa': [N_MSA, N_RES], | |||
'use_clamped_fape': [], | |||
'assembly_num_chains': [1], | |||
'asym_id': [N_RES], | |||
'sym_id': [N_RES], | |||
'entity_id': [N_RES], | |||
'num_sym': [N_RES], | |||
'asym_len': [None], | |||
'cluster_bias_mask': [N_MSA], | |||
}, | |||
'masked_msa': { | |||
'profile_prob': 0.1, | |||
'same_prob': 0.1, | |||
'uniform_prob': 0.1, | |||
}, | |||
'block_delete_msa': { | |||
'msa_fraction_per_block': 0.3, | |||
'randomize_num_blocks': False, | |||
'num_blocks': 5, | |||
'min_num_msa': 16, | |||
}, | |||
'random_delete_msa': { | |||
'max_msa_entry': 1 << 25, # := 33554432 | |||
}, | |||
'v2_feature': | |||
False, | |||
'gumbel_sample': | |||
False, | |||
'max_extra_msa': | |||
1024, | |||
'msa_cluster_features': | |||
True, | |||
'reduce_msa_clusters_by_max_templates': | |||
True, | |||
'resample_msa_in_recycling': | |||
True, | |||
'template_features': [ | |||
'template_all_atom_positions', | |||
'template_sum_probs', | |||
'template_aatype', | |||
'template_all_atom_mask', | |||
], | |||
'unsupervised_features': [ | |||
'aatype', | |||
'residue_index', | |||
'msa', | |||
'msa_chains', | |||
'num_alignments', | |||
'seq_length', | |||
'between_segment_residues', | |||
'deletion_matrix', | |||
'num_recycling_iters', | |||
'crop_and_fix_size_seed', | |||
], | |||
'recycling_features': [ | |||
'msa_chains', | |||
'msa_mask', | |||
'msa_row_mask', | |||
'bert_mask', | |||
'true_msa', | |||
'msa_feat', | |||
'extra_msa_deletion_value', | |||
'extra_msa_has_deletion', | |||
'extra_msa', | |||
'extra_msa_mask', | |||
'extra_msa_row_mask', | |||
'is_distillation', | |||
], | |||
'multimer_features': [ | |||
'assembly_num_chains', | |||
'asym_id', | |||
'sym_id', | |||
'num_sym', | |||
'entity_id', | |||
'asym_len', | |||
'cluster_bias_mask', | |||
], | |||
'use_templates': | |||
use_templates, | |||
'is_multimer': | |||
is_multimer, | |||
'use_template_torsion_angles': | |||
use_templates, | |||
'max_recycling_iters': | |||
max_recycling_iters, | |||
}, | |||
'supervised': { | |||
'use_clamped_fape_prob': | |||
1.0, | |||
'supervised_features': [ | |||
'all_atom_mask', | |||
'all_atom_positions', | |||
'resolution', | |||
'use_clamped_fape', | |||
'is_distillation', | |||
], | |||
}, | |||
'predict': { | |||
'fixed_size': True, | |||
'subsample_templates': False, | |||
'block_delete_msa': False, | |||
'random_delete_msa': True, | |||
'masked_msa_replace_fraction': 0.15, | |||
'max_msa_clusters': 128, | |||
'max_templates': 4, | |||
'num_ensembles': 2, | |||
'crop': False, | |||
'crop_size': None, | |||
'supervised': False, | |||
'biased_msa_by_chain': False, | |||
'share_mask': False, | |||
}, | |||
'eval': { | |||
'fixed_size': True, | |||
'subsample_templates': False, | |||
'block_delete_msa': False, | |||
'random_delete_msa': True, | |||
'masked_msa_replace_fraction': 0.15, | |||
'max_msa_clusters': 128, | |||
'max_templates': 4, | |||
'num_ensembles': 1, | |||
'crop': False, | |||
'crop_size': None, | |||
'spatial_crop_prob': 0.5, | |||
'ca_ca_threshold': 10.0, | |||
'supervised': True, | |||
'biased_msa_by_chain': False, | |||
'share_mask': False, | |||
}, | |||
'train': { | |||
'fixed_size': True, | |||
'subsample_templates': True, | |||
'block_delete_msa': True, | |||
'random_delete_msa': True, | |||
'masked_msa_replace_fraction': 0.15, | |||
'max_msa_clusters': 128, | |||
'max_templates': 4, | |||
'num_ensembles': 1, | |||
'crop': True, | |||
'crop_size': 256, | |||
'spatial_crop_prob': 0.5, | |||
'ca_ca_threshold': 10.0, | |||
'supervised': True, | |||
'use_clamped_fape_prob': 1.0, | |||
'max_distillation_msa_clusters': 1000, | |||
'biased_msa_by_chain': True, | |||
'share_mask': True, | |||
}, | |||
}, | |||
'globals': { | |||
'chunk_size': chunk_size, | |||
'block_size': None, | |||
'd_pair': d_pair, | |||
'd_msa': d_msa, | |||
'd_template': d_template, | |||
'd_extra_msa': d_extra_msa, | |||
'd_single': d_single, | |||
'eps': eps, | |||
'inf': inf, | |||
'max_recycling_iters': max_recycling_iters, | |||
'alphafold_original_mode': False, | |||
}, | |||
'model': { | |||
'is_multimer': is_multimer, | |||
'input_embedder': { | |||
'tf_dim': 22, | |||
'msa_dim': 49, | |||
'd_pair': d_pair, | |||
'd_msa': d_msa, | |||
'relpos_k': 32, | |||
'max_relative_chain': 2, | |||
}, | |||
'recycling_embedder': { | |||
'd_pair': d_pair, | |||
'd_msa': d_msa, | |||
'min_bin': 3.25, | |||
'max_bin': 20.75, | |||
'num_bins': 15, | |||
'inf': 1e8, | |||
}, | |||
'template': { | |||
'distogram': { | |||
'min_bin': 3.25, | |||
'max_bin': 50.75, | |||
'num_bins': 39, | |||
}, | |||
'template_angle_embedder': { | |||
'd_in': 57, | |||
'd_out': d_msa, | |||
}, | |||
'template_pair_embedder': { | |||
'd_in': 88, | |||
'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1], | |||
'd_pair': d_pair, | |||
'd_out': d_template, | |||
'v2_feature': False, | |||
}, | |||
'template_pair_stack': { | |||
'd_template': d_template, | |||
'd_hid_tri_att': 16, | |||
'd_hid_tri_mul': 64, | |||
'num_blocks': 2, | |||
'num_heads': 4, | |||
'pair_transition_n': 2, | |||
'dropout_rate': 0.25, | |||
'inf': 1e9, | |||
'tri_attn_first': True, | |||
}, | |||
'template_pointwise_attention': { | |||
'enabled': True, | |||
'd_template': d_template, | |||
'd_pair': d_pair, | |||
'd_hid': 16, | |||
'num_heads': 4, | |||
'inf': 1e5, | |||
}, | |||
'inf': 1e5, | |||
'eps': 1e-6, | |||
'enabled': use_templates, | |||
'embed_angles': use_templates, | |||
}, | |||
'extra_msa': { | |||
'extra_msa_embedder': { | |||
'd_in': 25, | |||
'd_out': d_extra_msa, | |||
}, | |||
'extra_msa_stack': { | |||
'd_msa': d_extra_msa, | |||
'd_pair': d_pair, | |||
'd_hid_msa_att': 8, | |||
'd_hid_opm': 32, | |||
'd_hid_mul': 128, | |||
'd_hid_pair_att': 32, | |||
'num_heads_msa': 8, | |||
'num_heads_pair': 4, | |||
'num_blocks': 4, | |||
'transition_n': 4, | |||
'msa_dropout': 0.15, | |||
'pair_dropout': 0.25, | |||
'inf': 1e9, | |||
'eps': 1e-10, | |||
'outer_product_mean_first': False, | |||
}, | |||
'enabled': True, | |||
}, | |||
'evoformer_stack': { | |||
'd_msa': d_msa, | |||
'd_pair': d_pair, | |||
'd_hid_msa_att': 32, | |||
'd_hid_opm': 32, | |||
'd_hid_mul': 128, | |||
'd_hid_pair_att': 32, | |||
'd_single': d_single, | |||
'num_heads_msa': 8, | |||
'num_heads_pair': 4, | |||
'num_blocks': 48, | |||
'transition_n': 4, | |||
'msa_dropout': 0.15, | |||
'pair_dropout': 0.25, | |||
'inf': 1e9, | |||
'eps': 1e-10, | |||
'outer_product_mean_first': False, | |||
}, | |||
'structure_module': { | |||
'd_single': d_single, | |||
'd_pair': d_pair, | |||
'd_ipa': 16, | |||
'd_angle': 128, | |||
'num_heads_ipa': 12, | |||
'num_qk_points': 4, | |||
'num_v_points': 8, | |||
'dropout_rate': 0.1, | |||
'num_blocks': 8, | |||
'no_transition_layers': 1, | |||
'num_resnet_blocks': 2, | |||
'num_angles': 7, | |||
'trans_scale_factor': 10, | |||
'epsilon': 1e-12, | |||
'inf': 1e5, | |||
'separate_kv': False, | |||
'ipa_bias': True, | |||
}, | |||
'heads': { | |||
'plddt': { | |||
'num_bins': 50, | |||
'd_in': d_single, | |||
'd_hid': 128, | |||
}, | |||
'distogram': { | |||
'd_pair': d_pair, | |||
'num_bins': aux_distogram_bins, | |||
'disable_enhance_head': False, | |||
}, | |||
'pae': { | |||
'd_pair': d_pair, | |||
'num_bins': aux_distogram_bins, | |||
'enabled': False, | |||
'iptm_weight': 0.8, | |||
'disable_enhance_head': False, | |||
}, | |||
'masked_msa': { | |||
'd_msa': d_msa, | |||
'd_out': 23, | |||
'disable_enhance_head': False, | |||
}, | |||
'experimentally_resolved': { | |||
'd_single': d_single, | |||
'd_out': 37, | |||
'enabled': False, | |||
'disable_enhance_head': False, | |||
}, | |||
}, | |||
}, | |||
'loss': { | |||
'distogram': { | |||
'min_bin': 2.3125, | |||
'max_bin': 21.6875, | |||
'num_bins': 64, | |||
'eps': 1e-6, | |||
'weight': 0.3, | |||
}, | |||
'experimentally_resolved': { | |||
'eps': 1e-8, | |||
'min_resolution': 0.1, | |||
'max_resolution': 3.0, | |||
'weight': 0.0, | |||
}, | |||
'fape': { | |||
'backbone': { | |||
'clamp_distance': 10.0, | |||
'clamp_distance_between_chains': 30.0, | |||
'loss_unit_distance': 10.0, | |||
'loss_unit_distance_between_chains': 20.0, | |||
'weight': 0.5, | |||
'eps': 1e-4, | |||
}, | |||
'sidechain': { | |||
'clamp_distance': 10.0, | |||
'length_scale': 10.0, | |||
'weight': 0.5, | |||
'eps': 1e-4, | |||
}, | |||
'weight': 1.0, | |||
}, | |||
'plddt': { | |||
'min_resolution': 0.1, | |||
'max_resolution': 3.0, | |||
'cutoff': 15.0, | |||
'num_bins': 50, | |||
'eps': 1e-10, | |||
'weight': 0.01, | |||
}, | |||
'masked_msa': { | |||
'eps': 1e-8, | |||
'weight': 2.0, | |||
}, | |||
'supervised_chi': { | |||
'chi_weight': 0.5, | |||
'angle_norm_weight': 0.01, | |||
'eps': 1e-6, | |||
'weight': 1.0, | |||
}, | |||
'violation': { | |||
'violation_tolerance_factor': 12.0, | |||
'clash_overlap_tolerance': 1.5, | |||
'bond_angle_loss_weight': 0.3, | |||
'eps': 1e-6, | |||
'weight': 0.0, | |||
}, | |||
'pae': { | |||
'max_bin': 31, | |||
'num_bins': 64, | |||
'min_resolution': 0.1, | |||
'max_resolution': 3.0, | |||
'eps': 1e-8, | |||
'weight': 0.0, | |||
}, | |||
'repr_norm': { | |||
'weight': 0.01, | |||
'tolerance': 1.0, | |||
}, | |||
'chain_centre_mass': { | |||
'weight': 0.0, | |||
'eps': 1e-8, | |||
}, | |||
}, | |||
}) | |||
def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None): | |||
with c.unlocked(): | |||
for k, v in c.items(): | |||
if ignore is not None and k == ignore: | |||
continue | |||
if isinstance(v, mlc.ConfigDict): | |||
recursive_set(v, key, value) | |||
elif k == key: | |||
c[k] = value | |||
def model_config(name, train=False): | |||
c = copy.deepcopy(base_config()) | |||
def model_2_v2(c): | |||
recursive_set(c, 'v2_feature', True) | |||
recursive_set(c, 'gumbel_sample', True) | |||
c.model.heads.masked_msa.d_out = 22 | |||
c.model.structure_module.separate_kv = True | |||
c.model.structure_module.ipa_bias = False | |||
c.model.template.template_angle_embedder.d_in = 34 | |||
return c | |||
def multimer(c): | |||
recursive_set(c, 'is_multimer', True) | |||
recursive_set(c, 'max_extra_msa', 1152) | |||
recursive_set(c, 'max_msa_clusters', 128) | |||
recursive_set(c, 'v2_feature', True) | |||
recursive_set(c, 'gumbel_sample', True) | |||
c.model.template.template_angle_embedder.d_in = 34 | |||
c.model.template.template_pair_stack.tri_attn_first = False | |||
c.model.template.template_pointwise_attention.enabled = False | |||
c.model.heads.pae.enabled = True | |||
# we forget to enable it in our training, so disable it here | |||
c.model.heads.pae.disable_enhance_head = True | |||
c.model.heads.masked_msa.d_out = 22 | |||
c.model.structure_module.separate_kv = True | |||
c.model.structure_module.ipa_bias = False | |||
c.model.structure_module.trans_scale_factor = 20 | |||
c.loss.pae.weight = 0.1 | |||
c.model.input_embedder.tf_dim = 21 | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.02 | |||
c.loss.chain_centre_mass.weight = 1.0 | |||
return c | |||
if name == 'model_1': | |||
pass | |||
elif name == 'model_1_ft': | |||
recursive_set(c, 'max_extra_msa', 5120) | |||
recursive_set(c, 'max_msa_clusters', 512) | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.02 | |||
elif name == 'model_1_af2': | |||
recursive_set(c, 'max_extra_msa', 5120) | |||
recursive_set(c, 'max_msa_clusters', 512) | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.02 | |||
c.loss.repr_norm.weight = 0 | |||
c.model.heads.experimentally_resolved.enabled = True | |||
c.loss.experimentally_resolved.weight = 0.01 | |||
c.globals.alphafold_original_mode = True | |||
elif name == 'model_2': | |||
pass | |||
elif name == 'model_init': | |||
pass | |||
elif name == 'model_init_af2': | |||
c.globals.alphafold_original_mode = True | |||
pass | |||
elif name == 'model_2_ft': | |||
recursive_set(c, 'max_extra_msa', 1024) | |||
recursive_set(c, 'max_msa_clusters', 512) | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.02 | |||
elif name == 'model_2_af2': | |||
recursive_set(c, 'max_extra_msa', 1024) | |||
recursive_set(c, 'max_msa_clusters', 512) | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.02 | |||
c.loss.repr_norm.weight = 0 | |||
c.model.heads.experimentally_resolved.enabled = True | |||
c.loss.experimentally_resolved.weight = 0.01 | |||
c.globals.alphafold_original_mode = True | |||
elif name == 'model_2_v2': | |||
c = model_2_v2(c) | |||
elif name == 'model_2_v2_ft': | |||
c = model_2_v2(c) | |||
recursive_set(c, 'max_extra_msa', 1024) | |||
recursive_set(c, 'max_msa_clusters', 512) | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.02 | |||
elif name == 'model_3_af2' or name == 'model_4_af2': | |||
recursive_set(c, 'max_extra_msa', 5120) | |||
recursive_set(c, 'max_msa_clusters', 512) | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.02 | |||
c.loss.repr_norm.weight = 0 | |||
c.model.heads.experimentally_resolved.enabled = True | |||
c.loss.experimentally_resolved.weight = 0.01 | |||
c.globals.alphafold_original_mode = True | |||
c.model.template.enabled = False | |||
c.model.template.embed_angles = False | |||
recursive_set(c, 'use_templates', False) | |||
recursive_set(c, 'use_template_torsion_angles', False) | |||
elif name == 'model_5_af2': | |||
recursive_set(c, 'max_extra_msa', 1024) | |||
recursive_set(c, 'max_msa_clusters', 512) | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.02 | |||
c.loss.repr_norm.weight = 0 | |||
c.model.heads.experimentally_resolved.enabled = True | |||
c.loss.experimentally_resolved.weight = 0.01 | |||
c.globals.alphafold_original_mode = True | |||
c.model.template.enabled = False | |||
c.model.template.embed_angles = False | |||
recursive_set(c, 'use_templates', False) | |||
recursive_set(c, 'use_template_torsion_angles', False) | |||
elif name == 'multimer': | |||
c = multimer(c) | |||
elif name == 'multimer_ft': | |||
c = multimer(c) | |||
recursive_set(c, 'max_extra_msa', 1152) | |||
recursive_set(c, 'max_msa_clusters', 256) | |||
c.data.train.crop_size = 384 | |||
c.loss.violation.weight = 0.5 | |||
elif name == 'multimer_af2': | |||
recursive_set(c, 'max_extra_msa', 1152) | |||
recursive_set(c, 'max_msa_clusters', 256) | |||
recursive_set(c, 'is_multimer', True) | |||
recursive_set(c, 'v2_feature', True) | |||
recursive_set(c, 'gumbel_sample', True) | |||
c.model.template.template_angle_embedder.d_in = 34 | |||
c.model.template.template_pair_stack.tri_attn_first = False | |||
c.model.template.template_pointwise_attention.enabled = False | |||
c.model.heads.pae.enabled = True | |||
c.model.heads.experimentally_resolved.enabled = True | |||
c.model.heads.masked_msa.d_out = 22 | |||
c.model.structure_module.separate_kv = True | |||
c.model.structure_module.ipa_bias = False | |||
c.model.structure_module.trans_scale_factor = 20 | |||
c.loss.pae.weight = 0.1 | |||
c.loss.violation.weight = 0.5 | |||
c.loss.experimentally_resolved.weight = 0.01 | |||
c.model.input_embedder.tf_dim = 21 | |||
c.globals.alphafold_original_mode = True | |||
c.data.train.crop_size = 384 | |||
c.loss.repr_norm.weight = 0 | |||
c.loss.chain_centre_mass.weight = 1.0 | |||
recursive_set(c, 'outer_product_mean_first', True) | |||
else: | |||
raise ValueError(f'invalid --model-name: {name}.') | |||
if train: | |||
c.globals.chunk_size = None | |||
recursive_set(c, 'inf', 3e4) | |||
recursive_set(c, 'eps', 1e-5, 'loss') | |||
return c |
@@ -0,0 +1,14 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Data pipeline for model features.""" |
@@ -0,0 +1,526 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Pairing logic for multimer data """ | |||
import collections | |||
from typing import Dict, Iterable, List, Sequence | |||
import numpy as np | |||
import pandas as pd | |||
import scipy.linalg | |||
from .data_ops import NumpyDict | |||
from .residue_constants import restypes_with_x_and_gap | |||
MSA_GAP_IDX = restypes_with_x_and_gap.index('-') | |||
SEQUENCE_GAP_CUTOFF = 0.5 | |||
SEQUENCE_SIMILARITY_CUTOFF = 0.9 | |||
MSA_PAD_VALUES = { | |||
'msa_all_seq': MSA_GAP_IDX, | |||
'msa_mask_all_seq': 1, | |||
'deletion_matrix_all_seq': 0, | |||
'deletion_matrix_int_all_seq': 0, | |||
'msa': MSA_GAP_IDX, | |||
'msa_mask': 1, | |||
'deletion_matrix': 0, | |||
'deletion_matrix_int': 0, | |||
} | |||
MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') | |||
SEQ_FEATURES = ( | |||
'residue_index', | |||
'aatype', | |||
'all_atom_positions', | |||
'all_atom_mask', | |||
'seq_mask', | |||
'between_segment_residues', | |||
'has_alt_locations', | |||
'has_hetatoms', | |||
'asym_id', | |||
'entity_id', | |||
'sym_id', | |||
'entity_mask', | |||
'deletion_mean', | |||
'prediction_atom_mask', | |||
'literature_positions', | |||
'atom_indices_to_group_indices', | |||
'rigid_group_default_frame', | |||
# zy | |||
'num_sym', | |||
) | |||
TEMPLATE_FEATURES = ( | |||
'template_aatype', | |||
'template_all_atom_positions', | |||
'template_all_atom_mask', | |||
) | |||
CHAIN_FEATURES = ('num_alignments', 'seq_length') | |||
def create_paired_features(chains: Iterable[NumpyDict], ) -> List[NumpyDict]: | |||
"""Returns the original chains with paired NUM_SEQ features. | |||
Args: | |||
chains: A list of feature dictionaries for each chain. | |||
Returns: | |||
A list of feature dictionaries with sequence features including only | |||
rows to be paired. | |||
""" | |||
chains = list(chains) | |||
chain_keys = chains[0].keys() | |||
if len(chains) < 2: | |||
return chains | |||
else: | |||
updated_chains = [] | |||
paired_chains_to_paired_row_indices = pair_sequences(chains) | |||
paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices) | |||
for chain_num, chain in enumerate(chains): | |||
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} | |||
for feature_name in chain_keys: | |||
if feature_name.endswith('_all_seq'): | |||
feats_padded = pad_features(chain[feature_name], | |||
feature_name) | |||
new_chain[feature_name] = feats_padded[ | |||
paired_rows[:, chain_num]] | |||
new_chain['num_alignments_all_seq'] = np.asarray( | |||
len(paired_rows[:, chain_num])) | |||
updated_chains.append(new_chain) | |||
return updated_chains | |||
def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: | |||
"""Add a 'padding' row at the end of the features list. | |||
The padding row will be selected as a 'paired' row in the case of partial | |||
alignment - for the chain that doesn't have paired alignment. | |||
Args: | |||
feature: The feature to be padded. | |||
feature_name: The name of the feature to be padded. | |||
Returns: | |||
The feature with an additional padding row. | |||
""" | |||
assert feature.dtype != np.dtype(np.string_) | |||
if feature_name in ( | |||
'msa_all_seq', | |||
'msa_mask_all_seq', | |||
'deletion_matrix_all_seq', | |||
'deletion_matrix_int_all_seq', | |||
): | |||
num_res = feature.shape[1] | |||
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], | |||
feature.dtype) | |||
elif feature_name == 'msa_species_identifiers_all_seq': | |||
padding = [b''] | |||
else: | |||
return feature | |||
feats_padded = np.concatenate([feature, padding], axis=0) | |||
return feats_padded | |||
def _make_msa_df(chain_features: NumpyDict) -> pd.DataFrame: | |||
"""Makes dataframe with msa features needed for msa pairing.""" | |||
chain_msa = chain_features['msa_all_seq'] | |||
query_seq = chain_msa[0] | |||
per_seq_similarity = np.sum( | |||
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) | |||
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) | |||
msa_df = pd.DataFrame({ | |||
'msa_species_identifiers': | |||
chain_features['msa_species_identifiers_all_seq'], | |||
'msa_row': | |||
np.arange(len(chain_features['msa_species_identifiers_all_seq'])), | |||
'msa_similarity': | |||
per_seq_similarity, | |||
'gap': | |||
per_seq_gap, | |||
}) | |||
return msa_df | |||
def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: | |||
"""Creates mapping from species to msa dataframe of that species.""" | |||
species_lookup = {} | |||
for species, species_df in msa_df.groupby('msa_species_identifiers'): | |||
species_lookup[species] = species_df | |||
return species_lookup | |||
def _match_rows_by_sequence_similarity( | |||
this_species_msa_dfs: List[pd.DataFrame], ) -> List[List[int]]: # noqa | |||
"""Finds MSA sequence pairings across chains based on sequence similarity. | |||
Each chain's MSA sequences are first sorted by their sequence similarity to | |||
their respective target sequence. The sequences are then paired, starting | |||
from the sequences most similar to their target sequence. | |||
Args: | |||
this_species_msa_dfs: a list of dataframes containing MSA features for | |||
sequences for a specific species. | |||
Returns: | |||
A list of lists, each containing M indices corresponding to paired MSA rows, | |||
where M is the number of chains. | |||
""" | |||
all_paired_msa_rows = [] | |||
num_seqs = [ | |||
len(species_df) for species_df in this_species_msa_dfs | |||
if species_df is not None | |||
] | |||
take_num_seqs = np.min(num_seqs) | |||
# sort_by_similarity = lambda x: x.sort_values( | |||
# 'msa_similarity', axis=0, ascending=False) | |||
def sort_by_similarity(x): | |||
return x.sort_values('msa_similarity', axis=0, ascending=False) | |||
for species_df in this_species_msa_dfs: | |||
if species_df is not None: | |||
species_df_sorted = sort_by_similarity(species_df) | |||
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values | |||
else: | |||
msa_rows = [-1] * take_num_seqs # take the last 'padding' row | |||
all_paired_msa_rows.append(msa_rows) | |||
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) | |||
return all_paired_msa_rows | |||
def pair_sequences(examples: List[NumpyDict]) -> Dict[int, np.ndarray]: | |||
"""Returns indices for paired MSA sequences across chains.""" | |||
num_examples = len(examples) | |||
all_chain_species_dict = [] | |||
common_species = set() | |||
for chain_features in examples: | |||
msa_df = _make_msa_df(chain_features) | |||
species_dict = _create_species_dict(msa_df) | |||
all_chain_species_dict.append(species_dict) | |||
common_species.update(set(species_dict)) | |||
common_species = sorted(common_species) | |||
common_species.remove(b'') # Remove target sequence species. | |||
all_paired_msa_rows = [np.zeros(len(examples), int)] | |||
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} | |||
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] | |||
for species in common_species: | |||
if not species: | |||
continue | |||
this_species_msa_dfs = [] | |||
species_dfs_present = 0 | |||
for species_dict in all_chain_species_dict: | |||
if species in species_dict: | |||
this_species_msa_dfs.append(species_dict[species]) | |||
species_dfs_present += 1 | |||
else: | |||
this_species_msa_dfs.append(None) | |||
# Skip species that are present in only one chain. | |||
if species_dfs_present <= 1: | |||
continue | |||
if np.any( | |||
np.array([ | |||
len(species_df) for species_df in this_species_msa_dfs | |||
if isinstance(species_df, pd.DataFrame) | |||
]) > 600): | |||
continue | |||
paired_msa_rows = _match_rows_by_sequence_similarity( | |||
this_species_msa_dfs) | |||
all_paired_msa_rows.extend(paired_msa_rows) | |||
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) | |||
all_paired_msa_rows_dict = { | |||
num_examples: np.array(paired_msa_rows) | |||
for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() | |||
} | |||
return all_paired_msa_rows_dict | |||
def reorder_paired_rows( | |||
all_paired_msa_rows_dict: Dict[int, np.ndarray]) -> np.ndarray: | |||
"""Creates a list of indices of paired MSA rows across chains. | |||
Args: | |||
all_paired_msa_rows_dict: a mapping from the number of paired chains to the | |||
paired indices. | |||
Returns: | |||
a list of lists, each containing indices of paired MSA rows across chains. | |||
The paired-index lists are ordered by: | |||
1) the number of chains in the paired alignment, i.e, all-chain pairings | |||
will come first. | |||
2) e-values | |||
""" | |||
all_paired_msa_rows = [] | |||
for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): | |||
paired_rows = all_paired_msa_rows_dict[num_pairings] | |||
paired_rows_product = np.abs( | |||
np.array( | |||
[np.prod(rows.astype(np.float64)) for rows in paired_rows])) | |||
paired_rows_sort_index = np.argsort(paired_rows_product) | |||
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) | |||
return np.array(all_paired_msa_rows) | |||
def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: | |||
"""Like scipy.linalg.block_diag but with an optional padding value.""" | |||
ones_arrs = [np.ones_like(x) for x in arrs] | |||
off_diag_mask = 1 - scipy.linalg.block_diag(*ones_arrs) | |||
diag = scipy.linalg.block_diag(*arrs) | |||
diag += (off_diag_mask * pad_value).astype(diag.dtype) | |||
return diag | |||
def _correct_post_merged_feats(np_example: NumpyDict, | |||
np_chains_list: Sequence[NumpyDict], | |||
pair_msa_sequences: bool) -> NumpyDict: | |||
"""Adds features that need to be computed/recomputed post merging.""" | |||
np_example['seq_length'] = np.asarray( | |||
np_example['aatype'].shape[0], dtype=np.int32) | |||
np_example['num_alignments'] = np.asarray( | |||
np_example['msa'].shape[0], dtype=np.int32) | |||
if not pair_msa_sequences: | |||
# Generate a bias that is 1 for the first row of every block in the | |||
# block diagonal MSA - i.e. make sure the cluster stack always includes | |||
# the query sequences for each chain (since the first row is the query | |||
# sequence). | |||
cluster_bias_masks = [] | |||
for chain in np_chains_list: | |||
mask = np.zeros(chain['msa'].shape[0]) | |||
mask[0] = 1 | |||
cluster_bias_masks.append(mask) | |||
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) | |||
# Initialize Bert mask with masked out off diagonals. | |||
msa_masks = [ | |||
np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list | |||
] | |||
np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0) | |||
else: | |||
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) | |||
np_example['cluster_bias_mask'][0] = 1 | |||
# Initialize Bert mask with masked out off diagonals. | |||
msa_masks = [ | |||
np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list | |||
] | |||
msa_masks_all_seq = [ | |||
np.ones(x['msa_all_seq'].shape, dtype=np.int8) | |||
for x in np_chains_list | |||
] | |||
msa_mask_block_diag = block_diag(*msa_masks, pad_value=0) | |||
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) | |||
np_example['bert_mask'] = np.concatenate( | |||
[msa_mask_all_seq, msa_mask_block_diag], axis=0) | |||
return np_example | |||
def _pad_templates(chains: Sequence[NumpyDict], | |||
max_templates: int) -> Sequence[NumpyDict]: | |||
"""For each chain pad the number of templates to a fixed size. | |||
Args: | |||
chains: A list of protein chains. | |||
max_templates: Each chain will be padded to have this many templates. | |||
Returns: | |||
The list of chains, updated to have template features padded to | |||
max_templates. | |||
""" | |||
for chain in chains: | |||
for k, v in chain.items(): | |||
if k in TEMPLATE_FEATURES: | |||
padding = np.zeros_like(v.shape) | |||
padding[0] = max_templates - v.shape[0] | |||
padding = [(0, p) for p in padding] | |||
chain[k] = np.pad(v, padding, mode='constant') | |||
return chains | |||
def _merge_features_from_multiple_chains( | |||
chains: Sequence[NumpyDict], pair_msa_sequences: bool) -> NumpyDict: | |||
"""Merge features from multiple chains. | |||
Args: | |||
chains: A list of feature dictionaries that we want to merge. | |||
pair_msa_sequences: Whether to concatenate MSA features along the | |||
num_res dimension (if True), or to block diagonalize them (if False). | |||
Returns: | |||
A feature dictionary for the merged example. | |||
""" | |||
merged_example = {} | |||
for feature_name in chains[0]: | |||
feats = [x[feature_name] for x in chains] | |||
feature_name_split = feature_name.split('_all_seq')[0] | |||
if feature_name_split in MSA_FEATURES: | |||
if pair_msa_sequences or '_all_seq' in feature_name: | |||
merged_example[feature_name] = np.concatenate(feats, axis=1) | |||
if feature_name_split == 'msa': | |||
merged_example['msa_chains_all_seq'] = np.ones( | |||
merged_example[feature_name].shape[0]).reshape(-1, 1) | |||
else: | |||
merged_example[feature_name] = block_diag( | |||
*feats, pad_value=MSA_PAD_VALUES[feature_name]) | |||
if feature_name_split == 'msa': | |||
msa_chains = [] | |||
for i, feat in enumerate(feats): | |||
cur_shape = feat.shape[0] | |||
vals = np.ones(cur_shape) * (i + 2) | |||
msa_chains.append(vals) | |||
merged_example['msa_chains'] = np.concatenate( | |||
msa_chains).reshape(-1, 1) | |||
elif feature_name_split in SEQ_FEATURES: | |||
merged_example[feature_name] = np.concatenate(feats, axis=0) | |||
elif feature_name_split in TEMPLATE_FEATURES: | |||
merged_example[feature_name] = np.concatenate(feats, axis=1) | |||
elif feature_name_split in CHAIN_FEATURES: | |||
merged_example[feature_name] = np.sum(feats).astype(np.int32) | |||
else: | |||
merged_example[feature_name] = feats[0] | |||
return merged_example | |||
def _merge_homomers_dense_msa( | |||
chains: Iterable[NumpyDict]) -> Sequence[NumpyDict]: | |||
"""Merge all identical chains, making the resulting MSA dense. | |||
Args: | |||
chains: An iterable of features for each chain. | |||
Returns: | |||
A list of feature dictionaries. All features with the same entity_id | |||
will be merged - MSA features will be concatenated along the num_res | |||
dimension - making them dense. | |||
""" | |||
entity_chains = collections.defaultdict(list) | |||
for chain in chains: | |||
entity_id = chain['entity_id'][0] | |||
entity_chains[entity_id].append(chain) | |||
grouped_chains = [] | |||
for entity_id in sorted(entity_chains): | |||
chains = entity_chains[entity_id] | |||
grouped_chains.append(chains) | |||
chains = [ | |||
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True) | |||
for chains in grouped_chains | |||
] | |||
return chains | |||
def _concatenate_paired_and_unpaired_features(example: NumpyDict) -> NumpyDict: | |||
"""Merges paired and block-diagonalised features.""" | |||
features = MSA_FEATURES + ('msa_chains', ) | |||
for feature_name in features: | |||
if feature_name in example: | |||
feat = example[feature_name] | |||
feat_all_seq = example[feature_name + '_all_seq'] | |||
try: | |||
merged_feat = np.concatenate([feat_all_seq, feat], axis=0) | |||
except Exception as ex: | |||
raise Exception( | |||
'concat failed.', | |||
feature_name, | |||
feat_all_seq.shape, | |||
feat.shape, | |||
ex.__class__, | |||
ex, | |||
) | |||
example[feature_name] = merged_feat | |||
example['num_alignments'] = np.array( | |||
example['msa'].shape[0], dtype=np.int32) | |||
return example | |||
def merge_chain_features(np_chains_list: List[NumpyDict], | |||
pair_msa_sequences: bool, | |||
max_templates: int) -> NumpyDict: | |||
"""Merges features for multiple chains to single FeatureDict. | |||
Args: | |||
np_chains_list: List of FeatureDicts for each chain. | |||
pair_msa_sequences: Whether to merge paired MSAs. | |||
max_templates: The maximum number of templates to include. | |||
Returns: | |||
Single FeatureDict for entire complex. | |||
""" | |||
np_chains_list = _pad_templates( | |||
np_chains_list, max_templates=max_templates) | |||
np_chains_list = _merge_homomers_dense_msa(np_chains_list) | |||
# Unpaired MSA features will be always block-diagonalised; paired MSA | |||
# features will be concatenated. | |||
np_example = _merge_features_from_multiple_chains( | |||
np_chains_list, pair_msa_sequences=False) | |||
if pair_msa_sequences: | |||
np_example = _concatenate_paired_and_unpaired_features(np_example) | |||
np_example = _correct_post_merged_feats( | |||
np_example=np_example, | |||
np_chains_list=np_chains_list, | |||
pair_msa_sequences=pair_msa_sequences, | |||
) | |||
return np_example | |||
def deduplicate_unpaired_sequences( | |||
np_chains: List[NumpyDict]) -> List[NumpyDict]: | |||
"""Removes unpaired sequences which duplicate a paired sequence.""" | |||
feature_names = np_chains[0].keys() | |||
msa_features = MSA_FEATURES | |||
cache_msa_features = {} | |||
for chain in np_chains: | |||
entity_id = int(chain['entity_id'][0]) | |||
if entity_id not in cache_msa_features: | |||
sequence_set = set(s.tobytes() for s in chain['msa_all_seq']) | |||
keep_rows = [] | |||
# Go through unpaired MSA seqs and remove any rows that correspond to the | |||
# sequences that are already present in the paired MSA. | |||
for row_num, seq in enumerate(chain['msa']): | |||
if seq.tobytes() not in sequence_set: | |||
keep_rows.append(row_num) | |||
new_msa_features = {} | |||
for feature_name in feature_names: | |||
if feature_name in msa_features: | |||
if keep_rows: | |||
new_msa_features[feature_name] = chain[feature_name][ | |||
keep_rows] | |||
else: | |||
new_shape = list(chain[feature_name].shape) | |||
new_shape[0] = 0 | |||
new_msa_features[feature_name] = np.zeros( | |||
new_shape, dtype=chain[feature_name].dtype) | |||
cache_msa_features[entity_id] = new_msa_features | |||
for feature_name in cache_msa_features[entity_id]: | |||
chain[feature_name] = cache_msa_features[entity_id][feature_name] | |||
chain['num_alignments'] = np.array( | |||
chain['msa'].shape[0], dtype=np.int32) | |||
return np_chains |
@@ -0,0 +1,264 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from typing import Optional | |||
import numpy as np | |||
import torch | |||
from modelscope.models.science.unifold.data import data_ops | |||
def nonensembled_fns(common_cfg, mode_cfg): | |||
"""Input pipeline data transformers that are not ensembled.""" | |||
v2_feature = common_cfg.v2_feature | |||
operators = [] | |||
if mode_cfg.random_delete_msa: | |||
operators.append( | |||
data_ops.random_delete_msa(common_cfg.random_delete_msa)) | |||
operators.extend([ | |||
data_ops.cast_to_64bit_ints, | |||
data_ops.correct_msa_restypes, | |||
data_ops.squeeze_features, | |||
data_ops.randomly_replace_msa_with_unknown(0.0), | |||
data_ops.make_seq_mask, | |||
data_ops.make_msa_mask, | |||
]) | |||
operators.append(data_ops.make_hhblits_profile_v2 | |||
if v2_feature else data_ops.make_hhblits_profile) | |||
if common_cfg.use_templates: | |||
operators.extend([ | |||
data_ops.make_template_mask, | |||
data_ops.make_pseudo_beta('template_'), | |||
]) | |||
operators.append( | |||
data_ops.crop_templates( | |||
max_templates=mode_cfg.max_templates, | |||
subsample_templates=mode_cfg.subsample_templates, | |||
)) | |||
if common_cfg.use_template_torsion_angles: | |||
operators.extend([ | |||
data_ops.atom37_to_torsion_angles('template_'), | |||
]) | |||
operators.append(data_ops.make_atom14_masks) | |||
operators.append(data_ops.make_target_feat) | |||
return operators | |||
def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed): | |||
operators = [] | |||
if common_cfg.reduce_msa_clusters_by_max_templates: | |||
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates | |||
else: | |||
pad_msa_clusters = mode_cfg.max_msa_clusters | |||
crop_feats = dict(common_cfg.features) | |||
if mode_cfg.fixed_size: | |||
if mode_cfg.crop: | |||
if common_cfg.is_multimer: | |||
crop_fn = data_ops.crop_to_size_multimer( | |||
crop_size=mode_cfg.crop_size, | |||
shape_schema=crop_feats, | |||
seed=crop_and_fix_size_seed, | |||
spatial_crop_prob=mode_cfg.spatial_crop_prob, | |||
ca_ca_threshold=mode_cfg.ca_ca_threshold, | |||
) | |||
else: | |||
crop_fn = data_ops.crop_to_size_single( | |||
crop_size=mode_cfg.crop_size, | |||
shape_schema=crop_feats, | |||
seed=crop_and_fix_size_seed, | |||
) | |||
operators.append(crop_fn) | |||
operators.append(data_ops.select_feat(crop_feats)) | |||
operators.append( | |||
data_ops.make_fixed_size( | |||
crop_feats, | |||
pad_msa_clusters, | |||
common_cfg.max_extra_msa, | |||
mode_cfg.crop_size, | |||
mode_cfg.max_templates, | |||
)) | |||
return operators | |||
def ensembled_fns(common_cfg, mode_cfg): | |||
"""Input pipeline data transformers that can be ensembled and averaged.""" | |||
operators = [] | |||
multimer_mode = common_cfg.is_multimer | |||
v2_feature = common_cfg.v2_feature | |||
# multimer don't use block delete msa | |||
if mode_cfg.block_delete_msa and not multimer_mode: | |||
operators.append( | |||
data_ops.block_delete_msa(common_cfg.block_delete_msa)) | |||
if 'max_distillation_msa_clusters' in mode_cfg: | |||
operators.append( | |||
data_ops.sample_msa_distillation( | |||
mode_cfg.max_distillation_msa_clusters)) | |||
if common_cfg.reduce_msa_clusters_by_max_templates: | |||
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates | |||
else: | |||
pad_msa_clusters = mode_cfg.max_msa_clusters | |||
max_msa_clusters = pad_msa_clusters | |||
max_extra_msa = common_cfg.max_extra_msa | |||
assert common_cfg.resample_msa_in_recycling | |||
gumbel_sample = common_cfg.gumbel_sample | |||
operators.append( | |||
data_ops.sample_msa( | |||
max_msa_clusters, | |||
keep_extra=True, | |||
gumbel_sample=gumbel_sample, | |||
biased_msa_by_chain=mode_cfg.biased_msa_by_chain, | |||
)) | |||
if 'masked_msa' in common_cfg: | |||
# Masked MSA should come *before* MSA clustering so that | |||
# the clustering and full MSA profile do not leak information about | |||
# the masked locations and secret corrupted locations. | |||
operators.append( | |||
data_ops.make_masked_msa( | |||
common_cfg.masked_msa, | |||
mode_cfg.masked_msa_replace_fraction, | |||
gumbel_sample=gumbel_sample, | |||
share_mask=mode_cfg.share_mask, | |||
)) | |||
if common_cfg.msa_cluster_features: | |||
if v2_feature: | |||
operators.append(data_ops.nearest_neighbor_clusters_v2()) | |||
else: | |||
operators.append(data_ops.nearest_neighbor_clusters()) | |||
operators.append(data_ops.summarize_clusters) | |||
if v2_feature: | |||
operators.append(data_ops.make_msa_feat_v2) | |||
else: | |||
operators.append(data_ops.make_msa_feat) | |||
# Crop after creating the cluster profiles. | |||
if max_extra_msa: | |||
if v2_feature: | |||
operators.append(data_ops.make_extra_msa_feat(max_extra_msa)) | |||
else: | |||
operators.append(data_ops.crop_extra_msa(max_extra_msa)) | |||
else: | |||
operators.append(data_ops.delete_extra_msa) | |||
# operators.append(data_operators.select_feat(common_cfg.recycling_features)) | |||
return operators | |||
def process_features(tensors, common_cfg, mode_cfg): | |||
"""Based on the config, apply filters and transformations to the data.""" | |||
is_distillation = bool(tensors.get('is_distillation', 0)) | |||
multimer_mode = common_cfg.is_multimer | |||
crop_and_fix_size_seed = int(tensors['crop_and_fix_size_seed']) | |||
crop_fn = crop_and_fix_size_fns( | |||
common_cfg, | |||
mode_cfg, | |||
crop_and_fix_size_seed, | |||
) | |||
def wrap_ensemble_fn(data, i): | |||
"""Function to be mapped over the ensemble dimension.""" | |||
d = data.copy() | |||
fns = ensembled_fns( | |||
common_cfg, | |||
mode_cfg, | |||
) | |||
new_d = compose(fns)(d) | |||
if not multimer_mode or is_distillation: | |||
new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d) | |||
return compose(crop_fn)(new_d) | |||
else: # select after crop for spatial cropping | |||
d = compose(crop_fn)(d) | |||
d = data_ops.select_feat(common_cfg.recycling_features)(d) | |||
return d | |||
nonensembled = nonensembled_fns(common_cfg, mode_cfg) | |||
if mode_cfg.supervised and (not multimer_mode or is_distillation): | |||
nonensembled.extend(label_transform_fn()) | |||
tensors = compose(nonensembled)(tensors) | |||
num_recycling = int(tensors['num_recycling_iters']) + 1 | |||
num_ensembles = mode_cfg.num_ensembles | |||
ensemble_tensors = map_fn( | |||
lambda x: wrap_ensemble_fn(tensors, x), | |||
torch.arange(num_recycling * num_ensembles), | |||
) | |||
tensors = compose(crop_fn)(tensors) | |||
# add a dummy dim to align with recycling features | |||
tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors} | |||
tensors.update(ensemble_tensors) | |||
return tensors | |||
@data_ops.curry1 | |||
def compose(x, fs): | |||
for f in fs: | |||
x = f(x) | |||
return x | |||
def pad_then_stack(values, ): | |||
if len(values[0].shape) >= 1: | |||
size = max(v.shape[0] for v in values) | |||
new_values = [] | |||
for v in values: | |||
if v.shape[0] < size: | |||
res = values[0].new_zeros(size, *v.shape[1:]) | |||
res[:v.shape[0], ...] = v | |||
else: | |||
res = v | |||
new_values.append(res) | |||
else: | |||
new_values = values | |||
return torch.stack(new_values, dim=0) | |||
def map_fn(fun, x): | |||
ensembles = [fun(elem) for elem in x] | |||
features = ensembles[0].keys() | |||
ensembled_dict = {} | |||
for feat in features: | |||
ensembled_dict[feat] = pad_then_stack( | |||
[dict_i[feat] for dict_i in ensembles]) | |||
return ensembled_dict | |||
def process_single_label(label: dict, | |||
num_ensemble: Optional[int] = None) -> dict: | |||
assert 'aatype' in label | |||
assert 'all_atom_positions' in label | |||
assert 'all_atom_mask' in label | |||
label = compose(label_transform_fn())(label) | |||
if num_ensemble is not None: | |||
label = { | |||
k: torch.stack([v for _ in range(num_ensemble)]) | |||
for k, v in label.items() | |||
} | |||
return label | |||
def process_labels(labels_list, num_ensemble: Optional[int] = None): | |||
return [process_single_label(ll, num_ensemble) for ll in labels_list] | |||
def label_transform_fn(): | |||
return [ | |||
data_ops.make_atom14_masks, | |||
data_ops.make_atom14_positions, | |||
data_ops.atom37_to_frames, | |||
data_ops.atom37_to_torsion_angles(''), | |||
data_ops.make_pseudo_beta(''), | |||
data_ops.get_backbone_frames, | |||
data_ops.get_chi_angles, | |||
] |
@@ -0,0 +1,417 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Feature processing logic for multimer data """ | |||
import collections | |||
from typing import Iterable, List, MutableMapping | |||
import numpy as np | |||
from modelscope.models.science.unifold.data import (msa_pairing, | |||
residue_constants) | |||
from .utils import correct_template_restypes | |||
FeatureDict = MutableMapping[str, np.ndarray] | |||
REQUIRED_FEATURES = frozenset({ | |||
'aatype', | |||
'all_atom_mask', | |||
'all_atom_positions', | |||
'all_chains_entity_ids', | |||
'all_crops_all_chains_mask', | |||
'all_crops_all_chains_positions', | |||
'all_crops_all_chains_residue_ids', | |||
'assembly_num_chains', | |||
'asym_id', | |||
'bert_mask', | |||
'cluster_bias_mask', | |||
'deletion_matrix', | |||
'deletion_mean', | |||
'entity_id', | |||
'entity_mask', | |||
'mem_peak', | |||
'msa', | |||
'msa_mask', | |||
'num_alignments', | |||
'num_templates', | |||
'queue_size', | |||
'residue_index', | |||
'resolution', | |||
'seq_length', | |||
'seq_mask', | |||
'sym_id', | |||
'template_aatype', | |||
'template_all_atom_mask', | |||
'template_all_atom_positions', | |||
# zy added: | |||
'asym_len', | |||
'template_sum_probs', | |||
'num_sym', | |||
'msa_chains', | |||
}) | |||
MAX_TEMPLATES = 4 | |||
MSA_CROP_SIZE = 2048 | |||
def _is_homomer_or_monomer(chains: Iterable[FeatureDict]) -> bool: | |||
"""Checks if a list of chains represents a homomer/monomer example.""" | |||
# Note that an entity_id of 0 indicates padding. | |||
num_unique_chains = len( | |||
np.unique( | |||
np.concatenate([ | |||
np.unique(chain['entity_id'][chain['entity_id'] > 0]) | |||
for chain in chains | |||
]))) | |||
return num_unique_chains == 1 | |||
def pair_and_merge( | |||
all_chain_features: MutableMapping[str, FeatureDict]) -> FeatureDict: | |||
"""Runs processing on features to augment, pair and merge. | |||
Args: | |||
all_chain_features: A MutableMap of dictionaries of features for each chain. | |||
Returns: | |||
A dictionary of features. | |||
""" | |||
process_unmerged_features(all_chain_features) | |||
np_chains_list = all_chain_features | |||
pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list) | |||
if pair_msa_sequences: | |||
np_chains_list = msa_pairing.create_paired_features( | |||
chains=np_chains_list) | |||
np_chains_list = msa_pairing.deduplicate_unpaired_sequences( | |||
np_chains_list) | |||
np_chains_list = crop_chains( | |||
np_chains_list, | |||
msa_crop_size=MSA_CROP_SIZE, | |||
pair_msa_sequences=pair_msa_sequences, | |||
max_templates=MAX_TEMPLATES, | |||
) | |||
np_example = msa_pairing.merge_chain_features( | |||
np_chains_list=np_chains_list, | |||
pair_msa_sequences=pair_msa_sequences, | |||
max_templates=MAX_TEMPLATES, | |||
) | |||
np_example = process_final(np_example) | |||
return np_example | |||
def crop_chains( | |||
chains_list: List[FeatureDict], | |||
msa_crop_size: int, | |||
pair_msa_sequences: bool, | |||
max_templates: int, | |||
) -> List[FeatureDict]: | |||
"""Crops the MSAs for a set of chains. | |||
Args: | |||
chains_list: A list of chains to be cropped. | |||
msa_crop_size: The total number of sequences to crop from the MSA. | |||
pair_msa_sequences: Whether we are operating in sequence-pairing mode. | |||
max_templates: The maximum templates to use per chain. | |||
Returns: | |||
The chains cropped. | |||
""" | |||
# Apply the cropping. | |||
cropped_chains = [] | |||
for chain in chains_list: | |||
cropped_chain = _crop_single_chain( | |||
chain, | |||
msa_crop_size=msa_crop_size, | |||
pair_msa_sequences=pair_msa_sequences, | |||
max_templates=max_templates, | |||
) | |||
cropped_chains.append(cropped_chain) | |||
return cropped_chains | |||
def _crop_single_chain(chain: FeatureDict, msa_crop_size: int, | |||
pair_msa_sequences: bool, | |||
max_templates: int) -> FeatureDict: | |||
"""Crops msa sequences to `msa_crop_size`.""" | |||
msa_size = chain['num_alignments'] | |||
if pair_msa_sequences: | |||
msa_size_all_seq = chain['num_alignments_all_seq'] | |||
msa_crop_size_all_seq = np.minimum(msa_size_all_seq, | |||
msa_crop_size // 2) | |||
# We reduce the number of un-paired sequences, by the number of times a | |||
# sequence from this chain's MSA is included in the paired MSA. This keeps | |||
# the MSA size for each chain roughly constant. | |||
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :] | |||
num_non_gapped_pairs = np.sum( | |||
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)) | |||
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, | |||
msa_crop_size_all_seq) | |||
# Restrict the unpaired crop size so that paired+unpaired sequences do not | |||
# exceed msa_seqs_per_chain for each chain. | |||
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) | |||
msa_crop_size = np.minimum(msa_size, max_msa_crop_size) | |||
else: | |||
msa_crop_size = np.minimum(msa_size, msa_crop_size) | |||
include_templates = 'template_aatype' in chain and max_templates | |||
if include_templates: | |||
num_templates = chain['template_aatype'].shape[0] | |||
templates_crop_size = np.minimum(num_templates, max_templates) | |||
for k in chain: | |||
k_split = k.split('_all_seq')[0] | |||
if k_split in msa_pairing.TEMPLATE_FEATURES: | |||
chain[k] = chain[k][:templates_crop_size, :] | |||
elif k_split in msa_pairing.MSA_FEATURES: | |||
if '_all_seq' in k and pair_msa_sequences: | |||
chain[k] = chain[k][:msa_crop_size_all_seq, :] | |||
else: | |||
chain[k] = chain[k][:msa_crop_size, :] | |||
chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32) | |||
if include_templates: | |||
chain['num_templates'] = np.asarray( | |||
templates_crop_size, dtype=np.int32) | |||
if pair_msa_sequences: | |||
chain['num_alignments_all_seq'] = np.asarray( | |||
msa_crop_size_all_seq, dtype=np.int32) | |||
return chain | |||
def process_final(np_example: FeatureDict) -> FeatureDict: | |||
"""Final processing steps in data pipeline, after merging and pairing.""" | |||
np_example = _make_seq_mask(np_example) | |||
np_example = _make_msa_mask(np_example) | |||
np_example = _filter_features(np_example) | |||
return np_example | |||
def _make_seq_mask(np_example): | |||
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32) | |||
return np_example | |||
def _make_msa_mask(np_example): | |||
"""Mask features are all ones, but will later be zero-padded.""" | |||
np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.int8) | |||
seq_mask = (np_example['entity_id'] > 0).astype(np.int8) | |||
np_example['msa_mask'] *= seq_mask[None] | |||
return np_example | |||
def _filter_features(np_example: FeatureDict) -> FeatureDict: | |||
"""Filters features of example to only those requested.""" | |||
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES} | |||
def process_unmerged_features(all_chain_features: MutableMapping[str, | |||
FeatureDict]): | |||
"""Postprocessing stage for per-chain features before merging.""" | |||
num_chains = len(all_chain_features) | |||
for chain_features in all_chain_features: | |||
# Convert deletion matrices to float. | |||
if 'deletion_matrix_int' in chain_features: | |||
chain_features['deletion_matrix'] = np.asarray( | |||
chain_features.pop('deletion_matrix_int'), dtype=np.float32) | |||
if 'deletion_matrix_int_all_seq' in chain_features: | |||
chain_features['deletion_matrix_all_seq'] = np.asarray( | |||
chain_features.pop('deletion_matrix_int_all_seq'), | |||
dtype=np.float32) | |||
chain_features['deletion_mean'] = np.mean( | |||
chain_features['deletion_matrix'], axis=0) | |||
if 'all_atom_positions' not in chain_features: | |||
# Add all_atom_mask and dummy all_atom_positions based on aatype. | |||
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ | |||
chain_features['aatype']] | |||
chain_features['all_atom_mask'] = all_atom_mask | |||
chain_features['all_atom_positions'] = np.zeros( | |||
list(all_atom_mask.shape) + [3]) | |||
# Add assembly_num_chains. | |||
chain_features['assembly_num_chains'] = np.asarray(num_chains) | |||
# Add entity_mask. | |||
for chain_features in all_chain_features: | |||
chain_features['entity_mask'] = ( | |||
chain_features['entity_id'] != # noqa W504 | |||
0).astype(np.int32) | |||
def empty_template_feats(n_res): | |||
return { | |||
'template_aatype': | |||
np.zeros((0, n_res)).astype(np.int64), | |||
'template_all_atom_positions': | |||
np.zeros((0, n_res, 37, 3)).astype(np.float32), | |||
'template_sum_probs': | |||
np.zeros((0, 1)).astype(np.float32), | |||
'template_all_atom_mask': | |||
np.zeros((0, n_res, 37)).astype(np.float32), | |||
} | |||
def convert_monomer_features(monomer_features: FeatureDict) -> FeatureDict: | |||
"""Reshapes and modifies monomer features for multimer models.""" | |||
if monomer_features['template_aatype'].shape[0] == 0: | |||
monomer_features.update( | |||
empty_template_feats(monomer_features['aatype'].shape[0])) | |||
converted = {} | |||
unnecessary_leading_dim_feats = { | |||
'sequence', | |||
'domain_name', | |||
'num_alignments', | |||
'seq_length', | |||
} | |||
for feature_name, feature in monomer_features.items(): | |||
if feature_name in unnecessary_leading_dim_feats: | |||
# asarray ensures it's a np.ndarray. | |||
feature = np.asarray(feature[0], dtype=feature.dtype) | |||
elif feature_name == 'aatype': | |||
# The multimer model performs the one-hot operation itself. | |||
feature = np.argmax(feature, axis=-1).astype(np.int32) | |||
elif feature_name == 'template_aatype': | |||
if feature.shape[0] > 0: | |||
feature = correct_template_restypes(feature) | |||
elif feature_name == 'template_all_atom_masks': | |||
feature_name = 'template_all_atom_mask' | |||
elif feature_name == 'msa': | |||
feature = feature.astype(np.uint8) | |||
if feature_name.endswith('_mask'): | |||
feature = feature.astype(np.float32) | |||
converted[feature_name] = feature | |||
if 'deletion_matrix_int' in monomer_features: | |||
monomer_features['deletion_matrix'] = monomer_features.pop( | |||
'deletion_matrix_int').astype(np.float32) | |||
converted.pop( | |||
'template_sum_probs' | |||
) # zy: this input is checked to be dirty in shape. TODO: figure out why and make it right. | |||
return converted | |||
def int_id_to_str_id(num: int) -> str: | |||
"""Encodes a number as a string, using reverse spreadsheet style naming. | |||
Args: | |||
num: A positive integer. | |||
Returns: | |||
A string that encodes the positive integer using reverse spreadsheet style, | |||
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the | |||
usual way to encode chain IDs in mmCIF files. | |||
""" | |||
if num <= 0: | |||
raise ValueError(f'Only positive integers allowed, got {num}.') | |||
num = num - 1 # 1-based indexing. | |||
output = [] | |||
while num >= 0: | |||
output.append(chr(num % 26 + ord('A'))) | |||
num = num // 26 - 1 | |||
return ''.join(output) | |||
def add_assembly_features(all_chain_features, ): | |||
"""Add features to distinguish between chains. | |||
Args: | |||
all_chain_features: A dictionary which maps chain_id to a dictionary of | |||
features for each chain. | |||
Returns: | |||
all_chain_features: A dictionary which maps strings of the form | |||
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two | |||
chains from a homodimer would have keys A_1 and A_2. Two chains from a | |||
heterodimer would have keys A_1 and B_1. | |||
""" | |||
# Group the chains by sequence | |||
seq_to_entity_id = {} | |||
grouped_chains = collections.defaultdict(list) | |||
for chain_features in all_chain_features: | |||
assert 'sequence' in chain_features | |||
seq = str(chain_features['sequence']) | |||
if seq not in seq_to_entity_id: | |||
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 | |||
grouped_chains[seq_to_entity_id[seq]].append(chain_features) | |||
new_all_chain_features = [] | |||
chain_id = 1 | |||
for entity_id, group_chain_features in grouped_chains.items(): | |||
num_sym = len(group_chain_features) # zy | |||
for sym_id, chain_features in enumerate(group_chain_features, start=1): | |||
seq_length = chain_features['seq_length'] | |||
chain_features['asym_id'] = chain_id * np.ones(seq_length) | |||
chain_features['sym_id'] = sym_id * np.ones(seq_length) | |||
chain_features['entity_id'] = entity_id * np.ones(seq_length) | |||
chain_features['num_sym'] = num_sym * np.ones(seq_length) | |||
chain_id += 1 | |||
new_all_chain_features.append(chain_features) | |||
return new_all_chain_features | |||
def pad_msa(np_example, min_num_seq): | |||
np_example = dict(np_example) | |||
num_seq = np_example['msa'].shape[0] | |||
if num_seq < min_num_seq: | |||
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask', | |||
'msa_chains'): | |||
np_example[feat] = np.pad(np_example[feat], | |||
((0, min_num_seq - num_seq), (0, 0))) | |||
np_example['cluster_bias_mask'] = np.pad( | |||
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq), )) | |||
return np_example | |||
def post_process(np_example): | |||
np_example = pad_msa(np_example, 512) | |||
no_dim_keys = [ | |||
'num_alignments', | |||
'assembly_num_chains', | |||
'num_templates', | |||
'seq_length', | |||
'resolution', | |||
] | |||
for k in no_dim_keys: | |||
if k in np_example: | |||
np_example[k] = np_example[k].reshape(-1) | |||
return np_example | |||
def merge_msas(msa, del_mat, new_msa, new_del_mat): | |||
cur_msa_set = set([tuple(m) for m in msa]) | |||
new_rows = [] | |||
for i, s in enumerate(new_msa): | |||
if tuple(s) not in cur_msa_set: | |||
new_rows.append(i) | |||
ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0) | |||
ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0) | |||
return ret_msa, ret_del_mat |
@@ -0,0 +1,322 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Protein data type.""" | |||
import dataclasses | |||
import io | |||
from typing import Any, Mapping, Optional | |||
import numpy as np | |||
from Bio.PDB import PDBParser | |||
from modelscope.models.science.unifold.data import residue_constants | |||
FeatureDict = Mapping[str, np.ndarray] | |||
ModelOutput = Mapping[str, Any] # Is a nested dict. | |||
# Complete sequence of chain IDs supported by the PDB format. | |||
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' | |||
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. | |||
@dataclasses.dataclass(frozen=True) | |||
class Protein: | |||
"""Protein structure representation.""" | |||
# Cartesian coordinates of atoms in angstroms. The atom types correspond to | |||
# residue_constants.atom_types, i.e. the first three are N, CA, CB. | |||
atom_positions: np.ndarray # [num_res, num_atom_type, 3] | |||
# Amino-acid type for each residue represented as an integer between 0 and | |||
# 20, where 20 is 'X'. | |||
aatype: np.ndarray # [num_res] | |||
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom | |||
# is present and 0.0 if not. This should be used for loss masking. | |||
atom_mask: np.ndarray # [num_res, num_atom_type] | |||
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed. | |||
residue_index: np.ndarray # [num_res] | |||
# 0-indexed number corresponding to the chain in the protein that this residue | |||
# belongs to. | |||
chain_index: np.ndarray # [num_res] | |||
# B-factors, or temperature factors, of each residue (in sq. angstroms units), | |||
# representing the displacement of the residue from its ground truth mean | |||
# value. | |||
b_factors: np.ndarray # [num_res, num_atom_type] | |||
def __post_init__(self): | |||
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: | |||
raise ValueError( | |||
f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' | |||
'because these cannot be written to PDB format.') | |||
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: | |||
"""Takes a PDB string and constructs a Protein object. | |||
WARNING: All non-standard residue types will be converted into UNK. All | |||
non-standard atoms will be ignored. | |||
Args: | |||
pdb_str: The contents of the pdb file | |||
chain_id: If chain_id is specified (e.g. A), then only that chain | |||
is parsed. Otherwise all chains are parsed. | |||
Returns: | |||
A new `Protein` parsed from the pdb contents. | |||
""" | |||
pdb_fh = io.StringIO(pdb_str) | |||
parser = PDBParser(QUIET=True) | |||
structure = parser.get_structure('none', pdb_fh) | |||
models = list(structure.get_models()) | |||
if len(models) != 1: | |||
raise ValueError( | |||
f'Only single model PDBs are supported. Found {len(models)} models.' | |||
) | |||
model = models[0] | |||
atom_positions = [] | |||
aatype = [] | |||
atom_mask = [] | |||
residue_index = [] | |||
chain_ids = [] | |||
b_factors = [] | |||
for chain in model: | |||
if chain_id is not None and chain.id != chain_id: | |||
continue | |||
for res in chain: | |||
if res.id[2] != ' ': | |||
raise ValueError( | |||
f'PDB contains an insertion code at chain {chain.id} and residue ' | |||
f'index {res.id[1]}. These are not supported.') | |||
res_shortname = residue_constants.restype_3to1.get( | |||
res.resname, 'X') | |||
restype_idx = residue_constants.restype_order.get( | |||
res_shortname, residue_constants.restype_num) | |||
pos = np.zeros((residue_constants.atom_type_num, 3)) | |||
mask = np.zeros((residue_constants.atom_type_num, )) | |||
res_b_factors = np.zeros((residue_constants.atom_type_num, )) | |||
for atom in res: | |||
if atom.name not in residue_constants.atom_types: | |||
continue | |||
pos[residue_constants.atom_order[atom.name]] = atom.coord | |||
mask[residue_constants.atom_order[atom.name]] = 1.0 | |||
res_b_factors[residue_constants.atom_order[ | |||
atom.name]] = atom.bfactor | |||
if np.sum(mask) < 0.5: | |||
# If no known atom positions are reported for the residue then skip it. | |||
continue | |||
aatype.append(restype_idx) | |||
atom_positions.append(pos) | |||
atom_mask.append(mask) | |||
residue_index.append(res.id[1]) | |||
chain_ids.append(chain.id) | |||
b_factors.append(res_b_factors) | |||
# Chain IDs are usually characters so map these to ints. | |||
unique_chain_ids = np.unique(chain_ids) | |||
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} | |||
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) | |||
return Protein( | |||
atom_positions=np.array(atom_positions), | |||
atom_mask=np.array(atom_mask), | |||
aatype=np.array(aatype), | |||
residue_index=np.array(residue_index), | |||
chain_index=chain_index, | |||
b_factors=np.array(b_factors), | |||
) | |||
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: | |||
chain_end = 'TER' | |||
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' | |||
f'{chain_name:>1}{residue_index:>4}') | |||
def to_pdb(prot: Protein) -> str: | |||
"""Converts a `Protein` instance to a PDB string. | |||
Args: | |||
prot: The protein to convert to PDB. | |||
Returns: | |||
PDB string. | |||
""" | |||
restypes = residue_constants.restypes + ['X'] | |||
# res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') | |||
def res_1to3(r): | |||
return residue_constants.restype_1to3.get(restypes[r], 'UNK') | |||
atom_types = residue_constants.atom_types | |||
pdb_lines = [] | |||
atom_mask = prot.atom_mask | |||
aatype = prot.aatype | |||
atom_positions = prot.atom_positions | |||
residue_index = prot.residue_index.astype(np.int32) | |||
chain_index = prot.chain_index.astype(np.int32) | |||
b_factors = prot.b_factors | |||
if np.any(aatype > residue_constants.restype_num): | |||
raise ValueError('Invalid aatypes.') | |||
# Construct a mapping from chain integer indices to chain ID strings. | |||
chain_ids = {} | |||
for i in np.unique(chain_index): # np.unique gives sorted output. | |||
if i >= PDB_MAX_CHAINS: | |||
raise ValueError( | |||
f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') | |||
chain_ids[i] = PDB_CHAIN_IDS[i] | |||
pdb_lines.append('MODEL 1') | |||
atom_index = 1 | |||
last_chain_index = chain_index[0] | |||
# Add all atom sites. | |||
for i in range(aatype.shape[0]): | |||
# Close the previous chain if in a multichain PDB. | |||
if last_chain_index != chain_index[i]: | |||
pdb_lines.append( | |||
_chain_end( | |||
atom_index, | |||
res_1to3(aatype[i - 1]), | |||
chain_ids[chain_index[i - 1]], | |||
residue_index[i - 1], | |||
)) | |||
last_chain_index = chain_index[i] | |||
atom_index += 1 # Atom index increases at the TER symbol. | |||
res_name_3 = res_1to3(aatype[i]) | |||
for atom_name, pos, mask, b_factor in zip(atom_types, | |||
atom_positions[i], | |||
atom_mask[i], b_factors[i]): | |||
if mask < 0.5: | |||
continue | |||
record_type = 'ATOM' | |||
name = atom_name if len(atom_name) == 4 else f' {atom_name}' | |||
alt_loc = '' | |||
insertion_code = '' | |||
occupancy = 1.00 | |||
element = atom_name[ | |||
0] # Protein supports only C, N, O, S, this works. | |||
charge = '' | |||
# PDB is a columnar format, every space matters here! | |||
atom_line = ( | |||
f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' | |||
f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' | |||
f'{residue_index[i]:>4}{insertion_code:>1} ' | |||
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' | |||
f'{occupancy:>6.2f}{b_factor:>6.2f} ' | |||
f'{element:>2}{charge:>2}') | |||
pdb_lines.append(atom_line) | |||
atom_index += 1 | |||
# Close the final chain. | |||
pdb_lines.append( | |||
_chain_end( | |||
atom_index, | |||
res_1to3(aatype[-1]), | |||
chain_ids[chain_index[-1]], | |||
residue_index[-1], | |||
)) | |||
pdb_lines.append('ENDMDL') | |||
pdb_lines.append('END') | |||
# Pad all lines to 80 characters. | |||
pdb_lines = [line.ljust(80) for line in pdb_lines] | |||
return '\n'.join(pdb_lines) + '\n' # Add terminating newline. | |||
def ideal_atom_mask(prot: Protein) -> np.ndarray: | |||
"""Computes an ideal atom mask. | |||
`Protein.atom_mask` typically is defined according to the atoms that are | |||
reported in the PDB. This function computes a mask according to heavy atoms | |||
that should be present in the given sequence of amino acids. | |||
Args: | |||
prot: `Protein` whose fields are `numpy.ndarray` objects. | |||
Returns: | |||
An ideal atom mask. | |||
""" | |||
return residue_constants.STANDARD_ATOM_MASK[prot.aatype] | |||
def from_prediction(features: FeatureDict, | |||
result: ModelOutput, | |||
b_factors: Optional[np.ndarray] = None) -> Protein: | |||
"""Assembles a protein from a prediction. | |||
Args: | |||
features: Dictionary holding model inputs. | |||
fold_output: Dictionary holding model outputs. | |||
b_factors: (Optional) B-factors to use for the protein. | |||
Returns: | |||
A protein instance. | |||
""" | |||
if 'asym_id' in features: | |||
chain_index = features['asym_id'] - 1 | |||
else: | |||
chain_index = np.zeros_like((features['aatype'])) | |||
if b_factors is None: | |||
b_factors = np.zeros_like(result['final_atom_mask']) | |||
return Protein( | |||
aatype=features['aatype'], | |||
atom_positions=result['final_atom_positions'], | |||
atom_mask=result['final_atom_mask'], | |||
residue_index=features['residue_index'] + 1, | |||
chain_index=chain_index, | |||
b_factors=b_factors, | |||
) | |||
def from_feature(features: FeatureDict, | |||
b_factors: Optional[np.ndarray] = None) -> Protein: | |||
"""Assembles a standard pdb from input atom positions & mask. | |||
Args: | |||
features: Dictionary holding model inputs. | |||
b_factors: (Optional) B-factors to use for the protein. | |||
Returns: | |||
A protein instance. | |||
""" | |||
if 'asym_id' in features: | |||
chain_index = features['asym_id'] - 1 | |||
else: | |||
chain_index = np.zeros_like((features['aatype'])) | |||
if b_factors is None: | |||
b_factors = np.zeros_like(features['all_atom_mask']) | |||
return Protein( | |||
aatype=features['aatype'], | |||
atom_positions=features['all_atom_positions'], | |||
atom_mask=features['all_atom_mask'], | |||
residue_index=features['residue_index'] + 1, | |||
chain_index=chain_index, | |||
b_factors=b_factors, | |||
) |
@@ -0,0 +1,345 @@ | |||
Bond Residue Mean StdDev | |||
CA-CB ALA 1.520 0.021 | |||
N-CA ALA 1.459 0.020 | |||
CA-C ALA 1.525 0.026 | |||
C-O ALA 1.229 0.019 | |||
CA-CB ARG 1.535 0.022 | |||
CB-CG ARG 1.521 0.027 | |||
CG-CD ARG 1.515 0.025 | |||
CD-NE ARG 1.460 0.017 | |||
NE-CZ ARG 1.326 0.013 | |||
CZ-NH1 ARG 1.326 0.013 | |||
CZ-NH2 ARG 1.326 0.013 | |||
N-CA ARG 1.459 0.020 | |||
CA-C ARG 1.525 0.026 | |||
C-O ARG 1.229 0.019 | |||
CA-CB ASN 1.527 0.026 | |||
CB-CG ASN 1.506 0.023 | |||
CG-OD1 ASN 1.235 0.022 | |||
CG-ND2 ASN 1.324 0.025 | |||
N-CA ASN 1.459 0.020 | |||
CA-C ASN 1.525 0.026 | |||
C-O ASN 1.229 0.019 | |||
CA-CB ASP 1.535 0.022 | |||
CB-CG ASP 1.513 0.021 | |||
CG-OD1 ASP 1.249 0.023 | |||
CG-OD2 ASP 1.249 0.023 | |||
N-CA ASP 1.459 0.020 | |||
CA-C ASP 1.525 0.026 | |||
C-O ASP 1.229 0.019 | |||
CA-CB CYS 1.526 0.013 | |||
CB-SG CYS 1.812 0.016 | |||
N-CA CYS 1.459 0.020 | |||
CA-C CYS 1.525 0.026 | |||
C-O CYS 1.229 0.019 | |||
CA-CB GLU 1.535 0.022 | |||
CB-CG GLU 1.517 0.019 | |||
CG-CD GLU 1.515 0.015 | |||
CD-OE1 GLU 1.252 0.011 | |||
CD-OE2 GLU 1.252 0.011 | |||
N-CA GLU 1.459 0.020 | |||
CA-C GLU 1.525 0.026 | |||
C-O GLU 1.229 0.019 | |||
CA-CB GLN 1.535 0.022 | |||
CB-CG GLN 1.521 0.027 | |||
CG-CD GLN 1.506 0.023 | |||
CD-OE1 GLN 1.235 0.022 | |||
CD-NE2 GLN 1.324 0.025 | |||
N-CA GLN 1.459 0.020 | |||
CA-C GLN 1.525 0.026 | |||
C-O GLN 1.229 0.019 | |||
N-CA GLY 1.456 0.015 | |||
CA-C GLY 1.514 0.016 | |||
C-O GLY 1.232 0.016 | |||
CA-CB HIS 1.535 0.022 | |||
CB-CG HIS 1.492 0.016 | |||
CG-ND1 HIS 1.369 0.015 | |||
CG-CD2 HIS 1.353 0.017 | |||
ND1-CE1 HIS 1.343 0.025 | |||
CD2-NE2 HIS 1.415 0.021 | |||
CE1-NE2 HIS 1.322 0.023 | |||
N-CA HIS 1.459 0.020 | |||
CA-C HIS 1.525 0.026 | |||
C-O HIS 1.229 0.019 | |||
CA-CB ILE 1.544 0.023 | |||
CB-CG1 ILE 1.536 0.028 | |||
CB-CG2 ILE 1.524 0.031 | |||
CG1-CD1 ILE 1.500 0.069 | |||
N-CA ILE 1.459 0.020 | |||
CA-C ILE 1.525 0.026 | |||
C-O ILE 1.229 0.019 | |||
CA-CB LEU 1.533 0.023 | |||
CB-CG LEU 1.521 0.029 | |||
CG-CD1 LEU 1.514 0.037 | |||
CG-CD2 LEU 1.514 0.037 | |||
N-CA LEU 1.459 0.020 | |||
CA-C LEU 1.525 0.026 | |||
C-O LEU 1.229 0.019 | |||
CA-CB LYS 1.535 0.022 | |||
CB-CG LYS 1.521 0.027 | |||
CG-CD LYS 1.520 0.034 | |||
CD-CE LYS 1.508 0.025 | |||
CE-NZ LYS 1.486 0.025 | |||
N-CA LYS 1.459 0.020 | |||
CA-C LYS 1.525 0.026 | |||
C-O LYS 1.229 0.019 | |||
CA-CB MET 1.535 0.022 | |||
CB-CG MET 1.509 0.032 | |||
CG-SD MET 1.807 0.026 | |||
SD-CE MET 1.774 0.056 | |||
N-CA MET 1.459 0.020 | |||
CA-C MET 1.525 0.026 | |||
C-O MET 1.229 0.019 | |||
CA-CB PHE 1.535 0.022 | |||
CB-CG PHE 1.509 0.017 | |||
CG-CD1 PHE 1.383 0.015 | |||
CG-CD2 PHE 1.383 0.015 | |||
CD1-CE1 PHE 1.388 0.020 | |||
CD2-CE2 PHE 1.388 0.020 | |||
CE1-CZ PHE 1.369 0.019 | |||
CE2-CZ PHE 1.369 0.019 | |||
N-CA PHE 1.459 0.020 | |||
CA-C PHE 1.525 0.026 | |||
C-O PHE 1.229 0.019 | |||
CA-CB PRO 1.531 0.020 | |||
CB-CG PRO 1.495 0.050 | |||
CG-CD PRO 1.502 0.033 | |||
CD-N PRO 1.474 0.014 | |||
N-CA PRO 1.468 0.017 | |||
CA-C PRO 1.524 0.020 | |||
C-O PRO 1.228 0.020 | |||
CA-CB SER 1.525 0.015 | |||
CB-OG SER 1.418 0.013 | |||
N-CA SER 1.459 0.020 | |||
CA-C SER 1.525 0.026 | |||
C-O SER 1.229 0.019 | |||
CA-CB THR 1.529 0.026 | |||
CB-OG1 THR 1.428 0.020 | |||
CB-CG2 THR 1.519 0.033 | |||
N-CA THR 1.459 0.020 | |||
CA-C THR 1.525 0.026 | |||
C-O THR 1.229 0.019 | |||
CA-CB TRP 1.535 0.022 | |||
CB-CG TRP 1.498 0.018 | |||
CG-CD1 TRP 1.363 0.014 | |||
CG-CD2 TRP 1.432 0.017 | |||
CD1-NE1 TRP 1.375 0.017 | |||
NE1-CE2 TRP 1.371 0.013 | |||
CD2-CE2 TRP 1.409 0.012 | |||
CD2-CE3 TRP 1.399 0.015 | |||
CE2-CZ2 TRP 1.393 0.017 | |||
CE3-CZ3 TRP 1.380 0.017 | |||
CZ2-CH2 TRP 1.369 0.019 | |||
CZ3-CH2 TRP 1.396 0.016 | |||
N-CA TRP 1.459 0.020 | |||
CA-C TRP 1.525 0.026 | |||
C-O TRP 1.229 0.019 | |||
CA-CB TYR 1.535 0.022 | |||
CB-CG TYR 1.512 0.015 | |||
CG-CD1 TYR 1.387 0.013 | |||
CG-CD2 TYR 1.387 0.013 | |||
CD1-CE1 TYR 1.389 0.015 | |||
CD2-CE2 TYR 1.389 0.015 | |||
CE1-CZ TYR 1.381 0.013 | |||
CE2-CZ TYR 1.381 0.013 | |||
CZ-OH TYR 1.374 0.017 | |||
N-CA TYR 1.459 0.020 | |||
CA-C TYR 1.525 0.026 | |||
C-O TYR 1.229 0.019 | |||
CA-CB VAL 1.543 0.021 | |||
CB-CG1 VAL 1.524 0.021 | |||
CB-CG2 VAL 1.524 0.021 | |||
N-CA VAL 1.459 0.020 | |||
CA-C VAL 1.525 0.026 | |||
C-O VAL 1.229 0.019 | |||
- | |||
Angle Residue Mean StdDev | |||
N-CA-CB ALA 110.1 1.4 | |||
CB-CA-C ALA 110.1 1.5 | |||
N-CA-C ALA 111.0 2.7 | |||
CA-C-O ALA 120.1 2.1 | |||
N-CA-CB ARG 110.6 1.8 | |||
CB-CA-C ARG 110.4 2.0 | |||
CA-CB-CG ARG 113.4 2.2 | |||
CB-CG-CD ARG 111.6 2.6 | |||
CG-CD-NE ARG 111.8 2.1 | |||
CD-NE-CZ ARG 123.6 1.4 | |||
NE-CZ-NH1 ARG 120.3 0.5 | |||
NE-CZ-NH2 ARG 120.3 0.5 | |||
NH1-CZ-NH2 ARG 119.4 1.1 | |||
N-CA-C ARG 111.0 2.7 | |||
CA-C-O ARG 120.1 2.1 | |||
N-CA-CB ASN 110.6 1.8 | |||
CB-CA-C ASN 110.4 2.0 | |||
CA-CB-CG ASN 113.4 2.2 | |||
CB-CG-ND2 ASN 116.7 2.4 | |||
CB-CG-OD1 ASN 121.6 2.0 | |||
ND2-CG-OD1 ASN 121.9 2.3 | |||
N-CA-C ASN 111.0 2.7 | |||
CA-C-O ASN 120.1 2.1 | |||
N-CA-CB ASP 110.6 1.8 | |||
CB-CA-C ASP 110.4 2.0 | |||
CA-CB-CG ASP 113.4 2.2 | |||
CB-CG-OD1 ASP 118.3 0.9 | |||
CB-CG-OD2 ASP 118.3 0.9 | |||
OD1-CG-OD2 ASP 123.3 1.9 | |||
N-CA-C ASP 111.0 2.7 | |||
CA-C-O ASP 120.1 2.1 | |||
N-CA-CB CYS 110.8 1.5 | |||
CB-CA-C CYS 111.5 1.2 | |||
CA-CB-SG CYS 114.2 1.1 | |||
N-CA-C CYS 111.0 2.7 | |||
CA-C-O CYS 120.1 2.1 | |||
N-CA-CB GLU 110.6 1.8 | |||
CB-CA-C GLU 110.4 2.0 | |||
CA-CB-CG GLU 113.4 2.2 | |||
CB-CG-CD GLU 114.2 2.7 | |||
CG-CD-OE1 GLU 118.3 2.0 | |||
CG-CD-OE2 GLU 118.3 2.0 | |||
OE1-CD-OE2 GLU 123.3 1.2 | |||
N-CA-C GLU 111.0 2.7 | |||
CA-C-O GLU 120.1 2.1 | |||
N-CA-CB GLN 110.6 1.8 | |||
CB-CA-C GLN 110.4 2.0 | |||
CA-CB-CG GLN 113.4 2.2 | |||
CB-CG-CD GLN 111.6 2.6 | |||
CG-CD-OE1 GLN 121.6 2.0 | |||
CG-CD-NE2 GLN 116.7 2.4 | |||
OE1-CD-NE2 GLN 121.9 2.3 | |||
N-CA-C GLN 111.0 2.7 | |||
CA-C-O GLN 120.1 2.1 | |||
N-CA-C GLY 113.1 2.5 | |||
CA-C-O GLY 120.6 1.8 | |||
N-CA-CB HIS 110.6 1.8 | |||
CB-CA-C HIS 110.4 2.0 | |||
CA-CB-CG HIS 113.6 1.7 | |||
CB-CG-ND1 HIS 123.2 2.5 | |||
CB-CG-CD2 HIS 130.8 3.1 | |||
CG-ND1-CE1 HIS 108.2 1.4 | |||
ND1-CE1-NE2 HIS 109.9 2.2 | |||
CE1-NE2-CD2 HIS 106.6 2.5 | |||
NE2-CD2-CG HIS 109.2 1.9 | |||
CD2-CG-ND1 HIS 106.0 1.4 | |||
N-CA-C HIS 111.0 2.7 | |||
CA-C-O HIS 120.1 2.1 | |||
N-CA-CB ILE 110.8 2.3 | |||
CB-CA-C ILE 111.6 2.0 | |||
CA-CB-CG1 ILE 111.0 1.9 | |||
CB-CG1-CD1 ILE 113.9 2.8 | |||
CA-CB-CG2 ILE 110.9 2.0 | |||
CG1-CB-CG2 ILE 111.4 2.2 | |||
N-CA-C ILE 111.0 2.7 | |||
CA-C-O ILE 120.1 2.1 | |||
N-CA-CB LEU 110.4 2.0 | |||
CB-CA-C LEU 110.2 1.9 | |||
CA-CB-CG LEU 115.3 2.3 | |||
CB-CG-CD1 LEU 111.0 1.7 | |||
CB-CG-CD2 LEU 111.0 1.7 | |||
CD1-CG-CD2 LEU 110.5 3.0 | |||
N-CA-C LEU 111.0 2.7 | |||
CA-C-O LEU 120.1 2.1 | |||
N-CA-CB LYS 110.6 1.8 | |||
CB-CA-C LYS 110.4 2.0 | |||
CA-CB-CG LYS 113.4 2.2 | |||
CB-CG-CD LYS 111.6 2.6 | |||
CG-CD-CE LYS 111.9 3.0 | |||
CD-CE-NZ LYS 111.7 2.3 | |||
N-CA-C LYS 111.0 2.7 | |||
CA-C-O LYS 120.1 2.1 | |||
N-CA-CB MET 110.6 1.8 | |||
CB-CA-C MET 110.4 2.0 | |||
CA-CB-CG MET 113.3 1.7 | |||
CB-CG-SD MET 112.4 3.0 | |||
CG-SD-CE MET 100.2 1.6 | |||
N-CA-C MET 111.0 2.7 | |||
CA-C-O MET 120.1 2.1 | |||
N-CA-CB PHE 110.6 1.8 | |||
CB-CA-C PHE 110.4 2.0 | |||
CA-CB-CG PHE 113.9 2.4 | |||
CB-CG-CD1 PHE 120.8 0.7 | |||
CB-CG-CD2 PHE 120.8 0.7 | |||
CD1-CG-CD2 PHE 118.3 1.3 | |||
CG-CD1-CE1 PHE 120.8 1.1 | |||
CG-CD2-CE2 PHE 120.8 1.1 | |||
CD1-CE1-CZ PHE 120.1 1.2 | |||
CD2-CE2-CZ PHE 120.1 1.2 | |||
CE1-CZ-CE2 PHE 120.0 1.8 | |||
N-CA-C PHE 111.0 2.7 | |||
CA-C-O PHE 120.1 2.1 | |||
N-CA-CB PRO 103.3 1.2 | |||
CB-CA-C PRO 111.7 2.1 | |||
CA-CB-CG PRO 104.8 1.9 | |||
CB-CG-CD PRO 106.5 3.9 | |||
CG-CD-N PRO 103.2 1.5 | |||
CA-N-CD PRO 111.7 1.4 | |||
N-CA-C PRO 112.1 2.6 | |||
CA-C-O PRO 120.2 2.4 | |||
N-CA-CB SER 110.5 1.5 | |||
CB-CA-C SER 110.1 1.9 | |||
CA-CB-OG SER 111.2 2.7 | |||
N-CA-C SER 111.0 2.7 | |||
CA-C-O SER 120.1 2.1 | |||
N-CA-CB THR 110.3 1.9 | |||
CB-CA-C THR 111.6 2.7 | |||
CA-CB-OG1 THR 109.0 2.1 | |||
CA-CB-CG2 THR 112.4 1.4 | |||
OG1-CB-CG2 THR 110.0 2.3 | |||
N-CA-C THR 111.0 2.7 | |||
CA-C-O THR 120.1 2.1 | |||
N-CA-CB TRP 110.6 1.8 | |||
CB-CA-C TRP 110.4 2.0 | |||
CA-CB-CG TRP 113.7 1.9 | |||
CB-CG-CD1 TRP 127.0 1.3 | |||
CB-CG-CD2 TRP 126.6 1.3 | |||
CD1-CG-CD2 TRP 106.3 0.8 | |||
CG-CD1-NE1 TRP 110.1 1.0 | |||
CD1-NE1-CE2 TRP 109.0 0.9 | |||
NE1-CE2-CD2 TRP 107.3 1.0 | |||
CE2-CD2-CG TRP 107.3 0.8 | |||
CG-CD2-CE3 TRP 133.9 0.9 | |||
NE1-CE2-CZ2 TRP 130.4 1.1 | |||
CE3-CD2-CE2 TRP 118.7 1.2 | |||
CD2-CE2-CZ2 TRP 122.3 1.2 | |||
CE2-CZ2-CH2 TRP 117.4 1.0 | |||
CZ2-CH2-CZ3 TRP 121.6 1.2 | |||
CH2-CZ3-CE3 TRP 121.2 1.1 | |||
CZ3-CE3-CD2 TRP 118.8 1.3 | |||
N-CA-C TRP 111.0 2.7 | |||
CA-C-O TRP 120.1 2.1 | |||
N-CA-CB TYR 110.6 1.8 | |||
CB-CA-C TYR 110.4 2.0 | |||
CA-CB-CG TYR 113.4 1.9 | |||
CB-CG-CD1 TYR 121.0 0.6 | |||
CB-CG-CD2 TYR 121.0 0.6 | |||
CD1-CG-CD2 TYR 117.9 1.1 | |||
CG-CD1-CE1 TYR 121.3 0.8 | |||
CG-CD2-CE2 TYR 121.3 0.8 | |||
CD1-CE1-CZ TYR 119.8 0.9 | |||
CD2-CE2-CZ TYR 119.8 0.9 | |||
CE1-CZ-CE2 TYR 119.8 1.6 | |||
CE1-CZ-OH TYR 120.1 2.7 | |||
CE2-CZ-OH TYR 120.1 2.7 | |||
N-CA-C TYR 111.0 2.7 | |||
CA-C-O TYR 120.1 2.1 | |||
N-CA-CB VAL 111.5 2.2 | |||
CB-CA-C VAL 111.4 1.9 | |||
CA-CB-CG1 VAL 110.9 1.5 | |||
CA-CB-CG2 VAL 110.9 1.5 | |||
CG1-CB-CG2 VAL 110.9 1.6 | |||
N-CA-C VAL 111.0 2.7 | |||
CA-C-O VAL 120.1 2.1 | |||
- | |||
Non-bonded distance Minimum Dist Tolerance | |||
C-C 3.4 1.5 | |||
C-N 3.25 1.5 | |||
C-S 3.5 1.5 | |||
C-O 3.22 1.5 | |||
N-N 3.1 1.5 | |||
N-S 3.35 1.5 | |||
N-O 3.07 1.5 | |||
O-S 3.32 1.5 | |||
O-O 3.04 1.5 | |||
S-S 2.03 1.0 | |||
- |
@@ -0,0 +1,161 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import copy as copy_lib | |||
import functools | |||
import gzip | |||
import pickle | |||
from typing import Any, Dict | |||
import json | |||
import numpy as np | |||
from scipy import sparse as sp | |||
from . import residue_constants as rc | |||
from .data_ops import NumpyDict | |||
# from typing import * | |||
def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False): | |||
if deepcopy: | |||
def decorator(f): | |||
cached_func = functools.lru_cache(maxsize, typed)(f) | |||
@functools.wraps(f) | |||
def wrapper(*args, **kwargs): | |||
return copy_lib.deepcopy(cached_func(*args, **kwargs)) | |||
return wrapper | |||
elif copy: | |||
def decorator(f): | |||
cached_func = functools.lru_cache(maxsize, typed)(f) | |||
@functools.wraps(f) | |||
def wrapper(*args, **kwargs): | |||
return copy_lib.copy(cached_func(*args, **kwargs)) | |||
return wrapper | |||
else: | |||
decorator = functools.lru_cache(maxsize, typed) | |||
return decorator | |||
@lru_cache(maxsize=8, deepcopy=True) | |||
def load_pickle_safe(path: str) -> Dict[str, Any]: | |||
def load(path): | |||
assert path.endswith('.pkl') or path.endswith( | |||
'.pkl.gz'), f'bad suffix in {path} as pickle file.' | |||
open_fn = gzip.open if path.endswith('.gz') else open | |||
with open_fn(path, 'rb') as f: | |||
return pickle.load(f) | |||
ret = load(path) | |||
ret = uncompress_features(ret) | |||
return ret | |||
@lru_cache(maxsize=8, copy=True) | |||
def load_pickle(path: str) -> Dict[str, Any]: | |||
def load(path): | |||
assert path.endswith('.pkl') or path.endswith( | |||
'.pkl.gz'), f'bad suffix in {path} as pickle file.' | |||
open_fn = gzip.open if path.endswith('.gz') else open | |||
with open_fn(path, 'rb') as f: | |||
return pickle.load(f) | |||
ret = load(path) | |||
ret = uncompress_features(ret) | |||
return ret | |||
def correct_template_restypes(feature): | |||
"""Correct template restype to have the same order as residue_constants.""" | |||
feature = np.argmax(feature, axis=-1).astype(np.int32) | |||
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE | |||
feature = np.take(new_order_list, feature.astype(np.int32), axis=0) | |||
return feature | |||
def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict: | |||
feature['msa'] = feature['msa'].astype(np.uint8) | |||
if 'num_alignments' in feature: | |||
feature.pop('num_alignments') | |||
# make_all_seq_key = lambda k: f'{k}_all_seq' if not k.endswith('_all_seq') else k | |||
def make_all_seq_key(k): | |||
if not k.endswith('_all_seq'): | |||
return f'{k}_all_seq' | |||
return k | |||
return {make_all_seq_key(k): v for k, v in feature.items()} | |||
def to_dense_matrix(spmat_dict: NumpyDict): | |||
spmat = sp.coo_matrix( | |||
(spmat_dict['data'], (spmat_dict['row'], spmat_dict['col'])), | |||
shape=spmat_dict['shape'], | |||
dtype=np.float32, | |||
) | |||
return spmat.toarray() | |||
FEATS_DTYPE = {'msa': np.int32} | |||
def uncompress_features(feats: NumpyDict) -> NumpyDict: | |||
if 'sparse_deletion_matrix_int' in feats: | |||
v = feats.pop('sparse_deletion_matrix_int') | |||
v = to_dense_matrix(v) | |||
feats['deletion_matrix'] = v | |||
return feats | |||
def filter(feature: NumpyDict, **kwargs) -> NumpyDict: | |||
assert len(kwargs) == 1, f'wrong usage of filter with kwargs: {kwargs}' | |||
if 'desired_keys' in kwargs: | |||
feature = { | |||
k: v | |||
for k, v in feature.items() if k in kwargs['desired_keys'] | |||
} | |||
elif 'required_keys' in kwargs: | |||
for k in kwargs['required_keys']: | |||
assert k in feature, f'cannot find required key {k}.' | |||
elif 'ignored_keys' in kwargs: | |||
feature = { | |||
k: v | |||
for k, v in feature.items() if k not in kwargs['ignored_keys'] | |||
} | |||
else: | |||
raise AssertionError(f'wrong usage of filter with kwargs: {kwargs}') | |||
return feature | |||
def compress_features(features: NumpyDict): | |||
change_dtype = { | |||
'msa': np.uint8, | |||
} | |||
sparse_keys = ['deletion_matrix_int'] | |||
compressed_features = {} | |||
for k, v in features.items(): | |||
if k in change_dtype: | |||
v = v.astype(change_dtype[k]) | |||
if k in sparse_keys: | |||
v = sp.coo_matrix(v, dtype=v.dtype) | |||
sp_v = { | |||
'shape': v.shape, | |||
'row': v.row, | |||
'col': v.col, | |||
'data': v.data | |||
} | |||
k = f'sparse_{k}' | |||
v = sp_v | |||
compressed_features[k] = v | |||
return compressed_features |
@@ -0,0 +1,514 @@ | |||
import copy | |||
import logging | |||
import os | |||
# from typing import * | |||
from typing import Dict, Iterable, List, Optional, Tuple, Union | |||
import json | |||
import ml_collections as mlc | |||
import numpy as np | |||
import torch | |||
from unicore.data import UnicoreDataset, data_utils | |||
from unicore.distributed import utils as distributed_utils | |||
from .data import utils | |||
from .data.data_ops import NumpyDict, TorchDict | |||
from .data.process import process_features, process_labels | |||
from .data.process_multimer import (add_assembly_features, | |||
convert_monomer_features, merge_msas, | |||
pair_and_merge, post_process) | |||
Rotation = Iterable[Iterable] | |||
Translation = Iterable | |||
Operation = Union[str, Tuple[Rotation, Translation]] | |||
NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]] | |||
TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]] | |||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |||
def make_data_config( | |||
config: mlc.ConfigDict, | |||
mode: str, | |||
num_res: int, | |||
) -> Tuple[mlc.ConfigDict, List[str]]: | |||
cfg = copy.deepcopy(config) | |||
mode_cfg = cfg[mode] | |||
with cfg.unlocked(): | |||
if mode_cfg.crop_size is None: | |||
mode_cfg.crop_size = num_res | |||
feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features | |||
if cfg.common.use_templates: | |||
feature_names += cfg.common.template_features | |||
if cfg.common.is_multimer: | |||
feature_names += cfg.common.multimer_features | |||
if cfg[mode].supervised: | |||
feature_names += cfg.supervised.supervised_features | |||
return cfg, feature_names | |||
def process_label(all_atom_positions: np.ndarray, | |||
operation: Operation) -> np.ndarray: | |||
if operation == 'I': | |||
return all_atom_positions | |||
rot, trans = operation | |||
rot = np.array(rot).reshape(3, 3) | |||
trans = np.array(trans).reshape(3) | |||
return all_atom_positions @ rot.T + trans | |||
@utils.lru_cache(maxsize=8, copy=True) | |||
def load_single_feature( | |||
sequence_id: str, | |||
monomer_feature_dir: str, | |||
uniprot_msa_dir: Optional[str] = None, | |||
is_monomer: bool = False, | |||
) -> NumpyDict: | |||
monomer_feature = utils.load_pickle( | |||
os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz')) | |||
monomer_feature = convert_monomer_features(monomer_feature) | |||
chain_feature = {**monomer_feature} | |||
if uniprot_msa_dir is not None: | |||
all_seq_feature = utils.load_pickle( | |||
os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz')) | |||
if is_monomer: | |||
chain_feature['msa'], chain_feature[ | |||
'deletion_matrix'] = merge_msas( | |||
chain_feature['msa'], | |||
chain_feature['deletion_matrix'], | |||
all_seq_feature['msa'], | |||
all_seq_feature['deletion_matrix'], | |||
) # noqa | |||
else: | |||
all_seq_feature = utils.convert_all_seq_feature(all_seq_feature) | |||
for key in [ | |||
'msa_all_seq', | |||
'msa_species_identifiers_all_seq', | |||
'deletion_matrix_all_seq', | |||
]: | |||
chain_feature[key] = all_seq_feature[key] | |||
return chain_feature | |||
def load_single_label( | |||
label_id: str, | |||
label_dir: str, | |||
symmetry_operation: Optional[Operation] = None, | |||
) -> NumpyDict: | |||
label = utils.load_pickle( | |||
os.path.join(label_dir, f'{label_id}.label.pkl.gz')) | |||
if symmetry_operation is not None: | |||
label['all_atom_positions'] = process_label( | |||
label['all_atom_positions'], symmetry_operation) | |||
label = { | |||
k: v | |||
for k, v in label.items() if k in | |||
['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] | |||
} | |||
return label | |||
def load( | |||
sequence_ids: List[str], | |||
monomer_feature_dir: str, | |||
uniprot_msa_dir: Optional[str] = None, | |||
label_ids: Optional[List[str]] = None, | |||
label_dir: Optional[str] = None, | |||
symmetry_operations: Optional[List[Operation]] = None, | |||
is_monomer: bool = False, | |||
) -> NumpyExample: | |||
all_chain_features = [ | |||
load_single_feature(s, monomer_feature_dir, uniprot_msa_dir, | |||
is_monomer) for s in sequence_ids | |||
] | |||
if label_ids is not None: | |||
# load labels | |||
assert len(label_ids) == len(sequence_ids) | |||
assert label_dir is not None | |||
if symmetry_operations is None: | |||
symmetry_operations = ['I' for _ in label_ids] | |||
all_chain_labels = [ | |||
load_single_label(ll, label_dir, o) | |||
for ll, o in zip(label_ids, symmetry_operations) | |||
] | |||
# update labels into features to calculate spatial cropping etc. | |||
[f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)] | |||
all_chain_features = add_assembly_features(all_chain_features) | |||
# get labels back from features, as add_assembly_features may alter the order of inputs. | |||
if label_ids is not None: | |||
all_chain_labels = [{ | |||
k: f[k] | |||
for k in | |||
['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] | |||
} for f in all_chain_features] | |||
else: | |||
all_chain_labels = None | |||
asym_len = np.array([c['seq_length'] for c in all_chain_features], | |||
dtype=np.int64) | |||
if is_monomer: | |||
all_chain_features = all_chain_features[0] | |||
else: | |||
all_chain_features = pair_and_merge(all_chain_features) | |||
all_chain_features = post_process(all_chain_features) | |||
all_chain_features['asym_len'] = asym_len | |||
return all_chain_features, all_chain_labels | |||
def process( | |||
config: mlc.ConfigDict, | |||
mode: str, | |||
features: NumpyDict, | |||
labels: Optional[List[NumpyDict]] = None, | |||
seed: int = 0, | |||
batch_idx: Optional[int] = None, | |||
data_idx: Optional[int] = None, | |||
is_distillation: bool = False, | |||
) -> TorchExample: | |||
if mode == 'train': | |||
assert batch_idx is not None | |||
with data_utils.numpy_seed(seed, batch_idx, key='recycling'): | |||
num_iters = np.random.randint( | |||
0, config.common.max_recycling_iters + 1) | |||
use_clamped_fape = np.random.rand( | |||
) < config[mode].use_clamped_fape_prob | |||
else: | |||
num_iters = config.common.max_recycling_iters | |||
use_clamped_fape = 1 | |||
features['num_recycling_iters'] = int(num_iters) | |||
features['use_clamped_fape'] = int(use_clamped_fape) | |||
features['is_distillation'] = int(is_distillation) | |||
if is_distillation and 'msa_chains' in features: | |||
features.pop('msa_chains') | |||
num_res = int(features['seq_length']) | |||
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) | |||
if labels is not None: | |||
features['resolution'] = labels[0]['resolution'].reshape(-1) | |||
with data_utils.numpy_seed(seed, data_idx, key='protein_feature'): | |||
features['crop_and_fix_size_seed'] = np.random.randint(0, 63355) | |||
features = utils.filter(features, desired_keys=feature_names) | |||
features = {k: torch.tensor(v) for k, v in features.items()} | |||
with torch.no_grad(): | |||
features = process_features(features, cfg.common, cfg[mode]) | |||
if labels is not None: | |||
labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels] | |||
with torch.no_grad(): | |||
labels = process_labels(labels) | |||
return features, labels | |||
def load_and_process( | |||
config: mlc.ConfigDict, | |||
mode: str, | |||
seed: int = 0, | |||
batch_idx: Optional[int] = None, | |||
data_idx: Optional[int] = None, | |||
is_distillation: bool = False, | |||
**load_kwargs, | |||
): | |||
is_monomer = ( | |||
is_distillation | |||
if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer')) | |||
features, labels = load(**load_kwargs, is_monomer=is_monomer) | |||
features, labels = process(config, mode, features, labels, seed, batch_idx, | |||
data_idx, is_distillation) | |||
return features, labels | |||
class UnifoldDataset(UnicoreDataset): | |||
def __init__( | |||
self, | |||
args, | |||
seed, | |||
config, | |||
data_path, | |||
mode='train', | |||
max_step=None, | |||
disable_sd=False, | |||
json_prefix='', | |||
): | |||
self.path = data_path | |||
def load_json(filename): | |||
return json.load(open(filename, 'r')) | |||
sample_weight = load_json( | |||
os.path.join(self.path, | |||
json_prefix + mode + '_sample_weight.json')) | |||
self.multi_label = load_json( | |||
os.path.join(self.path, json_prefix + mode + '_multi_label.json')) | |||
self.inverse_multi_label = self._inverse_map(self.multi_label) | |||
self.sample_weight = {} | |||
for chain in self.inverse_multi_label: | |||
entity = self.inverse_multi_label[chain] | |||
self.sample_weight[chain] = sample_weight[entity] | |||
self.seq_sample_weight = sample_weight | |||
logger.info('load {} chains (unique {} sequences)'.format( | |||
len(self.sample_weight), len(self.seq_sample_weight))) | |||
self.feature_path = os.path.join(self.path, 'pdb_features') | |||
self.label_path = os.path.join(self.path, 'pdb_labels') | |||
sd_sample_weight_path = os.path.join( | |||
self.path, json_prefix + 'sd_train_sample_weight.json') | |||
if mode == 'train' and os.path.isfile( | |||
sd_sample_weight_path) and not disable_sd: | |||
self.sd_sample_weight = load_json(sd_sample_weight_path) | |||
logger.info('load {} self-distillation samples.'.format( | |||
len(self.sd_sample_weight))) | |||
self.sd_feature_path = os.path.join(self.path, 'sd_features') | |||
self.sd_label_path = os.path.join(self.path, 'sd_labels') | |||
else: | |||
self.sd_sample_weight = None | |||
self.batch_size = ( | |||
args.batch_size * distributed_utils.get_data_parallel_world_size() | |||
* args.update_freq[0]) | |||
self.data_len = ( | |||
max_step * self.batch_size | |||
if max_step is not None else len(self.sample_weight)) | |||
self.mode = mode | |||
self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight( | |||
self.seq_sample_weight) | |||
self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( | |||
self.sample_weight) | |||
if self.sd_sample_weight is not None: | |||
( | |||
self.sd_num_chain, | |||
self.sd_chain_keys, | |||
self.sd_sample_prob, | |||
) = self.cal_sample_weight(self.sd_sample_weight) | |||
self.config = config.data | |||
self.seed = seed | |||
self.sd_prob = args.sd_prob | |||
def cal_sample_weight(self, sample_weight): | |||
prot_keys = list(sample_weight.keys()) | |||
sum_weight = sum(sample_weight.values()) | |||
sample_prob = [sample_weight[k] / sum_weight for k in prot_keys] | |||
num_prot = len(prot_keys) | |||
return num_prot, prot_keys, sample_prob | |||
def sample_chain(self, idx, sample_by_seq=False): | |||
is_distillation = False | |||
if self.mode == 'train': | |||
with data_utils.numpy_seed(self.seed, idx, key='data_sample'): | |||
is_distillation = ((np.random.rand(1)[0] < self.sd_prob) | |||
if self.sd_sample_weight is not None else | |||
False) | |||
if is_distillation: | |||
prot_idx = np.random.choice( | |||
self.sd_num_chain, p=self.sd_sample_prob) | |||
label_name = self.sd_chain_keys[prot_idx] | |||
seq_name = label_name | |||
else: | |||
if not sample_by_seq: | |||
prot_idx = np.random.choice( | |||
self.num_chain, p=self.sample_prob) | |||
label_name = self.chain_keys[prot_idx] | |||
seq_name = self.inverse_multi_label[label_name] | |||
else: | |||
seq_idx = np.random.choice( | |||
self.num_seq, p=self.seq_sample_prob) | |||
seq_name = self.seq_keys[seq_idx] | |||
label_name = np.random.choice( | |||
self.multi_label[seq_name]) | |||
else: | |||
label_name = self.chain_keys[idx] | |||
seq_name = self.inverse_multi_label[label_name] | |||
return seq_name, label_name, is_distillation | |||
def __getitem__(self, idx): | |||
sequence_id, label_id, is_distillation = self.sample_chain( | |||
idx, sample_by_seq=True) | |||
feature_dir, label_dir = ((self.feature_path, | |||
self.label_path) if not is_distillation else | |||
(self.sd_feature_path, self.sd_label_path)) | |||
features, _ = load_and_process( | |||
self.config, | |||
self.mode, | |||
self.seed, | |||
batch_idx=(idx // self.batch_size), | |||
data_idx=idx, | |||
is_distillation=is_distillation, | |||
sequence_ids=[sequence_id], | |||
monomer_feature_dir=feature_dir, | |||
uniprot_msa_dir=None, | |||
label_ids=[label_id], | |||
label_dir=label_dir, | |||
symmetry_operations=None, | |||
is_monomer=True, | |||
) | |||
return features | |||
def __len__(self): | |||
return self.data_len | |||
@staticmethod | |||
def collater(samples): | |||
# first dim is recyling. bsz is at the 2nd dim | |||
return data_utils.collate_dict(samples, dim=1) | |||
@staticmethod | |||
def _inverse_map(mapping: Dict[str, List[str]]): | |||
inverse_mapping = {} | |||
for ent, refs in mapping.items(): | |||
for ref in refs: | |||
if ref in inverse_mapping: # duplicated ent for this ref. | |||
ent_2 = inverse_mapping[ref] | |||
assert ( | |||
ent == ent_2 | |||
), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.' | |||
inverse_mapping[ref] = ent | |||
return inverse_mapping | |||
class UnifoldMultimerDataset(UnifoldDataset): | |||
def __init__( | |||
self, | |||
args: mlc.ConfigDict, | |||
seed: int, | |||
config: mlc.ConfigDict, | |||
data_path: str, | |||
mode: str = 'train', | |||
max_step: Optional[int] = None, | |||
disable_sd: bool = False, | |||
json_prefix: str = '', | |||
**kwargs, | |||
): | |||
super().__init__(args, seed, config, data_path, mode, max_step, | |||
disable_sd, json_prefix) | |||
self.data_path = data_path | |||
self.pdb_assembly = json.load( | |||
open( | |||
os.path.join(self.data_path, | |||
json_prefix + 'pdb_assembly.json'))) | |||
self.pdb_chains = self.get_chains(self.inverse_multi_label) | |||
self.monomer_feature_path = os.path.join(self.data_path, | |||
'pdb_features') | |||
self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots') | |||
self.label_path = os.path.join(self.data_path, 'pdb_labels') | |||
self.max_chains = args.max_chains | |||
if self.mode == 'train': | |||
self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains( | |||
self.pdb_chains, self.pdb_assembly, self.sample_weight, | |||
self.max_chains) | |||
self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( | |||
self.sample_weight) | |||
def __getitem__(self, idx): | |||
seq_id, label_id, is_distillation = self.sample_chain(idx) | |||
if is_distillation: | |||
label_ids = [label_id] | |||
sequence_ids = [seq_id] | |||
monomer_feature_path, uniprot_msa_path, label_path = ( | |||
self.sd_feature_path, | |||
None, | |||
self.sd_label_path, | |||
) | |||
symmetry_operations = None | |||
else: | |||
pdb_id = self.get_pdb_name(label_id) | |||
if pdb_id in self.pdb_assembly and self.mode == 'train': | |||
label_ids = [ | |||
pdb_id + '_' + id | |||
for id in self.pdb_assembly[pdb_id]['chains'] | |||
] | |||
symmetry_operations = [ | |||
t for t in self.pdb_assembly[pdb_id]['opers'] | |||
] | |||
else: | |||
label_ids = self.pdb_chains[pdb_id] | |||
symmetry_operations = None | |||
sequence_ids = [ | |||
self.inverse_multi_label[chain_id] for chain_id in label_ids | |||
] | |||
monomer_feature_path, uniprot_msa_path, label_path = ( | |||
self.monomer_feature_path, | |||
self.uniprot_msa_path, | |||
self.label_path, | |||
) | |||
return load_and_process( | |||
self.config, | |||
self.mode, | |||
self.seed, | |||
batch_idx=(idx // self.batch_size), | |||
data_idx=idx, | |||
is_distillation=is_distillation, | |||
sequence_ids=sequence_ids, | |||
monomer_feature_dir=monomer_feature_path, | |||
uniprot_msa_dir=uniprot_msa_path, | |||
label_ids=label_ids, | |||
label_dir=label_path, | |||
symmetry_operations=symmetry_operations, | |||
is_monomer=False, | |||
) | |||
@staticmethod | |||
def collater(samples): | |||
# first dim is recyling. bsz is at the 2nd dim | |||
if len(samples) <= 0: # tackle empty batch | |||
return None | |||
feats = [s[0] for s in samples] | |||
labs = [s[1] for s in samples if s[1] is not None] | |||
try: | |||
feats = data_utils.collate_dict(feats, dim=1) | |||
except BaseException: | |||
raise ValueError('cannot collate features', feats) | |||
if not labs: | |||
labs = None | |||
return feats, labs | |||
@staticmethod | |||
def get_pdb_name(chain): | |||
return chain.split('_')[0] | |||
@staticmethod | |||
def get_chains(canon_chain_map): | |||
pdb_chains = {} | |||
for chain in canon_chain_map: | |||
pdb = UnifoldMultimerDataset.get_pdb_name(chain) | |||
if pdb not in pdb_chains: | |||
pdb_chains[pdb] = [] | |||
pdb_chains[pdb].append(chain) | |||
return pdb_chains | |||
@staticmethod | |||
def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight, | |||
max_chains): | |||
new_pdb_chains = {} | |||
for chain in pdb_chains: | |||
if chain in pdb_assembly: | |||
size = len(pdb_assembly[chain]['chains']) | |||
if size <= max_chains: | |||
new_pdb_chains[chain] = pdb_chains[chain] | |||
else: | |||
size = len(pdb_chains[chain]) | |||
if size == 1: | |||
new_pdb_chains[chain] = pdb_chains[chain] | |||
new_sample_weight = { | |||
k: sample_weight[k] | |||
for k in sample_weight | |||
if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains | |||
} | |||
logger.info( | |||
f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs ' | |||
f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) ' | |||
f'by max_chains {max_chains}') | |||
return new_pdb_chains, new_sample_weight |
@@ -0,0 +1,75 @@ | |||
import argparse | |||
import os | |||
from typing import Any | |||
import torch | |||
from modelscope.metainfo import Models | |||
from modelscope.models import TorchModel | |||
from modelscope.models.builder import MODELS | |||
from modelscope.utils.constant import ModelFile, Tasks | |||
from .config import model_config | |||
from .modules.alphafold import AlphaFold | |||
__all__ = ['UnifoldForProteinStructrue'] | |||
@MODELS.register_module(Tasks.protein_structure, module_name=Models.unifold) | |||
class UnifoldForProteinStructrue(TorchModel): | |||
@staticmethod | |||
def add_args(parser): | |||
"""Add model-specific arguments to the parser.""" | |||
parser.add_argument( | |||
'--model-name', | |||
help='choose the model config', | |||
) | |||
def __init__(self, **kwargs): | |||
super().__init__() | |||
parser = argparse.ArgumentParser() | |||
parse_comm = [] | |||
for key in kwargs: | |||
parser.add_argument(f'--{key}') | |||
parse_comm.append(f'--{key}') | |||
parse_comm.append(kwargs[key]) | |||
args = parser.parse_args(parse_comm) | |||
base_architecture(args) | |||
self.args = args | |||
config = model_config( | |||
self.args.model_name, | |||
train=True, | |||
) | |||
self.model = AlphaFold(config) | |||
self.config = config | |||
# load model state dict | |||
param_path = os.path.join(kwargs['model_dir'], | |||
ModelFile.TORCH_MODEL_BIN_FILE) | |||
state_dict = torch.load(param_path)['ema']['params'] | |||
state_dict = { | |||
'.'.join(k.split('.')[1:]): v | |||
for k, v in state_dict.items() | |||
} | |||
self.model.load_state_dict(state_dict) | |||
def half(self): | |||
self.model = self.model.half() | |||
return self | |||
def bfloat16(self): | |||
self.model = self.model.bfloat16() | |||
return self | |||
@classmethod | |||
def build_model(cls, args, task): | |||
"""Build a new model instance.""" | |||
return cls(args) | |||
def forward(self, batch, **kwargs): | |||
outputs = self.model.forward(batch) | |||
return outputs, self.config.loss | |||
def base_architecture(args): | |||
args.model_name = getattr(args, 'model_name', 'model_2') |
@@ -0,0 +1,450 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import torch | |||
import torch.nn as nn | |||
from unicore.utils import tensor_tree_map | |||
from ..data import residue_constants | |||
from .attentions import gen_msa_attn_mask, gen_tri_attn_mask | |||
from .auxillary_heads import AuxiliaryHeads | |||
from .common import residual | |||
from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder, | |||
TemplateAngleEmbedder, TemplatePairEmbedder) | |||
from .evoformer import EvoformerStack, ExtraMSAStack | |||
from .featurization import (atom14_to_atom37, build_extra_msa_feat, | |||
build_template_angle_feat, | |||
build_template_pair_feat, | |||
build_template_pair_feat_v2, pseudo_beta_fn) | |||
from .structure_module import StructureModule | |||
from .template import (TemplatePairStack, TemplatePointwiseAttention, | |||
TemplateProjection) | |||
class AlphaFold(nn.Module): | |||
def __init__(self, config): | |||
super(AlphaFold, self).__init__() | |||
self.globals = config.globals | |||
config = config.model | |||
template_config = config.template | |||
extra_msa_config = config.extra_msa | |||
self.input_embedder = InputEmbedder( | |||
**config['input_embedder'], | |||
use_chain_relative=config.is_multimer, | |||
) | |||
self.recycling_embedder = RecyclingEmbedder( | |||
**config['recycling_embedder'], ) | |||
if config.template.enabled: | |||
self.template_angle_embedder = TemplateAngleEmbedder( | |||
**template_config['template_angle_embedder'], ) | |||
self.template_pair_embedder = TemplatePairEmbedder( | |||
**template_config['template_pair_embedder'], ) | |||
self.template_pair_stack = TemplatePairStack( | |||
**template_config['template_pair_stack'], ) | |||
else: | |||
self.template_pair_stack = None | |||
self.enable_template_pointwise_attention = template_config[ | |||
'template_pointwise_attention'].enabled | |||
if self.enable_template_pointwise_attention: | |||
self.template_pointwise_att = TemplatePointwiseAttention( | |||
**template_config['template_pointwise_attention'], ) | |||
else: | |||
self.template_proj = TemplateProjection( | |||
**template_config['template_pointwise_attention'], ) | |||
self.extra_msa_embedder = ExtraMSAEmbedder( | |||
**extra_msa_config['extra_msa_embedder'], ) | |||
self.extra_msa_stack = ExtraMSAStack( | |||
**extra_msa_config['extra_msa_stack'], ) | |||
self.evoformer = EvoformerStack(**config['evoformer_stack'], ) | |||
self.structure_module = StructureModule(**config['structure_module'], ) | |||
self.aux_heads = AuxiliaryHeads(config['heads'], ) | |||
self.config = config | |||
self.dtype = torch.float | |||
self.inf = self.globals.inf | |||
if self.globals.alphafold_original_mode: | |||
self.alphafold_original_mode() | |||
def __make_input_float__(self): | |||
self.input_embedder = self.input_embedder.float() | |||
self.recycling_embedder = self.recycling_embedder.float() | |||
def half(self): | |||
super().half() | |||
if (not getattr(self, 'inference', False)): | |||
self.__make_input_float__() | |||
self.dtype = torch.half | |||
return self | |||
def bfloat16(self): | |||
super().bfloat16() | |||
if (not getattr(self, 'inference', False)): | |||
self.__make_input_float__() | |||
self.dtype = torch.bfloat16 | |||
return self | |||
def alphafold_original_mode(self): | |||
def set_alphafold_original_mode(module): | |||
if hasattr(module, 'apply_alphafold_original_mode'): | |||
module.apply_alphafold_original_mode() | |||
if hasattr(module, 'act'): | |||
module.act = nn.ReLU() | |||
self.apply(set_alphafold_original_mode) | |||
def inference_mode(self): | |||
def set_inference_mode(module): | |||
setattr(module, 'inference', True) | |||
self.apply(set_inference_mode) | |||
def __convert_input_dtype__(self, batch): | |||
for key in batch: | |||
# only convert features with mask | |||
if batch[key].dtype != self.dtype and 'mask' in key: | |||
batch[key] = batch[key].type(self.dtype) | |||
return batch | |||
def embed_templates_pair_core(self, batch, z, pair_mask, | |||
tri_start_attn_mask, tri_end_attn_mask, | |||
templ_dim, multichain_mask_2d): | |||
if self.config.template.template_pair_embedder.v2_feature: | |||
t = build_template_pair_feat_v2( | |||
batch, | |||
inf=self.config.template.inf, | |||
eps=self.config.template.eps, | |||
multichain_mask_2d=multichain_mask_2d, | |||
**self.config.template.distogram, | |||
) | |||
num_template = t[0].shape[-4] | |||
single_templates = [ | |||
self.template_pair_embedder([x[..., ti, :, :, :] | |||
for x in t], z) | |||
for ti in range(num_template) | |||
] | |||
else: | |||
t = build_template_pair_feat( | |||
batch, | |||
inf=self.config.template.inf, | |||
eps=self.config.template.eps, | |||
**self.config.template.distogram, | |||
) | |||
single_templates = [ | |||
self.template_pair_embedder(x, z) | |||
for x in torch.unbind(t, dim=templ_dim) | |||
] | |||
t = self.template_pair_stack( | |||
single_templates, | |||
pair_mask, | |||
tri_start_attn_mask=tri_start_attn_mask, | |||
tri_end_attn_mask=tri_end_attn_mask, | |||
templ_dim=templ_dim, | |||
chunk_size=self.globals.chunk_size, | |||
block_size=self.globals.block_size, | |||
return_mean=not self.enable_template_pointwise_attention, | |||
) | |||
return t | |||
def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask, | |||
tri_end_attn_mask, templ_dim): | |||
if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch: | |||
multichain_mask_2d = ( | |||
batch['asym_id'][..., :, None] == batch['asym_id'][..., | |||
None, :]) | |||
multichain_mask_2d = multichain_mask_2d.unsqueeze(0) | |||
else: | |||
multichain_mask_2d = None | |||
if self.training or self.enable_template_pointwise_attention: | |||
t = self.embed_templates_pair_core(batch, z, pair_mask, | |||
tri_start_attn_mask, | |||
tri_end_attn_mask, templ_dim, | |||
multichain_mask_2d) | |||
if self.enable_template_pointwise_attention: | |||
t = self.template_pointwise_att( | |||
t, | |||
z, | |||
template_mask=batch['template_mask'], | |||
chunk_size=self.globals.chunk_size, | |||
) | |||
t_mask = torch.sum( | |||
batch['template_mask'], dim=-1, keepdims=True) > 0 | |||
t_mask = t_mask[..., None, None].type(t.dtype) | |||
t *= t_mask | |||
else: | |||
t = self.template_proj(t, z) | |||
else: | |||
template_aatype_shape = batch['template_aatype'].shape | |||
# template_aatype is either [n_template, n_res] or [1, n_template_, n_res] | |||
batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0 | |||
n_templ = batch['template_aatype'].shape[batch_templ_dim] | |||
if n_templ <= 0: | |||
t = None | |||
else: | |||
template_batch = { | |||
k: v | |||
for k, v in batch.items() if k.startswith('template_') | |||
} | |||
def embed_one_template(i): | |||
def slice_template_tensor(t): | |||
s = [slice(None) for _ in t.shape] | |||
s[batch_templ_dim] = slice(i, i + 1) | |||
return t[s] | |||
template_feats = tensor_tree_map( | |||
slice_template_tensor, | |||
template_batch, | |||
) | |||
t = self.embed_templates_pair_core( | |||
template_feats, z, pair_mask, tri_start_attn_mask, | |||
tri_end_attn_mask, templ_dim, multichain_mask_2d) | |||
return t | |||
t = embed_one_template(0) | |||
# iterate templates one by one | |||
for i in range(1, n_templ): | |||
t += embed_one_template(i) | |||
t /= n_templ | |||
t = self.template_proj(t, z) | |||
return t | |||
def embed_templates_angle(self, batch): | |||
template_angle_feat, template_angle_mask = build_template_angle_feat( | |||
batch, | |||
v2_feature=self.config.template.template_pair_embedder.v2_feature) | |||
t = self.template_angle_embedder(template_angle_feat) | |||
return t, template_angle_mask | |||
def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev): | |||
batch_dims = feats['target_feat'].shape[:-2] | |||
n = feats['target_feat'].shape[-2] | |||
seq_mask = feats['seq_mask'] | |||
pair_mask = seq_mask[..., None] * seq_mask[..., None, :] | |||
msa_mask = feats['msa_mask'] | |||
m, z = self.input_embedder( | |||
feats['target_feat'], | |||
feats['msa_feat'], | |||
) | |||
if m_1_prev is None: | |||
m_1_prev = m.new_zeros( | |||
(*batch_dims, n, self.config.input_embedder.d_msa), | |||
requires_grad=False, | |||
) | |||
if z_prev is None: | |||
z_prev = z.new_zeros( | |||
(*batch_dims, n, n, self.config.input_embedder.d_pair), | |||
requires_grad=False, | |||
) | |||
if x_prev is None: | |||
x_prev = z.new_zeros( | |||
(*batch_dims, n, residue_constants.atom_type_num, 3), | |||
requires_grad=False, | |||
) | |||
x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None) | |||
z += self.recycling_embedder.recyle_pos(x_prev) | |||
m_1_prev_emb, z_prev_emb = self.recycling_embedder( | |||
m_1_prev, | |||
z_prev, | |||
) | |||
m[..., 0, :, :] += m_1_prev_emb | |||
z += z_prev_emb | |||
z += self.input_embedder.relpos_emb( | |||
feats['residue_index'].long(), | |||
feats.get('sym_id', None), | |||
feats.get('asym_id', None), | |||
feats.get('entity_id', None), | |||
feats.get('num_sym', None), | |||
) | |||
m = m.type(self.dtype) | |||
z = z.type(self.dtype) | |||
tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask( | |||
pair_mask, self.inf) | |||
if self.config.template.enabled: | |||
template_mask = feats['template_mask'] | |||
if torch.any(template_mask): | |||
z = residual( | |||
z, | |||
self.embed_templates_pair( | |||
feats, | |||
z, | |||
pair_mask, | |||
tri_start_attn_mask, | |||
tri_end_attn_mask, | |||
templ_dim=-4, | |||
), | |||
self.training, | |||
) | |||
if self.config.extra_msa.enabled: | |||
a = self.extra_msa_embedder(build_extra_msa_feat(feats)) | |||
extra_msa_row_mask = gen_msa_attn_mask( | |||
feats['extra_msa_mask'], | |||
inf=self.inf, | |||
gen_col_mask=False, | |||
) | |||
z = self.extra_msa_stack( | |||
a, | |||
z, | |||
msa_mask=feats['extra_msa_mask'], | |||
chunk_size=self.globals.chunk_size, | |||
block_size=self.globals.block_size, | |||
pair_mask=pair_mask, | |||
msa_row_attn_mask=extra_msa_row_mask, | |||
msa_col_attn_mask=None, | |||
tri_start_attn_mask=tri_start_attn_mask, | |||
tri_end_attn_mask=tri_end_attn_mask, | |||
) | |||
if self.config.template.embed_angles: | |||
template_1d_feat, template_1d_mask = self.embed_templates_angle( | |||
feats) | |||
m = torch.cat([m, template_1d_feat], dim=-3) | |||
msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2) | |||
msa_row_mask, msa_col_mask = gen_msa_attn_mask( | |||
msa_mask, | |||
inf=self.inf, | |||
) | |||
m, z, s = self.evoformer( | |||
m, | |||
z, | |||
msa_mask=msa_mask, | |||
pair_mask=pair_mask, | |||
msa_row_attn_mask=msa_row_mask, | |||
msa_col_attn_mask=msa_col_mask, | |||
tri_start_attn_mask=tri_start_attn_mask, | |||
tri_end_attn_mask=tri_end_attn_mask, | |||
chunk_size=self.globals.chunk_size, | |||
block_size=self.globals.block_size, | |||
) | |||
return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb | |||
def iteration_evoformer_structure_module(self, | |||
batch, | |||
m_1_prev, | |||
z_prev, | |||
x_prev, | |||
cycle_no, | |||
num_recycling, | |||
num_ensembles=1): | |||
z, s = 0, 0 | |||
n_seq = batch['msa_feat'].shape[-3] | |||
assert num_ensembles >= 1 | |||
for ensemble_no in range(num_ensembles): | |||
idx = cycle_no * num_ensembles + ensemble_no | |||
# fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...] | |||
def fetch_cur_batch(t): | |||
return t[min(t.shape[0] - 1, idx), ...] | |||
feats = tensor_tree_map(fetch_cur_batch, batch) | |||
m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer( | |||
feats, m_1_prev, z_prev, x_prev) | |||
z += z0 | |||
s += s0 | |||
del z0, s0 | |||
if num_ensembles > 1: | |||
z /= float(num_ensembles) | |||
s /= float(num_ensembles) | |||
outputs = {} | |||
outputs['msa'] = m[..., :n_seq, :, :] | |||
outputs['pair'] = z | |||
outputs['single'] = s | |||
# norm loss | |||
if (not getattr(self, 'inference', | |||
False)) and num_recycling == (cycle_no + 1): | |||
delta_msa = m | |||
delta_msa[..., | |||
0, :, :] = delta_msa[..., | |||
0, :, :] - m_1_prev_emb.detach() | |||
delta_pair = z - z_prev_emb.detach() | |||
outputs['delta_msa'] = delta_msa | |||
outputs['delta_pair'] = delta_pair | |||
outputs['msa_norm_mask'] = msa_mask | |||
outputs['sm'] = self.structure_module( | |||
s, | |||
z, | |||
feats['aatype'], | |||
mask=feats['seq_mask'], | |||
) | |||
outputs['final_atom_positions'] = atom14_to_atom37( | |||
outputs['sm']['positions'], feats) | |||
outputs['final_atom_mask'] = feats['atom37_atom_exists'] | |||
outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1] | |||
# use float32 for numerical stability | |||
if (not getattr(self, 'inference', False)): | |||
m_1_prev = m[..., 0, :, :].float() | |||
z_prev = z.float() | |||
x_prev = outputs['final_atom_positions'].float() | |||
else: | |||
m_1_prev = m[..., 0, :, :] | |||
z_prev = z | |||
x_prev = outputs['final_atom_positions'] | |||
return outputs, m_1_prev, z_prev, x_prev | |||
def forward(self, batch): | |||
m_1_prev = batch.get('m_1_prev', None) | |||
z_prev = batch.get('z_prev', None) | |||
x_prev = batch.get('x_prev', None) | |||
is_grad_enabled = torch.is_grad_enabled() | |||
num_iters = int(batch['num_recycling_iters']) + 1 | |||
num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters | |||
if self.training: | |||
# don't use ensemble during training | |||
assert num_ensembles == 1 | |||
# convert dtypes in batch | |||
batch = self.__convert_input_dtype__(batch) | |||
for cycle_no in range(num_iters): | |||
is_final_iter = cycle_no == (num_iters - 1) | |||
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): | |||
( | |||
outputs, | |||
m_1_prev, | |||
z_prev, | |||
x_prev, | |||
) = self.iteration_evoformer_structure_module( | |||
batch, | |||
m_1_prev, | |||
z_prev, | |||
x_prev, | |||
cycle_no=cycle_no, | |||
num_recycling=num_iters, | |||
num_ensembles=num_ensembles, | |||
) | |||
if not is_final_iter: | |||
del outputs | |||
if 'asym_id' in batch: | |||
outputs['asym_id'] = batch['asym_id'][0, ...] | |||
outputs.update(self.aux_heads(outputs)) | |||
return outputs |
@@ -0,0 +1,430 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from functools import partialmethod | |||
from typing import List, Optional | |||
import torch | |||
import torch.nn as nn | |||
from unicore.modules import LayerNorm, softmax_dropout | |||
from unicore.utils import permute_final_dims | |||
from .common import Linear, chunk_layer | |||
def gen_attn_mask(mask, neg_inf): | |||
assert neg_inf < -1e4 | |||
attn_mask = torch.zeros_like(mask) | |||
attn_mask[mask == 0] = neg_inf | |||
return attn_mask | |||
class Attention(nn.Module): | |||
def __init__( | |||
self, | |||
q_dim: int, | |||
k_dim: int, | |||
v_dim: int, | |||
head_dim: int, | |||
num_heads: int, | |||
gating: bool = True, | |||
): | |||
super(Attention, self).__init__() | |||
self.num_heads = num_heads | |||
total_dim = head_dim * self.num_heads | |||
self.gating = gating | |||
self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot') | |||
self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot') | |||
self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot') | |||
self.linear_o = Linear(total_dim, q_dim, init='final') | |||
self.linear_g = None | |||
if self.gating: | |||
self.linear_g = Linear(q_dim, total_dim, init='gating') | |||
# precompute the 1/sqrt(head_dim) | |||
self.norm = head_dim**-0.5 | |||
def forward( | |||
self, | |||
q: torch.Tensor, | |||
k: torch.Tensor, | |||
v: torch.Tensor, | |||
mask: torch.Tensor = None, | |||
bias: Optional[torch.Tensor] = None, | |||
) -> torch.Tensor: | |||
g = None | |||
if self.linear_g is not None: | |||
# gating, use raw query input | |||
g = self.linear_g(q) | |||
q = self.linear_q(q) | |||
q *= self.norm | |||
k = self.linear_k(k) | |||
v = self.linear_v(v) | |||
q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose( | |||
-2, -3).contiguous() | |||
k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose( | |||
-2, -3).contiguous() | |||
v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) | |||
attn = torch.matmul(q, k.transpose(-1, -2)) | |||
del q, k | |||
attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) | |||
o = torch.matmul(attn, v) | |||
del attn, v | |||
o = o.transpose(-2, -3).contiguous() | |||
o = o.view(*o.shape[:-2], -1) | |||
if g is not None: | |||
o = torch.sigmoid(g) * o | |||
# merge heads | |||
o = nn.functional.linear(o, self.linear_o.weight) | |||
return o | |||
def get_output_bias(self): | |||
return self.linear_o.bias | |||
class GlobalAttention(nn.Module): | |||
def __init__(self, input_dim, head_dim, num_heads, inf, eps): | |||
super(GlobalAttention, self).__init__() | |||
self.num_heads = num_heads | |||
self.inf = inf | |||
self.eps = eps | |||
self.linear_q = Linear( | |||
input_dim, head_dim * num_heads, bias=False, init='glorot') | |||
self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot') | |||
self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot') | |||
self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating') | |||
self.linear_o = Linear(head_dim * num_heads, input_dim, init='final') | |||
self.sigmoid = nn.Sigmoid() | |||
# precompute the 1/sqrt(head_dim) | |||
self.norm = head_dim**-0.5 | |||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |||
# gating | |||
g = self.sigmoid(self.linear_g(x)) | |||
k = self.linear_k(x) | |||
v = self.linear_v(x) | |||
q = torch.sum( | |||
x * mask.unsqueeze(-1), dim=-2) / ( | |||
torch.sum(mask, dim=-1, keepdims=True) + self.eps) | |||
q = self.linear_q(q) | |||
q *= self.norm | |||
q = q.view(q.shape[:-1] + (self.num_heads, -1)) | |||
attn = torch.matmul(q, k.transpose(-1, -2)) | |||
del q, k | |||
attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :] | |||
attn = softmax_dropout(attn, 0, self.training, mask=attn_mask) | |||
o = torch.matmul( | |||
attn, | |||
v, | |||
) | |||
del attn, v | |||
g = g.view(g.shape[:-1] + (self.num_heads, -1)) | |||
o = o.unsqueeze(-3) * g | |||
del g | |||
# merge heads | |||
o = o.reshape(o.shape[:-2] + (-1, )) | |||
return self.linear_o(o) | |||
def gen_msa_attn_mask(mask, inf, gen_col_mask=True): | |||
row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] | |||
if gen_col_mask: | |||
col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, | |||
None, :] | |||
return row_mask, col_mask | |||
else: | |||
return row_mask | |||
class MSAAttention(nn.Module): | |||
def __init__( | |||
self, | |||
d_in, | |||
d_hid, | |||
num_heads, | |||
pair_bias=False, | |||
d_pair=None, | |||
): | |||
super(MSAAttention, self).__init__() | |||
self.pair_bias = pair_bias | |||
self.layer_norm_m = LayerNorm(d_in) | |||
self.layer_norm_z = None | |||
self.linear_z = None | |||
if self.pair_bias: | |||
self.layer_norm_z = LayerNorm(d_pair) | |||
self.linear_z = Linear( | |||
d_pair, num_heads, bias=False, init='normal') | |||
self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) | |||
@torch.jit.ignore | |||
def _chunk( | |||
self, | |||
m: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None, | |||
bias: Optional[torch.Tensor] = None, | |||
chunk_size: int = None, | |||
) -> torch.Tensor: | |||
return chunk_layer( | |||
self._attn_forward, | |||
{ | |||
'm': m, | |||
'mask': mask, | |||
'bias': bias | |||
}, | |||
chunk_size=chunk_size, | |||
num_batch_dims=len(m.shape[:-2]), | |||
) | |||
@torch.jit.ignore | |||
def _attn_chunk_forward( | |||
self, | |||
m: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None, | |||
bias: Optional[torch.Tensor] = None, | |||
chunk_size: Optional[int] = 2560, | |||
) -> torch.Tensor: | |||
m = self.layer_norm_m(m) | |||
num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size | |||
outputs = [] | |||
for i in range(num_chunk): | |||
chunk_start = i * chunk_size | |||
chunk_end = min(m.shape[-3], chunk_start + chunk_size) | |||
cur_m = m[..., chunk_start:chunk_end, :, :] | |||
cur_mask = ( | |||
mask[..., chunk_start:chunk_end, :, :, :] | |||
if mask is not None else None) | |||
outputs.append( | |||
self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) | |||
return torch.concat(outputs, dim=-3) | |||
def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): | |||
m = self.layer_norm_m(m) | |||
return self.mha(q=m, k=m, v=m, mask=mask, bias=bias) | |||
def forward( | |||
self, | |||
m: torch.Tensor, | |||
z: Optional[torch.Tensor] = None, | |||
attn_mask: Optional[torch.Tensor] = None, | |||
chunk_size: Optional[int] = None, | |||
) -> torch.Tensor: | |||
bias = None | |||
if self.pair_bias: | |||
z = self.layer_norm_z(z) | |||
bias = ( | |||
permute_final_dims(self.linear_z(z), | |||
(2, 0, 1)).unsqueeze(-4).contiguous()) | |||
if chunk_size is not None: | |||
m = self._chunk(m, attn_mask, bias, chunk_size) | |||
else: | |||
attn_chunk_size = 2560 | |||
if m.shape[-3] <= attn_chunk_size: | |||
m = self._attn_forward(m, attn_mask, bias) | |||
else: | |||
# reduce the peak memory cost in extra_msa_stack | |||
return self._attn_chunk_forward( | |||
m, attn_mask, bias, chunk_size=attn_chunk_size) | |||
return m | |||
def get_output_bias(self): | |||
return self.mha.get_output_bias() | |||
class MSARowAttentionWithPairBias(MSAAttention): | |||
def __init__(self, d_msa, d_pair, d_hid, num_heads): | |||
super(MSARowAttentionWithPairBias, self).__init__( | |||
d_msa, | |||
d_hid, | |||
num_heads, | |||
pair_bias=True, | |||
d_pair=d_pair, | |||
) | |||
class MSAColumnAttention(MSAAttention): | |||
def __init__(self, d_msa, d_hid, num_heads): | |||
super(MSAColumnAttention, self).__init__( | |||
d_in=d_msa, | |||
d_hid=d_hid, | |||
num_heads=num_heads, | |||
pair_bias=False, | |||
d_pair=None, | |||
) | |||
def forward( | |||
self, | |||
m: torch.Tensor, | |||
attn_mask: Optional[torch.Tensor] = None, | |||
chunk_size: Optional[int] = None, | |||
) -> torch.Tensor: | |||
m = m.transpose(-2, -3) | |||
m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size) | |||
m = m.transpose(-2, -3) | |||
return m | |||
class MSAColumnGlobalAttention(nn.Module): | |||
def __init__( | |||
self, | |||
d_in, | |||
d_hid, | |||
num_heads, | |||
inf=1e9, | |||
eps=1e-10, | |||
): | |||
super(MSAColumnGlobalAttention, self).__init__() | |||
self.layer_norm_m = LayerNorm(d_in) | |||
self.global_attention = GlobalAttention( | |||
d_in, | |||
d_hid, | |||
num_heads, | |||
inf=inf, | |||
eps=eps, | |||
) | |||
@torch.jit.ignore | |||
def _chunk( | |||
self, | |||
m: torch.Tensor, | |||
mask: torch.Tensor, | |||
chunk_size: int, | |||
) -> torch.Tensor: | |||
return chunk_layer( | |||
self._attn_forward, | |||
{ | |||
'm': m, | |||
'mask': mask | |||
}, | |||
chunk_size=chunk_size, | |||
num_batch_dims=len(m.shape[:-2]), | |||
) | |||
def _attn_forward(self, m, mask): | |||
m = self.layer_norm_m(m) | |||
return self.global_attention(m, mask=mask) | |||
def forward( | |||
self, | |||
m: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None, | |||
chunk_size: Optional[int] = None, | |||
) -> torch.Tensor: | |||
m = m.transpose(-2, -3) | |||
mask = mask.transpose(-1, -2) | |||
if chunk_size is not None: | |||
m = self._chunk(m, mask, chunk_size) | |||
else: | |||
m = self._attn_forward(m, mask=mask) | |||
m = m.transpose(-2, -3) | |||
return m | |||
def gen_tri_attn_mask(mask, inf): | |||
start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] | |||
end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, | |||
None, :] | |||
return start_mask, end_mask | |||
class TriangleAttention(nn.Module): | |||
def __init__( | |||
self, | |||
d_in, | |||
d_hid, | |||
num_heads, | |||
starting, | |||
): | |||
super(TriangleAttention, self).__init__() | |||
self.starting = starting | |||
self.layer_norm = LayerNorm(d_in) | |||
self.linear = Linear(d_in, num_heads, bias=False, init='normal') | |||
self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) | |||
@torch.jit.ignore | |||
def _chunk( | |||
self, | |||
x: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None, | |||
bias: Optional[torch.Tensor] = None, | |||
chunk_size: int = None, | |||
) -> torch.Tensor: | |||
return chunk_layer( | |||
self.mha, | |||
{ | |||
'q': x, | |||
'k': x, | |||
'v': x, | |||
'mask': mask, | |||
'bias': bias | |||
}, | |||
chunk_size=chunk_size, | |||
num_batch_dims=len(x.shape[:-2]), | |||
) | |||
def forward( | |||
self, | |||
x: torch.Tensor, | |||
attn_mask: Optional[torch.Tensor] = None, | |||
chunk_size: Optional[int] = None, | |||
) -> torch.Tensor: | |||
if not self.starting: | |||
x = x.transpose(-2, -3) | |||
x = self.layer_norm(x) | |||
triangle_bias = ( | |||
permute_final_dims(self.linear(x), | |||
(2, 0, 1)).unsqueeze(-4).contiguous()) | |||
if chunk_size is not None: | |||
x = self._chunk(x, attn_mask, triangle_bias, chunk_size) | |||
else: | |||
x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias) | |||
if not self.starting: | |||
x = x.transpose(-2, -3) | |||
return x | |||
def get_output_bias(self): | |||
return self.mha.get_output_bias() | |||
class TriangleAttentionStarting(TriangleAttention): | |||
__init__ = partialmethod(TriangleAttention.__init__, starting=True) | |||
class TriangleAttentionEnding(TriangleAttention): | |||
__init__ = partialmethod(TriangleAttention.__init__, starting=False) |
@@ -0,0 +1,171 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from typing import Dict | |||
import torch.nn as nn | |||
from unicore.modules import LayerNorm | |||
from .common import Linear | |||
from .confidence import (predicted_aligned_error, predicted_lddt, | |||
predicted_tm_score) | |||
class AuxiliaryHeads(nn.Module): | |||
def __init__(self, config): | |||
super(AuxiliaryHeads, self).__init__() | |||
self.plddt = PredictedLDDTHead(**config['plddt'], ) | |||
self.distogram = DistogramHead(**config['distogram'], ) | |||
self.masked_msa = MaskedMSAHead(**config['masked_msa'], ) | |||
if config.experimentally_resolved.enabled: | |||
self.experimentally_resolved = ExperimentallyResolvedHead( | |||
**config['experimentally_resolved'], ) | |||
if config.pae.enabled: | |||
self.pae = PredictedAlignedErrorHead(**config.pae, ) | |||
self.config = config | |||
def forward(self, outputs): | |||
aux_out = {} | |||
plddt_logits = self.plddt(outputs['sm']['single']) | |||
aux_out['plddt_logits'] = plddt_logits | |||
aux_out['plddt'] = predicted_lddt(plddt_logits.detach()) | |||
distogram_logits = self.distogram(outputs['pair']) | |||
aux_out['distogram_logits'] = distogram_logits | |||
masked_msa_logits = self.masked_msa(outputs['msa']) | |||
aux_out['masked_msa_logits'] = masked_msa_logits | |||
if self.config.experimentally_resolved.enabled: | |||
exp_res_logits = self.experimentally_resolved(outputs['single']) | |||
aux_out['experimentally_resolved_logits'] = exp_res_logits | |||
if self.config.pae.enabled: | |||
pae_logits = self.pae(outputs['pair']) | |||
aux_out['pae_logits'] = pae_logits | |||
pae_logits = pae_logits.detach() | |||
aux_out.update( | |||
predicted_aligned_error( | |||
pae_logits, | |||
**self.config.pae, | |||
)) | |||
aux_out['ptm'] = predicted_tm_score( | |||
pae_logits, interface=False, **self.config.pae) | |||
iptm_weight = self.config.pae.get('iptm_weight', 0.0) | |||
if iptm_weight > 0.0: | |||
aux_out['iptm'] = predicted_tm_score( | |||
pae_logits, | |||
interface=True, | |||
asym_id=outputs['asym_id'], | |||
**self.config.pae, | |||
) | |||
aux_out['iptm+ptm'] = ( | |||
iptm_weight * aux_out['iptm'] + # noqa W504 | |||
(1.0 - iptm_weight) * aux_out['ptm']) | |||
return aux_out | |||
class PredictedLDDTHead(nn.Module): | |||
def __init__(self, num_bins, d_in, d_hid): | |||
super(PredictedLDDTHead, self).__init__() | |||
self.num_bins = num_bins | |||
self.d_in = d_in | |||
self.d_hid = d_hid | |||
self.layer_norm = LayerNorm(self.d_in) | |||
self.linear_1 = Linear(self.d_in, self.d_hid, init='relu') | |||
self.linear_2 = Linear(self.d_hid, self.d_hid, init='relu') | |||
self.act = nn.GELU() | |||
self.linear_3 = Linear(self.d_hid, self.num_bins, init='final') | |||
def forward(self, s): | |||
s = self.layer_norm(s) | |||
s = self.linear_1(s) | |||
s = self.act(s) | |||
s = self.linear_2(s) | |||
s = self.act(s) | |||
s = self.linear_3(s) | |||
return s | |||
class EnhancedHeadBase(nn.Module): | |||
def __init__(self, d_in, d_out, disable_enhance_head): | |||
super(EnhancedHeadBase, self).__init__() | |||
if disable_enhance_head: | |||
self.layer_norm = None | |||
self.linear_in = None | |||
else: | |||
self.layer_norm = LayerNorm(d_in) | |||
self.linear_in = Linear(d_in, d_in, init='relu') | |||
self.act = nn.GELU() | |||
self.linear = Linear(d_in, d_out, init='final') | |||
def apply_alphafold_original_mode(self): | |||
self.layer_norm = None | |||
self.linear_in = None | |||
def forward(self, x): | |||
if self.layer_norm is not None: | |||
x = self.layer_norm(x) | |||
x = self.act(self.linear_in(x)) | |||
logits = self.linear(x) | |||
return logits | |||
class DistogramHead(EnhancedHeadBase): | |||
def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): | |||
super(DistogramHead, self).__init__( | |||
d_in=d_pair, | |||
d_out=num_bins, | |||
disable_enhance_head=disable_enhance_head, | |||
) | |||
def forward(self, x): | |||
logits = super().forward(x) | |||
logits = logits + logits.transpose(-2, -3) | |||
return logits | |||
class PredictedAlignedErrorHead(EnhancedHeadBase): | |||
def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): | |||
super(PredictedAlignedErrorHead, self).__init__( | |||
d_in=d_pair, | |||
d_out=num_bins, | |||
disable_enhance_head=disable_enhance_head, | |||
) | |||
class MaskedMSAHead(EnhancedHeadBase): | |||
def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs): | |||
super(MaskedMSAHead, self).__init__( | |||
d_in=d_msa, | |||
d_out=d_out, | |||
disable_enhance_head=disable_enhance_head, | |||
) | |||
class ExperimentallyResolvedHead(EnhancedHeadBase): | |||
def __init__(self, d_single, d_out, disable_enhance_head, **kwargs): | |||
super(ExperimentallyResolvedHead, self).__init__( | |||
d_in=d_single, | |||
d_out=d_out, | |||
disable_enhance_head=disable_enhance_head, | |||
) |
@@ -0,0 +1,387 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from functools import partial | |||
from typing import Any, Callable, Dict, Iterable, List, Optional | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import torch.utils.checkpoint | |||
from unicore.modules import LayerNorm | |||
from unicore.utils import tensor_tree_map | |||
class Linear(nn.Linear): | |||
def __init__( | |||
self, | |||
d_in: int, | |||
d_out: int, | |||
bias: bool = True, | |||
init: str = 'default', | |||
): | |||
super(Linear, self).__init__(d_in, d_out, bias=bias) | |||
self.use_bias = bias | |||
if self.use_bias: | |||
with torch.no_grad(): | |||
self.bias.fill_(0) | |||
if init == 'default': | |||
self._trunc_normal_init(1.0) | |||
elif init == 'relu': | |||
self._trunc_normal_init(2.0) | |||
elif init == 'glorot': | |||
self._glorot_uniform_init() | |||
elif init == 'gating': | |||
self._zero_init(self.use_bias) | |||
elif init == 'normal': | |||
self._normal_init() | |||
elif init == 'final': | |||
self._zero_init(False) | |||
else: | |||
raise ValueError('Invalid init method.') | |||
def _trunc_normal_init(self, scale=1.0): | |||
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) | |||
TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 | |||
_, fan_in = self.weight.shape | |||
scale = scale / max(1, fan_in) | |||
std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR | |||
nn.init.trunc_normal_(self.weight, mean=0.0, std=std) | |||
def _glorot_uniform_init(self): | |||
nn.init.xavier_uniform_(self.weight, gain=1) | |||
def _zero_init(self, use_bias=True): | |||
with torch.no_grad(): | |||
self.weight.fill_(0.0) | |||
if use_bias: | |||
with torch.no_grad(): | |||
self.bias.fill_(1.0) | |||
def _normal_init(self): | |||
torch.nn.init.kaiming_normal_(self.weight, nonlinearity='linear') | |||
class Transition(nn.Module): | |||
def __init__(self, d_in, n): | |||
super(Transition, self).__init__() | |||
self.d_in = d_in | |||
self.n = n | |||
self.layer_norm = LayerNorm(self.d_in) | |||
self.linear_1 = Linear(self.d_in, self.n * self.d_in, init='relu') | |||
self.act = nn.GELU() | |||
self.linear_2 = Linear(self.n * self.d_in, d_in, init='final') | |||
def _transition(self, x): | |||
x = self.layer_norm(x) | |||
x = self.linear_1(x) | |||
x = self.act(x) | |||
x = self.linear_2(x) | |||
return x | |||
@torch.jit.ignore | |||
def _chunk( | |||
self, | |||
x: torch.Tensor, | |||
chunk_size: int, | |||
) -> torch.Tensor: | |||
return chunk_layer( | |||
self._transition, | |||
{'x': x}, | |||
chunk_size=chunk_size, | |||
num_batch_dims=len(x.shape[:-2]), | |||
) | |||
def forward( | |||
self, | |||
x: torch.Tensor, | |||
chunk_size: Optional[int] = None, | |||
) -> torch.Tensor: | |||
if chunk_size is not None: | |||
x = self._chunk(x, chunk_size) | |||
else: | |||
x = self._transition(x=x) | |||
return x | |||
class OuterProductMean(nn.Module): | |||
def __init__(self, d_msa, d_pair, d_hid, eps=1e-3): | |||
super(OuterProductMean, self).__init__() | |||
self.d_msa = d_msa | |||
self.d_pair = d_pair | |||
self.d_hid = d_hid | |||
self.eps = eps | |||
self.layer_norm = LayerNorm(d_msa) | |||
self.linear_1 = Linear(d_msa, d_hid) | |||
self.linear_2 = Linear(d_msa, d_hid) | |||
self.linear_out = Linear(d_hid**2, d_pair, init='relu') | |||
self.act = nn.GELU() | |||
self.linear_z = Linear(self.d_pair, self.d_pair, init='final') | |||
self.layer_norm_out = LayerNorm(self.d_pair) | |||
def _opm(self, a, b): | |||
outer = torch.einsum('...bac,...dae->...bdce', a, b) | |||
outer = outer.reshape(outer.shape[:-2] + (-1, )) | |||
outer = self.linear_out(outer) | |||
return outer | |||
@torch.jit.ignore | |||
def _chunk(self, a: torch.Tensor, b: torch.Tensor, | |||
chunk_size: int) -> torch.Tensor: | |||
a = a.reshape((-1, ) + a.shape[-3:]) | |||
b = b.reshape((-1, ) + b.shape[-3:]) | |||
out = [] | |||
# TODO: optimize this | |||
for a_prime, b_prime in zip(a, b): | |||
outer = chunk_layer( | |||
partial(self._opm, b=b_prime), | |||
{'a': a_prime}, | |||
chunk_size=chunk_size, | |||
num_batch_dims=1, | |||
) | |||
out.append(outer) | |||
if len(out) == 1: | |||
outer = out[0].unsqueeze(0) | |||
else: | |||
outer = torch.stack(out, dim=0) | |||
outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) | |||
return outer | |||
def apply_alphafold_original_mode(self): | |||
self.linear_z = None | |||
self.layer_norm_out = None | |||
def forward( | |||
self, | |||
m: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None, | |||
chunk_size: Optional[int] = None, | |||
) -> torch.Tensor: | |||
m = self.layer_norm(m) | |||
mask = mask.unsqueeze(-1) | |||
if self.layer_norm_out is not None: | |||
# for numerical stability | |||
mask = mask * (mask.size(-2)**-0.5) | |||
a = self.linear_1(m) | |||
b = self.linear_2(m) | |||
if self.training: | |||
a = a * mask | |||
b = b * mask | |||
else: | |||
a *= mask | |||
b *= mask | |||
a = a.transpose(-2, -3) | |||
b = b.transpose(-2, -3) | |||
if chunk_size is not None: | |||
z = self._chunk(a, b, chunk_size) | |||
else: | |||
z = self._opm(a, b) | |||
norm = torch.einsum('...abc,...adc->...bdc', mask, mask) | |||
z /= self.eps + norm | |||
if self.layer_norm_out is not None: | |||
z = self.act(z) | |||
z = self.layer_norm_out(z) | |||
z = self.linear_z(z) | |||
return z | |||
def residual(residual, x, training): | |||
if training: | |||
return x + residual | |||
else: | |||
residual += x | |||
return residual | |||
@torch.jit.script | |||
def fused_bias_dropout_add( | |||
x: torch.Tensor, | |||
bias: torch.Tensor, | |||
residual: torch.Tensor, | |||
dropmask: torch.Tensor, | |||
prob: float, | |||
) -> torch.Tensor: | |||
return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual | |||
@torch.jit.script | |||
def fused_bias_dropout_add_inference( | |||
x: torch.Tensor, | |||
bias: torch.Tensor, | |||
residual: torch.Tensor, | |||
) -> torch.Tensor: | |||
residual += bias + x | |||
return residual | |||
def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob, | |||
training): | |||
bias = module.get_output_bias() | |||
if training: | |||
shape = list(x.shape) | |||
shape[dropout_shared_dim] = 1 | |||
with torch.no_grad(): | |||
mask = x.new_ones(shape) | |||
return fused_bias_dropout_add(x, bias, residual, mask, prob) | |||
else: | |||
return fused_bias_dropout_add_inference(x, bias, residual) | |||
@torch.jit.script | |||
def fused_bias_gated_dropout_add( | |||
x: torch.Tensor, | |||
bias: torch.Tensor, | |||
g: torch.Tensor, | |||
g_bias: torch.Tensor, | |||
residual: torch.Tensor, | |||
dropout_mask: torch.Tensor, | |||
prob: float, | |||
) -> torch.Tensor: | |||
return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout( | |||
dropout_mask, | |||
p=prob, | |||
training=True, | |||
) + residual | |||
def tri_mul_residual( | |||
module, | |||
residual, | |||
outputs, | |||
dropout_shared_dim, | |||
prob, | |||
training, | |||
block_size, | |||
): | |||
if training: | |||
x, g = outputs | |||
bias, g_bias = module.get_output_bias() | |||
shape = list(x.shape) | |||
shape[dropout_shared_dim] = 1 | |||
with torch.no_grad(): | |||
mask = x.new_ones(shape) | |||
return fused_bias_gated_dropout_add( | |||
x, | |||
bias, | |||
g, | |||
g_bias, | |||
residual, | |||
mask, | |||
prob, | |||
) | |||
elif block_size is None: | |||
x, g = outputs | |||
bias, g_bias = module.get_output_bias() | |||
residual += (torch.sigmoid(g + g_bias) * (x + bias)) | |||
return residual | |||
else: | |||
# gated is not used here | |||
residual += outputs | |||
return residual | |||
class SimpleModuleList(nn.ModuleList): | |||
def __repr__(self): | |||
return str(len(self)) + ' X ...\n' + self[0].__repr__() | |||
def chunk_layer( | |||
layer: Callable, | |||
inputs: Dict[str, Any], | |||
chunk_size: int, | |||
num_batch_dims: int, | |||
) -> Any: | |||
# TODO: support inplace add to output | |||
if not (len(inputs) > 0): | |||
raise ValueError('Must provide at least one input') | |||
def _dict_get_shapes(input): | |||
shapes = [] | |||
if type(input) is torch.Tensor: | |||
shapes.append(input.shape) | |||
elif type(input) is dict: | |||
for v in input.values(): | |||
shapes.extend(_dict_get_shapes(v)) | |||
elif isinstance(input, Iterable): | |||
for v in input: | |||
shapes.extend(_dict_get_shapes(v)) | |||
else: | |||
raise ValueError('Not supported') | |||
return shapes | |||
inputs = {k: v for k, v in inputs.items() if v is not None} | |||
initial_dims = [ | |||
shape[:num_batch_dims] for shape in _dict_get_shapes(inputs) | |||
] | |||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) | |||
flat_batch_dim = 1 | |||
for d in orig_batch_dims: | |||
flat_batch_dim *= d | |||
num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size | |||
def _flat_inputs(t): | |||
t = t.view(-1, *t.shape[num_batch_dims:]) | |||
assert ( | |||
t.shape[0] == flat_batch_dim or t.shape[0] == 1 | |||
), 'batch dimension must be 1 or equal to the flat batch dimension' | |||
return t | |||
flat_inputs = tensor_tree_map(_flat_inputs, inputs) | |||
out = None | |||
for i in range(num_chunks): | |||
chunk_start = i * chunk_size | |||
chunk_end = min((i + 1) * chunk_size, flat_batch_dim) | |||
def select_chunk(t): | |||
if t.shape[0] == 1: | |||
return t[0:1] | |||
else: | |||
return t[chunk_start:chunk_end] | |||
chunkes = tensor_tree_map(select_chunk, flat_inputs) | |||
output_chunk = layer(**chunkes) | |||
if out is None: | |||
out = tensor_tree_map( | |||
lambda t: t.new_zeros((flat_batch_dim, ) + t.shape[1:]), | |||
output_chunk) | |||
out_type = type(output_chunk) | |||
if out_type is tuple: | |||
for x, y in zip(out, output_chunk): | |||
x[chunk_start:chunk_end] = y | |||
elif out_type is torch.Tensor: | |||
out[chunk_start:chunk_end] = output_chunk | |||
else: | |||
raise ValueError('Not supported') | |||
# reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) | |||
def reshape(t): | |||
return t.view(orig_batch_dims + t.shape[1:]) | |||
out = tensor_tree_map(reshape, out) | |||
return out |
@@ -0,0 +1,159 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from typing import Dict, Optional, Tuple | |||
import torch | |||
def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor: | |||
"""Computes per-residue pLDDT from logits. | |||
Args: | |||
logits: [num_res, num_bins] output from the PredictedLDDTHead. | |||
Returns: | |||
plddt: [num_res] per-residue pLDDT. | |||
""" | |||
num_bins = plddt_logits.shape[-1] | |||
bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1) | |||
bin_width = 1.0 / num_bins | |||
bounds = torch.arange( | |||
start=0.5 * bin_width, | |||
end=1.0, | |||
step=bin_width, | |||
device=plddt_logits.device) | |||
plddt = torch.sum( | |||
bin_probs | |||
* bounds.view(*((1, ) * len(bin_probs.shape[:-1])), *bounds.shape), | |||
dim=-1, | |||
) | |||
return plddt | |||
def compute_bin_values(breaks: torch.Tensor): | |||
"""Gets the bin centers from the bin edges. | |||
Args: | |||
breaks: [num_bins - 1] the error bin edges. | |||
Returns: | |||
bin_centers: [num_bins] the error bin centers. | |||
""" | |||
step = breaks[1] - breaks[0] | |||
bin_values = breaks + step / 2 | |||
bin_values = torch.cat([bin_values, (bin_values[-1] + step).unsqueeze(-1)], | |||
dim=0) | |||
return bin_values | |||
def compute_predicted_aligned_error( | |||
bin_edges: torch.Tensor, | |||
bin_probs: torch.Tensor, | |||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||
"""Calculates expected aligned distance errors for every pair of residues. | |||
Args: | |||
alignment_confidence_breaks: [num_bins - 1] the error bin edges. | |||
aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted | |||
probs for each error bin, for each pair of residues. | |||
Returns: | |||
predicted_aligned_error: [num_res, num_res] the expected aligned distance | |||
error for each pair of residues. | |||
max_predicted_aligned_error: The maximum predicted error possible. | |||
""" | |||
bin_values = compute_bin_values(bin_edges) | |||
return torch.sum(bin_probs * bin_values, dim=-1) | |||
def predicted_aligned_error( | |||
pae_logits: torch.Tensor, | |||
max_bin: int = 31, | |||
num_bins: int = 64, | |||
**kwargs, | |||
) -> Dict[str, torch.Tensor]: | |||
"""Computes aligned confidence metrics from logits. | |||
Args: | |||
logits: [num_res, num_res, num_bins] the logits output from | |||
PredictedAlignedErrorHead. | |||
breaks: [num_bins - 1] the error bin edges. | |||
Returns: | |||
aligned_confidence_probs: [num_res, num_res, num_bins] the predicted | |||
aligned error probabilities over bins for each residue pair. | |||
predicted_aligned_error: [num_res, num_res] the expected aligned distance | |||
error for each pair of residues. | |||
max_predicted_aligned_error: The maximum predicted error possible. | |||
""" | |||
bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1) | |||
bin_edges = torch.linspace( | |||
0, max_bin, steps=(num_bins - 1), device=pae_logits.device) | |||
predicted_aligned_error = compute_predicted_aligned_error( | |||
bin_edges=bin_edges, | |||
bin_probs=bin_probs, | |||
) | |||
return { | |||
'aligned_error_probs_per_bin': bin_probs, | |||
'predicted_aligned_error': predicted_aligned_error, | |||
} | |||
def predicted_tm_score( | |||
pae_logits: torch.Tensor, | |||
residue_weights: Optional[torch.Tensor] = None, | |||
max_bin: int = 31, | |||
num_bins: int = 64, | |||
eps: float = 1e-8, | |||
asym_id: Optional[torch.Tensor] = None, | |||
interface: bool = False, | |||
**kwargs, | |||
) -> torch.Tensor: | |||
"""Computes predicted TM alignment or predicted interface TM alignment score. | |||
Args: | |||
logits: [num_res, num_res, num_bins] the logits output from | |||
PredictedAlignedErrorHead. | |||
breaks: [num_bins] the error bins. | |||
residue_weights: [num_res] the per residue weights to use for the | |||
expectation. | |||
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for | |||
ipTM calculation, i.e. when interface=True. | |||
interface: If True, interface predicted TM score is computed. | |||
Returns: | |||
ptm_score: The predicted TM alignment or the predicted iTM score. | |||
""" | |||
pae_logits = pae_logits.float() | |||
if residue_weights is None: | |||
residue_weights = pae_logits.new_ones(pae_logits.shape[:-2]) | |||
breaks = torch.linspace( | |||
0, max_bin, steps=(num_bins - 1), device=pae_logits.device) | |||
def tm_kernal(nres): | |||
clipped_n = max(nres, 19) | |||
d0 = 1.24 * (clipped_n - 15)**(1.0 / 3.0) - 1.8 | |||
return lambda x: 1.0 / (1.0 + (x / d0)**2) | |||
def rmsd_kernal(eps): # leave for compute pRMS | |||
return lambda x: 1. / (x + eps) | |||
bin_centers = compute_bin_values(breaks) | |||
probs = torch.nn.functional.softmax(pae_logits, dim=-1) | |||
tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers) | |||
# tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) | |||
# rmsd_per_bin = rmsd_kernal()(bin_centers) | |||
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) | |||
pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape) | |||
if interface: | |||
assert asym_id is not None, 'must provide asym_id for iptm calculation.' | |||
pair_mask *= asym_id[..., :, None] != asym_id[..., None, :] | |||
predicted_tm_term *= pair_mask | |||
pair_residue_weights = pair_mask * ( | |||
residue_weights[None, :] * residue_weights[:, None]) | |||
normed_residue_mask = pair_residue_weights / ( | |||
eps + pair_residue_weights.sum(dim=-1, keepdim=True)) | |||
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) | |||
weighted = per_alignment * residue_weights | |||
ret = per_alignment.gather( | |||
dim=-1, index=weighted.max(dim=-1, | |||
keepdim=True).indices).squeeze(dim=-1) | |||
return ret |
@@ -0,0 +1,290 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from typing import Optional, Tuple | |||
import torch | |||
import torch.nn as nn | |||
from unicore.modules import LayerNorm | |||
from unicore.utils import one_hot | |||
from .common import Linear, SimpleModuleList, residual | |||
class InputEmbedder(nn.Module): | |||
def __init__( | |||
self, | |||
tf_dim: int, | |||
msa_dim: int, | |||
d_pair: int, | |||
d_msa: int, | |||
relpos_k: int, | |||
use_chain_relative: bool = False, | |||
max_relative_chain: Optional[int] = None, | |||
**kwargs, | |||
): | |||
super(InputEmbedder, self).__init__() | |||
self.tf_dim = tf_dim | |||
self.msa_dim = msa_dim | |||
self.d_pair = d_pair | |||
self.d_msa = d_msa | |||
self.linear_tf_z_i = Linear(tf_dim, d_pair) | |||
self.linear_tf_z_j = Linear(tf_dim, d_pair) | |||
self.linear_tf_m = Linear(tf_dim, d_msa) | |||
self.linear_msa_m = Linear(msa_dim, d_msa) | |||
# RPE stuff | |||
self.relpos_k = relpos_k | |||
self.use_chain_relative = use_chain_relative | |||
self.max_relative_chain = max_relative_chain | |||
if not self.use_chain_relative: | |||
self.num_bins = 2 * self.relpos_k + 1 | |||
else: | |||
self.num_bins = 2 * self.relpos_k + 2 | |||
self.num_bins += 1 # entity id | |||
self.num_bins += 2 * max_relative_chain + 2 | |||
self.linear_relpos = Linear(self.num_bins, d_pair) | |||
def _relpos_indices( | |||
self, | |||
res_id: torch.Tensor, | |||
sym_id: Optional[torch.Tensor] = None, | |||
asym_id: Optional[torch.Tensor] = None, | |||
entity_id: Optional[torch.Tensor] = None, | |||
): | |||
max_rel_res = self.relpos_k | |||
rp = res_id[..., None] - res_id[..., None, :] | |||
rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res | |||
if not self.use_chain_relative: | |||
return rp | |||
else: | |||
asym_id_same = asym_id[..., :, None] == asym_id[..., None, :] | |||
rp[~asym_id_same] = 2 * max_rel_res + 1 | |||
entity_id_same = entity_id[..., :, None] == entity_id[..., None, :] | |||
rp_entity_id = entity_id_same.type(rp.dtype)[..., None] | |||
rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :] | |||
max_rel_chain = self.max_relative_chain | |||
clipped_rel_chain = torch.clamp( | |||
rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain) | |||
clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1 | |||
return rp, rp_entity_id, clipped_rel_chain | |||
def relpos_emb( | |||
self, | |||
res_id: torch.Tensor, | |||
sym_id: Optional[torch.Tensor] = None, | |||
asym_id: Optional[torch.Tensor] = None, | |||
entity_id: Optional[torch.Tensor] = None, | |||
num_sym: Optional[torch.Tensor] = None, | |||
): | |||
dtype = self.linear_relpos.weight.dtype | |||
if not self.use_chain_relative: | |||
rp = self._relpos_indices(res_id=res_id) | |||
return self.linear_relpos( | |||
one_hot(rp, num_classes=self.num_bins, dtype=dtype)) | |||
else: | |||
rp, rp_entity_id, rp_rel_chain = self._relpos_indices( | |||
res_id=res_id, | |||
sym_id=sym_id, | |||
asym_id=asym_id, | |||
entity_id=entity_id) | |||
rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype) | |||
rp_entity_id = rp_entity_id.type(dtype) | |||
rp_rel_chain = one_hot( | |||
rp_rel_chain, | |||
num_classes=(2 * self.max_relative_chain + 2), | |||
dtype=dtype) | |||
return self.linear_relpos( | |||
torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1)) | |||
def forward( | |||
self, | |||
tf: torch.Tensor, | |||
msa: torch.Tensor, | |||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||
# [*, N_res, d_pair] | |||
if self.tf_dim == 21: | |||
# multimer use 21 target dim | |||
tf = tf[..., 1:] | |||
# convert type if necessary | |||
tf = tf.type(self.linear_tf_z_i.weight.dtype) | |||
msa = msa.type(self.linear_tf_z_i.weight.dtype) | |||
n_clust = msa.shape[-3] | |||
msa_emb = self.linear_msa_m(msa) | |||
# target_feat (aatype) into msa representation | |||
tf_m = ( | |||
self.linear_tf_m(tf).unsqueeze(-3).expand( | |||
((-1, ) * len(tf.shape[:-2]) + # noqa W504 | |||
(n_clust, -1, -1)))) | |||
msa_emb += tf_m | |||
tf_emb_i = self.linear_tf_z_i(tf) | |||
tf_emb_j = self.linear_tf_z_j(tf) | |||
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] | |||
return msa_emb, pair_emb | |||
class RecyclingEmbedder(nn.Module): | |||
def __init__( | |||
self, | |||
d_msa: int, | |||
d_pair: int, | |||
min_bin: float, | |||
max_bin: float, | |||
num_bins: int, | |||
inf: float = 1e8, | |||
**kwargs, | |||
): | |||
super(RecyclingEmbedder, self).__init__() | |||
self.d_msa = d_msa | |||
self.d_pair = d_pair | |||
self.min_bin = min_bin | |||
self.max_bin = max_bin | |||
self.num_bins = num_bins | |||
self.inf = inf | |||
self.squared_bins = None | |||
self.linear = Linear(self.num_bins, self.d_pair) | |||
self.layer_norm_m = LayerNorm(self.d_msa) | |||
self.layer_norm_z = LayerNorm(self.d_pair) | |||
def forward( | |||
self, | |||
m: torch.Tensor, | |||
z: torch.Tensor, | |||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||
m_update = self.layer_norm_m(m) | |||
z_update = self.layer_norm_z(z) | |||
return m_update, z_update | |||
def recyle_pos( | |||
self, | |||
x: torch.Tensor, | |||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||
if self.squared_bins is None: | |||
bins = torch.linspace( | |||
self.min_bin, | |||
self.max_bin, | |||
self.num_bins, | |||
dtype=torch.float if self.training else x.dtype, | |||
device=x.device, | |||
requires_grad=False, | |||
) | |||
self.squared_bins = bins**2 | |||
upper = torch.cat( | |||
[self.squared_bins[1:], | |||
self.squared_bins.new_tensor([self.inf])], | |||
dim=-1) | |||
if self.training: | |||
x = x.float() | |||
d = torch.sum( | |||
(x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True) | |||
d = ((d > self.squared_bins) * # noqa W504 | |||
(d < upper)).type(self.linear.weight.dtype) | |||
d = self.linear(d) | |||
return d | |||
class TemplateAngleEmbedder(nn.Module): | |||
def __init__( | |||
self, | |||
d_in: int, | |||
d_out: int, | |||
**kwargs, | |||
): | |||
super(TemplateAngleEmbedder, self).__init__() | |||
self.d_out = d_out | |||
self.d_in = d_in | |||
self.linear_1 = Linear(self.d_in, self.d_out, init='relu') | |||
self.act = nn.GELU() | |||
self.linear_2 = Linear(self.d_out, self.d_out, init='relu') | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
x = self.linear_1(x.type(self.linear_1.weight.dtype)) | |||
x = self.act(x) | |||
x = self.linear_2(x) | |||
return x | |||
class TemplatePairEmbedder(nn.Module): | |||
def __init__( | |||
self, | |||
d_in: int, | |||
v2_d_in: list, | |||
d_out: int, | |||
d_pair: int, | |||
v2_feature: bool = False, | |||
**kwargs, | |||
): | |||
super(TemplatePairEmbedder, self).__init__() | |||
self.d_out = d_out | |||
self.v2_feature = v2_feature | |||
if self.v2_feature: | |||
self.d_in = v2_d_in | |||
self.linear = SimpleModuleList() | |||
for d_in in self.d_in: | |||
self.linear.append(Linear(d_in, self.d_out, init='relu')) | |||
self.z_layer_norm = LayerNorm(d_pair) | |||
self.z_linear = Linear(d_pair, self.d_out, init='relu') | |||
else: | |||
self.d_in = d_in | |||
self.linear = Linear(self.d_in, self.d_out, init='relu') | |||
def forward( | |||
self, | |||
x, | |||
z, | |||
) -> torch.Tensor: | |||
if not self.v2_feature: | |||
x = self.linear(x.type(self.linear.weight.dtype)) | |||
return x | |||
else: | |||
dtype = self.z_linear.weight.dtype | |||
t = self.linear[0](x[0].type(dtype)) | |||
for i, s in enumerate(x[1:]): | |||
t = residual(t, self.linear[i + 1](s.type(dtype)), | |||
self.training) | |||
t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training) | |||
return t | |||
class ExtraMSAEmbedder(nn.Module): | |||
def __init__( | |||
self, | |||
d_in: int, | |||
d_out: int, | |||
**kwargs, | |||
): | |||
super(ExtraMSAEmbedder, self).__init__() | |||
self.d_in = d_in | |||
self.d_out = d_out | |||
self.linear = Linear(self.d_in, self.d_out) | |||
def forward(self, x: torch.Tensor) -> torch.Tensor: | |||
return self.linear(x.type(self.linear.weight.dtype)) |
@@ -0,0 +1,362 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from functools import partial | |||
from typing import Optional, Tuple | |||
import torch | |||
import torch.nn as nn | |||
from unicore.utils import checkpoint_sequential | |||
from .attentions import (MSAColumnAttention, MSAColumnGlobalAttention, | |||
MSARowAttentionWithPairBias, TriangleAttentionEnding, | |||
TriangleAttentionStarting) | |||
from .common import (Linear, OuterProductMean, SimpleModuleList, Transition, | |||
bias_dropout_residual, residual, tri_mul_residual) | |||
from .triangle_multiplication import (TriangleMultiplicationIncoming, | |||
TriangleMultiplicationOutgoing) | |||
class EvoformerIteration(nn.Module): | |||
def __init__( | |||
self, | |||
d_msa: int, | |||
d_pair: int, | |||
d_hid_msa_att: int, | |||
d_hid_opm: int, | |||
d_hid_mul: int, | |||
d_hid_pair_att: int, | |||
num_heads_msa: int, | |||
num_heads_pair: int, | |||
transition_n: int, | |||
msa_dropout: float, | |||
pair_dropout: float, | |||
outer_product_mean_first: bool, | |||
inf: float, | |||
eps: float, | |||
_is_extra_msa_stack: bool = False, | |||
): | |||
super(EvoformerIteration, self).__init__() | |||
self._is_extra_msa_stack = _is_extra_msa_stack | |||
self.outer_product_mean_first = outer_product_mean_first | |||
self.msa_att_row = MSARowAttentionWithPairBias( | |||
d_msa=d_msa, | |||
d_pair=d_pair, | |||
d_hid=d_hid_msa_att, | |||
num_heads=num_heads_msa, | |||
) | |||
if _is_extra_msa_stack: | |||
self.msa_att_col = MSAColumnGlobalAttention( | |||
d_in=d_msa, | |||
d_hid=d_hid_msa_att, | |||
num_heads=num_heads_msa, | |||
inf=inf, | |||
eps=eps, | |||
) | |||
else: | |||
self.msa_att_col = MSAColumnAttention( | |||
d_msa, | |||
d_hid_msa_att, | |||
num_heads_msa, | |||
) | |||
self.msa_transition = Transition( | |||
d_in=d_msa, | |||
n=transition_n, | |||
) | |||
self.outer_product_mean = OuterProductMean( | |||
d_msa, | |||
d_pair, | |||
d_hid_opm, | |||
) | |||
self.tri_mul_out = TriangleMultiplicationOutgoing( | |||
d_pair, | |||
d_hid_mul, | |||
) | |||
self.tri_mul_in = TriangleMultiplicationIncoming( | |||
d_pair, | |||
d_hid_mul, | |||
) | |||
self.tri_att_start = TriangleAttentionStarting( | |||
d_pair, | |||
d_hid_pair_att, | |||
num_heads_pair, | |||
) | |||
self.tri_att_end = TriangleAttentionEnding( | |||
d_pair, | |||
d_hid_pair_att, | |||
num_heads_pair, | |||
) | |||
self.pair_transition = Transition( | |||
d_in=d_pair, | |||
n=transition_n, | |||
) | |||
self.row_dropout_share_dim = -3 | |||
self.col_dropout_share_dim = -2 | |||
self.msa_dropout = msa_dropout | |||
self.pair_dropout = pair_dropout | |||
def forward( | |||
self, | |||
m: torch.Tensor, | |||
z: torch.Tensor, | |||
msa_mask: torch.Tensor, | |||
pair_mask: torch.Tensor, | |||
msa_row_attn_mask: torch.Tensor, | |||
msa_col_attn_mask: Optional[torch.Tensor], | |||
tri_start_attn_mask: torch.Tensor, | |||
tri_end_attn_mask: torch.Tensor, | |||
chunk_size: Optional[int] = None, | |||
block_size: Optional[int] = None, | |||
) -> Tuple[torch.Tensor, torch.Tensor]: | |||
if self.outer_product_mean_first: | |||
z = residual( | |||
z, | |||
self.outer_product_mean( | |||
m, mask=msa_mask, chunk_size=chunk_size), self.training) | |||
m = bias_dropout_residual( | |||
self.msa_att_row, | |||
m, | |||
self.msa_att_row( | |||
m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size), | |||
self.row_dropout_share_dim, | |||
self.msa_dropout, | |||
self.training, | |||
) | |||
if self._is_extra_msa_stack: | |||
m = residual( | |||
m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size), | |||
self.training) | |||
else: | |||
m = bias_dropout_residual( | |||
self.msa_att_col, | |||
m, | |||
self.msa_att_col( | |||
m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), | |||
self.col_dropout_share_dim, | |||
self.msa_dropout, | |||
self.training, | |||
) | |||
m = residual(m, self.msa_transition(m, chunk_size=chunk_size), | |||
self.training) | |||
if not self.outer_product_mean_first: | |||
z = residual( | |||
z, | |||
self.outer_product_mean( | |||
m, mask=msa_mask, chunk_size=chunk_size), self.training) | |||
z = tri_mul_residual( | |||
self.tri_mul_out, | |||
z, | |||
self.tri_mul_out(z, mask=pair_mask, block_size=block_size), | |||
self.row_dropout_share_dim, | |||
self.pair_dropout, | |||
self.training, | |||
block_size=block_size, | |||
) | |||
z = tri_mul_residual( | |||
self.tri_mul_in, | |||
z, | |||
self.tri_mul_in(z, mask=pair_mask, block_size=block_size), | |||
self.row_dropout_share_dim, | |||
self.pair_dropout, | |||
self.training, | |||
block_size=block_size, | |||
) | |||
z = bias_dropout_residual( | |||
self.tri_att_start, | |||
z, | |||
self.tri_att_start( | |||
z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), | |||
self.row_dropout_share_dim, | |||
self.pair_dropout, | |||
self.training, | |||
) | |||
z = bias_dropout_residual( | |||
self.tri_att_end, | |||
z, | |||
self.tri_att_end( | |||
z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), | |||
self.col_dropout_share_dim, | |||
self.pair_dropout, | |||
self.training, | |||
) | |||
z = residual(z, self.pair_transition(z, chunk_size=chunk_size), | |||
self.training) | |||
return m, z | |||
class EvoformerStack(nn.Module): | |||
def __init__( | |||
self, | |||
d_msa: int, | |||
d_pair: int, | |||
d_hid_msa_att: int, | |||
d_hid_opm: int, | |||
d_hid_mul: int, | |||
d_hid_pair_att: int, | |||
d_single: int, | |||
num_heads_msa: int, | |||
num_heads_pair: int, | |||
num_blocks: int, | |||
transition_n: int, | |||
msa_dropout: float, | |||
pair_dropout: float, | |||
outer_product_mean_first: bool, | |||
inf: float, | |||
eps: float, | |||
_is_extra_msa_stack: bool = False, | |||
**kwargs, | |||
): | |||
super(EvoformerStack, self).__init__() | |||
self._is_extra_msa_stack = _is_extra_msa_stack | |||
self.blocks = SimpleModuleList() | |||
for _ in range(num_blocks): | |||
self.blocks.append( | |||
EvoformerIteration( | |||
d_msa=d_msa, | |||
d_pair=d_pair, | |||
d_hid_msa_att=d_hid_msa_att, | |||
d_hid_opm=d_hid_opm, | |||
d_hid_mul=d_hid_mul, | |||
d_hid_pair_att=d_hid_pair_att, | |||
num_heads_msa=num_heads_msa, | |||
num_heads_pair=num_heads_pair, | |||
transition_n=transition_n, | |||
msa_dropout=msa_dropout, | |||
pair_dropout=pair_dropout, | |||
outer_product_mean_first=outer_product_mean_first, | |||
inf=inf, | |||
eps=eps, | |||
_is_extra_msa_stack=_is_extra_msa_stack, | |||
)) | |||
if not self._is_extra_msa_stack: | |||
self.linear = Linear(d_msa, d_single) | |||
else: | |||
self.linear = None | |||
def forward( | |||
self, | |||
m: torch.Tensor, | |||
z: torch.Tensor, | |||
msa_mask: torch.Tensor, | |||
pair_mask: torch.Tensor, | |||
msa_row_attn_mask: torch.Tensor, | |||
msa_col_attn_mask: torch.Tensor, | |||
tri_start_attn_mask: torch.Tensor, | |||
tri_end_attn_mask: torch.Tensor, | |||
chunk_size: int, | |||
block_size: int, | |||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||
blocks = [ | |||
partial( | |||
b, | |||
msa_mask=msa_mask, | |||
pair_mask=pair_mask, | |||
msa_row_attn_mask=msa_row_attn_mask, | |||
msa_col_attn_mask=msa_col_attn_mask, | |||
tri_start_attn_mask=tri_start_attn_mask, | |||
tri_end_attn_mask=tri_end_attn_mask, | |||
chunk_size=chunk_size, | |||
block_size=block_size) for b in self.blocks | |||
] | |||
m, z = checkpoint_sequential( | |||
blocks, | |||
input=(m, z), | |||
) | |||
s = None | |||
if not self._is_extra_msa_stack: | |||
seq_dim = -3 | |||
index = torch.tensor([0], device=m.device) | |||
s = self.linear(torch.index_select(m, dim=seq_dim, index=index)) | |||
s = s.squeeze(seq_dim) | |||
return m, z, s | |||
class ExtraMSAStack(EvoformerStack): | |||
def __init__( | |||
self, | |||
d_msa: int, | |||
d_pair: int, | |||
d_hid_msa_att: int, | |||
d_hid_opm: int, | |||
d_hid_mul: int, | |||
d_hid_pair_att: int, | |||
num_heads_msa: int, | |||
num_heads_pair: int, | |||
num_blocks: int, | |||
transition_n: int, | |||
msa_dropout: float, | |||
pair_dropout: float, | |||
outer_product_mean_first: bool, | |||
inf: float, | |||
eps: float, | |||
**kwargs, | |||
): | |||
super(ExtraMSAStack, self).__init__( | |||
d_msa=d_msa, | |||
d_pair=d_pair, | |||
d_hid_msa_att=d_hid_msa_att, | |||
d_hid_opm=d_hid_opm, | |||
d_hid_mul=d_hid_mul, | |||
d_hid_pair_att=d_hid_pair_att, | |||
d_single=None, | |||
num_heads_msa=num_heads_msa, | |||
num_heads_pair=num_heads_pair, | |||
num_blocks=num_blocks, | |||
transition_n=transition_n, | |||
msa_dropout=msa_dropout, | |||
pair_dropout=pair_dropout, | |||
outer_product_mean_first=outer_product_mean_first, | |||
inf=inf, | |||
eps=eps, | |||
_is_extra_msa_stack=True, | |||
) | |||
def forward( | |||
self, | |||
m: torch.Tensor, | |||
z: torch.Tensor, | |||
msa_mask: Optional[torch.Tensor] = None, | |||
pair_mask: Optional[torch.Tensor] = None, | |||
msa_row_attn_mask: torch.Tensor = None, | |||
msa_col_attn_mask: torch.Tensor = None, | |||
tri_start_attn_mask: torch.Tensor = None, | |||
tri_end_attn_mask: torch.Tensor = None, | |||
chunk_size: int = None, | |||
block_size: int = None, | |||
) -> torch.Tensor: | |||
_, z, _ = super().forward( | |||
m, | |||
z, | |||
msa_mask=msa_mask, | |||
pair_mask=pair_mask, | |||
msa_row_attn_mask=msa_row_attn_mask, | |||
msa_col_attn_mask=msa_col_attn_mask, | |||
tri_start_attn_mask=tri_start_attn_mask, | |||
tri_end_attn_mask=tri_end_attn_mask, | |||
chunk_size=chunk_size, | |||
block_size=block_size) | |||
return z |
@@ -0,0 +1,195 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from typing import Dict | |||
import torch | |||
import torch.nn as nn | |||
from unicore.utils import batched_gather, one_hot | |||
from modelscope.models.science.unifold.data import residue_constants as rc | |||
from .frame import Frame | |||
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): | |||
is_gly = aatype == rc.restype_order['G'] | |||
ca_idx = rc.atom_order['CA'] | |||
cb_idx = rc.atom_order['CB'] | |||
pseudo_beta = torch.where( | |||
is_gly[..., None].expand(*((-1, ) * len(is_gly.shape)), 3), | |||
all_atom_positions[..., ca_idx, :], | |||
all_atom_positions[..., cb_idx, :], | |||
) | |||
if all_atom_masks is not None: | |||
pseudo_beta_mask = torch.where( | |||
is_gly, | |||
all_atom_masks[..., ca_idx], | |||
all_atom_masks[..., cb_idx], | |||
) | |||
return pseudo_beta, pseudo_beta_mask | |||
else: | |||
return pseudo_beta | |||
def atom14_to_atom37(atom14, batch): | |||
atom37_data = batched_gather( | |||
atom14, | |||
batch['residx_atom37_to_atom14'], | |||
dim=-2, | |||
num_batch_dims=len(atom14.shape[:-2]), | |||
) | |||
atom37_data = atom37_data * batch['atom37_atom_exists'][..., None] | |||
return atom37_data | |||
def build_template_angle_feat(template_feats, v2_feature=False): | |||
template_aatype = template_feats['template_aatype'] | |||
torsion_angles_sin_cos = template_feats['template_torsion_angles_sin_cos'] | |||
torsion_angles_mask = template_feats['template_torsion_angles_mask'] | |||
if not v2_feature: | |||
alt_torsion_angles_sin_cos = template_feats[ | |||
'template_alt_torsion_angles_sin_cos'] | |||
template_angle_feat = torch.cat( | |||
[ | |||
one_hot(template_aatype, 22), | |||
torsion_angles_sin_cos.reshape( | |||
*torsion_angles_sin_cos.shape[:-2], 14), | |||
alt_torsion_angles_sin_cos.reshape( | |||
*alt_torsion_angles_sin_cos.shape[:-2], 14), | |||
torsion_angles_mask, | |||
], | |||
dim=-1, | |||
) | |||
template_angle_mask = torsion_angles_mask[..., 2] | |||
else: | |||
chi_mask = torsion_angles_mask[..., 3:] | |||
chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask | |||
chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask | |||
template_angle_feat = torch.cat( | |||
[ | |||
one_hot(template_aatype, 22), | |||
chi_angles_sin, | |||
chi_angles_cos, | |||
chi_mask, | |||
], | |||
dim=-1, | |||
) | |||
template_angle_mask = chi_mask[..., 0] | |||
return template_angle_feat, template_angle_mask | |||
def build_template_pair_feat( | |||
batch, | |||
min_bin, | |||
max_bin, | |||
num_bins, | |||
eps=1e-20, | |||
inf=1e8, | |||
): | |||
template_mask = batch['template_pseudo_beta_mask'] | |||
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] | |||
tpb = batch['template_pseudo_beta'] | |||
dgram = torch.sum( | |||
(tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) | |||
lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 | |||
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) | |||
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) | |||
to_concat = [dgram, template_mask_2d[..., None]] | |||
aatype_one_hot = nn.functional.one_hot( | |||
batch['template_aatype'], | |||
rc.restype_num + 2, | |||
) | |||
n_res = batch['template_aatype'].shape[-1] | |||
to_concat.append(aatype_one_hot[..., None, :, :].expand( | |||
*aatype_one_hot.shape[:-2], n_res, -1, -1)) | |||
to_concat.append(aatype_one_hot[..., | |||
None, :].expand(*aatype_one_hot.shape[:-2], | |||
-1, n_res, -1)) | |||
to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3)) | |||
to_concat.append(template_mask_2d[..., None]) | |||
act = torch.cat(to_concat, dim=-1) | |||
act = act * template_mask_2d[..., None] | |||
return act | |||
def build_template_pair_feat_v2( | |||
batch, | |||
min_bin, | |||
max_bin, | |||
num_bins, | |||
multichain_mask_2d=None, | |||
eps=1e-20, | |||
inf=1e8, | |||
): | |||
template_mask = batch['template_pseudo_beta_mask'] | |||
template_mask_2d = template_mask[..., None] * template_mask[..., None, :] | |||
if multichain_mask_2d is not None: | |||
template_mask_2d *= multichain_mask_2d | |||
tpb = batch['template_pseudo_beta'] | |||
dgram = torch.sum( | |||
(tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) | |||
lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 | |||
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) | |||
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) | |||
dgram *= template_mask_2d[..., None] | |||
to_concat = [dgram, template_mask_2d[..., None]] | |||
aatype_one_hot = one_hot( | |||
batch['template_aatype'], | |||
rc.restype_num + 2, | |||
) | |||
n_res = batch['template_aatype'].shape[-1] | |||
to_concat.append(aatype_one_hot[..., None, :, :].expand( | |||
*aatype_one_hot.shape[:-2], n_res, -1, -1)) | |||
to_concat.append(aatype_one_hot[..., | |||
None, :].expand(*aatype_one_hot.shape[:-2], | |||
-1, n_res, -1)) | |||
n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']] | |||
rigids = Frame.make_transform_from_reference( | |||
n_xyz=batch['template_all_atom_positions'][..., n, :], | |||
ca_xyz=batch['template_all_atom_positions'][..., ca, :], | |||
c_xyz=batch['template_all_atom_positions'][..., c, :], | |||
eps=eps, | |||
) | |||
points = rigids.get_trans()[..., None, :, :] | |||
rigid_vec = rigids[..., None].invert_apply(points) | |||
inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) | |||
t_aa_masks = batch['template_all_atom_mask'] | |||
backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., | |||
c] | |||
backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[..., | |||
None, :] | |||
if multichain_mask_2d is not None: | |||
backbone_mask_2d *= multichain_mask_2d | |||
inv_distance_scalar = inv_distance_scalar * backbone_mask_2d | |||
unit_vector_data = rigid_vec * inv_distance_scalar[..., None] | |||
to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1)) | |||
to_concat.append(backbone_mask_2d[..., None]) | |||
return to_concat | |||
def build_extra_msa_feat(batch): | |||
msa_1hot = one_hot(batch['extra_msa'], 23) | |||
msa_feat = [ | |||
msa_1hot, | |||
batch['extra_msa_has_deletion'].unsqueeze(-1), | |||
batch['extra_msa_deletion_value'].unsqueeze(-1), | |||
] | |||
return torch.cat(msa_feat, dim=-1) |
@@ -0,0 +1,562 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from __future__ import annotations # noqa | |||
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple | |||
import numpy as np | |||
import torch | |||
def zero_translation( | |||
batch_dims: Tuple[int], | |||
dtype: Optional[torch.dtype] = torch.float, | |||
device: Optional[torch.device] = torch.device('cpu'), | |||
requires_grad: bool = False, | |||
) -> torch.Tensor: | |||
trans = torch.zeros((*batch_dims, 3), | |||
dtype=dtype, | |||
device=device, | |||
requires_grad=requires_grad) | |||
return trans | |||
# pylint: disable=bad-whitespace | |||
_QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) | |||
_QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr | |||
_QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii | |||
_QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj | |||
_QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk | |||
_QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij | |||
_QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik | |||
_QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk | |||
_QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir | |||
_QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr | |||
_QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr | |||
_QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9) | |||
_QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT) | |||
_QUAT_MULTIPLY = np.zeros((4, 4, 4)) | |||
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], | |||
[0, 0, 0, -1]] | |||
_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], | |||
[0, 0, -1, 0]] | |||
_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], | |||
[0, 1, 0, 0]] | |||
_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], | |||
[1, 0, 0, 0]] | |||
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] | |||
_QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC) | |||
class Rotation: | |||
def __init__( | |||
self, | |||
mat: torch.Tensor, | |||
): | |||
if mat.shape[-2:] != (3, 3): | |||
raise ValueError(f'incorrect rotation shape: {mat.shape}') | |||
self._mat = mat | |||
@staticmethod | |||
def identity( | |||
shape, | |||
dtype: Optional[torch.dtype] = torch.float, | |||
device: Optional[torch.device] = torch.device('cpu'), | |||
requires_grad: bool = False, | |||
) -> Rotation: | |||
mat = torch.eye( | |||
3, dtype=dtype, device=device, requires_grad=requires_grad) | |||
mat = mat.view(*((1, ) * len(shape)), 3, 3) | |||
mat = mat.expand(*shape, -1, -1) | |||
return Rotation(mat) | |||
@staticmethod | |||
def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |||
return (a.float() @ b.float()).type(a.dtype) | |||
@staticmethod | |||
def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |||
return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype) | |||
def __getitem__(self, index: Any) -> Rotation: | |||
if not isinstance(index, tuple): | |||
index = (index, ) | |||
return Rotation(mat=self._mat[index + (slice(None), slice(None))]) | |||
def __mul__(self, right: Any) -> Rotation: | |||
if isinstance(right, (int, float)): | |||
return Rotation(mat=self._mat * right) | |||
elif isinstance(right, torch.Tensor): | |||
return Rotation(mat=self._mat * right[..., None, None]) | |||
else: | |||
raise TypeError( | |||
f'multiplicand must be a tensor or a number, got {type(right)}.' | |||
) | |||
def __rmul__(self, left: Any) -> Rotation: | |||
return self.__mul__(left) | |||
def __matmul__(self, other: Rotation) -> Rotation: | |||
new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat) | |||
return Rotation(mat=new_mat) | |||
@property | |||
def _inv_mat(self): | |||
return self._mat.transpose(-1, -2) | |||
@property | |||
def rot_mat(self) -> torch.Tensor: | |||
return self._mat | |||
def invert(self) -> Rotation: | |||
return Rotation(mat=self._inv_mat) | |||
def apply(self, pts: torch.Tensor) -> torch.Tensor: | |||
return Rotation.mat_mul_vec(self._mat, pts) | |||
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: | |||
return Rotation.mat_mul_vec(self._inv_mat, pts) | |||
# inherit tensor behaviors | |||
@property | |||
def shape(self) -> torch.Size: | |||
s = self._mat.shape[:-2] | |||
return s | |||
@property | |||
def dtype(self) -> torch.dtype: | |||
return self._mat.dtype | |||
@property | |||
def device(self) -> torch.device: | |||
return self._mat.device | |||
@property | |||
def requires_grad(self) -> bool: | |||
return self._mat.requires_grad | |||
def unsqueeze(self, dim: int) -> Rotation: | |||
if dim >= len(self.shape): | |||
raise ValueError('Invalid dimension') | |||
rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2) | |||
return Rotation(mat=rot_mats) | |||
def map_tensor_fn(self, fn: Callable[[torch.Tensor], | |||
torch.Tensor]) -> Rotation: | |||
mat = self._mat.view(self._mat.shape[:-2] + (9, )) | |||
mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1) | |||
mat = mat.view(mat.shape[:-1] + (3, 3)) | |||
return Rotation(mat=mat) | |||
@staticmethod | |||
def cat(rs: Sequence[Rotation], dim: int) -> Rotation: | |||
rot_mats = [r.rot_mat for r in rs] | |||
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) | |||
return Rotation(mat=rot_mats) | |||
def cuda(self) -> Rotation: | |||
return Rotation(mat=self._mat.cuda()) | |||
def to(self, device: Optional[torch.device], | |||
dtype: Optional[torch.dtype]) -> Rotation: | |||
return Rotation(mat=self._mat.to(device=device, dtype=dtype)) | |||
def type(self, dtype: Optional[torch.dtype]) -> Rotation: | |||
return Rotation(mat=self._mat.type(dtype)) | |||
def detach(self) -> Rotation: | |||
return Rotation(mat=self._mat.detach()) | |||
class Frame: | |||
def __init__( | |||
self, | |||
rotation: Optional[Rotation], | |||
translation: Optional[torch.Tensor], | |||
): | |||
if rotation is None and translation is None: | |||
rotation = Rotation.identity((0, )) | |||
translation = zero_translation((0, )) | |||
elif translation is None: | |||
translation = zero_translation(rotation.shape, rotation.dtype, | |||
rotation.device, | |||
rotation.requires_grad) | |||
elif rotation is None: | |||
rotation = Rotation.identity( | |||
translation.shape[:-1], | |||
translation.dtype, | |||
translation.device, | |||
translation.requires_grad, | |||
) | |||
if (rotation.shape != translation.shape[:-1]) or (rotation.device | |||
!= # noqa W504 | |||
translation.device): | |||
raise ValueError('RotationMatrix and translation incompatible') | |||
self._r = rotation | |||
self._t = translation | |||
@staticmethod | |||
def identity( | |||
shape: Iterable[int], | |||
dtype: Optional[torch.dtype] = torch.float, | |||
device: Optional[torch.device] = torch.device('cpu'), | |||
requires_grad: bool = False, | |||
) -> Frame: | |||
return Frame( | |||
Rotation.identity(shape, dtype, device, requires_grad), | |||
zero_translation(shape, dtype, device, requires_grad), | |||
) | |||
def __getitem__( | |||
self, | |||
index: Any, | |||
) -> Frame: | |||
if type(index) != tuple: | |||
index = (index, ) | |||
return Frame( | |||
self._r[index], | |||
self._t[index + (slice(None), )], | |||
) | |||
def __mul__( | |||
self, | |||
right: torch.Tensor, | |||
) -> Frame: | |||
if not (isinstance(right, torch.Tensor)): | |||
raise TypeError('The other multiplicand must be a Tensor') | |||
new_rots = self._r * right | |||
new_trans = self._t * right[..., None] | |||
return Frame(new_rots, new_trans) | |||
def __rmul__( | |||
self, | |||
left: torch.Tensor, | |||
) -> Frame: | |||
return self.__mul__(left) | |||
@property | |||
def shape(self) -> torch.Size: | |||
s = self._t.shape[:-1] | |||
return s | |||
@property | |||
def device(self) -> torch.device: | |||
return self._t.device | |||
def get_rots(self) -> Rotation: | |||
return self._r | |||
def get_trans(self) -> torch.Tensor: | |||
return self._t | |||
def compose( | |||
self, | |||
other: Frame, | |||
) -> Frame: | |||
new_rot = self._r @ other._r | |||
new_trans = self._r.apply(other._t) + self._t | |||
return Frame(new_rot, new_trans) | |||
def apply( | |||
self, | |||
pts: torch.Tensor, | |||
) -> torch.Tensor: | |||
rotated = self._r.apply(pts) | |||
return rotated + self._t | |||
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: | |||
pts = pts - self._t | |||
return self._r.invert_apply(pts) | |||
def invert(self) -> Frame: | |||
rot_inv = self._r.invert() | |||
trn_inv = rot_inv.apply(self._t) | |||
return Frame(rot_inv, -1 * trn_inv) | |||
def map_tensor_fn(self, fn: Callable[[torch.Tensor], | |||
torch.Tensor]) -> Frame: | |||
new_rots = self._r.map_tensor_fn(fn) | |||
new_trans = torch.stack( | |||
list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1) | |||
return Frame(new_rots, new_trans) | |||
def to_tensor_4x4(self) -> torch.Tensor: | |||
tensor = self._t.new_zeros((*self.shape, 4, 4)) | |||
tensor[..., :3, :3] = self._r.rot_mat | |||
tensor[..., :3, 3] = self._t | |||
tensor[..., 3, 3] = 1 | |||
return tensor | |||
@staticmethod | |||
def from_tensor_4x4(t: torch.Tensor) -> Frame: | |||
if t.shape[-2:] != (4, 4): | |||
raise ValueError('Incorrectly shaped input tensor') | |||
rots = Rotation(mat=t[..., :3, :3]) | |||
trans = t[..., :3, 3] | |||
return Frame(rots, trans) | |||
@staticmethod | |||
def from_3_points( | |||
p_neg_x_axis: torch.Tensor, | |||
origin: torch.Tensor, | |||
p_xy_plane: torch.Tensor, | |||
eps: float = 1e-8, | |||
) -> Frame: | |||
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) | |||
origin = torch.unbind(origin, dim=-1) | |||
p_xy_plane = torch.unbind(p_xy_plane, dim=-1) | |||
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] | |||
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] | |||
denom = torch.sqrt(sum((c * c for c in e0)) + eps) | |||
e0 = [c / denom for c in e0] | |||
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) | |||
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] | |||
denom = torch.sqrt(sum((c * c for c in e1)) + eps) | |||
e1 = [c / denom for c in e1] | |||
e2 = [ | |||
e0[1] * e1[2] - e0[2] * e1[1], | |||
e0[2] * e1[0] - e0[0] * e1[2], | |||
e0[0] * e1[1] - e0[1] * e1[0], | |||
] | |||
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) | |||
rots = rots.reshape(rots.shape[:-1] + (3, 3)) | |||
rot_obj = Rotation(mat=rots) | |||
return Frame(rot_obj, torch.stack(origin, dim=-1)) | |||
def unsqueeze( | |||
self, | |||
dim: int, | |||
) -> Frame: | |||
if dim >= len(self.shape): | |||
raise ValueError('Invalid dimension') | |||
rots = self._r.unsqueeze(dim) | |||
trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1) | |||
return Frame(rots, trans) | |||
@staticmethod | |||
def cat( | |||
Ts: Sequence[Frame], | |||
dim: int, | |||
) -> Frame: | |||
rots = Rotation.cat([T._r for T in Ts], dim) | |||
trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1) | |||
return Frame(rots, trans) | |||
def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame: | |||
return Frame(fn(self._r), self._t) | |||
def apply_trans_fn(self, fn: Callable[[torch.Tensor], | |||
torch.Tensor]) -> Frame: | |||
return Frame(self._r, fn(self._t)) | |||
def scale_translation(self, trans_scale_factor: float) -> Frame: | |||
# fn = lambda t: t * trans_scale_factor | |||
def fn(t): | |||
return t * trans_scale_factor | |||
return self.apply_trans_fn(fn) | |||
def stop_rot_gradient(self) -> Frame: | |||
# fn = lambda r: r.detach() | |||
def fn(r): | |||
return r.detach() | |||
return self.apply_rot_fn(fn) | |||
@staticmethod | |||
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): | |||
input_dtype = ca_xyz.dtype | |||
n_xyz = n_xyz.float() | |||
ca_xyz = ca_xyz.float() | |||
c_xyz = c_xyz.float() | |||
n_xyz = n_xyz - ca_xyz | |||
c_xyz = c_xyz - ca_xyz | |||
c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)] | |||
norm = torch.sqrt(eps + c_x**2 + c_y**2) | |||
sin_c1 = -c_y / norm | |||
cos_c1 = c_x / norm | |||
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) | |||
c1_rots[..., 0, 0] = cos_c1 | |||
c1_rots[..., 0, 1] = -1 * sin_c1 | |||
c1_rots[..., 1, 0] = sin_c1 | |||
c1_rots[..., 1, 1] = cos_c1 | |||
c1_rots[..., 2, 2] = 1 | |||
norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2) | |||
sin_c2 = d_pair / norm | |||
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm | |||
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) | |||
c2_rots[..., 0, 0] = cos_c2 | |||
c2_rots[..., 0, 2] = sin_c2 | |||
c2_rots[..., 1, 1] = 1 | |||
c2_rots[..., 2, 0] = -1 * sin_c2 | |||
c2_rots[..., 2, 2] = cos_c2 | |||
c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots) | |||
n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz) | |||
_, n_y, n_z = [n_xyz[..., i] for i in range(3)] | |||
norm = torch.sqrt(eps + n_y**2 + n_z**2) | |||
sin_n = -n_z / norm | |||
cos_n = n_y / norm | |||
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) | |||
n_rots[..., 0, 0] = 1 | |||
n_rots[..., 1, 1] = cos_n | |||
n_rots[..., 1, 2] = -1 * sin_n | |||
n_rots[..., 2, 1] = sin_n | |||
n_rots[..., 2, 2] = cos_n | |||
rots = Rotation.mat_mul_mat(n_rots, c_rots) | |||
rots = rots.transpose(-1, -2) | |||
rot_obj = Rotation(mat=rots.type(input_dtype)) | |||
return Frame(rot_obj, ca_xyz.type(input_dtype)) | |||
def cuda(self) -> Frame: | |||
return Frame(self._r.cuda(), self._t.cuda()) | |||
@property | |||
def dtype(self) -> torch.dtype: | |||
assert self._r.dtype == self._t.dtype | |||
return self._r.dtype | |||
def type(self, dtype) -> Frame: | |||
return Frame(self._r.type(dtype), self._t.type(dtype)) | |||
class Quaternion: | |||
def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor): | |||
if quaternion.shape[-1] != 4: | |||
raise ValueError(f'incorrect quaternion shape: {quaternion.shape}') | |||
self._q = quaternion | |||
self._t = translation | |||
@staticmethod | |||
def identity( | |||
shape: Iterable[int], | |||
dtype: Optional[torch.dtype] = torch.float, | |||
device: Optional[torch.device] = torch.device('cpu'), | |||
requires_grad: bool = False, | |||
) -> Quaternion: | |||
trans = zero_translation(shape, dtype, device, requires_grad) | |||
quats = torch.zeros((*shape, 4), | |||
dtype=dtype, | |||
device=device, | |||
requires_grad=requires_grad) | |||
with torch.no_grad(): | |||
quats[..., 0] = 1 | |||
return Quaternion(quats, trans) | |||
def get_quats(self): | |||
return self._q | |||
def get_trans(self): | |||
return self._t | |||
def get_rot_mats(self): | |||
quats = self.get_quats() | |||
rot_mats = Quaternion.quat_to_rot(quats) | |||
return rot_mats | |||
@staticmethod | |||
def quat_to_rot(normalized_quat): | |||
global _QUAT_TO_ROT_tensor | |||
dtype = normalized_quat.dtype | |||
normalized_quat = normalized_quat.float() | |||
if _QUAT_TO_ROT_tensor.device != normalized_quat.device: | |||
_QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to( | |||
normalized_quat.device) | |||
rot_tensor = torch.sum( | |||
_QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None] | |||
* normalized_quat[..., None, :, None], | |||
dim=(-3, -2), | |||
) | |||
rot_tensor = rot_tensor.type(dtype) | |||
rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3) | |||
return rot_tensor | |||
@staticmethod | |||
def normalize_quat(quats): | |||
dtype = quats.dtype | |||
quats = quats.float() | |||
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) | |||
quats = quats.type(dtype) | |||
return quats | |||
@staticmethod | |||
def quat_multiply_by_vec(quat, vec): | |||
dtype = quat.dtype | |||
quat = quat.float() | |||
vec = vec.float() | |||
global _QUAT_MULTIPLY_BY_VEC_tensor | |||
if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device: | |||
_QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to( | |||
quat.device) | |||
mat = _QUAT_MULTIPLY_BY_VEC_tensor | |||
reshaped_mat = mat.view((1, ) * len(quat.shape[:-1]) + mat.shape) | |||
return torch.sum( | |||
reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], | |||
dim=(-3, -2), | |||
).type(dtype) | |||
def compose_q_update_vec(self, | |||
q_update_vec: torch.Tensor, | |||
normalize_quats: bool = True) -> torch.Tensor: | |||
quats = self.get_quats() | |||
new_quats = quats + Quaternion.quat_multiply_by_vec( | |||
quats, q_update_vec) | |||
if normalize_quats: | |||
new_quats = Quaternion.normalize_quat(new_quats) | |||
return new_quats | |||
def compose_update_vec( | |||
self, | |||
update_vec: torch.Tensor, | |||
pre_rot_mat: Rotation, | |||
) -> Quaternion: | |||
q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:] | |||
new_quats = self.compose_q_update_vec(q_vec) | |||
trans_update = pre_rot_mat.apply(t_vec) | |||
new_trans = self._t + trans_update | |||
return Quaternion(new_quats, new_trans) | |||
def stop_rot_gradient(self) -> Quaternion: | |||
return Quaternion(self._q.detach(), self._t) |
@@ -0,0 +1,592 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import math | |||
from typing import Tuple | |||
import torch | |||
import torch.nn as nn | |||
from unicore.modules import LayerNorm, softmax_dropout | |||
from unicore.utils import dict_multimap, one_hot, permute_final_dims | |||
from modelscope.models.science.unifold.data.residue_constants import ( | |||
restype_atom14_mask, restype_atom14_rigid_group_positions, | |||
restype_atom14_to_rigid_group, restype_rigid_group_default_frame) | |||
from .attentions import gen_attn_mask | |||
from .common import Linear, SimpleModuleList, residual | |||
from .frame import Frame, Quaternion, Rotation | |||
def ipa_point_weights_init_(weights): | |||
with torch.no_grad(): | |||
softplus_inverse_1 = 0.541324854612918 | |||
weights.fill_(softplus_inverse_1) | |||
def torsion_angles_to_frames( | |||
frame: Frame, | |||
alpha: torch.Tensor, | |||
aatype: torch.Tensor, | |||
default_frames: torch.Tensor, | |||
): | |||
default_frame = Frame.from_tensor_4x4(default_frames[aatype, ...]) | |||
bb_rot = alpha.new_zeros((*((1, ) * len(alpha.shape[:-1])), 2)) | |||
bb_rot[..., 1] = 1 | |||
alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], | |||
dim=-2) | |||
all_rots = alpha.new_zeros(default_frame.get_rots().rot_mat.shape) | |||
all_rots[..., 0, 0] = 1 | |||
all_rots[..., 1, 1] = alpha[..., 1] | |||
all_rots[..., 1, 2] = -alpha[..., 0] | |||
all_rots[..., 2, 1:] = alpha | |||
all_rots = Frame(Rotation(mat=all_rots), None) | |||
all_frames = default_frame.compose(all_rots) | |||
chi2_frame_to_frame = all_frames[..., 5] | |||
chi3_frame_to_frame = all_frames[..., 6] | |||
chi4_frame_to_frame = all_frames[..., 7] | |||
chi1_frame_to_bb = all_frames[..., 4] | |||
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) | |||
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) | |||
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) | |||
all_frames_to_bb = Frame.cat( | |||
[ | |||
all_frames[..., :5], | |||
chi2_frame_to_bb.unsqueeze(-1), | |||
chi3_frame_to_bb.unsqueeze(-1), | |||
chi4_frame_to_bb.unsqueeze(-1), | |||
], | |||
dim=-1, | |||
) | |||
all_frames_to_global = frame[..., None].compose(all_frames_to_bb) | |||
return all_frames_to_global | |||
def frames_and_literature_positions_to_atom14_pos( | |||
frame: Frame, | |||
aatype: torch.Tensor, | |||
default_frames, | |||
group_idx, | |||
atom_mask, | |||
lit_positions, | |||
): | |||
group_mask = group_idx[aatype, ...] | |||
group_mask = one_hot( | |||
group_mask, | |||
num_classes=default_frames.shape[-3], | |||
) | |||
t_atoms_to_global = frame[..., None, :] * group_mask | |||
t_atoms_to_global = t_atoms_to_global.map_tensor_fn( | |||
lambda x: torch.sum(x, dim=-1)) | |||
atom_mask = atom_mask[aatype, ...].unsqueeze(-1) | |||
lit_positions = lit_positions[aatype, ...] | |||
pred_positions = t_atoms_to_global.apply(lit_positions) | |||
pred_positions = pred_positions * atom_mask | |||
return pred_positions | |||
class SideChainAngleResnetIteration(nn.Module): | |||
def __init__(self, d_hid): | |||
super(SideChainAngleResnetIteration, self).__init__() | |||
self.d_hid = d_hid | |||
self.linear_1 = Linear(self.d_hid, self.d_hid, init='relu') | |||
self.act = nn.GELU() | |||
self.linear_2 = Linear(self.d_hid, self.d_hid, init='final') | |||
def forward(self, s: torch.Tensor) -> torch.Tensor: | |||
x = self.act(s) | |||
x = self.linear_1(x) | |||
x = self.act(x) | |||
x = self.linear_2(x) | |||
return residual(s, x, self.training) | |||
class SidechainAngleResnet(nn.Module): | |||
def __init__(self, d_in, d_hid, num_blocks, num_angles): | |||
super(SidechainAngleResnet, self).__init__() | |||
self.linear_in = Linear(d_in, d_hid) | |||
self.act = nn.GELU() | |||
self.linear_initial = Linear(d_in, d_hid) | |||
self.layers = SimpleModuleList() | |||
for _ in range(num_blocks): | |||
self.layers.append(SideChainAngleResnetIteration(d_hid=d_hid)) | |||
self.linear_out = Linear(d_hid, num_angles * 2) | |||
def forward(self, s: torch.Tensor, | |||
initial_s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||
initial_s = self.linear_initial(self.act(initial_s)) | |||
s = self.linear_in(self.act(s)) | |||
s = s + initial_s | |||
for layer in self.layers: | |||
s = layer(s) | |||
s = self.linear_out(self.act(s)) | |||
s = s.view(s.shape[:-1] + (-1, 2)) | |||
unnormalized_s = s | |||
norm_denom = torch.sqrt( | |||
torch.clamp( | |||
torch.sum(s.float()**2, dim=-1, keepdim=True), | |||
min=1e-12, | |||
)) | |||
s = s.float() / norm_denom | |||
return unnormalized_s, s.type(unnormalized_s.dtype) | |||
class InvariantPointAttention(nn.Module): | |||
def __init__( | |||
self, | |||
d_single: int, | |||
d_pair: int, | |||
d_hid: int, | |||
num_heads: int, | |||
num_qk_points: int, | |||
num_v_points: int, | |||
separate_kv: bool = False, | |||
bias: bool = True, | |||
eps: float = 1e-8, | |||
): | |||
super(InvariantPointAttention, self).__init__() | |||
self.d_hid = d_hid | |||
self.num_heads = num_heads | |||
self.num_qk_points = num_qk_points | |||
self.num_v_points = num_v_points | |||
self.eps = eps | |||
hc = self.d_hid * self.num_heads | |||
self.linear_q = Linear(d_single, hc, bias=bias) | |||
self.separate_kv = separate_kv | |||
if self.separate_kv: | |||
self.linear_k = Linear(d_single, hc, bias=bias) | |||
self.linear_v = Linear(d_single, hc, bias=bias) | |||
else: | |||
self.linear_kv = Linear(d_single, 2 * hc, bias=bias) | |||
hpq = self.num_heads * self.num_qk_points * 3 | |||
self.linear_q_points = Linear(d_single, hpq) | |||
hpk = self.num_heads * self.num_qk_points * 3 | |||
hpv = self.num_heads * self.num_v_points * 3 | |||
if self.separate_kv: | |||
self.linear_k_points = Linear(d_single, hpk) | |||
self.linear_v_points = Linear(d_single, hpv) | |||
else: | |||
hpkv = hpk + hpv | |||
self.linear_kv_points = Linear(d_single, hpkv) | |||
self.linear_b = Linear(d_pair, self.num_heads) | |||
self.head_weights = nn.Parameter(torch.zeros((num_heads))) | |||
ipa_point_weights_init_(self.head_weights) | |||
concat_out_dim = self.num_heads * ( | |||
d_pair + self.d_hid + self.num_v_points * 4) | |||
self.linear_out = Linear(concat_out_dim, d_single, init='final') | |||
self.softplus = nn.Softplus() | |||
def forward( | |||
self, | |||
s: torch.Tensor, | |||
z: torch.Tensor, | |||
f: Frame, | |||
square_mask: torch.Tensor, | |||
) -> torch.Tensor: | |||
q = self.linear_q(s) | |||
q = q.view(q.shape[:-1] + (self.num_heads, -1)) | |||
if self.separate_kv: | |||
k = self.linear_k(s) | |||
v = self.linear_v(s) | |||
k = k.view(k.shape[:-1] + (self.num_heads, -1)) | |||
v = v.view(v.shape[:-1] + (self.num_heads, -1)) | |||
else: | |||
kv = self.linear_kv(s) | |||
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) | |||
k, v = torch.split(kv, self.d_hid, dim=-1) | |||
q_pts = self.linear_q_points(s) | |||
def process_points(pts, no_points): | |||
shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3) | |||
if self.separate_kv: | |||
# alphafold-multimer uses different layout | |||
pts = pts.view(pts.shape[:-1] | |||
+ (self.num_heads, no_points * 3)) | |||
pts = torch.split(pts, pts.shape[-1] // 3, dim=-1) | |||
pts = torch.stack(pts, dim=-1).view(*shape) | |||
pts = f[..., None].apply(pts) | |||
pts = pts.view(pts.shape[:-2] + (self.num_heads, no_points, 3)) | |||
return pts | |||
q_pts = process_points(q_pts, self.num_qk_points) | |||
if self.separate_kv: | |||
k_pts = self.linear_k_points(s) | |||
v_pts = self.linear_v_points(s) | |||
k_pts = process_points(k_pts, self.num_qk_points) | |||
v_pts = process_points(v_pts, self.num_v_points) | |||
else: | |||
kv_pts = self.linear_kv_points(s) | |||
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) | |||
kv_pts = torch.stack(kv_pts, dim=-1) | |||
kv_pts = f[..., None].apply(kv_pts) | |||
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) | |||
k_pts, v_pts = torch.split( | |||
kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) | |||
bias = self.linear_b(z) | |||
attn = torch.matmul( | |||
permute_final_dims(q, (1, 0, 2)), | |||
permute_final_dims(k, (1, 2, 0)), | |||
) | |||
if self.training: | |||
attn = attn * math.sqrt(1.0 / (3 * self.d_hid)) | |||
attn = attn + ( | |||
math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) | |||
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) | |||
pt_att = pt_att.float()**2 | |||
else: | |||
attn *= math.sqrt(1.0 / (3 * self.d_hid)) | |||
attn += (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) | |||
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) | |||
pt_att *= pt_att | |||
pt_att = pt_att.sum(dim=-1) | |||
head_weights = self.softplus(self.head_weights).view( | |||
*((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) | |||
head_weights = head_weights * math.sqrt( | |||
1.0 / (3 * (self.num_qk_points * 9.0 / 2))) | |||
pt_att *= head_weights * (-0.5) | |||
pt_att = torch.sum(pt_att, dim=-1) | |||
pt_att = permute_final_dims(pt_att, (2, 0, 1)) | |||
attn += square_mask | |||
attn = softmax_dropout( | |||
attn, 0, self.training, bias=pt_att.type(attn.dtype)) | |||
del pt_att, q_pts, k_pts, bias | |||
o = torch.matmul(attn, v.transpose(-2, -3)).transpose(-2, -3) | |||
o = o.contiguous().view(*o.shape[:-2], -1) | |||
del q, k, v | |||
o_pts = torch.sum( | |||
(attn[..., None, :, :, None] | |||
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), | |||
dim=-2, | |||
) | |||
o_pts = permute_final_dims(o_pts, (2, 0, 3, 1)) | |||
o_pts = f[..., None, None].invert_apply(o_pts) | |||
if self.training: | |||
o_pts_norm = torch.sqrt( | |||
torch.sum(o_pts.float()**2, dim=-1) + self.eps).type( | |||
o_pts.dtype) | |||
else: | |||
o_pts_norm = torch.sqrt(torch.sum(o_pts**2, dim=-1) | |||
+ self.eps).type(o_pts.dtype) | |||
o_pts_norm = o_pts_norm.view(*o_pts_norm.shape[:-2], -1) | |||
o_pts = o_pts.view(*o_pts.shape[:-3], -1, 3) | |||
o_pair = torch.matmul(attn.transpose(-2, -3), z) | |||
o_pair = o_pair.view(*o_pair.shape[:-2], -1) | |||
s = self.linear_out( | |||
torch.cat((o, *torch.unbind(o_pts, dim=-1), o_pts_norm, o_pair), | |||
dim=-1)) | |||
return s | |||
class BackboneUpdate(nn.Module): | |||
def __init__(self, d_single): | |||
super(BackboneUpdate, self).__init__() | |||
self.linear = Linear(d_single, 6, init='final') | |||
def forward(self, s: torch.Tensor): | |||
return self.linear(s) | |||
class StructureModuleTransitionLayer(nn.Module): | |||
def __init__(self, c): | |||
super(StructureModuleTransitionLayer, self).__init__() | |||
self.linear_1 = Linear(c, c, init='relu') | |||
self.linear_2 = Linear(c, c, init='relu') | |||
self.act = nn.GELU() | |||
self.linear_3 = Linear(c, c, init='final') | |||
def forward(self, s): | |||
s_old = s | |||
s = self.linear_1(s) | |||
s = self.act(s) | |||
s = self.linear_2(s) | |||
s = self.act(s) | |||
s = self.linear_3(s) | |||
s = residual(s_old, s, self.training) | |||
return s | |||
class StructureModuleTransition(nn.Module): | |||
def __init__(self, c, num_layers, dropout_rate): | |||
super(StructureModuleTransition, self).__init__() | |||
self.num_layers = num_layers | |||
self.dropout_rate = dropout_rate | |||
self.layers = SimpleModuleList() | |||
for _ in range(self.num_layers): | |||
self.layers.append(StructureModuleTransitionLayer(c)) | |||
self.dropout = nn.Dropout(self.dropout_rate) | |||
self.layer_norm = LayerNorm(c) | |||
def forward(self, s): | |||
for layer in self.layers: | |||
s = layer(s) | |||
s = self.dropout(s) | |||
s = self.layer_norm(s) | |||
return s | |||
class StructureModule(nn.Module): | |||
def __init__( | |||
self, | |||
d_single, | |||
d_pair, | |||
d_ipa, | |||
d_angle, | |||
num_heads_ipa, | |||
num_qk_points, | |||
num_v_points, | |||
dropout_rate, | |||
num_blocks, | |||
no_transition_layers, | |||
num_resnet_blocks, | |||
num_angles, | |||
trans_scale_factor, | |||
separate_kv, | |||
ipa_bias, | |||
epsilon, | |||
inf, | |||
**kwargs, | |||
): | |||
super(StructureModule, self).__init__() | |||
self.num_blocks = num_blocks | |||
self.trans_scale_factor = trans_scale_factor | |||
self.default_frames = None | |||
self.group_idx = None | |||
self.atom_mask = None | |||
self.lit_positions = None | |||
self.inf = inf | |||
self.layer_norm_s = LayerNorm(d_single) | |||
self.layer_norm_z = LayerNorm(d_pair) | |||
self.linear_in = Linear(d_single, d_single) | |||
self.ipa = InvariantPointAttention( | |||
d_single, | |||
d_pair, | |||
d_ipa, | |||
num_heads_ipa, | |||
num_qk_points, | |||
num_v_points, | |||
separate_kv=separate_kv, | |||
bias=ipa_bias, | |||
eps=epsilon, | |||
) | |||
self.ipa_dropout = nn.Dropout(dropout_rate) | |||
self.layer_norm_ipa = LayerNorm(d_single) | |||
self.transition = StructureModuleTransition( | |||
d_single, | |||
no_transition_layers, | |||
dropout_rate, | |||
) | |||
self.bb_update = BackboneUpdate(d_single) | |||
self.angle_resnet = SidechainAngleResnet( | |||
d_single, | |||
d_angle, | |||
num_resnet_blocks, | |||
num_angles, | |||
) | |||
def forward( | |||
self, | |||
s, | |||
z, | |||
aatype, | |||
mask=None, | |||
): | |||
if mask is None: | |||
mask = s.new_ones(s.shape[:-1]) | |||
# generate square mask | |||
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) | |||
square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3) | |||
s = self.layer_norm_s(s) | |||
z = self.layer_norm_z(z) | |||
initial_s = s | |||
s = self.linear_in(s) | |||
quat_encoder = Quaternion.identity( | |||
s.shape[:-1], | |||
s.dtype, | |||
s.device, | |||
requires_grad=False, | |||
) | |||
backb_to_global = Frame( | |||
Rotation(mat=quat_encoder.get_rot_mats(), ), | |||
quat_encoder.get_trans(), | |||
) | |||
outputs = [] | |||
for i in range(self.num_blocks): | |||
s = residual(s, self.ipa(s, z, backb_to_global, square_mask), | |||
self.training) | |||
s = self.ipa_dropout(s) | |||
s = self.layer_norm_ipa(s) | |||
s = self.transition(s) | |||
# update quaternion encoder | |||
# use backb_to_global to avoid quat-to-rot conversion | |||
quat_encoder = quat_encoder.compose_update_vec( | |||
self.bb_update(s), pre_rot_mat=backb_to_global.get_rots()) | |||
# initial_s is always used to update the backbone | |||
unnormalized_angles, angles = self.angle_resnet(s, initial_s) | |||
# convert quaternion to rotation matrix | |||
backb_to_global = Frame( | |||
Rotation(mat=quat_encoder.get_rot_mats(), ), | |||
quat_encoder.get_trans(), | |||
) | |||
if i == self.num_blocks - 1: | |||
all_frames_to_global = self.torsion_angles_to_frames( | |||
backb_to_global.scale_translation(self.trans_scale_factor), | |||
angles, | |||
aatype, | |||
) | |||
pred_positions = self.frames_and_literature_positions_to_atom14_pos( | |||
all_frames_to_global, | |||
aatype, | |||
) | |||
preds = { | |||
'frames': | |||
backb_to_global.scale_translation( | |||
self.trans_scale_factor).to_tensor_4x4(), | |||
'unnormalized_angles': | |||
unnormalized_angles, | |||
'angles': | |||
angles, | |||
} | |||
outputs.append(preds) | |||
if i < (self.num_blocks - 1): | |||
# stop gradient in iteration | |||
quat_encoder = quat_encoder.stop_rot_gradient() | |||
backb_to_global = backb_to_global.stop_rot_gradient() | |||
outputs = dict_multimap(torch.stack, outputs) | |||
outputs['sidechain_frames'] = all_frames_to_global.to_tensor_4x4() | |||
outputs['positions'] = pred_positions | |||
outputs['single'] = s | |||
return outputs | |||
def _init_residue_constants(self, float_dtype, device): | |||
if self.default_frames is None: | |||
self.default_frames = torch.tensor( | |||
restype_rigid_group_default_frame, | |||
dtype=float_dtype, | |||
device=device, | |||
requires_grad=False, | |||
) | |||
if self.group_idx is None: | |||
self.group_idx = torch.tensor( | |||
restype_atom14_to_rigid_group, | |||
device=device, | |||
requires_grad=False, | |||
) | |||
if self.atom_mask is None: | |||
self.atom_mask = torch.tensor( | |||
restype_atom14_mask, | |||
dtype=float_dtype, | |||
device=device, | |||
requires_grad=False, | |||
) | |||
if self.lit_positions is None: | |||
self.lit_positions = torch.tensor( | |||
restype_atom14_rigid_group_positions, | |||
dtype=float_dtype, | |||
device=device, | |||
requires_grad=False, | |||
) | |||
def torsion_angles_to_frames(self, frame, alpha, aatype): | |||
self._init_residue_constants(alpha.dtype, alpha.device) | |||
return torsion_angles_to_frames(frame, alpha, aatype, | |||
self.default_frames) | |||
def frames_and_literature_positions_to_atom14_pos(self, frame, aatype): | |||
self._init_residue_constants(frame.get_rots().dtype, | |||
frame.get_rots().device) | |||
return frames_and_literature_positions_to_atom14_pos( | |||
frame, | |||
aatype, | |||
self.default_frames, | |||
self.group_idx, | |||
self.atom_mask, | |||
self.lit_positions, | |||
) |
@@ -0,0 +1,330 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import math | |||
from functools import partial | |||
from typing import List, Optional, Tuple | |||
import torch | |||
import torch.nn as nn | |||
from unicore.modules import LayerNorm | |||
from unicore.utils import (checkpoint_sequential, permute_final_dims, | |||
tensor_tree_map) | |||
from .attentions import (Attention, TriangleAttentionEnding, | |||
TriangleAttentionStarting, gen_attn_mask) | |||
from .common import (Linear, SimpleModuleList, Transition, | |||
bias_dropout_residual, chunk_layer, residual, | |||
tri_mul_residual) | |||
from .featurization import build_template_pair_feat_v2 | |||
from .triangle_multiplication import (TriangleMultiplicationIncoming, | |||
TriangleMultiplicationOutgoing) | |||
class TemplatePointwiseAttention(nn.Module): | |||
def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs): | |||
super(TemplatePointwiseAttention, self).__init__() | |||
self.inf = inf | |||
self.mha = Attention( | |||
d_pair, | |||
d_template, | |||
d_template, | |||
d_hid, | |||
num_heads, | |||
gating=False, | |||
) | |||
def _chunk( | |||
self, | |||
z: torch.Tensor, | |||
t: torch.Tensor, | |||
mask: torch.Tensor, | |||
chunk_size: int, | |||
) -> torch.Tensor: | |||
mha_inputs = { | |||
'q': z, | |||
'k': t, | |||
'v': t, | |||
'mask': mask, | |||
} | |||
return chunk_layer( | |||
self.mha, | |||
mha_inputs, | |||
chunk_size=chunk_size, | |||
num_batch_dims=len(z.shape[:-2]), | |||
) | |||
def forward( | |||
self, | |||
t: torch.Tensor, | |||
z: torch.Tensor, | |||
template_mask: Optional[torch.Tensor] = None, | |||
chunk_size: Optional[int] = None, | |||
) -> torch.Tensor: | |||
if template_mask is None: | |||
template_mask = t.new_ones(t.shape[:-3]) | |||
mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None, | |||
None, :] | |||
z = z.unsqueeze(-2) | |||
t = permute_final_dims(t, (1, 2, 0, 3)) | |||
if chunk_size is not None: | |||
z = self._chunk(z, t, mask, chunk_size) | |||
else: | |||
z = self.mha(z, t, t, mask=mask) | |||
z = z.squeeze(-2) | |||
return z | |||
class TemplateProjection(nn.Module): | |||
def __init__(self, d_template, d_pair, **kwargs): | |||
super(TemplateProjection, self).__init__() | |||
self.d_pair = d_pair | |||
self.act = nn.ReLU() | |||
self.output_linear = Linear(d_template, d_pair, init='relu') | |||
def forward(self, t, z) -> torch.Tensor: | |||
if t is None: | |||
# handle for non-template case | |||
shape = z.shape | |||
shape[-1] = self.d_pair | |||
t = torch.zeros(shape, dtype=z.dtype, device=z.device) | |||
t = self.act(t) | |||
z_t = self.output_linear(t) | |||
return z_t | |||
class TemplatePairStackBlock(nn.Module): | |||
def __init__( | |||
self, | |||
d_template: int, | |||
d_hid_tri_att: int, | |||
d_hid_tri_mul: int, | |||
num_heads: int, | |||
pair_transition_n: int, | |||
dropout_rate: float, | |||
tri_attn_first: bool, | |||
inf: float, | |||
**kwargs, | |||
): | |||
super(TemplatePairStackBlock, self).__init__() | |||
self.tri_att_start = TriangleAttentionStarting( | |||
d_template, | |||
d_hid_tri_att, | |||
num_heads, | |||
) | |||
self.tri_att_end = TriangleAttentionEnding( | |||
d_template, | |||
d_hid_tri_att, | |||
num_heads, | |||
) | |||
self.tri_mul_out = TriangleMultiplicationOutgoing( | |||
d_template, | |||
d_hid_tri_mul, | |||
) | |||
self.tri_mul_in = TriangleMultiplicationIncoming( | |||
d_template, | |||
d_hid_tri_mul, | |||
) | |||
self.pair_transition = Transition( | |||
d_template, | |||
pair_transition_n, | |||
) | |||
self.tri_attn_first = tri_attn_first | |||
self.dropout = dropout_rate | |||
self.row_dropout_share_dim = -3 | |||
self.col_dropout_share_dim = -2 | |||
def forward( | |||
self, | |||
s: torch.Tensor, | |||
mask: torch.Tensor, | |||
tri_start_attn_mask: torch.Tensor, | |||
tri_end_attn_mask: torch.Tensor, | |||
chunk_size: Optional[int] = None, | |||
block_size: Optional[int] = None, | |||
): | |||
if self.tri_attn_first: | |||
s = bias_dropout_residual( | |||
self.tri_att_start, | |||
s, | |||
self.tri_att_start( | |||
s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), | |||
self.row_dropout_share_dim, | |||
self.dropout, | |||
self.training, | |||
) | |||
s = bias_dropout_residual( | |||
self.tri_att_end, | |||
s, | |||
self.tri_att_end( | |||
s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), | |||
self.col_dropout_share_dim, | |||
self.dropout, | |||
self.training, | |||
) | |||
s = tri_mul_residual( | |||
self.tri_mul_out, | |||
s, | |||
self.tri_mul_out(s, mask=mask, block_size=block_size), | |||
self.row_dropout_share_dim, | |||
self.dropout, | |||
self.training, | |||
block_size=block_size, | |||
) | |||
s = tri_mul_residual( | |||
self.tri_mul_in, | |||
s, | |||
self.tri_mul_in(s, mask=mask, block_size=block_size), | |||
self.row_dropout_share_dim, | |||
self.dropout, | |||
self.training, | |||
block_size=block_size, | |||
) | |||
else: | |||
s = tri_mul_residual( | |||
self.tri_mul_out, | |||
s, | |||
self.tri_mul_out(s, mask=mask, block_size=block_size), | |||
self.row_dropout_share_dim, | |||
self.dropout, | |||
self.training, | |||
block_size=block_size, | |||
) | |||
s = tri_mul_residual( | |||
self.tri_mul_in, | |||
s, | |||
self.tri_mul_in(s, mask=mask, block_size=block_size), | |||
self.row_dropout_share_dim, | |||
self.dropout, | |||
self.training, | |||
block_size=block_size, | |||
) | |||
s = bias_dropout_residual( | |||
self.tri_att_start, | |||
s, | |||
self.tri_att_start( | |||
s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), | |||
self.row_dropout_share_dim, | |||
self.dropout, | |||
self.training, | |||
) | |||
s = bias_dropout_residual( | |||
self.tri_att_end, | |||
s, | |||
self.tri_att_end( | |||
s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), | |||
self.col_dropout_share_dim, | |||
self.dropout, | |||
self.training, | |||
) | |||
s = residual(s, self.pair_transition( | |||
s, | |||
chunk_size=chunk_size, | |||
), self.training) | |||
return s | |||
class TemplatePairStack(nn.Module): | |||
def __init__( | |||
self, | |||
d_template, | |||
d_hid_tri_att, | |||
d_hid_tri_mul, | |||
num_blocks, | |||
num_heads, | |||
pair_transition_n, | |||
dropout_rate, | |||
tri_attn_first, | |||
inf=1e9, | |||
**kwargs, | |||
): | |||
super(TemplatePairStack, self).__init__() | |||
self.blocks = SimpleModuleList() | |||
for _ in range(num_blocks): | |||
self.blocks.append( | |||
TemplatePairStackBlock( | |||
d_template=d_template, | |||
d_hid_tri_att=d_hid_tri_att, | |||
d_hid_tri_mul=d_hid_tri_mul, | |||
num_heads=num_heads, | |||
pair_transition_n=pair_transition_n, | |||
dropout_rate=dropout_rate, | |||
inf=inf, | |||
tri_attn_first=tri_attn_first, | |||
)) | |||
self.layer_norm = LayerNorm(d_template) | |||
def forward( | |||
self, | |||
single_templates: Tuple[torch.Tensor], | |||
mask: torch.tensor, | |||
tri_start_attn_mask: torch.Tensor, | |||
tri_end_attn_mask: torch.Tensor, | |||
templ_dim: int, | |||
chunk_size: int, | |||
block_size: int, | |||
return_mean: bool, | |||
): | |||
def one_template(i): | |||
(s, ) = checkpoint_sequential( | |||
functions=[ | |||
partial( | |||
b, | |||
mask=mask, | |||
tri_start_attn_mask=tri_start_attn_mask, | |||
tri_end_attn_mask=tri_end_attn_mask, | |||
chunk_size=chunk_size, | |||
block_size=block_size, | |||
) for b in self.blocks | |||
], | |||
input=(single_templates[i], ), | |||
) | |||
return s | |||
n_templ = len(single_templates) | |||
if n_templ > 0: | |||
new_single_templates = [one_template(0)] | |||
if return_mean: | |||
t = self.layer_norm(new_single_templates[0]) | |||
for i in range(1, n_templ): | |||
s = one_template(i) | |||
if return_mean: | |||
t = residual(t, self.layer_norm(s), self.training) | |||
else: | |||
new_single_templates.append(s) | |||
if return_mean: | |||
if n_templ > 0: | |||
t /= n_templ | |||
else: | |||
t = None | |||
else: | |||
t = torch.cat( | |||
[s.unsqueeze(templ_dim) for s in new_single_templates], | |||
dim=templ_dim) | |||
t = self.layer_norm(t) | |||
return t |
@@ -0,0 +1,158 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
from functools import partialmethod | |||
from typing import List, Optional | |||
import torch | |||
import torch.nn as nn | |||
from unicore.modules import LayerNorm | |||
from unicore.utils import permute_final_dims | |||
from .common import Linear | |||
class TriangleMultiplication(nn.Module): | |||
def __init__(self, d_pair, d_hid, outgoing=True): | |||
super(TriangleMultiplication, self).__init__() | |||
self.outgoing = outgoing | |||
self.linear_ab_p = Linear(d_pair, d_hid * 2) | |||
self.linear_ab_g = Linear(d_pair, d_hid * 2, init='gating') | |||
self.linear_g = Linear(d_pair, d_pair, init='gating') | |||
self.linear_z = Linear(d_hid, d_pair, init='final') | |||
self.layer_norm_in = LayerNorm(d_pair) | |||
self.layer_norm_out = LayerNorm(d_hid) | |||
self._alphafold_original_mode = False | |||
def _chunk_2d( | |||
self, | |||
z: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None, | |||
block_size: int = None, | |||
) -> torch.Tensor: | |||
# avoid too small chunk size | |||
# block_size = max(block_size, 256) | |||
new_z = z.new_zeros(z.shape) | |||
dim1 = z.shape[-3] | |||
def _slice_linear(z, linear: Linear, a=True): | |||
d_hid = linear.bias.shape[0] // 2 | |||
index = 0 if a else d_hid | |||
p = ( | |||
nn.functional.linear(z, linear.weight[index:index + d_hid]) | |||
+ linear.bias[index:index + d_hid]) | |||
return p | |||
def _chunk_projection(z, mask, a=True): | |||
p = _slice_linear(z, self.linear_ab_p, a) * mask | |||
p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a)) | |||
return p | |||
num_chunk = (dim1 + block_size - 1) // block_size | |||
for i in range(num_chunk): | |||
chunk_start = i * block_size | |||
chunk_end = min(chunk_start + block_size, dim1) | |||
if self.outgoing: | |||
a_chunk = _chunk_projection( | |||
z[..., chunk_start:chunk_end, :, :], | |||
mask[..., chunk_start:chunk_end, :, :], | |||
a=True, | |||
) | |||
a_chunk = permute_final_dims(a_chunk, (2, 0, 1)) | |||
else: | |||
a_chunk = _chunk_projection( | |||
z[..., :, chunk_start:chunk_end, :], | |||
mask[..., :, chunk_start:chunk_end, :], | |||
a=True, | |||
) | |||
a_chunk = a_chunk.transpose(-1, -3) | |||
for j in range(num_chunk): | |||
j_chunk_start = j * block_size | |||
j_chunk_end = min(j_chunk_start + block_size, dim1) | |||
if self.outgoing: | |||
b_chunk = _chunk_projection( | |||
z[..., j_chunk_start:j_chunk_end, :, :], | |||
mask[..., j_chunk_start:j_chunk_end, :, :], | |||
a=False, | |||
) | |||
b_chunk = b_chunk.transpose(-1, -3) | |||
else: | |||
b_chunk = _chunk_projection( | |||
z[..., :, j_chunk_start:j_chunk_end, :], | |||
mask[..., :, j_chunk_start:j_chunk_end, :], | |||
a=False, | |||
) | |||
b_chunk = permute_final_dims(b_chunk, (2, 0, 1)) | |||
x_chunk = torch.matmul(a_chunk, b_chunk) | |||
del b_chunk | |||
x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) | |||
x_chunk = self.layer_norm_out(x_chunk) | |||
x_chunk = self.linear_z(x_chunk) | |||
x_chunk *= torch.sigmoid( | |||
self.linear_g(z[..., chunk_start:chunk_end, | |||
j_chunk_start:j_chunk_end, :])) | |||
new_z[..., chunk_start:chunk_end, | |||
j_chunk_start:j_chunk_end, :] = x_chunk | |||
del x_chunk | |||
del a_chunk | |||
return new_z | |||
def forward( | |||
self, | |||
z: torch.Tensor, | |||
mask: Optional[torch.Tensor] = None, | |||
block_size=None, | |||
) -> torch.Tensor: | |||
mask = mask.unsqueeze(-1) | |||
if not self._alphafold_original_mode: | |||
# divided by 1/sqrt(dim) for numerical stability | |||
mask = mask * (mask.shape[-2]**-0.5) | |||
z = self.layer_norm_in(z) | |||
if not self.training and block_size is not None: | |||
return self._chunk_2d(z, mask, block_size=block_size) | |||
g = nn.functional.linear(z, self.linear_g.weight) | |||
if self.training: | |||
ab = self.linear_ab_p(z) * mask * torch.sigmoid( | |||
self.linear_ab_g(z)) | |||
else: | |||
ab = self.linear_ab_p(z) | |||
ab *= mask | |||
ab *= torch.sigmoid(self.linear_ab_g(z)) | |||
a, b = torch.chunk(ab, 2, dim=-1) | |||
del z, ab | |||
if self.outgoing: | |||
a = permute_final_dims(a, (2, 0, 1)) | |||
b = b.transpose(-1, -3) | |||
else: | |||
b = permute_final_dims(b, (2, 0, 1)) | |||
a = a.transpose(-1, -3) | |||
x = torch.matmul(a, b) | |||
del a, b | |||
x = permute_final_dims(x, (1, 2, 0)) | |||
x = self.layer_norm_out(x) | |||
x = nn.functional.linear(x, self.linear_z.weight) | |||
return x, g | |||
def get_output_bias(self): | |||
return self.linear_z.bias, self.linear_g.bias | |||
class TriangleMultiplicationOutgoing(TriangleMultiplication): | |||
__init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True) | |||
class TriangleMultiplicationIncoming(TriangleMultiplication): | |||
__init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False) |
@@ -0,0 +1 @@ | |||
""" Scripts for MSA & template searching. """ |
@@ -0,0 +1,483 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Parses the mmCIF file format.""" | |||
import collections | |||
import dataclasses | |||
import functools | |||
import io | |||
from typing import Any, Mapping, Optional, Sequence, Tuple | |||
from absl import logging | |||
from Bio import PDB | |||
from Bio.Data import SCOPData | |||
from Bio.PDB.MMCIFParser import MMCIFParser | |||
# Type aliases: | |||
ChainId = str | |||
PdbHeader = Mapping[str, Any] | |||
PdbStructure = PDB.Structure.Structure | |||
SeqRes = str | |||
MmCIFDict = Mapping[str, Sequence[str]] | |||
@dataclasses.dataclass(frozen=True) | |||
class Monomer: | |||
id: str | |||
num: int | |||
# Note - mmCIF format provides no guarantees on the type of author-assigned | |||
# sequence numbers. They need not be integers. | |||
@dataclasses.dataclass(frozen=True) | |||
class AtomSite: | |||
residue_name: str | |||
author_chain_id: str | |||
mmcif_chain_id: str | |||
author_seq_num: str | |||
mmcif_seq_num: int | |||
insertion_code: str | |||
hetatm_atom: str | |||
model_num: int | |||
# Used to map SEQRES index to a residue in the structure. | |||
@dataclasses.dataclass(frozen=True) | |||
class ResiduePosition: | |||
chain_id: str | |||
residue_number: int | |||
insertion_code: str | |||
@dataclasses.dataclass(frozen=True) | |||
class ResidueAtPosition: | |||
position: Optional[ResiduePosition] | |||
name: str | |||
is_missing: bool | |||
hetflag: str | |||
@dataclasses.dataclass(frozen=True) | |||
class MmcifObject: | |||
"""Representation of a parsed mmCIF file. | |||
Contains: | |||
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all | |||
files being processed. | |||
header: Biopython header. | |||
structure: Biopython structure. | |||
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. | |||
{'A': 'ABCDEFG'} | |||
seqres_to_structure: Dict; for each chain_id contains a mapping between | |||
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, 1: ResidueAtPosition, ...}} | |||
raw_string: The raw string used to construct the MmcifObject. | |||
""" | |||
file_id: str | |||
header: PdbHeader | |||
structure: PdbStructure | |||
chain_to_seqres: Mapping[ChainId, SeqRes] | |||
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] | |||
raw_string: Any | |||
mmcif_to_author_chain_id: Mapping[ChainId, ChainId] | |||
valid_chains: Mapping[ChainId, str] | |||
@dataclasses.dataclass(frozen=True) | |||
class ParsingResult: | |||
"""Returned by the parse function. | |||
Contains: | |||
mmcif_object: A MmcifObject, may be None if no chain could be successfully | |||
parsed. | |||
errors: A dict mapping (file_id, chain_id) to any exception generated. | |||
""" | |||
mmcif_object: Optional[MmcifObject] | |||
errors: Mapping[Tuple[str, str], Any] | |||
class ParseError(Exception): | |||
"""An error indicating that an mmCIF file could not be parsed.""" | |||
def mmcif_loop_to_list(prefix: str, | |||
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: | |||
"""Extracts loop associated with a prefix from mmCIF data as a list. | |||
Reference for loop_ in mmCIF: | |||
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html | |||
Args: | |||
prefix: Prefix shared by each of the data items in the loop. | |||
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, | |||
_entity_poly_seq.mon_id. Should include the trailing period. | |||
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython | |||
parser. | |||
Returns: | |||
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. | |||
""" | |||
cols = [] | |||
data = [] | |||
for key, value in parsed_info.items(): | |||
if key.startswith(prefix): | |||
cols.append(key) | |||
data.append(value) | |||
assert all([ | |||
len(xs) == len(data[0]) for xs in data | |||
]), ('mmCIF error: Not all loops are the same length: %s' % cols) | |||
return [dict(zip(cols, xs)) for xs in zip(*data)] | |||
def mmcif_loop_to_dict( | |||
prefix: str, | |||
index: str, | |||
parsed_info: MmCIFDict, | |||
) -> Mapping[str, Mapping[str, str]]: | |||
"""Extracts loop associated with a prefix from mmCIF data as a dictionary. | |||
Args: | |||
prefix: Prefix shared by each of the data items in the loop. | |||
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, | |||
_entity_poly_seq.mon_id. Should include the trailing period. | |||
index: Which item of loop data should serve as the key. | |||
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython | |||
parser. | |||
Returns: | |||
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, | |||
indexed by the index column. | |||
""" | |||
entries = mmcif_loop_to_list(prefix, parsed_info) | |||
return {entry[index]: entry for entry in entries} | |||
@functools.lru_cache(16, typed=False) | |||
def fast_parse(*, | |||
file_id: str, | |||
mmcif_string: str, | |||
catch_all_errors: bool = True) -> ParsingResult: | |||
"""Entry point, parses an mmcif_string. | |||
Args: | |||
file_id: A string identifier for this file. Should be unique within the | |||
collection of files being processed. | |||
mmcif_string: Contents of an mmCIF file. | |||
catch_all_errors: If True, all exceptions are caught and error messages are | |||
returned as part of the ParsingResult. If False exceptions will be allowed | |||
to propagate. | |||
Returns: | |||
A ParsingResult. | |||
""" | |||
errors = {} | |||
try: | |||
parser = MMCIFParser(QUIET=True) | |||
# handle = io.StringIO(mmcif_string) | |||
# full_structure = parser.get_structure('', handle) | |||
parsed_info = parser._mmcif_dict # pylint:disable=protected-access | |||
# Ensure all values are lists, even if singletons. | |||
for key, value in parsed_info.items(): | |||
if not isinstance(value, list): | |||
parsed_info[key] = [value] | |||
header = _get_header(parsed_info) | |||
# Determine the protein chains, and their start numbers according to the | |||
# internal mmCIF numbering scheme (likely but not guaranteed to be 1). | |||
valid_chains = _get_protein_chains(parsed_info=parsed_info) | |||
if not valid_chains: | |||
return ParsingResult( | |||
None, {(file_id, ''): 'No protein chains found in this file.'}) | |||
mmcif_to_author_chain_id = {} | |||
# seq_to_structure_mappings = {} | |||
for atom in _get_atom_site_list(parsed_info): | |||
if atom.model_num != '1': | |||
# We only process the first model at the moment. | |||
continue | |||
mmcif_to_author_chain_id[ | |||
atom.mmcif_chain_id] = atom.author_chain_id | |||
mmcif_object = MmcifObject( | |||
file_id=file_id, | |||
header=header, | |||
structure=None, | |||
chain_to_seqres=None, | |||
seqres_to_structure=None, | |||
raw_string=parsed_info, | |||
mmcif_to_author_chain_id=mmcif_to_author_chain_id, | |||
valid_chains=valid_chains, | |||
) | |||
return ParsingResult(mmcif_object=mmcif_object, errors=errors) | |||
except Exception as e: # pylint:disable=broad-except | |||
errors[(file_id, '')] = e | |||
if not catch_all_errors: | |||
raise | |||
return ParsingResult(mmcif_object=None, errors=errors) | |||
@functools.lru_cache(16, typed=False) | |||
def parse(*, | |||
file_id: str, | |||
mmcif_string: str, | |||
catch_all_errors: bool = True) -> ParsingResult: | |||
"""Entry point, parses an mmcif_string. | |||
Args: | |||
file_id: A string identifier for this file. Should be unique within the | |||
collection of files being processed. | |||
mmcif_string: Contents of an mmCIF file. | |||
catch_all_errors: If True, all exceptions are caught and error messages are | |||
returned as part of the ParsingResult. If False exceptions will be allowed | |||
to propagate. | |||
Returns: | |||
A ParsingResult. | |||
""" | |||
errors = {} | |||
try: | |||
parser = PDB.MMCIFParser(QUIET=True) | |||
handle = io.StringIO(mmcif_string) | |||
full_structure = parser.get_structure('', handle) | |||
first_model_structure = _get_first_model(full_structure) | |||
# Extract the _mmcif_dict from the parser, which contains useful fields not | |||
# reflected in the Biopython structure. | |||
parsed_info = parser._mmcif_dict # pylint:disable=protected-access | |||
# Ensure all values are lists, even if singletons. | |||
for key, value in parsed_info.items(): | |||
if not isinstance(value, list): | |||
parsed_info[key] = [value] | |||
header = _get_header(parsed_info) | |||
# Determine the protein chains, and their start numbers according to the | |||
# internal mmCIF numbering scheme (likely but not guaranteed to be 1). | |||
valid_chains = _get_protein_chains(parsed_info=parsed_info) | |||
if not valid_chains: | |||
return ParsingResult( | |||
None, {(file_id, ''): 'No protein chains found in this file.'}) | |||
seq_start_num = { | |||
chain_id: min([monomer.num for monomer in seq]) | |||
for chain_id, seq in valid_chains.items() | |||
} | |||
# Loop over the atoms for which we have coordinates. Populate two mappings: | |||
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used | |||
# the authors / Biopython). | |||
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). | |||
mmcif_to_author_chain_id = {} | |||
seq_to_structure_mappings = {} | |||
for atom in _get_atom_site_list(parsed_info): | |||
if atom.model_num != '1': | |||
# We only process the first model at the moment. | |||
continue | |||
mmcif_to_author_chain_id[ | |||
atom.mmcif_chain_id] = atom.author_chain_id | |||
if atom.mmcif_chain_id in valid_chains: | |||
hetflag = ' ' | |||
if atom.hetatm_atom == 'HETATM': | |||
# Water atoms are assigned a special hetflag of W in Biopython. We | |||
# need to do the same, so that this hetflag can be used to fetch | |||
# a residue from the Biopython structure by id. | |||
if atom.residue_name in ('HOH', 'WAT'): | |||
hetflag = 'W' | |||
else: | |||
hetflag = 'H_' + atom.residue_name | |||
insertion_code = atom.insertion_code | |||
if not _is_set(atom.insertion_code): | |||
insertion_code = ' ' | |||
position = ResiduePosition( | |||
chain_id=atom.author_chain_id, | |||
residue_number=int(atom.author_seq_num), | |||
insertion_code=insertion_code, | |||
) | |||
seq_idx = int( | |||
atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] | |||
current = seq_to_structure_mappings.get( | |||
atom.author_chain_id, {}) | |||
current[seq_idx] = ResidueAtPosition( | |||
position=position, | |||
name=atom.residue_name, | |||
is_missing=False, | |||
hetflag=hetflag, | |||
) | |||
seq_to_structure_mappings[atom.author_chain_id] = current | |||
# Add missing residue information to seq_to_structure_mappings. | |||
for chain_id, seq_info in valid_chains.items(): | |||
author_chain = mmcif_to_author_chain_id[chain_id] | |||
current_mapping = seq_to_structure_mappings[author_chain] | |||
for idx, monomer in enumerate(seq_info): | |||
if idx not in current_mapping: | |||
current_mapping[idx] = ResidueAtPosition( | |||
position=None, | |||
name=monomer.id, | |||
is_missing=True, | |||
hetflag=' ') | |||
author_chain_to_sequence = {} | |||
for chain_id, seq_info in valid_chains.items(): | |||
author_chain = mmcif_to_author_chain_id[chain_id] | |||
seq = [] | |||
for monomer in seq_info: | |||
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') | |||
seq.append(code if len(code) == 1 else 'X') | |||
seq = ''.join(seq) | |||
author_chain_to_sequence[author_chain] = seq | |||
mmcif_object = MmcifObject( | |||
file_id=file_id, | |||
header=header, | |||
structure=first_model_structure, | |||
chain_to_seqres=author_chain_to_sequence, | |||
seqres_to_structure=seq_to_structure_mappings, | |||
raw_string=parsed_info, | |||
mmcif_to_author_chain_id=mmcif_to_author_chain_id, | |||
valid_chains=valid_chains, | |||
) | |||
return ParsingResult(mmcif_object=mmcif_object, errors=errors) | |||
except Exception as e: # pylint:disable=broad-except | |||
errors[(file_id, '')] = e | |||
if not catch_all_errors: | |||
raise | |||
return ParsingResult(mmcif_object=None, errors=errors) | |||
def _get_first_model(structure: PdbStructure) -> PdbStructure: | |||
"""Returns the first model in a Biopython structure.""" | |||
return next(structure.get_models()) | |||
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 | |||
def get_release_date(parsed_info: MmCIFDict) -> str: | |||
"""Returns the oldest revision date.""" | |||
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] | |||
return min(revision_dates) | |||
def _get_header(parsed_info: MmCIFDict) -> PdbHeader: | |||
"""Returns a basic header containing method, release date and resolution.""" | |||
header = {} | |||
experiments = mmcif_loop_to_list('_exptl.', parsed_info) | |||
header['structure_method'] = ','.join( | |||
[experiment['_exptl.method'].lower() for experiment in experiments]) | |||
# Note: The release_date here corresponds to the oldest revision. We prefer to | |||
# use this for dataset filtering over the deposition_date. | |||
if '_pdbx_audit_revision_history.revision_date' in parsed_info: | |||
header['release_date'] = get_release_date(parsed_info) | |||
else: | |||
logging.warning('Could not determine release_date: %s', | |||
parsed_info['_entry.id']) | |||
header['resolution'] = 0.00 | |||
for res_key in ( | |||
'_refine.ls_d_res_high', | |||
'_em_3d_reconstruction.resolution', | |||
'_reflns.d_resolution_high', | |||
): | |||
if res_key in parsed_info: | |||
try: | |||
raw_resolution = parsed_info[res_key][0] | |||
header['resolution'] = float(raw_resolution) | |||
except ValueError: | |||
logging.debug('Invalid resolution format: %s', | |||
parsed_info[res_key]) | |||
return header | |||
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: | |||
"""Returns list of atom sites; contains data not present in the structure.""" | |||
return [ | |||
AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension | |||
parsed_info['_atom_site.label_comp_id'], | |||
parsed_info['_atom_site.auth_asym_id'], | |||
parsed_info['_atom_site.label_asym_id'], | |||
parsed_info['_atom_site.auth_seq_id'], | |||
parsed_info['_atom_site.label_seq_id'], | |||
parsed_info['_atom_site.pdbx_PDB_ins_code'], | |||
parsed_info['_atom_site.group_PDB'], | |||
parsed_info['_atom_site.pdbx_PDB_model_num'], | |||
) | |||
] | |||
def _get_protein_chains( | |||
*, parsed_info: Mapping[str, | |||
Any]) -> Mapping[ChainId, Sequence[Monomer]]: | |||
"""Extracts polymer information for protein chains only. | |||
Args: | |||
parsed_info: _mmcif_dict produced by the Biopython parser. | |||
Returns: | |||
A dict mapping mmcif chain id to a list of Monomers. | |||
""" | |||
# Get polymer information for each entity in the structure. | |||
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) | |||
polymers = collections.defaultdict(list) | |||
for entity_poly_seq in entity_poly_seqs: | |||
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( | |||
Monomer( | |||
id=entity_poly_seq['_entity_poly_seq.mon_id'], | |||
num=int(entity_poly_seq['_entity_poly_seq.num']), | |||
)) | |||
# Get chemical compositions. Will allow us to identify which of these polymers | |||
# are proteins. | |||
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', | |||
parsed_info) | |||
# Get chains information for each entity. Necessary so that we can return a | |||
# dict keyed on chain id rather than entity. | |||
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) | |||
entity_to_mmcif_chains = collections.defaultdict(list) | |||
for struct_asym in struct_asyms: | |||
chain_id = struct_asym['_struct_asym.id'] | |||
entity_id = struct_asym['_struct_asym.entity_id'] | |||
entity_to_mmcif_chains[entity_id].append(chain_id) | |||
# Identify and return the valid protein chains. | |||
valid_chains = {} | |||
for entity_id, seq_info in polymers.items(): | |||
chain_ids = entity_to_mmcif_chains[entity_id] | |||
# Reject polymers without any peptide-like components, such as DNA/RNA. | |||
if any([ | |||
'peptide' in chem_comps[monomer.id]['_chem_comp.type'] | |||
for monomer in seq_info | |||
]): | |||
for chain_id in chain_ids: | |||
valid_chains[chain_id] = seq_info | |||
return valid_chains | |||
def _is_set(data: str) -> bool: | |||
"""Returns False if data is a special mmCIF character indicating 'unset'.""" | |||
return data not in ('.', '?') |
@@ -0,0 +1,88 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Utilities for extracting identifiers from MSA sequence descriptions.""" | |||
import dataclasses | |||
import re | |||
from typing import Optional | |||
# Sequences coming from UniProtKB database come in the | |||
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` | |||
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). | |||
_UNIPROT_PATTERN = re.compile( | |||
r""" | |||
^ | |||
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot | |||
(?:tr|sp) | |||
\| | |||
# A primary accession number of the UniProtKB entry. | |||
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10}) | |||
# Occasionally there is a _0 or _1 isoform suffix, which we ignore. | |||
(?:_\d)? | |||
\| | |||
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic | |||
# protein ID code. | |||
(?:[A-Za-z0-9]+) | |||
_ | |||
# A mnemonic species identification code. | |||
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5}) | |||
# Small BFD uses a final value after an underscore, which we ignore. | |||
(?:_\d+)? | |||
$ | |||
""", | |||
re.VERBOSE, | |||
) | |||
@dataclasses.dataclass(frozen=True) | |||
class Identifiers: | |||
species_id: str = '' | |||
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: | |||
"""Gets accession id and species from an msa sequence identifier. | |||
The sequence identifier has the format specified by | |||
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. | |||
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` | |||
Args: | |||
msa_sequence_identifier: a sequence identifier. | |||
Returns: | |||
An `Identifiers` instance with a species_id. These | |||
can be empty in the case where no identifier was found. | |||
""" | |||
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) | |||
if matches: | |||
return Identifiers(species_id=matches.group('SpeciesIdentifier')) | |||
return Identifiers() | |||
def _extract_sequence_identifier(description: str) -> Optional[str]: | |||
"""Extracts sequence identifier from description. Returns None if no match.""" | |||
split_description = description.split() | |||
if split_description: | |||
return split_description[0].partition('/')[0] | |||
else: | |||
return None | |||
def get_identifiers(description: str) -> Identifiers: | |||
"""Computes extra MSA features from the description.""" | |||
sequence_identifier = _extract_sequence_identifier(description) | |||
if sequence_identifier is None: | |||
return Identifiers() | |||
else: | |||
return _parse_sequence_identifier(sequence_identifier) |
@@ -0,0 +1,627 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Functions for parsing various file formats.""" | |||
import collections | |||
import dataclasses | |||
import itertools | |||
import re | |||
import string | |||
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple | |||
DeletionMatrix = Sequence[Sequence[int]] | |||
@dataclasses.dataclass(frozen=True) | |||
class Msa: | |||
"""Class representing a parsed MSA file.""" | |||
sequences: Sequence[str] | |||
deletion_matrix: DeletionMatrix | |||
descriptions: Sequence[str] | |||
def __post_init__(self): | |||
if not (len(self.sequences) == len(self.deletion_matrix) == len( | |||
self.descriptions)): | |||
raise ValueError( | |||
'All fields for an MSA must have the same length. ' | |||
f'Got {len(self.sequences)} sequences, ' | |||
f'{len(self.deletion_matrix)} rows in the deletion matrix and ' | |||
f'{len(self.descriptions)} descriptions.') | |||
def __len__(self): | |||
return len(self.sequences) | |||
def truncate(self, max_seqs: int): | |||
return Msa( | |||
sequences=self.sequences[:max_seqs], | |||
deletion_matrix=self.deletion_matrix[:max_seqs], | |||
descriptions=self.descriptions[:max_seqs], | |||
) | |||
@dataclasses.dataclass(frozen=True) | |||
class TemplateHit: | |||
"""Class representing a template hit.""" | |||
index: int | |||
name: str | |||
aligned_cols: int | |||
sum_probs: Optional[float] | |||
query: str | |||
hit_sequence: str | |||
indices_query: List[int] | |||
indices_hit: List[int] | |||
def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: | |||
"""Parses FASTA string and returns list of strings with amino-acid sequences. | |||
Arguments: | |||
fasta_string: The string contents of a FASTA file. | |||
Returns: | |||
A tuple of two lists: | |||
* A list of sequences. | |||
* A list of sequence descriptions taken from the comment lines. In the | |||
same order as the sequences. | |||
""" | |||
sequences = [] | |||
descriptions = [] | |||
index = -1 | |||
for line in fasta_string.splitlines(): | |||
line = line.strip() | |||
if line.startswith('>'): | |||
index += 1 | |||
descriptions.append(line[1:]) # Remove the '>' at the beginning. | |||
sequences.append('') | |||
continue | |||
elif not line: | |||
continue # Skip blank lines. | |||
sequences[index] += line | |||
return sequences, descriptions | |||
def parse_stockholm(stockholm_string: str) -> Msa: | |||
"""Parses sequences and deletion matrix from stockholm format alignment. | |||
Args: | |||
stockholm_string: The string contents of a stockholm file. The first | |||
sequence in the file should be the query sequence. | |||
Returns: | |||
A tuple of: | |||
* A list of sequences that have been aligned to the query. These | |||
might contain duplicates. | |||
* The deletion matrix for the alignment as a list of lists. The element | |||
at `deletion_matrix[i][j]` is the number of residues deleted from | |||
the aligned sequence i at residue position j. | |||
* The names of the targets matched, including the jackhmmer subsequence | |||
suffix. | |||
""" | |||
name_to_sequence = collections.OrderedDict() | |||
for line in stockholm_string.splitlines(): | |||
line = line.strip() | |||
if not line or line.startswith(('#', '//')): | |||
continue | |||
name, sequence = line.split() | |||
if name not in name_to_sequence: | |||
name_to_sequence[name] = '' | |||
name_to_sequence[name] += sequence | |||
msa = [] | |||
deletion_matrix = [] | |||
query = '' | |||
keep_columns = [] | |||
for seq_index, sequence in enumerate(name_to_sequence.values()): | |||
if seq_index == 0: | |||
# Gather the columns with gaps from the query | |||
query = sequence | |||
keep_columns = [i for i, res in enumerate(query) if res != '-'] | |||
# Remove the columns with gaps in the query from all sequences. | |||
aligned_sequence = ''.join([sequence[c] for c in keep_columns]) | |||
msa.append(aligned_sequence) | |||
# Count the number of deletions w.r.t. query. | |||
deletion_vec = [] | |||
deletion_count = 0 | |||
for seq_res, query_res in zip(sequence, query): | |||
if seq_res != '-' or query_res != '-': | |||
if query_res == '-': | |||
deletion_count += 1 | |||
else: | |||
deletion_vec.append(deletion_count) | |||
deletion_count = 0 | |||
deletion_matrix.append(deletion_vec) | |||
return Msa( | |||
sequences=msa, | |||
deletion_matrix=deletion_matrix, | |||
descriptions=list(name_to_sequence.keys()), | |||
) | |||
def parse_a3m(a3m_string: str) -> Msa: | |||
"""Parses sequences and deletion matrix from a3m format alignment. | |||
Args: | |||
a3m_string: The string contents of a a3m file. The first sequence in the | |||
file should be the query sequence. | |||
Returns: | |||
A tuple of: | |||
* A list of sequences that have been aligned to the query. These | |||
might contain duplicates. | |||
* The deletion matrix for the alignment as a list of lists. The element | |||
at `deletion_matrix[i][j]` is the number of residues deleted from | |||
the aligned sequence i at residue position j. | |||
* A list of descriptions, one per sequence, from the a3m file. | |||
""" | |||
sequences, descriptions = parse_fasta(a3m_string) | |||
deletion_matrix = [] | |||
for msa_sequence in sequences: | |||
deletion_vec = [] | |||
deletion_count = 0 | |||
for j in msa_sequence: | |||
if j.islower(): | |||
deletion_count += 1 | |||
else: | |||
deletion_vec.append(deletion_count) | |||
deletion_count = 0 | |||
deletion_matrix.append(deletion_vec) | |||
# Make the MSA matrix out of aligned (deletion-free) sequences. | |||
deletion_table = str.maketrans('', '', string.ascii_lowercase) | |||
aligned_sequences = [s.translate(deletion_table) for s in sequences] | |||
return Msa( | |||
sequences=aligned_sequences, | |||
deletion_matrix=deletion_matrix, | |||
descriptions=descriptions, | |||
) | |||
def _convert_sto_seq_to_a3m(query_non_gaps: Sequence[bool], | |||
sto_seq: str) -> Iterable[str]: | |||
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): | |||
if is_query_res_non_gap: | |||
yield sequence_res | |||
elif sequence_res != '-': | |||
yield sequence_res.lower() | |||
def convert_stockholm_to_a3m( | |||
stockholm_format: str, | |||
max_sequences: Optional[int] = None, | |||
remove_first_row_gaps: bool = True, | |||
) -> str: | |||
"""Converts MSA in Stockholm format to the A3M format.""" | |||
descriptions = {} | |||
sequences = {} | |||
reached_max_sequences = False | |||
for line in stockholm_format.splitlines(): | |||
reached_max_sequences = max_sequences and len( | |||
sequences) >= max_sequences | |||
if line.strip() and not line.startswith(('#', '//')): | |||
# Ignore blank lines, markup and end symbols - remainder are alignment | |||
# sequence parts. | |||
seqname, aligned_seq = line.split(maxsplit=1) | |||
if seqname not in sequences: | |||
if reached_max_sequences: | |||
continue | |||
sequences[seqname] = '' | |||
sequences[seqname] += aligned_seq | |||
for line in stockholm_format.splitlines(): | |||
if line[:4] == '#=GS': | |||
# Description row - example format is: | |||
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... | |||
columns = line.split(maxsplit=3) | |||
seqname, feature = columns[1:3] | |||
value = columns[3] if len(columns) == 4 else '' | |||
if feature != 'DE': | |||
continue | |||
if reached_max_sequences and seqname not in sequences: | |||
continue | |||
descriptions[seqname] = value | |||
if len(descriptions) == len(sequences): | |||
break | |||
# Convert sto format to a3m line by line | |||
a3m_sequences = {} | |||
if remove_first_row_gaps: | |||
# query_sequence is assumed to be the first sequence | |||
query_sequence = next(iter(sequences.values())) | |||
query_non_gaps = [res != '-' for res in query_sequence] | |||
for seqname, sto_sequence in sequences.items(): | |||
# Dots are optional in a3m format and are commonly removed. | |||
out_sequence = sto_sequence.replace('.', '') | |||
if remove_first_row_gaps: | |||
out_sequence = ''.join( | |||
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence)) | |||
a3m_sequences[seqname] = out_sequence | |||
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" | |||
for k in a3m_sequences) | |||
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. | |||
def _keep_line(line: str, seqnames: Set[str]) -> bool: | |||
"""Function to decide which lines to keep.""" | |||
if not line.strip(): | |||
return True | |||
if line.strip() == '//': # End tag | |||
return True | |||
if line.startswith('# STOCKHOLM'): # Start tag | |||
return True | |||
if line.startswith('#=GC RF'): # Reference Annotation Line | |||
return True | |||
if line[:4] == '#=GS': # Description lines - keep if sequence in list. | |||
_, seqname, _ = line.split(maxsplit=2) | |||
return seqname in seqnames | |||
elif line.startswith('#'): # Other markup - filter out | |||
return False | |||
else: # Alignment data - keep if sequence in list. | |||
seqname = line.partition(' ')[0] | |||
return seqname in seqnames | |||
def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str: | |||
"""Truncates a stockholm file to a maximum number of sequences.""" | |||
seqnames = set() | |||
filtered_lines = [] | |||
for line in stockholm_msa.splitlines(): | |||
if line.strip() and not line.startswith(('#', '//')): | |||
# Ignore blank lines, markup and end symbols - remainder are alignment | |||
# sequence parts. | |||
seqname = line.partition(' ')[0] | |||
seqnames.add(seqname) | |||
if len(seqnames) >= max_sequences: | |||
break | |||
for line in stockholm_msa.splitlines(): | |||
if _keep_line(line, seqnames): | |||
filtered_lines.append(line) | |||
return '\n'.join(filtered_lines) + '\n' | |||
def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str: | |||
"""Removes empty columns (dashes-only) from a Stockholm MSA.""" | |||
processed_lines = {} | |||
unprocessed_lines = {} | |||
for i, line in enumerate(stockholm_msa.splitlines()): | |||
if line.startswith('#=GC RF'): | |||
reference_annotation_i = i | |||
reference_annotation_line = line | |||
# Reached the end of this chunk of the alignment. Process chunk. | |||
_, _, first_alignment = line.rpartition(' ') | |||
mask = [] | |||
for j in range(len(first_alignment)): | |||
for _, unprocessed_line in unprocessed_lines.items(): | |||
prefix, _, alignment = unprocessed_line.rpartition(' ') | |||
if alignment[j] != '-': | |||
mask.append(True) | |||
break | |||
else: # Every row contained a hyphen - empty column. | |||
mask.append(False) | |||
# Add reference annotation for processing with mask. | |||
unprocessed_lines[ | |||
reference_annotation_i] = reference_annotation_line | |||
if not any( | |||
mask | |||
): # All columns were empty. Output empty lines for chunk. | |||
for line_index in unprocessed_lines: | |||
processed_lines[line_index] = '' | |||
else: | |||
for line_index, unprocessed_line in unprocessed_lines.items(): | |||
prefix, _, alignment = unprocessed_line.rpartition(' ') | |||
masked_alignment = ''.join( | |||
itertools.compress(alignment, mask)) | |||
processed_lines[ | |||
line_index] = f'{prefix} {masked_alignment}' | |||
# Clear raw_alignments. | |||
unprocessed_lines = {} | |||
elif line.strip() and not line.startswith(('#', '//')): | |||
unprocessed_lines[i] = line | |||
else: | |||
processed_lines[i] = line | |||
return '\n'.join((processed_lines[i] for i in range(len(processed_lines)))) | |||
def deduplicate_stockholm_msa(stockholm_msa: str) -> str: | |||
"""Remove duplicate sequences (ignoring insertions wrt query).""" | |||
sequence_dict = collections.defaultdict(str) | |||
# First we must extract all sequences from the MSA. | |||
for line in stockholm_msa.splitlines(): | |||
# Only consider the alignments - ignore reference annotation, empty lines, | |||
# descriptions or markup. | |||
if line.strip() and not line.startswith(('#', '//')): | |||
line = line.strip() | |||
seqname, alignment = line.split() | |||
sequence_dict[seqname] += alignment | |||
seen_sequences = set() | |||
seqnames = set() | |||
# First alignment is the query. | |||
query_align = next(iter(sequence_dict.values())) | |||
mask = [c != '-' for c in query_align] # Mask is False for insertions. | |||
for seqname, alignment in sequence_dict.items(): | |||
# Apply mask to remove all insertions from the string. | |||
masked_alignment = ''.join(itertools.compress(alignment, mask)) | |||
if masked_alignment in seen_sequences: | |||
continue | |||
else: | |||
seen_sequences.add(masked_alignment) | |||
seqnames.add(seqname) | |||
filtered_lines = [] | |||
for line in stockholm_msa.splitlines(): | |||
if _keep_line(line, seqnames): | |||
filtered_lines.append(line) | |||
return '\n'.join(filtered_lines) + '\n' | |||
def _get_hhr_line_regex_groups(regex_pattern: str, | |||
line: str) -> Sequence[Optional[str]]: | |||
match = re.match(regex_pattern, line) | |||
if match is None: | |||
raise RuntimeError(f'Could not parse query line {line}') | |||
return match.groups() | |||
def _update_hhr_residue_indices_list(sequence: str, start_index: int, | |||
indices_list: List[int]): | |||
"""Computes the relative indices for each residue with respect to the original sequence.""" | |||
counter = start_index | |||
for symbol in sequence: | |||
if symbol == '-': | |||
indices_list.append(-1) | |||
else: | |||
indices_list.append(counter) | |||
counter += 1 | |||
def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: | |||
"""Parses the detailed HMM HMM comparison section for a single Hit. | |||
This works on .hhr files generated from both HHBlits and HHSearch. | |||
Args: | |||
detailed_lines: A list of lines from a single comparison section between 2 | |||
sequences (which each have their own HMM's) | |||
Returns: | |||
A dictionary with the information from that detailed comparison section | |||
Raises: | |||
RuntimeError: If a certain line cannot be processed | |||
""" | |||
# Parse first 2 lines. | |||
number_of_hit = int(detailed_lines[0].split()[-1]) | |||
name_hit = detailed_lines[1][1:] | |||
# Parse the summary line. | |||
pattern = ( | |||
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' | |||
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' | |||
']*Template_Neff=(.*)') | |||
match = re.match(pattern, detailed_lines[2]) | |||
if match is None: | |||
raise RuntimeError( | |||
'Could not parse section: %s. Expected this: \n%s to contain summary.' | |||
% (detailed_lines, detailed_lines[2])) | |||
(_, _, _, aligned_cols, _, _, sum_probs, | |||
_) = [float(x) for x in match.groups()] | |||
# The next section reads the detailed comparisons. These are in a 'human | |||
# readable' format which has a fixed length. The strategy employed is to | |||
# assume that each block starts with the query sequence line, and to parse | |||
# that with a regexp in order to deduce the fixed length used for that block. | |||
query = '' | |||
hit_sequence = '' | |||
indices_query = [] | |||
indices_hit = [] | |||
length_block = None | |||
for line in detailed_lines[3:]: | |||
# Parse the query sequence line | |||
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') | |||
and not line.startswith('Q ss_pred') | |||
and not line.startswith('Q Consensus')): | |||
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse | |||
# everything after that. | |||
# start sequence end total_sequence_length | |||
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' | |||
groups = _get_hhr_line_regex_groups(patt, line[17:]) | |||
# Get the length of the parsed block using the start and finish indices, | |||
# and ensure it is the same as the actual block length. | |||
start = int(groups[0]) - 1 # Make index zero based. | |||
delta_query = groups[1] | |||
end = int(groups[2]) | |||
num_insertions = len([x for x in delta_query if x == '-']) | |||
length_block = end - start + num_insertions | |||
assert length_block == len(delta_query) | |||
# Update the query sequence and indices list. | |||
query += delta_query | |||
_update_hhr_residue_indices_list(delta_query, start, indices_query) | |||
elif line.startswith('T '): | |||
# Parse the hit sequence. | |||
if (not line.startswith('T ss_dssp') | |||
and not line.startswith('T ss_pred') | |||
and not line.startswith('T Consensus')): | |||
# Thus the first 17 characters must be 'T <hit_name> ', and we can | |||
# parse everything after that. | |||
# start sequence end total_sequence_length | |||
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' | |||
groups = _get_hhr_line_regex_groups(patt, line[17:]) | |||
start = int(groups[0]) - 1 # Make index zero based. | |||
delta_hit_sequence = groups[1] | |||
assert length_block == len(delta_hit_sequence) | |||
# Update the hit sequence and indices list. | |||
hit_sequence += delta_hit_sequence | |||
_update_hhr_residue_indices_list(delta_hit_sequence, start, | |||
indices_hit) | |||
return TemplateHit( | |||
index=number_of_hit, | |||
name=name_hit, | |||
aligned_cols=int(aligned_cols), | |||
sum_probs=sum_probs, | |||
query=query, | |||
hit_sequence=hit_sequence, | |||
indices_query=indices_query, | |||
indices_hit=indices_hit, | |||
) | |||
def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: | |||
"""Parses the content of an entire HHR file.""" | |||
lines = hhr_string.splitlines() | |||
# Each .hhr file starts with a results table, then has a sequence of hit | |||
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We | |||
# iterate through each paragraph to parse each hit. | |||
block_starts = [ | |||
i for i, line in enumerate(lines) if line.startswith('No ') | |||
] | |||
hits = [] | |||
if block_starts: | |||
block_starts.append(len(lines)) # Add the end of the final block. | |||
for i in range(len(block_starts) - 1): | |||
hits.append( | |||
_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) | |||
return hits | |||
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: | |||
"""Parse target to e-value mapping parsed from Jackhmmer tblout string.""" | |||
e_values = {'query': 0} | |||
lines = [line for line in tblout.splitlines() if line[0] != '#'] | |||
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are | |||
# space-delimited. Relevant fields are (1) target name: and | |||
# (5) E-value (full sequence) (numbering from 1). | |||
for line in lines: | |||
fields = line.split() | |||
e_value = fields[4] | |||
target_name = fields[0] | |||
e_values[target_name] = float(e_value) | |||
return e_values | |||
def _get_indices(sequence: str, start: int) -> List[int]: | |||
"""Returns indices for non-gap/insert residues starting at the given index.""" | |||
indices = [] | |||
counter = start | |||
for symbol in sequence: | |||
# Skip gaps but add a placeholder so that the alignment is preserved. | |||
if symbol == '-': | |||
indices.append(-1) | |||
# Skip deleted residues, but increase the counter. | |||
elif symbol.islower(): | |||
counter += 1 | |||
# Normal aligned residue. Increase the counter and append to indices. | |||
else: | |||
indices.append(counter) | |||
counter += 1 | |||
return indices | |||
@dataclasses.dataclass(frozen=True) | |||
class HitMetadata: | |||
pdb_id: str | |||
chain: str | |||
start: int | |||
end: int | |||
length: int | |||
text: str | |||
def _parse_hmmsearch_description(description: str) -> HitMetadata: | |||
"""Parses the hmmsearch A3M sequence description line.""" | |||
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text | |||
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352 | |||
match = re.match( | |||
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$', | |||
description.strip(), | |||
) | |||
if not match: | |||
raise ValueError(f'Could not parse description: "{description}".') | |||
return HitMetadata( | |||
pdb_id=match[1], | |||
chain=match[2], | |||
start=int(match[3]), | |||
end=int(match[4]), | |||
length=int(match[5]), | |||
text=match[6], | |||
) | |||
def parse_hmmsearch_a3m(query_sequence: str, | |||
a3m_string: str, | |||
skip_first: bool = True) -> Sequence[TemplateHit]: | |||
"""Parses an a3m string produced by hmmsearch. | |||
Args: | |||
query_sequence: The query sequence. | |||
a3m_string: The a3m string produced by hmmsearch. | |||
skip_first: Whether to skip the first sequence in the a3m string. | |||
Returns: | |||
A sequence of `TemplateHit` results. | |||
""" | |||
# Zip the descriptions and MSAs together, skip the first query sequence. | |||
parsed_a3m = list(zip(*parse_fasta(a3m_string))) | |||
if skip_first: | |||
parsed_a3m = parsed_a3m[1:] | |||
indices_query = _get_indices(query_sequence, start=0) | |||
hits = [] | |||
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1): | |||
if 'mol:protein' not in hit_description: | |||
continue # Skip non-protein chains. | |||
metadata = _parse_hmmsearch_description(hit_description) | |||
# Aligned columns are only the match states. | |||
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence]) | |||
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1) | |||
hit = TemplateHit( | |||
index=i, | |||
name=f'{metadata.pdb_id}_{metadata.chain}', | |||
aligned_cols=aligned_cols, | |||
sum_probs=None, | |||
query=query_sequence, | |||
hit_sequence=hit_sequence.upper(), | |||
indices_query=indices_query, | |||
indices_hit=indices_hit, | |||
) | |||
hits.append(hit) | |||
return hits |
@@ -0,0 +1,282 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Functions for building the input features for the unifold model.""" | |||
import os | |||
from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union | |||
import numpy as np | |||
from absl import logging | |||
from modelscope.models.science.unifold.data import residue_constants | |||
from modelscope.models.science.unifold.msa import (msa_identifiers, parsers, | |||
templates) | |||
from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch, | |||
hmmsearch, jackhmmer) | |||
FeatureDict = MutableMapping[str, np.ndarray] | |||
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] | |||
def make_sequence_features(sequence: str, description: str, | |||
num_res: int) -> FeatureDict: | |||
"""Constructs a feature dict of sequence features.""" | |||
features = {} | |||
features['aatype'] = residue_constants.sequence_to_onehot( | |||
sequence=sequence, | |||
mapping=residue_constants.restype_order_with_x, | |||
map_unknown_to_x=True, | |||
) | |||
features['between_segment_residues'] = np.zeros((num_res, ), | |||
dtype=np.int32) | |||
features['domain_name'] = np.array([description.encode('utf-8')], | |||
dtype=np.object_) | |||
features['residue_index'] = np.array(range(num_res), dtype=np.int32) | |||
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) | |||
features['sequence'] = np.array([sequence.encode('utf-8')], | |||
dtype=np.object_) | |||
return features | |||
def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: | |||
"""Constructs a feature dict of MSA features.""" | |||
if not msas: | |||
raise ValueError('At least one MSA must be provided.') | |||
int_msa = [] | |||
deletion_matrix = [] | |||
species_ids = [] | |||
seen_sequences = set() | |||
for msa_index, msa in enumerate(msas): | |||
if not msa: | |||
raise ValueError( | |||
f'MSA {msa_index} must contain at least one sequence.') | |||
for sequence_index, sequence in enumerate(msa.sequences): | |||
if sequence in seen_sequences: | |||
continue | |||
seen_sequences.add(sequence) | |||
int_msa.append( | |||
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) | |||
deletion_matrix.append(msa.deletion_matrix[sequence_index]) | |||
identifiers = msa_identifiers.get_identifiers( | |||
msa.descriptions[sequence_index]) | |||
species_ids.append(identifiers.species_id.encode('utf-8')) | |||
num_res = len(msas[0].sequences[0]) | |||
num_alignments = len(int_msa) | |||
features = {} | |||
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) | |||
features['msa'] = np.array(int_msa, dtype=np.int32) | |||
features['num_alignments'] = np.array( | |||
[num_alignments] * num_res, dtype=np.int32) | |||
features['msa_species_identifiers'] = np.array( | |||
species_ids, dtype=np.object_) | |||
return features | |||
def run_msa_tool( | |||
msa_runner, | |||
input_fasta_path: str, | |||
msa_out_path: str, | |||
msa_format: str, | |||
use_precomputed_msas: bool, | |||
) -> Mapping[str, Any]: | |||
"""Runs an MSA tool, checking if output already exists first.""" | |||
if not use_precomputed_msas or not os.path.exists(msa_out_path): | |||
result = msa_runner.query(input_fasta_path)[0] | |||
with open(msa_out_path, 'w') as f: | |||
f.write(result[msa_format]) | |||
else: | |||
logging.warning('Reading MSA from file %s', msa_out_path) | |||
with open(msa_out_path, 'r') as f: | |||
result = {msa_format: f.read()} | |||
return result | |||
class DataPipeline: | |||
"""Runs the alignment tools and assembles the input features.""" | |||
def __init__( | |||
self, | |||
jackhmmer_binary_path: str, | |||
hhblits_binary_path: str, | |||
uniref90_database_path: str, | |||
mgnify_database_path: str, | |||
bfd_database_path: Optional[str], | |||
uniclust30_database_path: Optional[str], | |||
small_bfd_database_path: Optional[str], | |||
uniprot_database_path: Optional[str], | |||
template_searcher: TemplateSearcher, | |||
template_featurizer: templates.TemplateHitFeaturizer, | |||
use_small_bfd: bool, | |||
mgnify_max_hits: int = 501, | |||
uniref_max_hits: int = 10000, | |||
use_precomputed_msas: bool = False, | |||
): | |||
"""Initializes the data pipeline.""" | |||
self._use_small_bfd = use_small_bfd | |||
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( | |||
binary_path=jackhmmer_binary_path, | |||
database_path=uniref90_database_path) | |||
if use_small_bfd: | |||
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( | |||
binary_path=jackhmmer_binary_path, | |||
database_path=small_bfd_database_path) | |||
else: | |||
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( | |||
binary_path=hhblits_binary_path, | |||
databases=[bfd_database_path, uniclust30_database_path], | |||
) | |||
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( | |||
binary_path=jackhmmer_binary_path, | |||
database_path=mgnify_database_path) | |||
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer( | |||
binary_path=jackhmmer_binary_path, | |||
database_path=uniprot_database_path) | |||
self.template_searcher = template_searcher | |||
self.template_featurizer = template_featurizer | |||
self.mgnify_max_hits = mgnify_max_hits | |||
self.uniref_max_hits = uniref_max_hits | |||
self.use_precomputed_msas = use_precomputed_msas | |||
def process(self, input_fasta_path: str, | |||
msa_output_dir: str) -> FeatureDict: | |||
"""Runs alignment tools on the input sequence and creates features.""" | |||
with open(input_fasta_path) as f: | |||
input_fasta_str = f.read() | |||
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) | |||
if len(input_seqs) != 1: | |||
raise ValueError( | |||
f'More than one input sequence found in {input_fasta_path}.') | |||
input_sequence = input_seqs[0] | |||
input_description = input_descs[0] | |||
num_res = len(input_sequence) | |||
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') | |||
jackhmmer_uniref90_result = run_msa_tool( | |||
self.jackhmmer_uniref90_runner, | |||
input_fasta_path, | |||
uniref90_out_path, | |||
'sto', | |||
self.use_precomputed_msas, | |||
) | |||
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') | |||
jackhmmer_mgnify_result = run_msa_tool( | |||
self.jackhmmer_mgnify_runner, | |||
input_fasta_path, | |||
mgnify_out_path, | |||
'sto', | |||
self.use_precomputed_msas, | |||
) | |||
msa_for_templates = jackhmmer_uniref90_result['sto'] | |||
msa_for_templates = parsers.truncate_stockholm_msa( | |||
msa_for_templates, max_sequences=self.uniref_max_hits) | |||
msa_for_templates = parsers.deduplicate_stockholm_msa( | |||
msa_for_templates) | |||
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( | |||
msa_for_templates) | |||
if self.template_searcher.input_format == 'sto': | |||
pdb_templates_result = self.template_searcher.query( | |||
msa_for_templates) | |||
elif self.template_searcher.input_format == 'a3m': | |||
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( | |||
msa_for_templates) | |||
pdb_templates_result = self.template_searcher.query( | |||
uniref90_msa_as_a3m) | |||
else: | |||
raise ValueError('Unrecognized template input format: ' | |||
f'{self.template_searcher.input_format}') | |||
pdb_hits_out_path = os.path.join( | |||
msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}') | |||
with open(pdb_hits_out_path, 'w') as f: | |||
f.write(pdb_templates_result) | |||
uniref90_msa = parsers.parse_stockholm( | |||
jackhmmer_uniref90_result['sto']) | |||
uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) | |||
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) | |||
mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) | |||
pdb_template_hits = self.template_searcher.get_template_hits( | |||
output_string=pdb_templates_result, input_sequence=input_sequence) | |||
if self._use_small_bfd: | |||
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') | |||
jackhmmer_small_bfd_result = run_msa_tool( | |||
self.jackhmmer_small_bfd_runner, | |||
input_fasta_path, | |||
bfd_out_path, | |||
'sto', | |||
self.use_precomputed_msas, | |||
) | |||
bfd_msa = parsers.parse_stockholm( | |||
jackhmmer_small_bfd_result['sto']) | |||
else: | |||
bfd_out_path = os.path.join(msa_output_dir, | |||
'bfd_uniclust_hits.a3m') | |||
hhblits_bfd_uniclust_result = run_msa_tool( | |||
self.hhblits_bfd_uniclust_runner, | |||
input_fasta_path, | |||
bfd_out_path, | |||
'a3m', | |||
self.use_precomputed_msas, | |||
) | |||
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) | |||
templates_result = self.template_featurizer.get_templates( | |||
query_sequence=input_sequence, hits=pdb_template_hits) | |||
sequence_features = make_sequence_features( | |||
sequence=input_sequence, | |||
description=input_description, | |||
num_res=num_res) | |||
msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa)) | |||
logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) | |||
logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) | |||
logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) | |||
logging.info( | |||
'Final (deduplicated) MSA size: %d sequences.', | |||
msa_features['num_alignments'][0], | |||
) | |||
logging.info( | |||
'Total number of templates (NB: this can include bad ' | |||
'templates and is later filtered to top 4): %d.', | |||
templates_result.features['template_domain_names'].shape[0], | |||
) | |||
return { | |||
**sequence_features, | |||
**msa_features, | |||
**templates_result.features | |||
} | |||
def process_uniprot(self, input_fasta_path: str, | |||
msa_output_dir: str) -> FeatureDict: | |||
uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') | |||
uniprot_result = run_msa_tool( | |||
self.jackhmmer_uniprot_runner, | |||
input_fasta_path, | |||
uniprot_path, | |||
'sto', | |||
self.use_precomputed_msas, | |||
) | |||
msa = parsers.parse_stockholm(uniprot_result['sto']) | |||
msa = msa.truncate(max_seqs=50000) | |||
all_seq_dict = make_msa_features([msa]) | |||
return all_seq_dict |
@@ -0,0 +1,14 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Python wrappers for third party tools.""" |
@@ -0,0 +1,170 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Library to run HHblits from Python.""" | |||
import glob | |||
import os | |||
import subprocess | |||
from typing import Any, List, Mapping, Optional, Sequence | |||
from absl import logging | |||
from . import utils | |||
_HHBLITS_DEFAULT_P = 20 | |||
_HHBLITS_DEFAULT_Z = 500 | |||
class HHBlits: | |||
"""Python wrapper of the HHblits binary.""" | |||
def __init__( | |||
self, | |||
*, | |||
binary_path: str, | |||
databases: Sequence[str], | |||
n_cpu: int = 4, | |||
n_iter: int = 3, | |||
e_value: float = 0.001, | |||
maxseq: int = 1_000_000, | |||
realign_max: int = 100_000, | |||
maxfilt: int = 100_000, | |||
min_prefilter_hits: int = 1000, | |||
all_seqs: bool = False, | |||
alt: Optional[int] = None, | |||
p: int = _HHBLITS_DEFAULT_P, | |||
z: int = _HHBLITS_DEFAULT_Z, | |||
): | |||
"""Initializes the Python HHblits wrapper. | |||
Args: | |||
binary_path: The path to the HHblits executable. | |||
databases: A sequence of HHblits database paths. This should be the | |||
common prefix for the database files (i.e. up to but not including | |||
_hhm.ffindex etc.) | |||
n_cpu: The number of CPUs to give HHblits. | |||
n_iter: The number of HHblits iterations. | |||
e_value: The E-value, see HHblits docs for more details. | |||
maxseq: The maximum number of rows in an input alignment. Note that this | |||
parameter is only supported in HHBlits version 3.1 and higher. | |||
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. | |||
maxfilt: Max number of hits allowed to pass the 2nd prefilter. | |||
HHblits default: 20000. | |||
min_prefilter_hits: Min number of hits to pass prefilter. | |||
HHblits default: 100. | |||
all_seqs: Return all sequences in the MSA / Do not filter the result MSA. | |||
HHblits default: False. | |||
alt: Show up to this many alternative alignments. | |||
p: Minimum Prob for a hit to be included in the output hhr file. | |||
HHblits default: 20. | |||
z: Hard cap on number of hits reported in the hhr file. | |||
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. | |||
Raises: | |||
RuntimeError: If HHblits binary not found within the path. | |||
""" | |||
self.binary_path = binary_path | |||
self.databases = databases | |||
for database_path in self.databases: | |||
if not glob.glob(database_path + '_*'): | |||
logging.error('Could not find HHBlits database %s', | |||
database_path) | |||
raise ValueError( | |||
f'Could not find HHBlits database {database_path}') | |||
self.n_cpu = n_cpu | |||
self.n_iter = n_iter | |||
self.e_value = e_value | |||
self.maxseq = maxseq | |||
self.realign_max = realign_max | |||
self.maxfilt = maxfilt | |||
self.min_prefilter_hits = min_prefilter_hits | |||
self.all_seqs = all_seqs | |||
self.alt = alt | |||
self.p = p | |||
self.z = z | |||
def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: | |||
"""Queries the database using HHblits.""" | |||
with utils.tmpdir_manager() as query_tmp_dir: | |||
a3m_path = os.path.join(query_tmp_dir, 'output.a3m') | |||
db_cmd = [] | |||
for db_path in self.databases: | |||
db_cmd.append('-d') | |||
db_cmd.append(db_path) | |||
cmd = [ | |||
self.binary_path, | |||
'-i', | |||
input_fasta_path, | |||
'-cpu', | |||
str(self.n_cpu), | |||
'-oa3m', | |||
a3m_path, | |||
'-o', | |||
'/dev/null', | |||
'-n', | |||
str(self.n_iter), | |||
'-e', | |||
str(self.e_value), | |||
'-maxseq', | |||
str(self.maxseq), | |||
'-realign_max', | |||
str(self.realign_max), | |||
'-maxfilt', | |||
str(self.maxfilt), | |||
'-min_prefilter_hits', | |||
str(self.min_prefilter_hits), | |||
] | |||
if self.all_seqs: | |||
cmd += ['-all'] | |||
if self.alt: | |||
cmd += ['-alt', str(self.alt)] | |||
if self.p != _HHBLITS_DEFAULT_P: | |||
cmd += ['-p', str(self.p)] | |||
if self.z != _HHBLITS_DEFAULT_Z: | |||
cmd += ['-Z', str(self.z)] | |||
cmd += db_cmd | |||
logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||
process = subprocess.Popen( | |||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with utils.timing('HHblits query'): | |||
stdout, stderr = process.communicate() | |||
retcode = process.wait() | |||
if retcode: | |||
# Logs have a 15k character limit, so log HHblits error line by line. | |||
logging.error('HHblits failed. HHblits stderr begin:') | |||
for error_line in stderr.decode('utf-8').splitlines(): | |||
if error_line.strip(): | |||
logging.error(error_line.strip()) | |||
logging.error('HHblits stderr end') | |||
raise RuntimeError( | |||
'HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % | |||
(stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) | |||
with open(a3m_path) as f: | |||
a3m = f.read() | |||
raw_output = dict( | |||
a3m=a3m, | |||
output=stdout, | |||
stderr=stderr, | |||
n_iter=self.n_iter, | |||
e_value=self.e_value, | |||
) | |||
return [raw_output] |
@@ -0,0 +1,111 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Library to run HHsearch from Python.""" | |||
import glob | |||
import os | |||
import subprocess | |||
from typing import Sequence | |||
from absl import logging | |||
from modelscope.models.science.unifold.msa import parsers | |||
from . import utils | |||
class HHSearch: | |||
"""Python wrapper of the HHsearch binary.""" | |||
def __init__(self, | |||
*, | |||
binary_path: str, | |||
databases: Sequence[str], | |||
maxseq: int = 1_000_000): | |||
"""Initializes the Python HHsearch wrapper. | |||
Args: | |||
binary_path: The path to the HHsearch executable. | |||
databases: A sequence of HHsearch database paths. This should be the | |||
common prefix for the database files (i.e. up to but not including | |||
_hhm.ffindex etc.) | |||
maxseq: The maximum number of rows in an input alignment. Note that this | |||
parameter is only supported in HHBlits version 3.1 and higher. | |||
Raises: | |||
RuntimeError: If HHsearch binary not found within the path. | |||
""" | |||
self.binary_path = binary_path | |||
self.databases = databases | |||
self.maxseq = maxseq | |||
for database_path in self.databases: | |||
if not glob.glob(database_path + '_*'): | |||
logging.error('Could not find HHsearch database %s', | |||
database_path) | |||
raise ValueError( | |||
f'Could not find HHsearch database {database_path}') | |||
@property | |||
def output_format(self) -> str: | |||
return 'hhr' | |||
@property | |||
def input_format(self) -> str: | |||
return 'a3m' | |||
def query(self, a3m: str) -> str: | |||
"""Queries the database using HHsearch using a given a3m.""" | |||
with utils.tmpdir_manager() as query_tmp_dir: | |||
input_path = os.path.join(query_tmp_dir, 'query.a3m') | |||
hhr_path = os.path.join(query_tmp_dir, 'output.hhr') | |||
with open(input_path, 'w') as f: | |||
f.write(a3m) | |||
db_cmd = [] | |||
for db_path in self.databases: | |||
db_cmd.append('-d') | |||
db_cmd.append(db_path) | |||
cmd = [ | |||
self.binary_path, | |||
'-i', | |||
input_path, | |||
'-o', | |||
hhr_path, | |||
'-maxseq', | |||
str(self.maxseq), | |||
] + db_cmd | |||
logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||
process = subprocess.Popen( | |||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with utils.timing('HHsearch query'): | |||
stdout, stderr = process.communicate() | |||
retcode = process.wait() | |||
if retcode: | |||
# Stderr is truncated to prevent proto size errors in Beam. | |||
raise RuntimeError( | |||
'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % | |||
(stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) | |||
with open(hhr_path) as f: | |||
hhr = f.read() | |||
return hhr | |||
def get_template_hits( | |||
self, output_string: str, | |||
input_sequence: str) -> Sequence[parsers.TemplateHit]: | |||
"""Gets parsed template hits from the raw string output by the tool.""" | |||
del input_sequence # Used by hmmseach but not needed for hhsearch. | |||
return parsers.parse_hhr(output_string) |
@@ -0,0 +1,143 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" | |||
import os | |||
import re | |||
import subprocess | |||
from absl import logging | |||
from . import utils | |||
class Hmmbuild(object): | |||
"""Python wrapper of the hmmbuild binary.""" | |||
def __init__(self, *, binary_path: str, singlemx: bool = False): | |||
"""Initializes the Python hmmbuild wrapper. | |||
Args: | |||
binary_path: The path to the hmmbuild executable. | |||
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to | |||
just use a common substitution score matrix. | |||
Raises: | |||
RuntimeError: If hmmbuild binary not found within the path. | |||
""" | |||
self.binary_path = binary_path | |||
self.singlemx = singlemx | |||
def build_profile_from_sto(self, | |||
sto: str, | |||
model_construction='fast') -> str: | |||
"""Builds a HHM for the aligned sequences given as an A3M string. | |||
Args: | |||
sto: A string with the aligned sequences in the Stockholm format. | |||
model_construction: Whether to use reference annotation in the msa to | |||
determine consensus columns ('hand') or default ('fast'). | |||
Returns: | |||
A string with the profile in the HMM format. | |||
Raises: | |||
RuntimeError: If hmmbuild fails. | |||
""" | |||
return self._build_profile(sto, model_construction=model_construction) | |||
def build_profile_from_a3m(self, a3m: str) -> str: | |||
"""Builds a HHM for the aligned sequences given as an A3M string. | |||
Args: | |||
a3m: A string with the aligned sequences in the A3M format. | |||
Returns: | |||
A string with the profile in the HMM format. | |||
Raises: | |||
RuntimeError: If hmmbuild fails. | |||
""" | |||
lines = [] | |||
for line in a3m.splitlines(): | |||
if not line.startswith('>'): | |||
line = re.sub('[a-z]+', '', line) # Remove inserted residues. | |||
lines.append(line + '\n') | |||
msa = ''.join(lines) | |||
return self._build_profile(msa, model_construction='fast') | |||
def _build_profile(self, | |||
msa: str, | |||
model_construction: str = 'fast') -> str: | |||
"""Builds a HMM for the aligned sequences given as an MSA string. | |||
Args: | |||
msa: A string with the aligned sequences, in A3M or STO format. | |||
model_construction: Whether to use reference annotation in the msa to | |||
determine consensus columns ('hand') or default ('fast'). | |||
Returns: | |||
A string with the profile in the HMM format. | |||
Raises: | |||
RuntimeError: If hmmbuild fails. | |||
ValueError: If unspecified arguments are provided. | |||
""" | |||
if model_construction not in {'hand', 'fast'}: | |||
raise ValueError( | |||
f'Invalid model_construction {model_construction} - only' | |||
'hand and fast supported.') | |||
with utils.tmpdir_manager() as query_tmp_dir: | |||
input_query = os.path.join(query_tmp_dir, 'query.msa') | |||
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') | |||
with open(input_query, 'w') as f: | |||
f.write(msa) | |||
cmd = [self.binary_path] | |||
# If adding flags, we have to do so before the output and input: | |||
if model_construction == 'hand': | |||
cmd.append(f'--{model_construction}') | |||
if self.singlemx: | |||
cmd.append('--singlemx') | |||
cmd.extend([ | |||
'--amino', | |||
output_hmm_path, | |||
input_query, | |||
]) | |||
logging.info('Launching subprocess %s', cmd) | |||
process = subprocess.Popen( | |||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with utils.timing('hmmbuild query'): | |||
stdout, stderr = process.communicate() | |||
retcode = process.wait() | |||
logging.info( | |||
'hmmbuild stdout:\n%s\n\nstderr:\n%s\n', | |||
stdout.decode('utf-8'), | |||
stderr.decode('utf-8'), | |||
) | |||
if retcode: | |||
raise RuntimeError( | |||
'hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' % | |||
(stdout.decode('utf-8'), stderr.decode('utf-8'))) | |||
with open(output_hmm_path, encoding='utf-8') as f: | |||
hmm = f.read() | |||
return hmm |
@@ -0,0 +1,146 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""A Python wrapper for hmmsearch - search profile against a sequence db.""" | |||
import os | |||
import subprocess | |||
from typing import Optional, Sequence | |||
from absl import logging | |||
from modelscope.models.science.unifold.msa import parsers | |||
from . import hmmbuild, utils | |||
class Hmmsearch(object): | |||
"""Python wrapper of the hmmsearch binary.""" | |||
def __init__( | |||
self, | |||
*, | |||
binary_path: str, | |||
hmmbuild_binary_path: str, | |||
database_path: str, | |||
flags: Optional[Sequence[str]] = None, | |||
): | |||
"""Initializes the Python hmmsearch wrapper. | |||
Args: | |||
binary_path: The path to the hmmsearch executable. | |||
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build | |||
an hmm from an input a3m. | |||
database_path: The path to the hmmsearch database (FASTA format). | |||
flags: List of flags to be used by hmmsearch. | |||
Raises: | |||
RuntimeError: If hmmsearch binary not found within the path. | |||
""" | |||
self.binary_path = binary_path | |||
self.hmmbuild_runner = hmmbuild.Hmmbuild( | |||
binary_path=hmmbuild_binary_path) | |||
self.database_path = database_path | |||
if flags is None: | |||
# Default hmmsearch run settings. | |||
flags = [ | |||
'--F1', | |||
'0.1', | |||
'--F2', | |||
'0.1', | |||
'--F3', | |||
'0.1', | |||
'--incE', | |||
'100', | |||
'-E', | |||
'100', | |||
'--domE', | |||
'100', | |||
'--incdomE', | |||
'100', | |||
] | |||
self.flags = flags | |||
if not os.path.exists(self.database_path): | |||
logging.error('Could not find hmmsearch database %s', | |||
database_path) | |||
raise ValueError( | |||
f'Could not find hmmsearch database {database_path}') | |||
@property | |||
def output_format(self) -> str: | |||
return 'sto' | |||
@property | |||
def input_format(self) -> str: | |||
return 'sto' | |||
def query(self, msa_sto: str) -> str: | |||
"""Queries the database using hmmsearch using a given stockholm msa.""" | |||
hmm = self.hmmbuild_runner.build_profile_from_sto( | |||
msa_sto, model_construction='hand') | |||
return self.query_with_hmm(hmm) | |||
def query_with_hmm(self, hmm: str) -> str: | |||
"""Queries the database using hmmsearch using a given hmm.""" | |||
with utils.tmpdir_manager() as query_tmp_dir: | |||
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') | |||
out_path = os.path.join(query_tmp_dir, 'output.sto') | |||
with open(hmm_input_path, 'w') as f: | |||
f.write(hmm) | |||
cmd = [ | |||
self.binary_path, | |||
'--noali', # Don't include the alignment in stdout. | |||
'--cpu', | |||
'8', | |||
] | |||
# If adding flags, we have to do so before the output and input: | |||
if self.flags: | |||
cmd.extend(self.flags) | |||
cmd.extend([ | |||
'-A', | |||
out_path, | |||
hmm_input_path, | |||
self.database_path, | |||
]) | |||
logging.info('Launching sub-process %s', cmd) | |||
process = subprocess.Popen( | |||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with utils.timing( | |||
f'hmmsearch ({os.path.basename(self.database_path)}) query' | |||
): | |||
stdout, stderr = process.communicate() | |||
retcode = process.wait() | |||
if retcode: | |||
raise RuntimeError( | |||
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % | |||
(stdout.decode('utf-8'), stderr.decode('utf-8'))) | |||
with open(out_path) as f: | |||
out_msa = f.read() | |||
return out_msa | |||
def get_template_hits( | |||
self, output_string: str, | |||
input_sequence: str) -> Sequence[parsers.TemplateHit]: | |||
"""Gets parsed template hits from the raw string output by the tool.""" | |||
a3m_string = parsers.convert_stockholm_to_a3m( | |||
output_string, remove_first_row_gaps=False) | |||
template_hits = parsers.parse_hmmsearch_a3m( | |||
query_sequence=input_sequence, | |||
a3m_string=a3m_string, | |||
skip_first=False) | |||
return template_hits |
@@ -0,0 +1,224 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Library to run Jackhmmer from Python.""" | |||
import glob | |||
import os | |||
import subprocess | |||
from concurrent import futures | |||
from typing import Any, Callable, Mapping, Optional, Sequence | |||
from urllib import request | |||
from absl import logging | |||
from . import utils | |||
class Jackhmmer: | |||
"""Python wrapper of the Jackhmmer binary.""" | |||
def __init__( | |||
self, | |||
*, | |||
binary_path: str, | |||
database_path: str, | |||
n_cpu: int = 8, | |||
n_iter: int = 1, | |||
e_value: float = 0.0001, | |||
z_value: Optional[int] = None, | |||
get_tblout: bool = False, | |||
filter_f1: float = 0.0005, | |||
filter_f2: float = 0.00005, | |||
filter_f3: float = 0.0000005, | |||
incdom_e: Optional[float] = None, | |||
dom_e: Optional[float] = None, | |||
num_streamed_chunks: Optional[int] = None, | |||
streaming_callback: Optional[Callable[[int], None]] = None, | |||
): | |||
"""Initializes the Python Jackhmmer wrapper. | |||
Args: | |||
binary_path: The path to the jackhmmer executable. | |||
database_path: The path to the jackhmmer database (FASTA format). | |||
n_cpu: The number of CPUs to give Jackhmmer. | |||
n_iter: The number of Jackhmmer iterations. | |||
e_value: The E-value, see Jackhmmer docs for more details. | |||
z_value: The Z-value, see Jackhmmer docs for more details. | |||
get_tblout: Whether to save tblout string. | |||
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. | |||
filter_f2: Viterbi pre-filter, set to >1.0 to turn off. | |||
filter_f3: Forward pre-filter, set to >1.0 to turn off. | |||
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next | |||
round. | |||
dom_e: Domain e-value criteria for inclusion in tblout. | |||
num_streamed_chunks: Number of database chunks to stream over. | |||
streaming_callback: Callback function run after each chunk iteration with | |||
the iteration number as argument. | |||
""" | |||
self.binary_path = binary_path | |||
self.database_path = database_path | |||
self.num_streamed_chunks = num_streamed_chunks | |||
if not os.path.exists( | |||
self.database_path) and num_streamed_chunks is None: | |||
logging.error('Could not find Jackhmmer database %s', | |||
database_path) | |||
raise ValueError( | |||
f'Could not find Jackhmmer database {database_path}') | |||
self.n_cpu = n_cpu | |||
self.n_iter = n_iter | |||
self.e_value = e_value | |||
self.z_value = z_value | |||
self.filter_f1 = filter_f1 | |||
self.filter_f2 = filter_f2 | |||
self.filter_f3 = filter_f3 | |||
self.incdom_e = incdom_e | |||
self.dom_e = dom_e | |||
self.get_tblout = get_tblout | |||
self.streaming_callback = streaming_callback | |||
def _query_chunk(self, input_fasta_path: str, | |||
database_path: str) -> Mapping[str, Any]: | |||
"""Queries the database chunk using Jackhmmer.""" | |||
with utils.tmpdir_manager() as query_tmp_dir: | |||
sto_path = os.path.join(query_tmp_dir, 'output.sto') | |||
# The F1/F2/F3 are the expected proportion to pass each of the filtering | |||
# stages (which get progressively more expensive), reducing these | |||
# speeds up the pipeline at the expensive of sensitivity. They are | |||
# currently set very low to make querying Mgnify run in a reasonable | |||
# amount of time. | |||
cmd_flags = [ | |||
# Don't pollute stdout with Jackhmmer output. | |||
'-o', | |||
'/dev/null', | |||
'-A', | |||
sto_path, | |||
'--noali', | |||
'--F1', | |||
str(self.filter_f1), | |||
'--F2', | |||
str(self.filter_f2), | |||
'--F3', | |||
str(self.filter_f3), | |||
'--incE', | |||
str(self.e_value), | |||
# Report only sequences with E-values <= x in per-sequence output. | |||
'-E', | |||
str(self.e_value), | |||
'--cpu', | |||
str(self.n_cpu), | |||
'-N', | |||
str(self.n_iter), | |||
] | |||
if self.get_tblout: | |||
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') | |||
cmd_flags.extend(['--tblout', tblout_path]) | |||
if self.z_value: | |||
cmd_flags.extend(['-Z', str(self.z_value)]) | |||
if self.dom_e is not None: | |||
cmd_flags.extend(['--domE', str(self.dom_e)]) | |||
if self.incdom_e is not None: | |||
cmd_flags.extend(['--incdomE', str(self.incdom_e)]) | |||
cmd = [self.binary_path | |||
] + cmd_flags + [input_fasta_path, database_path] | |||
logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||
process = subprocess.Popen( | |||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with utils.timing( | |||
f'Jackhmmer ({os.path.basename(database_path)}) query'): | |||
_, stderr = process.communicate() | |||
retcode = process.wait() | |||
if retcode: | |||
raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' | |||
% stderr.decode('utf-8')) | |||
# Get e-values for each target name | |||
tbl = '' | |||
if self.get_tblout: | |||
with open(tblout_path) as f: | |||
tbl = f.read() | |||
with open(sto_path) as f: | |||
sto = f.read() | |||
raw_output = dict( | |||
sto=sto, | |||
tbl=tbl, | |||
stderr=stderr, | |||
n_iter=self.n_iter, | |||
e_value=self.e_value) | |||
return raw_output | |||
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: | |||
"""Queries the database using Jackhmmer.""" | |||
if self.num_streamed_chunks is None: | |||
return [self._query_chunk(input_fasta_path, self.database_path)] | |||
db_basename = os.path.basename(self.database_path) | |||
def db_remote_chunk(db_idx): | |||
return f'{self.database_path}.{db_idx}' | |||
def db_local_chunk(db_idx): | |||
return f'/tmp/ramdisk/{db_basename}.{db_idx}' | |||
# db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' | |||
# db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' | |||
# Remove existing files to prevent OOM | |||
for f in glob.glob(db_local_chunk('[0-9]*')): | |||
try: | |||
os.remove(f) | |||
except OSError: | |||
print(f'OSError while deleting {f}') | |||
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk | |||
with futures.ThreadPoolExecutor(max_workers=2) as executor: | |||
chunked_output = [] | |||
for i in range(1, self.num_streamed_chunks + 1): | |||
# Copy the chunk locally | |||
if i == 1: | |||
future = executor.submit(request.urlretrieve, | |||
db_remote_chunk(i), | |||
db_local_chunk(i)) | |||
if i < self.num_streamed_chunks: | |||
next_future = executor.submit( | |||
request.urlretrieve, | |||
db_remote_chunk(i + 1), | |||
db_local_chunk(i + 1), | |||
) | |||
# Run Jackhmmer with the chunk | |||
future.result() | |||
chunked_output.append( | |||
self._query_chunk(input_fasta_path, db_local_chunk(i))) | |||
# Remove the local copy of the chunk | |||
os.remove(db_local_chunk(i)) | |||
# Do not set next_future for the last chunk so that this works even for | |||
# databases with only 1 chunk. | |||
if i < self.num_streamed_chunks: | |||
future = next_future | |||
if self.streaming_callback: | |||
self.streaming_callback(i) | |||
return chunked_output |
@@ -0,0 +1,110 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""A Python wrapper for Kalign.""" | |||
import os | |||
import subprocess | |||
from typing import Sequence | |||
from absl import logging | |||
from . import utils | |||
def _to_a3m(sequences: Sequence[str]) -> str: | |||
"""Converts sequences to an a3m file.""" | |||
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] | |||
a3m = [] | |||
for sequence, name in zip(sequences, names): | |||
a3m.append('>' + name + '\n') | |||
a3m.append(sequence + '\n') | |||
return ''.join(a3m) | |||
class Kalign: | |||
"""Python wrapper of the Kalign binary.""" | |||
def __init__(self, *, binary_path: str): | |||
"""Initializes the Python Kalign wrapper. | |||
Args: | |||
binary_path: The path to the Kalign binary. | |||
Raises: | |||
RuntimeError: If Kalign binary not found within the path. | |||
""" | |||
self.binary_path = binary_path | |||
def align(self, sequences: Sequence[str]) -> str: | |||
"""Aligns the sequences and returns the alignment in A3M string. | |||
Args: | |||
sequences: A list of query sequence strings. The sequences have to be at | |||
least 6 residues long (Kalign requires this). Note that the order in | |||
which you give the sequences might alter the output slightly as | |||
different alignment tree might get constructed. | |||
Returns: | |||
A string with the alignment in a3m format. | |||
Raises: | |||
RuntimeError: If Kalign fails. | |||
ValueError: If any of the sequences is less than 6 residues long. | |||
""" | |||
logging.info('Aligning %d sequences', len(sequences)) | |||
for s in sequences: | |||
if len(s) < 6: | |||
raise ValueError( | |||
'Kalign requires all sequences to be at least 6 ' | |||
'residues long. Got %s (%d residues).' % (s, len(s))) | |||
with utils.tmpdir_manager() as query_tmp_dir: | |||
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') | |||
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') | |||
with open(input_fasta_path, 'w') as f: | |||
f.write(_to_a3m(sequences)) | |||
cmd = [ | |||
self.binary_path, | |||
'-i', | |||
input_fasta_path, | |||
'-o', | |||
output_a3m_path, | |||
'-format', | |||
'fasta', | |||
] | |||
logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||
process = subprocess.Popen( | |||
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||
with utils.timing('Kalign query'): | |||
stdout, stderr = process.communicate() | |||
retcode = process.wait() | |||
logging.info( | |||
'Kalign stdout:\n%s\n\nstderr:\n%s\n', | |||
stdout.decode('utf-8'), | |||
stderr.decode('utf-8'), | |||
) | |||
if retcode: | |||
raise RuntimeError( | |||
'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % | |||
(stdout.decode('utf-8'), stderr.decode('utf-8'))) | |||
with open(output_a3m_path) as f: | |||
a3m = f.read() | |||
return a3m |
@@ -0,0 +1,40 @@ | |||
# Copyright 2021 DeepMind Technologies Limited | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
"""Common utilities for data pipeline tools.""" | |||
import contextlib | |||
import shutil | |||
import tempfile | |||
import time | |||
from typing import Optional | |||
from absl import logging | |||
@contextlib.contextmanager | |||
def tmpdir_manager(base_dir: Optional[str] = None): | |||
"""Context manager that deletes a temporary directory on exit.""" | |||
tmpdir = tempfile.mkdtemp(dir=base_dir) | |||
try: | |||
yield tmpdir | |||
finally: | |||
shutil.rmtree(tmpdir, ignore_errors=True) | |||
@contextlib.contextmanager | |||
def timing(msg: str): | |||
logging.info('Started %s', msg) | |||
tic = time.time() | |||
yield | |||
toc = time.time() | |||
logging.info('Finished %s in %.3f seconds', msg, toc - tic) |
@@ -0,0 +1,89 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import os | |||
from typing import Mapping, Sequence | |||
import json | |||
from absl import logging | |||
from modelscope.models.science.unifold.data import protein | |||
def get_chain_id_map( | |||
sequences: Sequence[str], | |||
descriptions: Sequence[str], | |||
): | |||
""" | |||
Makes a mapping from PDB-format chain ID to sequence and description, | |||
and parses the order of multi-chains | |||
""" | |||
unique_seqs = [] | |||
for seq in sequences: | |||
if seq not in unique_seqs: | |||
unique_seqs.append(seq) | |||
chain_id_map = { | |||
chain_id: { | |||
'descriptions': [], | |||
'sequence': seq | |||
} | |||
for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs) | |||
} | |||
chain_order = [] | |||
for seq, des in zip(sequences, descriptions): | |||
chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)] | |||
chain_id_map[chain_id]['descriptions'].append(des) | |||
chain_order.append(chain_id) | |||
return chain_id_map, chain_order | |||
def divide_multi_chains( | |||
fasta_name: str, | |||
output_dir_base: str, | |||
sequences: Sequence[str], | |||
descriptions: Sequence[str], | |||
): | |||
""" | |||
Divides the multi-chains fasta into several single fasta files and | |||
records multi-chains mapping information. | |||
""" | |||
if len(sequences) != len(descriptions): | |||
raise ValueError('sequences and descriptions must have equal length. ' | |||
f'Got {len(sequences)} != {len(descriptions)}.') | |||
if len(sequences) > protein.PDB_MAX_CHAINS: | |||
raise ValueError( | |||
'Cannot process more chains than the PDB format supports. ' | |||
f'Got {len(sequences)} chains.') | |||
chain_id_map, chain_order = get_chain_id_map(sequences, descriptions) | |||
output_dir = os.path.join(output_dir_base, fasta_name) | |||
if not os.path.exists(output_dir): | |||
os.makedirs(output_dir) | |||
chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json') | |||
with open(chain_id_map_path, 'w') as f: | |||
json.dump(chain_id_map, f, indent=4, sort_keys=True) | |||
chain_order_path = os.path.join(output_dir, 'chains.txt') | |||
with open(chain_order_path, 'w') as f: | |||
f.write(' '.join(chain_order)) | |||
logging.info('Mapping multi-chains fasta with chain order: %s', | |||
' '.join(chain_order)) | |||
temp_names = [] | |||
temp_paths = [] | |||
for chain_id in chain_id_map.keys(): | |||
temp_name = fasta_name + '_{}'.format(chain_id) | |||
temp_path = os.path.join(output_dir, temp_name + '.fasta') | |||
des = 'chain_{}'.format(chain_id) | |||
seq = chain_id_map[chain_id]['sequence'] | |||
with open(temp_path, 'w') as f: | |||
f.write('>' + des + '\n' + seq) | |||
temp_names.append(temp_name) | |||
temp_paths.append(temp_path) | |||
return temp_names, temp_paths |
@@ -0,0 +1,22 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .protein_structure_pipeline import ProteinStructurePipeline | |||
else: | |||
_import_structure = { | |||
'protein_structure_pipeline': ['ProteinStructurePipeline'] | |||
} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,215 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import os | |||
import time | |||
from typing import Any, Dict, List, Optional, Union | |||
import json | |||
import numpy as np | |||
import torch | |||
from unicore.utils import tensor_tree_map | |||
from modelscope.metainfo import Pipelines | |||
from modelscope.models.base import Model | |||
from modelscope.models.science.unifold.config import model_config | |||
from modelscope.models.science.unifold.data import protein, residue_constants | |||
from modelscope.models.science.unifold.dataset import (UnifoldDataset, | |||
load_and_process) | |||
from modelscope.outputs import OutputKeys | |||
from modelscope.pipelines.base import Pipeline, Tensor | |||
from modelscope.pipelines.builder import PIPELINES | |||
from modelscope.preprocessors import Preprocessor, build_preprocessor | |||
from modelscope.utils.constant import Fields, Frameworks, Tasks | |||
from modelscope.utils.device import device_placement | |||
from modelscope.utils.hub import read_config | |||
from modelscope.utils.logger import get_logger | |||
logger = get_logger() | |||
__all__ = ['ProteinStructurePipeline'] | |||
def automatic_chunk_size(seq_len): | |||
if seq_len < 512: | |||
chunk_size = 256 | |||
elif seq_len < 1024: | |||
chunk_size = 128 | |||
elif seq_len < 2048: | |||
chunk_size = 32 | |||
elif seq_len < 3072: | |||
chunk_size = 16 | |||
else: | |||
chunk_size = 1 | |||
return chunk_size | |||
def load_feature_for_one_target( | |||
config, | |||
data_folder, | |||
seed=0, | |||
is_multimer=False, | |||
use_uniprot=False, | |||
symmetry_group=None, | |||
): | |||
if not is_multimer: | |||
uniprot_msa_dir = None | |||
sequence_ids = ['A'] | |||
if use_uniprot: | |||
uniprot_msa_dir = data_folder | |||
else: | |||
uniprot_msa_dir = data_folder | |||
sequence_ids = open(os.path.join(data_folder, | |||
'chains.txt')).readline().split() | |||
if symmetry_group is None: | |||
batch, _ = load_and_process( | |||
config=config.data, | |||
mode='predict', | |||
seed=seed, | |||
batch_idx=None, | |||
data_idx=0, | |||
is_distillation=False, | |||
sequence_ids=sequence_ids, | |||
monomer_feature_dir=data_folder, | |||
uniprot_msa_dir=uniprot_msa_dir, | |||
) | |||
else: | |||
raise NotImplementedError | |||
batch = UnifoldDataset.collater([batch]) | |||
return batch | |||
@PIPELINES.register_module( | |||
Tasks.protein_structure, module_name=Pipelines.protein_structure) | |||
class ProteinStructurePipeline(Pipeline): | |||
def __init__(self, | |||
model: Union[Model, str], | |||
preprocessor: Optional[Preprocessor] = None, | |||
**kwargs): | |||
"""Use `model` and `preprocessor` to create a protein structure pipeline for prediction. | |||
Args: | |||
model (str or Model): Supply either a local model dir which supported the protein structure task, | |||
or a model id from the model hub, or a torch model instance. | |||
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||
the model if supplied. | |||
Example: | |||
>>> from modelscope.pipelines import pipeline | |||
>>> pipeline_ins = pipeline(task='protein-structure', | |||
>>> model='DPTech/uni-fold-monomer') | |||
>>> protein = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' | |||
>>> print(pipeline_ins(protein)) | |||
""" | |||
import copy | |||
model_path = copy.deepcopy(model) if isinstance(model, str) else None | |||
cfg = read_config(model_path) # only model is str | |||
self.cfg = cfg | |||
self.config = model_config( | |||
cfg['pipeline']['model_name']) # alphafold config | |||
model = model if isinstance( | |||
model, Model) else Model.from_pretrained(model_path) | |||
self.postprocessor = cfg.pop('postprocessor', None) | |||
if preprocessor is None: | |||
preprocessor_cfg = cfg.preprocessor | |||
preprocessor = build_preprocessor(preprocessor_cfg, Fields.science) | |||
model.eval() | |||
model.model.inference_mode() | |||
model.model_dir = model_path | |||
super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
def _sanitize_parameters(self, **pipeline_parameters): | |||
return pipeline_parameters, pipeline_parameters, pipeline_parameters | |||
def _process_single(self, input, *args, **kwargs) -> Dict[str, Any]: | |||
preprocess_params = kwargs.get('preprocess_params', {}) | |||
forward_params = kwargs.get('forward_params', {}) | |||
postprocess_params = kwargs.get('postprocess_params', {}) | |||
out = self.preprocess(input, **preprocess_params) | |||
with device_placement(self.framework, self.device_name): | |||
with torch.no_grad(): | |||
out = self.forward(out, **forward_params) | |||
out = self.postprocess(out, **postprocess_params) | |||
return out | |||
def forward(self, inputs: Dict[str, Any], | |||
**forward_params) -> Dict[str, Any]: | |||
plddts = {} | |||
ptms = {} | |||
output_dir = os.path.join(self.preprocessor.output_dir_base, | |||
inputs['target_id']) | |||
pdbs = [] | |||
for seed in range(self.cfg['pipeline']['times']): | |||
cur_seed = hash((42, seed)) % 100000 | |||
batch = load_feature_for_one_target( | |||
self.config, | |||
output_dir, | |||
cur_seed, | |||
is_multimer=inputs['is_multimer'], | |||
use_uniprot=inputs['is_multimer'], | |||
symmetry_group=self.preprocessor.symmetry_group, | |||
) | |||
seq_len = batch['aatype'].shape[-1] | |||
self.model.model.globals.chunk_size = automatic_chunk_size(seq_len) | |||
with torch.no_grad(): | |||
batch = { | |||
k: torch.as_tensor(v, device='cuda:0') | |||
for k, v in batch.items() | |||
} | |||
out = self.model(batch) | |||
def to_float(x): | |||
if x.dtype == torch.bfloat16 or x.dtype == torch.half: | |||
return x.float() | |||
else: | |||
return x | |||
# Toss out the recycling dimensions --- we don't need them anymore | |||
batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch) | |||
batch = tensor_tree_map(to_float, batch) | |||
out = tensor_tree_map(lambda t: t[0, ...], out[0]) | |||
out = tensor_tree_map(to_float, out) | |||
batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch) | |||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out) | |||
plddt = out['plddt'] | |||
mean_plddt = np.mean(plddt) | |||
plddt_b_factors = np.repeat( | |||
plddt[..., None], residue_constants.atom_type_num, axis=-1) | |||
# TODO: , may need to reorder chains, based on entity_ids | |||
cur_protein = protein.from_prediction( | |||
features=batch, result=out, b_factors=plddt_b_factors) | |||
cur_save_name = (f'{cur_seed}') | |||
plddts[cur_save_name] = str(mean_plddt) | |||
if inputs[ | |||
'is_multimer'] and self.preprocessor.symmetry_group is None: | |||
ptms[cur_save_name] = str(np.mean(out['iptm+ptm'])) | |||
with open(os.path.join(output_dir, cur_save_name + '.pdb'), | |||
'w') as f: | |||
f.write(protein.to_pdb(cur_protein)) | |||
pdbs.append(protein.to_pdb(cur_protein)) | |||
logger.info('plddts:' + str(plddts)) | |||
model_name = self.cfg['pipeline']['model_name'] | |||
score_name = f'{model_name}' | |||
plddt_fname = score_name + '_plddt.json' | |||
with open(os.path.join(output_dir, plddt_fname), 'w') as f: | |||
json.dump(plddts, f, indent=4) | |||
if ptms: | |||
logger.info('ptms' + str(ptms)) | |||
ptm_fname = score_name + '_ptm.json' | |||
with open(os.path.join(output_dir, ptm_fname), 'w') as f: | |||
json.dump(ptms, f, indent=4) | |||
return pdbs | |||
def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params): | |||
return inputs |
@@ -0,0 +1,20 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
from typing import TYPE_CHECKING | |||
from modelscope.utils.import_utils import LazyImportModule | |||
if TYPE_CHECKING: | |||
from .unifold import (UniFoldPreprocessor) | |||
else: | |||
_import_structure = {'unifold': ['UniFoldPreprocessor']} | |||
import sys | |||
sys.modules[__name__] = LazyImportModule( | |||
__name__, | |||
globals()['__file__'], | |||
_import_structure, | |||
module_spec=__spec__, | |||
extra_objects={}, | |||
) |
@@ -0,0 +1,569 @@ | |||
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||
# and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||
import gzip | |||
import hashlib | |||
import logging | |||
import os | |||
import pickle | |||
import random | |||
import re | |||
import tarfile | |||
import time | |||
from pathlib import Path | |||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | |||
from unittest import result | |||
import json | |||
import numpy as np | |||
import requests | |||
import torch | |||
from tqdm import tqdm | |||
from modelscope.metainfo import Preprocessors | |||
from modelscope.models.science.unifold.data import protein, residue_constants | |||
from modelscope.models.science.unifold.data.protein import PDB_CHAIN_IDS | |||
from modelscope.models.science.unifold.data.utils import compress_features | |||
from modelscope.models.science.unifold.msa import parsers, pipeline, templates | |||
from modelscope.models.science.unifold.msa.tools import hhsearch | |||
from modelscope.models.science.unifold.msa.utils import divide_multi_chains | |||
from modelscope.preprocessors.base import Preprocessor | |||
from modelscope.preprocessors.builder import PREPROCESSORS | |||
from modelscope.utils.constant import Fields | |||
__all__ = [ | |||
'UniFoldPreprocessor', | |||
] | |||
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' | |||
DEFAULT_API_SERVER = 'https://api.colabfold.com' | |||
def run_mmseqs2( | |||
x, | |||
prefix, | |||
use_env=True, | |||
use_templates=False, | |||
use_pairing=False, | |||
host_url='https://api.colabfold.com') -> Tuple[List[str], List[str]]: | |||
submission_endpoint = 'ticket/pair' if use_pairing else 'ticket/msa' | |||
def submit(seqs, mode, N=101): | |||
n, query = N, '' | |||
for seq in seqs: | |||
query += f'>{n}\n{seq}\n' | |||
n += 1 | |||
res = requests.post( | |||
f'{host_url}/{submission_endpoint}', | |||
data={ | |||
'q': query, | |||
'mode': mode | |||
}) | |||
try: | |||
out = res.json() | |||
except ValueError: | |||
out = {'status': 'ERROR'} | |||
return out | |||
def status(ID): | |||
res = requests.get(f'{host_url}/ticket/{ID}') | |||
try: | |||
out = res.json() | |||
except ValueError: | |||
out = {'status': 'ERROR'} | |||
return out | |||
def download(ID, path): | |||
res = requests.get(f'{host_url}/result/download/{ID}') | |||
with open(path, 'wb') as out: | |||
out.write(res.content) | |||
# process input x | |||
seqs = [x] if isinstance(x, str) else x | |||
mode = 'env' | |||
if use_pairing: | |||
mode = '' | |||
use_templates = False | |||
use_env = False | |||
# define path | |||
path = f'{prefix}' | |||
if not os.path.isdir(path): | |||
os.mkdir(path) | |||
# call mmseqs2 api | |||
tar_gz_file = f'{path}/out_{mode}.tar.gz' | |||
N, REDO = 101, True | |||
# deduplicate and keep track of order | |||
seqs_unique = [] | |||
# TODO this might be slow for large sets | |||
[seqs_unique.append(x) for x in seqs if x not in seqs_unique] | |||
Ms = [N + seqs_unique.index(seq) for seq in seqs] | |||
# lets do it! | |||
if not os.path.isfile(tar_gz_file): | |||
TIME_ESTIMATE = 150 * len(seqs_unique) | |||
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: | |||
while REDO: | |||
pbar.set_description('SUBMIT') | |||
# Resubmit job until it goes through | |||
out = submit(seqs_unique, mode, N) | |||
while out['status'] in ['UNKNOWN', 'RATELIMIT']: | |||
sleep_time = 5 + random.randint(0, 5) | |||
# logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") | |||
# resubmit | |||
time.sleep(sleep_time) | |||
out = submit(seqs_unique, mode, N) | |||
if out['status'] == 'ERROR': | |||
error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' | |||
error = error + 'If error persists, please try again an hour later.' | |||
raise Exception(error) | |||
if out['status'] == 'MAINTENANCE': | |||
raise Exception( | |||
'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.' | |||
) | |||
# wait for job to finish | |||
ID, TIME = out['id'], 0 | |||
pbar.set_description(out['status']) | |||
while out['status'] in ['UNKNOWN', 'RUNNING', 'PENDING']: | |||
t = 5 + random.randint(0, 5) | |||
# logger.error(f"Sleeping for {t}s. Reason: {out['status']}") | |||
time.sleep(t) | |||
out = status(ID) | |||
pbar.set_description(out['status']) | |||
if out['status'] == 'RUNNING': | |||
TIME += t | |||
pbar.update(n=t) | |||
if out['status'] == 'COMPLETE': | |||
if TIME < TIME_ESTIMATE: | |||
pbar.update(n=(TIME_ESTIMATE - TIME)) | |||
REDO = False | |||
if out['status'] == 'ERROR': | |||
REDO = False | |||
error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' | |||
error = error + 'If error persists, please try again an hour later.' | |||
raise Exception(error) | |||
# Download results | |||
download(ID, tar_gz_file) | |||
# prep list of a3m files | |||
if use_pairing: | |||
a3m_files = [f'{path}/pair.a3m'] | |||
else: | |||
a3m_files = [f'{path}/uniref.a3m'] | |||
if use_env: | |||
a3m_files.append(f'{path}/bfd.mgnify30.metaeuk30.smag30.a3m') | |||
# extract a3m files | |||
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): | |||
with tarfile.open(tar_gz_file) as tar_gz: | |||
tar_gz.extractall(path) | |||
# templates | |||
if use_templates: | |||
templates = {} | |||
with open(f'{path}/pdb70.m8', 'r') as f: | |||
lines = f.readlines() | |||
for line in lines: | |||
p = line.rstrip().split() | |||
M, pdb, _, _ = p[0], p[1], p[2], p[10] # qid, e_value | |||
M = int(M) | |||
if M not in templates: | |||
templates[M] = [] | |||
templates[M].append(pdb) | |||
template_paths = {} | |||
for k, TMPL in templates.items(): | |||
TMPL_PATH = f'{prefix}/templates_{k}' | |||
if not os.path.isdir(TMPL_PATH): | |||
os.mkdir(TMPL_PATH) | |||
TMPL_LINE = ','.join(TMPL[:20]) | |||
os.system( | |||
f'curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/' | |||
) | |||
os.system( | |||
f'cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex' | |||
) | |||
os.system(f'touch {TMPL_PATH}/pdb70_cs219.ffdata') | |||
template_paths[k] = TMPL_PATH | |||
# gather a3m lines | |||
a3m_lines = {} | |||
for a3m_file in a3m_files: | |||
update_M, M = True, None | |||
with open(a3m_file, 'r') as f: | |||
lines = f.readlines() | |||
for line in lines: | |||
if len(line) > 0: | |||
if '\x00' in line: | |||
line = line.replace('\x00', '') | |||
update_M = True | |||
if line.startswith('>') and update_M: | |||
M = int(line[1:].rstrip()) | |||
update_M = False | |||
if M not in a3m_lines: | |||
a3m_lines[M] = [] | |||
a3m_lines[M].append(line) | |||
# return results | |||
a3m_lines = [''.join(a3m_lines[n]) for n in Ms] | |||
if use_templates: | |||
template_paths_ = [] | |||
for n in Ms: | |||
if n not in template_paths: | |||
template_paths_.append(None) | |||
# print(f"{n-N}\tno_templates_found") | |||
else: | |||
template_paths_.append(template_paths[n]) | |||
template_paths = template_paths_ | |||
return (a3m_lines, template_paths) if use_templates else a3m_lines | |||
def get_null_template(query_sequence: Union[List[str], str], | |||
num_temp: int = 1) -> Dict[str, Any]: | |||
ln = ( | |||
len(query_sequence) if isinstance(query_sequence, str) else sum( | |||
len(s) for s in query_sequence)) | |||
output_templates_sequence = 'A' * ln | |||
# output_confidence_scores = np.full(ln, 1.0) | |||
templates_all_atom_positions = np.zeros( | |||
(ln, templates.residue_constants.atom_type_num, 3)) | |||
templates_all_atom_masks = np.zeros( | |||
(ln, templates.residue_constants.atom_type_num)) | |||
templates_aatype = templates.residue_constants.sequence_to_onehot( | |||
output_templates_sequence, | |||
templates.residue_constants.HHBLITS_AA_TO_ID) | |||
template_features = { | |||
'template_all_atom_positions': | |||
np.tile(templates_all_atom_positions[None], [num_temp, 1, 1, 1]), | |||
'template_all_atom_masks': | |||
np.tile(templates_all_atom_masks[None], [num_temp, 1, 1]), | |||
'template_sequence': ['none'.encode()] * num_temp, | |||
'template_aatype': | |||
np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]), | |||
'template_domain_names': ['none'.encode()] * num_temp, | |||
'template_sum_probs': | |||
np.zeros([num_temp], dtype=np.float32), | |||
} | |||
return template_features | |||
def get_template(a3m_lines: str, template_path: str, | |||
query_sequence: str) -> Dict[str, Any]: | |||
template_featurizer = templates.HhsearchHitFeaturizer( | |||
mmcif_dir=template_path, | |||
max_template_date='2100-01-01', | |||
max_hits=20, | |||
kalign_binary_path='kalign', | |||
release_dates_path=None, | |||
obsolete_pdbs_path=None, | |||
) | |||
hhsearch_pdb70_runner = hhsearch.HHSearch( | |||
binary_path='hhsearch', databases=[f'{template_path}/pdb70']) | |||
hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines) | |||
hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result) | |||
templates_result = template_featurizer.get_templates( | |||
query_sequence=query_sequence, hits=hhsearch_hits) | |||
return dict(templates_result.features) | |||
@PREPROCESSORS.register_module( | |||
Fields.science, module_name=Preprocessors.unifold_preprocessor) | |||
class UniFoldPreprocessor(Preprocessor): | |||
def __init__(self, **cfg): | |||
self.symmetry_group = cfg['symmetry_group'] # "C1" | |||
if not self.symmetry_group: | |||
self.symmetry_group = None | |||
self.MIN_SINGLE_SEQUENCE_LENGTH = 16 # TODO: change to cfg | |||
self.MAX_SINGLE_SEQUENCE_LENGTH = 1000 | |||
self.MAX_MULTIMER_LENGTH = 1000 | |||
self.jobname = 'unifold' | |||
self.output_dir_base = './unifold-predictions' | |||
os.makedirs(self.output_dir_base, exist_ok=True) | |||
def clean_and_validate_sequence(self, input_sequence: str, min_length: int, | |||
max_length: int) -> str: | |||
clean_sequence = input_sequence.translate( | |||
str.maketrans('', '', ' \n\t')).upper() | |||
aatypes = set(residue_constants.restypes) # 20 standard aatypes. | |||
if not set(clean_sequence).issubset(aatypes): | |||
raise ValueError( | |||
f'Input sequence contains non-amino acid letters: ' | |||
f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard ' | |||
'amino acids as inputs.') | |||
if len(clean_sequence) < min_length: | |||
raise ValueError( | |||
f'Input sequence is too short: {len(clean_sequence)} amino acids, ' | |||
f'while the minimum is {min_length}') | |||
if len(clean_sequence) > max_length: | |||
raise ValueError( | |||
f'Input sequence is too long: {len(clean_sequence)} amino acids, while ' | |||
f'the maximum is {max_length}. You may be able to run it with the full ' | |||
f'Uni-Fold system depending on your resources (system memory, ' | |||
f'GPU memory).') | |||
return clean_sequence | |||
def validate_input(self, input_sequences: Sequence[str], | |||
symmetry_group: str, min_length: int, max_length: int, | |||
max_multimer_length: int) -> Tuple[Sequence[str], bool]: | |||
"""Validates and cleans input sequences and determines which model to use.""" | |||
sequences = [] | |||
for input_sequence in input_sequences: | |||
if input_sequence.strip(): | |||
input_sequence = self.clean_and_validate_sequence( | |||
input_sequence=input_sequence, | |||
min_length=min_length, | |||
max_length=max_length) | |||
sequences.append(input_sequence) | |||
if symmetry_group is not None and symmetry_group != 'C1': | |||
if symmetry_group.startswith( | |||
'C') and symmetry_group[1:].isnumeric(): | |||
print( | |||
f'Using UF-Symmetry with group {symmetry_group}. If you do not ' | |||
f'want to use UF-Symmetry, please use `C1` and copy the AU ' | |||
f'sequences to the count in the assembly.') | |||
is_multimer = (len(sequences) > 1) | |||
return sequences, is_multimer, symmetry_group | |||
else: | |||
raise ValueError( | |||
f'UF-Symmetry does not support symmetry group ' | |||
f'{symmetry_group} currently. Cyclic groups (Cx) are ' | |||
f'supported only.') | |||
elif len(sequences) == 1: | |||
print('Using the single-chain model.') | |||
return sequences, False, None | |||
elif len(sequences) > 1: | |||
total_multimer_length = sum([len(seq) for seq in sequences]) | |||
if total_multimer_length > max_multimer_length: | |||
raise ValueError( | |||
f'The total length of multimer sequences is too long: ' | |||
f'{total_multimer_length}, while the maximum is ' | |||
f'{max_multimer_length}. Please use the full AlphaFold ' | |||
f'system for long multimers.') | |||
print(f'Using the multimer model with {len(sequences)} sequences.') | |||
return sequences, True, None | |||
else: | |||
raise ValueError( | |||
'No input amino acid sequence provided, please provide at ' | |||
'least one sequence.') | |||
def add_hash(self, x, y): | |||
return x + '_' + hashlib.sha1(y.encode()).hexdigest()[:5] | |||
def get_msa_and_templates( | |||
self, | |||
jobname: str, | |||
query_seqs_unique: Union[str, List[str]], | |||
result_dir: Path, | |||
msa_mode: str, | |||
use_templates: bool, | |||
homooligomers_num: int = 1, | |||
host_url: str = DEFAULT_API_SERVER, | |||
) -> Tuple[Optional[List[str]], Optional[List[str]], List[str], List[int], | |||
List[Dict[str, Any]]]: | |||
use_env = msa_mode == 'MMseqs2' | |||
template_features = [] | |||
if use_templates: | |||
a3m_lines_mmseqs2, template_paths = run_mmseqs2( | |||
query_seqs_unique, | |||
str(result_dir.joinpath(jobname)), | |||
use_env, | |||
use_templates=True, | |||
host_url=host_url, | |||
) | |||
if template_paths is None: | |||
for index in range(0, len(query_seqs_unique)): | |||
template_feature = get_null_template( | |||
query_seqs_unique[index]) | |||
template_features.append(template_feature) | |||
else: | |||
for index in range(0, len(query_seqs_unique)): | |||
if template_paths[index] is not None: | |||
template_feature = get_template( | |||
a3m_lines_mmseqs2[index], | |||
template_paths[index], | |||
query_seqs_unique[index], | |||
) | |||
if len(template_feature['template_domain_names']) == 0: | |||
template_feature = get_null_template( | |||
query_seqs_unique[index]) | |||
else: | |||
template_feature = get_null_template( | |||
query_seqs_unique[index]) | |||
template_features.append(template_feature) | |||
else: | |||
for index in range(0, len(query_seqs_unique)): | |||
template_feature = get_null_template(query_seqs_unique[index]) | |||
template_features.append(template_feature) | |||
if msa_mode == 'single_sequence': | |||
a3m_lines = [] | |||
num = 101 | |||
for i, seq in enumerate(query_seqs_unique): | |||
a3m_lines.append('>' + str(num + i) + '\n' + seq) | |||
else: | |||
# find normal a3ms | |||
a3m_lines = run_mmseqs2( | |||
query_seqs_unique, | |||
str(result_dir.joinpath(jobname)), | |||
use_env, | |||
use_pairing=False, | |||
host_url=host_url, | |||
) | |||
if len(query_seqs_unique) > 1: | |||
# find paired a3m if not a homooligomers | |||
paired_a3m_lines = run_mmseqs2( | |||
query_seqs_unique, | |||
str(result_dir.joinpath(jobname)), | |||
use_env, | |||
use_pairing=True, | |||
host_url=host_url, | |||
) | |||
else: | |||
num = 101 | |||
paired_a3m_lines = [] | |||
for i in range(0, homooligomers_num): | |||
paired_a3m_lines.append('>' + str(num + i) + '\n' | |||
+ query_seqs_unique[0] + '\n') | |||
return ( | |||
a3m_lines, | |||
paired_a3m_lines, | |||
template_features, | |||
) | |||
def __call__(self, data: Union[str, Tuple]): | |||
if isinstance(data, str): | |||
data = [data, '', '', ''] | |||
basejobname = ''.join(data) | |||
basejobname = re.sub(r'\W+', '', basejobname) | |||
target_id = self.add_hash(self.jobname, basejobname) | |||
sequences, is_multimer, _ = self.validate_input( | |||
input_sequences=data, | |||
symmetry_group=self.symmetry_group, | |||
min_length=self.MIN_SINGLE_SEQUENCE_LENGTH, | |||
max_length=self.MAX_SINGLE_SEQUENCE_LENGTH, | |||
max_multimer_length=self.MAX_MULTIMER_LENGTH) | |||
descriptions = [ | |||
'> ' + target_id + ' seq' + str(ii) | |||
for ii in range(len(sequences)) | |||
] | |||
if is_multimer: | |||
divide_multi_chains(target_id, self.output_dir_base, sequences, | |||
descriptions) | |||
s = [] | |||
for des, seq in zip(descriptions, sequences): | |||
s += [des, seq] | |||
unique_sequences = [] | |||
[ | |||
unique_sequences.append(x) for x in sequences | |||
if x not in unique_sequences | |||
] | |||
if len(unique_sequences) == 1: | |||
homooligomers_num = len(sequences) | |||
else: | |||
homooligomers_num = 1 | |||
with open(f'{self.jobname}.fasta', 'w') as f: | |||
f.write('\n'.join(s)) | |||
result_dir = Path(self.output_dir_base) | |||
output_dir = os.path.join(self.output_dir_base, target_id) | |||
# msa_mode = 'single_sequence' | |||
msa_mode = 'MMseqs2' | |||
use_templates = True | |||
unpaired_msa, paired_msa, template_results = self.get_msa_and_templates( | |||
target_id, | |||
unique_sequences, | |||
result_dir=result_dir, | |||
msa_mode=msa_mode, | |||
use_templates=use_templates, | |||
homooligomers_num=homooligomers_num) | |||
features = [] | |||
pair_features = [] | |||
for idx, seq in enumerate(unique_sequences): | |||
chain_id = PDB_CHAIN_IDS[idx] | |||
sequence_features = pipeline.make_sequence_features( | |||
sequence=seq, | |||
description=f'> {self.jobname} seq {chain_id}', | |||
num_res=len(seq)) | |||
monomer_msa = parsers.parse_a3m(unpaired_msa[idx]) | |||
msa_features = pipeline.make_msa_features([monomer_msa]) | |||
template_features = template_results[idx] | |||
feature_dict = { | |||
**sequence_features, | |||
**msa_features, | |||
**template_features | |||
} | |||
feature_dict = compress_features(feature_dict) | |||
features_output_path = os.path.join( | |||
output_dir, '{}.feature.pkl.gz'.format(chain_id)) | |||
pickle.dump( | |||
feature_dict, | |||
gzip.GzipFile(features_output_path, 'wb'), | |||
protocol=4) | |||
features.append(feature_dict) | |||
if is_multimer: | |||
multimer_msa = parsers.parse_a3m(paired_msa[idx]) | |||
pair_features = pipeline.make_msa_features([multimer_msa]) | |||
pair_feature_dict = compress_features(pair_features) | |||
uniprot_output_path = os.path.join( | |||
output_dir, '{}.uniprot.pkl.gz'.format(chain_id)) | |||
pickle.dump( | |||
pair_feature_dict, | |||
gzip.GzipFile(uniprot_output_path, 'wb'), | |||
protocol=4, | |||
) | |||
pair_features.append(pair_feature_dict) | |||
# return features, pair_features, target_id | |||
return { | |||
'features': features, | |||
'pair_features': pair_features, | |||
'target_id': target_id, | |||
'is_multimer': is_multimer, | |||
} | |||
if __name__ == '__main__': | |||
proc = UniFoldPreprocessor() | |||
protein_example = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + \ | |||
'TVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI' | |||
features, pair_features = proc.__call__(protein_example) | |||
import ipdb | |||
ipdb.set_trace() |
@@ -9,6 +9,7 @@ class Fields(object): | |||
nlp = 'nlp' | |||
audio = 'audio' | |||
multi_modal = 'multi-modal' | |||
science = 'science' | |||
class CVTasks(object): | |||
@@ -151,6 +152,10 @@ class MultiModalTasks(object): | |||
image_text_retrieval = 'image-text-retrieval' | |||
class ScienceTasks(object): | |||
protein_structure = 'protein-structure' | |||
class TasksIODescriptions(object): | |||
image_to_image = 'image_to_image', | |||
images_to_image = 'images_to_image', | |||
@@ -167,7 +172,7 @@ class TasksIODescriptions(object): | |||
generative_multi_modal_embedding = 'generative_multi_modal_embedding' | |||
class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | |||
class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks): | |||
""" Names for tasks supported by modelscope. | |||
Holds the standard task name to use for identifying different tasks. | |||
@@ -196,6 +201,10 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | |||
getattr(Tasks, attr) for attr in dir(MultiModalTasks) | |||
if not attr.startswith('__') | |||
], | |||
Fields.science: [ | |||
getattr(Tasks, attr) for attr in dir(ScienceTasks) | |||
if not attr.startswith('__') | |||
], | |||
} | |||
for field, tasks in field_dict.items(): | |||
@@ -0,0 +1,6 @@ | |||
iopath | |||
lmdb | |||
ml_collections | |||
scipy | |||
tensorboardX | |||
tokenizers |
@@ -0,0 +1,34 @@ | |||
# Copyright (c) Alibaba, Inc. and its affiliates. | |||
import unittest | |||
from modelscope.hub.snapshot_download import snapshot_download | |||
from modelscope.pipelines import pipeline | |||
from modelscope.utils.constant import Tasks | |||
from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
from modelscope.utils.test_utils import test_level | |||
class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): | |||
def setUp(self) -> None: | |||
self.task = Tasks.protein_structure | |||
self.model_id = 'DPTech/uni-fold-monomer' | |||
self.model_id_multimer = 'DPTech/uni-fold-multimer' | |||
self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD' | |||
self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ | |||
'NIAALKNHIDKIKPIAMQIYKKYSKNIP' | |||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||
def test_run_by_direct_model_download(self): | |||
model_dir = snapshot_download(self.model_id) | |||
mono_pipeline_ins = pipeline(task=self.task, model=model_dir) | |||
_ = mono_pipeline_ins(self.protein) | |||
model_dir1 = snapshot_download(self.model_id_multimer) | |||
multi_pipeline_ins = pipeline(task=self.task, model=model_dir1) | |||
_ = multi_pipeline_ins(self.protein_multimer) | |||
if __name__ == '__main__': | |||
unittest.main() |