diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 16190eb8..c5067c39 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -99,6 +99,10 @@ class Models(object): team = 'team-multi-modal-similarity' video_clip = 'video-clip-multi-modal-embedding' + # science models + unifold = 'unifold' + unifold_symmetry = 'unifold-symmetry' + class TaskModels(object): # nlp task @@ -266,6 +270,9 @@ class Pipelines(object): image_text_retrieval = 'image-text-retrieval' ofa_ocr_recognition = 'ofa-ocr-recognition' + # science tasks + protein_structure = 'unifold-protein-structure' + class Trainers(object): """ Names for different trainer. @@ -368,6 +375,9 @@ class Preprocessors(object): ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' + # science preprocessor + unifold_preprocessor = 'unifold-preprocessor' + class Metrics(object): """ Names for different metrics. diff --git a/modelscope/models/science/__init__.py b/modelscope/models/science/__init__.py new file mode 100644 index 00000000..50ab55d7 --- /dev/null +++ b/modelscope/models/science/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .unifold import UnifoldForProteinStructrue + +else: + _import_structure = {'unifold': ['UnifoldForProteinStructrue']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/science/unifold/__init__.py b/modelscope/models/science/unifold/__init__.py new file mode 100644 index 00000000..75435fed --- /dev/null +++ b/modelscope/models/science/unifold/__init__.py @@ -0,0 +1 @@ +from .model import UnifoldForProteinStructrue diff --git a/modelscope/models/science/unifold/config.py b/modelscope/models/science/unifold/config.py new file mode 100644 index 00000000..e760fbf9 --- /dev/null +++ b/modelscope/models/science/unifold/config.py @@ -0,0 +1,636 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import copy +from typing import Any + +import ml_collections as mlc + +N_RES = 'number of residues' +N_MSA = 'number of MSA sequences' +N_EXTRA_MSA = 'number of extra MSA sequences' +N_TPL = 'number of templates' + +d_pair = mlc.FieldReference(128, field_type=int) +d_msa = mlc.FieldReference(256, field_type=int) +d_template = mlc.FieldReference(64, field_type=int) +d_extra_msa = mlc.FieldReference(64, field_type=int) +d_single = mlc.FieldReference(384, field_type=int) +max_recycling_iters = mlc.FieldReference(3, field_type=int) +chunk_size = mlc.FieldReference(4, field_type=int) +aux_distogram_bins = mlc.FieldReference(64, field_type=int) +eps = mlc.FieldReference(1e-8, field_type=float) +inf = mlc.FieldReference(3e4, field_type=float) +use_templates = mlc.FieldReference(True, field_type=bool) +is_multimer = mlc.FieldReference(False, field_type=bool) + + +def base_config(): + return mlc.ConfigDict({ + 'data': { + 'common': { + 'features': { + 'aatype': [N_RES], + 'all_atom_mask': [N_RES, None], + 'all_atom_positions': [N_RES, None, None], + 'alt_chi_angles': [N_RES, None], + 'atom14_alt_gt_exists': [N_RES, None], + 'atom14_alt_gt_positions': [N_RES, None, None], + 'atom14_atom_exists': [N_RES, None], + 'atom14_atom_is_ambiguous': [N_RES, None], + 'atom14_gt_exists': [N_RES, None], + 'atom14_gt_positions': [N_RES, None, None], + 'atom37_atom_exists': [N_RES, None], + 'frame_mask': [N_RES], + 'true_frame_tensor': [N_RES, None, None], + 'bert_mask': [N_MSA, N_RES], + 'chi_angles_sin_cos': [N_RES, None, None], + 'chi_mask': [N_RES, None], + 'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES], + 'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES], + 'extra_msa': [N_EXTRA_MSA, N_RES], + 'extra_msa_mask': [N_EXTRA_MSA, N_RES], + 'extra_msa_row_mask': [N_EXTRA_MSA], + 'is_distillation': [], + 'msa_feat': [N_MSA, N_RES, None], + 'msa_mask': [N_MSA, N_RES], + 'msa_chains': [N_MSA, None], + 'msa_row_mask': [N_MSA], + 'num_recycling_iters': [], + 'pseudo_beta': [N_RES, None], + 'pseudo_beta_mask': [N_RES], + 'residue_index': [N_RES], + 'residx_atom14_to_atom37': [N_RES, None], + 'residx_atom37_to_atom14': [N_RES, None], + 'resolution': [], + 'rigidgroups_alt_gt_frames': [N_RES, None, None, None], + 'rigidgroups_group_exists': [N_RES, None], + 'rigidgroups_group_is_ambiguous': [N_RES, None], + 'rigidgroups_gt_exists': [N_RES, None], + 'rigidgroups_gt_frames': [N_RES, None, None, None], + 'seq_length': [], + 'seq_mask': [N_RES], + 'target_feat': [N_RES, None], + 'template_aatype': [N_TPL, N_RES], + 'template_all_atom_mask': [N_TPL, N_RES, None], + 'template_all_atom_positions': [N_TPL, N_RES, None, None], + 'template_alt_torsion_angles_sin_cos': [ + N_TPL, + N_RES, + None, + None, + ], + 'template_frame_mask': [N_TPL, N_RES], + 'template_frame_tensor': [N_TPL, N_RES, None, None], + 'template_mask': [N_TPL], + 'template_pseudo_beta': [N_TPL, N_RES, None], + 'template_pseudo_beta_mask': [N_TPL, N_RES], + 'template_sum_probs': [N_TPL, None], + 'template_torsion_angles_mask': [N_TPL, N_RES, None], + 'template_torsion_angles_sin_cos': + [N_TPL, N_RES, None, None], + 'true_msa': [N_MSA, N_RES], + 'use_clamped_fape': [], + 'assembly_num_chains': [1], + 'asym_id': [N_RES], + 'sym_id': [N_RES], + 'entity_id': [N_RES], + 'num_sym': [N_RES], + 'asym_len': [None], + 'cluster_bias_mask': [N_MSA], + }, + 'masked_msa': { + 'profile_prob': 0.1, + 'same_prob': 0.1, + 'uniform_prob': 0.1, + }, + 'block_delete_msa': { + 'msa_fraction_per_block': 0.3, + 'randomize_num_blocks': False, + 'num_blocks': 5, + 'min_num_msa': 16, + }, + 'random_delete_msa': { + 'max_msa_entry': 1 << 25, # := 33554432 + }, + 'v2_feature': + False, + 'gumbel_sample': + False, + 'max_extra_msa': + 1024, + 'msa_cluster_features': + True, + 'reduce_msa_clusters_by_max_templates': + True, + 'resample_msa_in_recycling': + True, + 'template_features': [ + 'template_all_atom_positions', + 'template_sum_probs', + 'template_aatype', + 'template_all_atom_mask', + ], + 'unsupervised_features': [ + 'aatype', + 'residue_index', + 'msa', + 'msa_chains', + 'num_alignments', + 'seq_length', + 'between_segment_residues', + 'deletion_matrix', + 'num_recycling_iters', + 'crop_and_fix_size_seed', + ], + 'recycling_features': [ + 'msa_chains', + 'msa_mask', + 'msa_row_mask', + 'bert_mask', + 'true_msa', + 'msa_feat', + 'extra_msa_deletion_value', + 'extra_msa_has_deletion', + 'extra_msa', + 'extra_msa_mask', + 'extra_msa_row_mask', + 'is_distillation', + ], + 'multimer_features': [ + 'assembly_num_chains', + 'asym_id', + 'sym_id', + 'num_sym', + 'entity_id', + 'asym_len', + 'cluster_bias_mask', + ], + 'use_templates': + use_templates, + 'is_multimer': + is_multimer, + 'use_template_torsion_angles': + use_templates, + 'max_recycling_iters': + max_recycling_iters, + }, + 'supervised': { + 'use_clamped_fape_prob': + 1.0, + 'supervised_features': [ + 'all_atom_mask', + 'all_atom_positions', + 'resolution', + 'use_clamped_fape', + 'is_distillation', + ], + }, + 'predict': { + 'fixed_size': True, + 'subsample_templates': False, + 'block_delete_msa': False, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 2, + 'crop': False, + 'crop_size': None, + 'supervised': False, + 'biased_msa_by_chain': False, + 'share_mask': False, + }, + 'eval': { + 'fixed_size': True, + 'subsample_templates': False, + 'block_delete_msa': False, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 1, + 'crop': False, + 'crop_size': None, + 'spatial_crop_prob': 0.5, + 'ca_ca_threshold': 10.0, + 'supervised': True, + 'biased_msa_by_chain': False, + 'share_mask': False, + }, + 'train': { + 'fixed_size': True, + 'subsample_templates': True, + 'block_delete_msa': True, + 'random_delete_msa': True, + 'masked_msa_replace_fraction': 0.15, + 'max_msa_clusters': 128, + 'max_templates': 4, + 'num_ensembles': 1, + 'crop': True, + 'crop_size': 256, + 'spatial_crop_prob': 0.5, + 'ca_ca_threshold': 10.0, + 'supervised': True, + 'use_clamped_fape_prob': 1.0, + 'max_distillation_msa_clusters': 1000, + 'biased_msa_by_chain': True, + 'share_mask': True, + }, + }, + 'globals': { + 'chunk_size': chunk_size, + 'block_size': None, + 'd_pair': d_pair, + 'd_msa': d_msa, + 'd_template': d_template, + 'd_extra_msa': d_extra_msa, + 'd_single': d_single, + 'eps': eps, + 'inf': inf, + 'max_recycling_iters': max_recycling_iters, + 'alphafold_original_mode': False, + }, + 'model': { + 'is_multimer': is_multimer, + 'input_embedder': { + 'tf_dim': 22, + 'msa_dim': 49, + 'd_pair': d_pair, + 'd_msa': d_msa, + 'relpos_k': 32, + 'max_relative_chain': 2, + }, + 'recycling_embedder': { + 'd_pair': d_pair, + 'd_msa': d_msa, + 'min_bin': 3.25, + 'max_bin': 20.75, + 'num_bins': 15, + 'inf': 1e8, + }, + 'template': { + 'distogram': { + 'min_bin': 3.25, + 'max_bin': 50.75, + 'num_bins': 39, + }, + 'template_angle_embedder': { + 'd_in': 57, + 'd_out': d_msa, + }, + 'template_pair_embedder': { + 'd_in': 88, + 'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1], + 'd_pair': d_pair, + 'd_out': d_template, + 'v2_feature': False, + }, + 'template_pair_stack': { + 'd_template': d_template, + 'd_hid_tri_att': 16, + 'd_hid_tri_mul': 64, + 'num_blocks': 2, + 'num_heads': 4, + 'pair_transition_n': 2, + 'dropout_rate': 0.25, + 'inf': 1e9, + 'tri_attn_first': True, + }, + 'template_pointwise_attention': { + 'enabled': True, + 'd_template': d_template, + 'd_pair': d_pair, + 'd_hid': 16, + 'num_heads': 4, + 'inf': 1e5, + }, + 'inf': 1e5, + 'eps': 1e-6, + 'enabled': use_templates, + 'embed_angles': use_templates, + }, + 'extra_msa': { + 'extra_msa_embedder': { + 'd_in': 25, + 'd_out': d_extra_msa, + }, + 'extra_msa_stack': { + 'd_msa': d_extra_msa, + 'd_pair': d_pair, + 'd_hid_msa_att': 8, + 'd_hid_opm': 32, + 'd_hid_mul': 128, + 'd_hid_pair_att': 32, + 'num_heads_msa': 8, + 'num_heads_pair': 4, + 'num_blocks': 4, + 'transition_n': 4, + 'msa_dropout': 0.15, + 'pair_dropout': 0.25, + 'inf': 1e9, + 'eps': 1e-10, + 'outer_product_mean_first': False, + }, + 'enabled': True, + }, + 'evoformer_stack': { + 'd_msa': d_msa, + 'd_pair': d_pair, + 'd_hid_msa_att': 32, + 'd_hid_opm': 32, + 'd_hid_mul': 128, + 'd_hid_pair_att': 32, + 'd_single': d_single, + 'num_heads_msa': 8, + 'num_heads_pair': 4, + 'num_blocks': 48, + 'transition_n': 4, + 'msa_dropout': 0.15, + 'pair_dropout': 0.25, + 'inf': 1e9, + 'eps': 1e-10, + 'outer_product_mean_first': False, + }, + 'structure_module': { + 'd_single': d_single, + 'd_pair': d_pair, + 'd_ipa': 16, + 'd_angle': 128, + 'num_heads_ipa': 12, + 'num_qk_points': 4, + 'num_v_points': 8, + 'dropout_rate': 0.1, + 'num_blocks': 8, + 'no_transition_layers': 1, + 'num_resnet_blocks': 2, + 'num_angles': 7, + 'trans_scale_factor': 10, + 'epsilon': 1e-12, + 'inf': 1e5, + 'separate_kv': False, + 'ipa_bias': True, + }, + 'heads': { + 'plddt': { + 'num_bins': 50, + 'd_in': d_single, + 'd_hid': 128, + }, + 'distogram': { + 'd_pair': d_pair, + 'num_bins': aux_distogram_bins, + 'disable_enhance_head': False, + }, + 'pae': { + 'd_pair': d_pair, + 'num_bins': aux_distogram_bins, + 'enabled': False, + 'iptm_weight': 0.8, + 'disable_enhance_head': False, + }, + 'masked_msa': { + 'd_msa': d_msa, + 'd_out': 23, + 'disable_enhance_head': False, + }, + 'experimentally_resolved': { + 'd_single': d_single, + 'd_out': 37, + 'enabled': False, + 'disable_enhance_head': False, + }, + }, + }, + 'loss': { + 'distogram': { + 'min_bin': 2.3125, + 'max_bin': 21.6875, + 'num_bins': 64, + 'eps': 1e-6, + 'weight': 0.3, + }, + 'experimentally_resolved': { + 'eps': 1e-8, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'weight': 0.0, + }, + 'fape': { + 'backbone': { + 'clamp_distance': 10.0, + 'clamp_distance_between_chains': 30.0, + 'loss_unit_distance': 10.0, + 'loss_unit_distance_between_chains': 20.0, + 'weight': 0.5, + 'eps': 1e-4, + }, + 'sidechain': { + 'clamp_distance': 10.0, + 'length_scale': 10.0, + 'weight': 0.5, + 'eps': 1e-4, + }, + 'weight': 1.0, + }, + 'plddt': { + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'cutoff': 15.0, + 'num_bins': 50, + 'eps': 1e-10, + 'weight': 0.01, + }, + 'masked_msa': { + 'eps': 1e-8, + 'weight': 2.0, + }, + 'supervised_chi': { + 'chi_weight': 0.5, + 'angle_norm_weight': 0.01, + 'eps': 1e-6, + 'weight': 1.0, + }, + 'violation': { + 'violation_tolerance_factor': 12.0, + 'clash_overlap_tolerance': 1.5, + 'bond_angle_loss_weight': 0.3, + 'eps': 1e-6, + 'weight': 0.0, + }, + 'pae': { + 'max_bin': 31, + 'num_bins': 64, + 'min_resolution': 0.1, + 'max_resolution': 3.0, + 'eps': 1e-8, + 'weight': 0.0, + }, + 'repr_norm': { + 'weight': 0.01, + 'tolerance': 1.0, + }, + 'chain_centre_mass': { + 'weight': 0.0, + 'eps': 1e-8, + }, + }, + }) + + +def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None): + with c.unlocked(): + for k, v in c.items(): + if ignore is not None and k == ignore: + continue + if isinstance(v, mlc.ConfigDict): + recursive_set(v, key, value) + elif k == key: + c[k] = value + + +def model_config(name, train=False): + c = copy.deepcopy(base_config()) + + def model_2_v2(c): + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.template.template_angle_embedder.d_in = 34 + return c + + def multimer(c): + recursive_set(c, 'is_multimer', True) + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 128) + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.template.template_angle_embedder.d_in = 34 + c.model.template.template_pair_stack.tri_attn_first = False + c.model.template.template_pointwise_attention.enabled = False + c.model.heads.pae.enabled = True + # we forget to enable it in our training, so disable it here + c.model.heads.pae.disable_enhance_head = True + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.structure_module.trans_scale_factor = 20 + c.loss.pae.weight = 0.1 + c.model.input_embedder.tf_dim = 21 + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.chain_centre_mass.weight = 1.0 + return c + + if name == 'model_1': + pass + elif name == 'model_1_ft': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_1_af2': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + elif name == 'model_2': + pass + elif name == 'model_init': + pass + elif name == 'model_init_af2': + c.globals.alphafold_original_mode = True + pass + elif name == 'model_2_ft': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_2_af2': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + elif name == 'model_2_v2': + c = model_2_v2(c) + elif name == 'model_2_v2_ft': + c = model_2_v2(c) + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + elif name == 'model_3_af2' or name == 'model_4_af2': + recursive_set(c, 'max_extra_msa', 5120) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + c.model.template.enabled = False + c.model.template.embed_angles = False + recursive_set(c, 'use_templates', False) + recursive_set(c, 'use_template_torsion_angles', False) + elif name == 'model_5_af2': + recursive_set(c, 'max_extra_msa', 1024) + recursive_set(c, 'max_msa_clusters', 512) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.02 + c.loss.repr_norm.weight = 0 + c.model.heads.experimentally_resolved.enabled = True + c.loss.experimentally_resolved.weight = 0.01 + c.globals.alphafold_original_mode = True + c.model.template.enabled = False + c.model.template.embed_angles = False + recursive_set(c, 'use_templates', False) + recursive_set(c, 'use_template_torsion_angles', False) + elif name == 'multimer': + c = multimer(c) + elif name == 'multimer_ft': + c = multimer(c) + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 256) + c.data.train.crop_size = 384 + c.loss.violation.weight = 0.5 + elif name == 'multimer_af2': + recursive_set(c, 'max_extra_msa', 1152) + recursive_set(c, 'max_msa_clusters', 256) + recursive_set(c, 'is_multimer', True) + recursive_set(c, 'v2_feature', True) + recursive_set(c, 'gumbel_sample', True) + c.model.template.template_angle_embedder.d_in = 34 + c.model.template.template_pair_stack.tri_attn_first = False + c.model.template.template_pointwise_attention.enabled = False + c.model.heads.pae.enabled = True + c.model.heads.experimentally_resolved.enabled = True + c.model.heads.masked_msa.d_out = 22 + c.model.structure_module.separate_kv = True + c.model.structure_module.ipa_bias = False + c.model.structure_module.trans_scale_factor = 20 + c.loss.pae.weight = 0.1 + c.loss.violation.weight = 0.5 + c.loss.experimentally_resolved.weight = 0.01 + c.model.input_embedder.tf_dim = 21 + c.globals.alphafold_original_mode = True + c.data.train.crop_size = 384 + c.loss.repr_norm.weight = 0 + c.loss.chain_centre_mass.weight = 1.0 + recursive_set(c, 'outer_product_mean_first', True) + else: + raise ValueError(f'invalid --model-name: {name}.') + if train: + c.globals.chunk_size = None + recursive_set(c, 'inf', 3e4) + recursive_set(c, 'eps', 1e-5, 'loss') + return c diff --git a/modelscope/models/science/unifold/data/__init__.py b/modelscope/models/science/unifold/data/__init__.py new file mode 100644 index 00000000..9821d212 --- /dev/null +++ b/modelscope/models/science/unifold/data/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Data pipeline for model features.""" diff --git a/modelscope/models/science/unifold/data/data_ops.py b/modelscope/models/science/unifold/data/data_ops.py new file mode 100644 index 00000000..637aa0cd --- /dev/null +++ b/modelscope/models/science/unifold/data/data_ops.py @@ -0,0 +1,1397 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import itertools +from functools import reduce, wraps +from operator import add +from typing import List, MutableMapping, Optional + +import numpy as np +import torch +from unicore.data import data_utils +from unicore.utils import batched_gather, one_hot, tensor_tree_map, tree_map + +from modelscope.models.science.unifold.config import (N_EXTRA_MSA, N_MSA, + N_RES, N_TPL) +from modelscope.models.science.unifold.data import residue_constants as rc +from modelscope.models.science.unifold.modules.frame import Frame, Rotation + +NumpyDict = MutableMapping[str, np.ndarray] +TorchDict = MutableMapping[str, np.ndarray] + +protein: TorchDict + +MSA_FEATURE_NAMES = [ + 'msa', + 'deletion_matrix', + 'msa_mask', + 'msa_row_mask', + 'bert_mask', + 'true_msa', + 'msa_chains', +] + + +def cast_to_64bit_ints(protein): + # We keep all ints as int64 + for k, v in protein.items(): + if k.endswith('_mask'): + protein[k] = v.type(torch.float32) + elif v.dtype in (torch.int32, torch.uint8, torch.int8): + protein[k] = v.type(torch.int64) + + return protein + + +def make_seq_mask(protein): + protein['seq_mask'] = torch.ones( + protein['aatype'].shape, dtype=torch.float32) + return protein + + +def make_template_mask(protein): + protein['template_mask'] = torch.ones( + protein['template_aatype'].shape[0], dtype=torch.float32) + return protein + + +def curry1(f): + """Supply all arguments but the first.""" + + @wraps(f) + def fc(*args, **kwargs): + return lambda x: f(x, *args, **kwargs) + + return fc + + +def correct_msa_restypes(protein): + """Correct MSA restype to have the same order as rc.""" + protein['msa'] = protein['msa'].long() + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = ( + torch.tensor(new_order_list, dtype=torch.int8).unsqueeze(-1).expand( + -1, protein['msa'].shape[1])) + protein['msa'] = torch.gather(new_order, 0, protein['msa']).long() + + return protein + + +def squeeze_features(protein): + """Remove singleton and repeated dimensions in protein features.""" + if len(protein['aatype'].shape) == 2: + protein['aatype'] = torch.argmax(protein['aatype'], dim=-1) + if 'resolution' in protein and len(protein['resolution'].shape) == 1: + # use tensor for resolution + protein['resolution'] = protein['resolution'][0] + for k in [ + 'domain_name', + 'msa', + 'num_alignments', + 'seq_length', + 'sequence', + 'superfamily', + 'deletion_matrix', + 'between_segment_residues', + 'residue_index', + 'template_all_atom_mask', + ]: + if k in protein and len(protein[k].shape): + final_dim = protein[k].shape[-1] + if isinstance(final_dim, int) and final_dim == 1: + if torch.is_tensor(protein[k]): + protein[k] = torch.squeeze(protein[k], dim=-1) + else: + protein[k] = np.squeeze(protein[k], axis=-1) + + for k in ['seq_length', 'num_alignments']: + if k in protein and len(protein[k].shape): + protein[k] = protein[k][0] + + return protein + + +@curry1 +def randomly_replace_msa_with_unknown(protein, replace_proportion): + """Replace a portion of the MSA with 'X'.""" + if replace_proportion > 0.0: + msa_mask = np.random.rand(protein['msa'].shape) < replace_proportion + x_idx = 20 + gap_idx = 21 + msa_mask = torch.logical_and(msa_mask, protein['msa'] != gap_idx) + protein['msa'] = torch.where(msa_mask, + torch.ones_like(protein['msa']) * x_idx, + protein['msa']) + aatype_mask = np.random.rand( + protein['aatype'].shape) < replace_proportion + + protein['aatype'] = torch.where( + aatype_mask, + torch.ones_like(protein['aatype']) * x_idx, + protein['aatype'], + ) + return protein + + +def gumbel_noise(shape): + """Generate Gumbel Noise of given Shape. + This generates samples from Gumbel(0, 1). + Args: + shape: Shape of noise to return. + Returns: + Gumbel noise of given shape. + """ + epsilon = 1e-6 + uniform_noise = torch.from_numpy(np.random.uniform(0, 1, shape)) + gumbel = -torch.log(-torch.log(uniform_noise + epsilon) + epsilon) + return gumbel + + +def gumbel_max_sample(logits): + """Samples from a probability distribution given by 'logits'. + This uses Gumbel-max trick to implement the sampling in an efficient manner. + Args: + logits: Logarithm of probabilities to sample from, probabilities can be + unnormalized. + Returns: + Sample from logprobs in one-hot form. + """ + z = gumbel_noise(logits.shape) + return torch.argmax(logits + z, dim=-1) + + +def gumbel_argsort_sample_idx(logits): + """Samples with replacement from a distribution given by 'logits'. + This uses Gumbel trick to implement the sampling an efficient manner. For a + distribution over k items this samples k times without replacement, so this + is effectively sampling a random permutation with probabilities over the + permutations derived from the logprobs. + Args: + logits: Logarithm of probabilities to sample from, probabilities can be + unnormalized. + Returns: + Sample from logprobs in index + """ + z = gumbel_noise(logits.shape) + return torch.argsort(logits + z, dim=-1, descending=True) + + +def uniform_permutation(num_seq): + shuffled = torch.from_numpy(np.random.permutation(num_seq - 1) + 1) + return torch.cat((torch.tensor([0]), shuffled), dim=0) + + +def gumbel_permutation(msa_mask, msa_chains=None): + has_msa = torch.sum(msa_mask.long(), dim=-1) > 0 + # default logits is zero + logits = torch.zeros_like(has_msa, dtype=torch.float32) + logits[~has_msa] = -1e6 + # one sample only + assert len(logits.shape) == 1 + # skip first row + logits = logits[1:] + has_msa = has_msa[1:] + if logits.shape[0] == 0: + return torch.tensor([0]) + if msa_chains is not None: + # skip first row + msa_chains = msa_chains[1:].reshape(-1) + msa_chains[~has_msa] = 0 + keys, counts = np.unique(msa_chains, return_counts=True) + num_has_msa = has_msa.sum() + num_pair = (msa_chains == 1).sum() + num_unpair = num_has_msa - num_pair + num_chains = (keys > 1).sum() + logits[has_msa] = 1.0 / (num_has_msa + 1e-6) + logits[~has_msa] = 0 + for k in keys: + if k > 1: + cur_mask = msa_chains == k + cur_cnt = cur_mask.sum() + if cur_cnt > 0: + logits[cur_mask] *= num_unpair / (num_chains * cur_cnt) + logits = torch.log(logits + 1e-6) + shuffled = gumbel_argsort_sample_idx(logits) + 1 + return torch.cat((torch.tensor([0]), shuffled), dim=0) + + +@curry1 +def sample_msa(protein, + max_seq, + keep_extra, + gumbel_sample=False, + biased_msa_by_chain=False): + """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" + num_seq = protein['msa'].shape[0] + num_sel = min(max_seq, num_seq) + if not gumbel_sample: + index_order = uniform_permutation(num_seq) + else: + msa_chains = ( + protein['msa_chains'] if + (biased_msa_by_chain and 'msa_chains' in protein) else None) + index_order = gumbel_permutation(protein['msa_mask'], msa_chains) + num_sel = min(max_seq, num_seq) + sel_seq, not_sel_seq = torch.split(index_order, + [num_sel, num_seq - num_sel]) + + for k in MSA_FEATURE_NAMES: + if k in protein: + if keep_extra: + protein['extra_' + k] = torch.index_select( + protein[k], 0, not_sel_seq) + protein[k] = torch.index_select(protein[k], 0, sel_seq) + + return protein + + +@curry1 +def sample_msa_distillation(protein, max_seq): + if 'is_distillation' in protein and protein['is_distillation'] == 1: + protein = sample_msa(max_seq, keep_extra=False)(protein) + return protein + + +@curry1 +def random_delete_msa(protein, config): + # to reduce the cost of msa features + num_seq = protein['msa'].shape[0] + seq_len = protein['msa'].shape[1] + max_seq = config.max_msa_entry // seq_len + if num_seq > max_seq: + keep_index = ( + torch.from_numpy( + np.random.choice(num_seq - 1, max_seq - 1, + replace=False)).long() + 1) + keep_index = torch.sort(keep_index)[0] + keep_index = torch.cat((torch.tensor([0]), keep_index), dim=0) + for k in MSA_FEATURE_NAMES: + if k in protein: + protein[k] = torch.index_select(protein[k], 0, keep_index) + return protein + + +@curry1 +def crop_extra_msa(protein, max_extra_msa): + num_seq = protein['extra_msa'].shape[0] + num_sel = min(max_extra_msa, num_seq) + select_indices = torch.from_numpy(np.random.permutation(num_seq)[:num_sel]) + for k in MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + protein['extra_' + k] = torch.index_select(protein['extra_' + k], + 0, select_indices) + + return protein + + +def delete_extra_msa(protein): + for k in MSA_FEATURE_NAMES: + if 'extra_' + k in protein: + del protein['extra_' + k] + return protein + + +@curry1 +def block_delete_msa(protein, config): + if 'is_distillation' in protein and protein['is_distillation'] == 1: + return protein + num_seq = protein['msa'].shape[0] + if num_seq <= config.min_num_msa: + return protein + block_num_seq = torch.floor( + torch.tensor(num_seq, dtype=torch.float32) + * config.msa_fraction_per_block).to(torch.int32) + + if config.randomize_num_blocks: + nb = np.random.randint(0, config.num_blocks + 1) + else: + nb = config.num_blocks + + del_block_starts = torch.from_numpy(np.random.randint(0, num_seq, [nb])) + del_blocks = del_block_starts[:, None] + torch.arange(0, block_num_seq) + del_blocks = torch.clip(del_blocks, 0, num_seq - 1) + del_indices = torch.unique(del_blocks.view(-1)) + # add zeros to ensure cnt_zero > 1 + combined = torch.hstack((torch.arange(0, num_seq)[None], del_indices[None], + torch.zeros(2)[None])).long() + uniques, counts = combined.unique(return_counts=True) + difference = uniques[counts == 1] + # intersection = uniques[counts > 1] + keep_indices = difference.view(-1) + keep_indices = torch.hstack( + [torch.zeros(1).long()[None], keep_indices[None]]).view(-1) + assert int(keep_indices[0]) == 0 + for k in MSA_FEATURE_NAMES: + if k in protein: + protein[k] = torch.index_select(protein[k], 0, index=keep_indices) + return protein + + +@curry1 +def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0): + weights = torch.cat( + [torch.ones(21), gap_agreement_weight * torch.ones(1), + torch.zeros(1)], + 0, + ) + + msa_one_hot = one_hot(protein['msa'], 23) + sample_one_hot = protein['msa_mask'][:, :, None] * msa_one_hot + extra_msa_one_hot = one_hot(protein['extra_msa'], 23) + extra_one_hot = protein['extra_msa_mask'][:, :, None] * extra_msa_one_hot + + num_seq, num_res, _ = sample_one_hot.shape + extra_num_seq, _, _ = extra_one_hot.shape + + # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) + # in an optimized fashion to avoid possible memory or computation blowup. + a = extra_one_hot.view(extra_num_seq, num_res * 23) + b = (sample_one_hot * weights).view(num_seq, num_res * 23).transpose(0, 1) + agreement = a @ b + # Assign each sequence in the extra sequences to the closest MSA sample + protein['extra_cluster_assignment'] = torch.argmax(agreement, dim=1).long() + + return protein + + +def unsorted_segment_sum(data, segment_ids, num_segments): + assert len( + segment_ids.shape) == 1 and segment_ids.shape[0] == data.shape[0] + segment_ids = segment_ids.view(segment_ids.shape[0], + *((1, ) * len(data.shape[1:]))) + segment_ids = segment_ids.expand(data.shape) + shape = [num_segments] + list(data.shape[1:]) + tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float()) + tensor = tensor.type(data.dtype) + return tensor + + +def summarize_clusters(protein): + """Produce profile and deletion_matrix_mean within each cluster.""" + num_seq = protein['msa'].shape[0] + + def csum(x): + return unsorted_segment_sum(x, protein['extra_cluster_assignment'], + num_seq) + + mask = protein['extra_msa_mask'] + mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center + + # TODO: this line is very slow + msa_sum = csum(mask[:, :, None] * one_hot(protein['extra_msa'], 23)) + msa_sum += one_hot(protein['msa'], 23) # Original sequence + protein['cluster_profile'] = msa_sum / mask_counts[:, :, None] + del msa_sum + + del_sum = csum(mask * protein['extra_deletion_matrix']) + del_sum += protein['deletion_matrix'] # Original sequence + protein['cluster_deletion_mean'] = del_sum / mask_counts + del del_sum + + return protein + + +@curry1 +def nearest_neighbor_clusters_v2(batch, gap_agreement_weight=0.0): + """Assign each extra MSA sequence to its nearest neighbor in sampled MSA.""" + + # Determine how much weight we assign to each agreement. In theory, we could + # use a full blosum matrix here, but right now let's just down-weight gap + # agreement because it could be spurious. + # Never put weight on agreeing on BERT mask. + + weights = torch.tensor( + [1.0] * 21 + [gap_agreement_weight] + [0.0], dtype=torch.float32) + + msa_mask = batch['msa_mask'] + extra_mask = batch['extra_msa_mask'] + msa_one_hot = one_hot(batch['msa'], 23) + extra_one_hot = one_hot(batch['extra_msa'], 23) + + msa_one_hot_masked = msa_mask[:, :, None] * msa_one_hot + extra_one_hot_masked = extra_mask[:, :, None] * extra_one_hot + + t1 = weights * msa_one_hot_masked + t1 = t1.view(t1.shape[0], t1.shape[1] * t1.shape[2]) + t2 = extra_one_hot_masked.view( + extra_one_hot.shape[0], + extra_one_hot.shape[1] * extra_one_hot.shape[2]) + agreement = t1 @ t2.T + + cluster_assignment = torch.nn.functional.softmax(1e3 * agreement, dim=0) + cluster_assignment *= torch.einsum('mr, nr->mn', msa_mask, extra_mask) + + cluster_count = torch.sum(cluster_assignment, dim=-1) + cluster_count += 1.0 # We always include the sequence itself. + + msa_sum = torch.einsum('nm, mrc->nrc', cluster_assignment, + extra_one_hot_masked) + msa_sum += msa_one_hot_masked + + cluster_profile = msa_sum / cluster_count[:, None, None] + + deletion_matrix = batch['deletion_matrix'] + extra_deletion_matrix = batch['extra_deletion_matrix'] + + del_sum = torch.einsum('nm, mc->nc', cluster_assignment, + extra_mask * extra_deletion_matrix) + del_sum += deletion_matrix # Original sequence. + cluster_deletion_mean = del_sum / cluster_count[:, None] + batch['cluster_profile'] = cluster_profile + batch['cluster_deletion_mean'] = cluster_deletion_mean + + return batch + + +def make_msa_mask(protein): + """Mask features are all ones, but will later be zero-padded.""" + if 'msa_mask' not in protein: + protein['msa_mask'] = torch.ones( + protein['msa'].shape, dtype=torch.float32) + protein['msa_row_mask'] = torch.ones((protein['msa'].shape[0]), + dtype=torch.float32) + return protein + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): + """Create pseudo beta features.""" + if aatype.shape[0] > 0: + is_gly = torch.eq(aatype, rc.restype_order['G']) + ca_idx = rc.atom_order['CA'] + cb_idx = rc.atom_order['CB'] + pseudo_beta = torch.where( + torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + else: + pseudo_beta = all_atom_positions.new_zeros(*aatype.shape, 3) + if all_atom_mask is not None: + if aatype.shape[0] > 0: + pseudo_beta_mask = torch.where(is_gly, all_atom_mask[..., ca_idx], + all_atom_mask[..., cb_idx]) + else: + pseudo_beta_mask = torch.zeros_like(aatype).float() + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +@curry1 +def make_pseudo_beta(protein, prefix=''): + """Create pseudo-beta (alpha for glycine) position and mask.""" + assert prefix in ['', 'template_'] + ( + protein[prefix + 'pseudo_beta'], + protein[prefix + 'pseudo_beta_mask'], + ) = pseudo_beta_fn( + protein['template_aatype' if prefix else 'aatype'], + protein[prefix + 'all_atom_positions'], + protein['template_all_atom_mask' if prefix else 'all_atom_mask'], + ) + return protein + + +@curry1 +def add_constant_field(protein, key, value): + protein[key] = torch.tensor(value) + return protein + + +def shaped_categorical(probs, epsilon=1e-10): + ds = probs.shape + num_classes = ds[-1] + probs = torch.reshape(probs + epsilon, [-1, num_classes]) + gen = torch.Generator() + gen.manual_seed(np.random.randint(65535)) + counts = torch.multinomial(probs, 1, generator=gen) + return torch.reshape(counts, ds[:-1]) + + +def make_hhblits_profile(protein): + """Compute the HHblits MSA profile if not already present.""" + if 'hhblits_profile' in protein: + return protein + + # Compute the profile for every residue (over all MSA sequences). + msa_one_hot = one_hot(protein['msa'], 22) + + protein['hhblits_profile'] = torch.mean(msa_one_hot, dim=0) + return protein + + +def make_msa_profile(batch): + """Compute the MSA profile.""" + # Compute the profile for every residue (over all MSA sequences). + oh = one_hot(batch['msa'], 22) + mask = batch['msa_mask'][:, :, None] + oh *= mask + return oh.sum(dim=0) / (mask.sum(dim=0) + 1e-10) + + +def make_hhblits_profile_v2(protein): + """Compute the HHblits MSA profile if not already present.""" + if 'hhblits_profile' in protein: + return protein + protein['hhblits_profile'] = make_msa_profile(protein) + return protein + + +def share_mask_by_entity(mask_position, protein): # new in unifold + if 'num_sym' not in protein: + return mask_position + entity_id = protein['entity_id'] + sym_id = protein['sym_id'] + num_sym = protein['num_sym'] + unique_entity_ids = entity_id.unique() + first_sym_mask = sym_id == 1 + for cur_entity_id in unique_entity_ids: + cur_entity_mask = entity_id == cur_entity_id + cur_num_sym = int(num_sym[cur_entity_mask][0]) + if cur_num_sym > 1: + cur_sym_mask = first_sym_mask & cur_entity_mask + cur_sym_bert_mask = mask_position[:, cur_sym_mask] + mask_position[:, cur_entity_mask] = cur_sym_bert_mask.repeat( + 1, cur_num_sym) + return mask_position + + +@curry1 +def make_masked_msa(protein, + config, + replace_fraction, + gumbel_sample=False, + share_mask=False): + """Create data for BERT on raw MSA.""" + # Add a random amino acid uniformly. + random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32) + + categorical_probs = ( + config.uniform_prob * random_aa + + config.profile_prob * protein['hhblits_profile'] + + config.same_prob * one_hot(protein['msa'], 22)) + + # Put all remaining probability on [MASK] which is a new column + pad_shapes = list( + reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])) + pad_shapes[1] = 1 + mask_prob = 1.0 - config.profile_prob - config.same_prob - config.uniform_prob + assert mask_prob >= 0.0 + categorical_probs = torch.nn.functional.pad( + categorical_probs, pad_shapes, value=mask_prob) + sh = protein['msa'].shape + mask_position = torch.from_numpy(np.random.rand(*sh) < replace_fraction) + mask_position &= protein['msa_mask'].bool() + + if 'bert_mask' in protein: + mask_position &= protein['bert_mask'].bool() + + if share_mask: + mask_position = share_mask_by_entity(mask_position, protein) + if gumbel_sample: + logits = torch.log(categorical_probs + 1e-6) + bert_msa = gumbel_max_sample(logits) + else: + bert_msa = shaped_categorical(categorical_probs) + bert_msa = torch.where(mask_position, bert_msa, protein['msa']) + bert_msa *= protein['msa_mask'].long() + + # Mix real and masked MSA + protein['bert_mask'] = mask_position.to(torch.float32) + protein['true_msa'] = protein['msa'] + protein['msa'] = bert_msa + + return protein + + +@curry1 +def make_fixed_size( + protein, + shape_schema, + msa_cluster_size, + extra_msa_size, + num_res=0, + num_templates=0, +): + """Guess at the MSA and sequence dimension to make fixed size.""" + + def get_pad_size(cur_size, multiplier=4): + return max(multiplier, + ((cur_size + multiplier - 1) // multiplier) * multiplier) + + if num_res is not None: + input_num_res = ( + protein['aatype'].shape[0] + if 'aatype' in protein else protein['msa_mask'].shape[1]) + if input_num_res != num_res: + num_res = get_pad_size(input_num_res, 4) + if 'extra_msa_mask' in protein: + input_extra_msa_size = protein['extra_msa_mask'].shape[0] + if input_extra_msa_size != extra_msa_size: + extra_msa_size = get_pad_size(input_extra_msa_size, 8) + pad_size_map = { + N_RES: num_res, + N_MSA: msa_cluster_size, + N_EXTRA_MSA: extra_msa_size, + N_TPL: num_templates, + } + + for k, v in protein.items(): + # Don't transfer this to the accelerator. + if k == 'extra_cluster_assignment': + continue + shape = list(v.shape) + schema = shape_schema[k] + msg = 'Rank mismatch between shape and shape schema for' + assert len(shape) == len(schema), f'{msg} {k}: {shape} vs {schema}' + pad_size = [ + pad_size_map.get(s2, None) or s1 + for (s1, s2) in zip(shape, schema) + ] + + padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)] + padding.reverse() + padding = list(itertools.chain(*padding)) + if padding: + protein[k] = torch.nn.functional.pad(v, padding) + protein[k] = torch.reshape(protein[k], pad_size) + + return protein + + +def make_target_feat(protein): + """Create and concatenate MSA features.""" + protein['aatype'] = protein['aatype'].long() + + if 'between_segment_residues' in protein: + has_break = torch.clip( + protein['between_segment_residues'].to(torch.float32), 0, 1) + else: + has_break = torch.zeros_like(protein['aatype'], dtype=torch.float32) + if 'asym_len' in protein: + asym_len = protein['asym_len'] + entity_ends = torch.cumsum(asym_len, dim=-1)[:-1] + has_break[entity_ends] = 1.0 + has_break = has_break.float() + aatype_1hot = one_hot(protein['aatype'], 21) + target_feat = [ + torch.unsqueeze(has_break, dim=-1), + aatype_1hot, # Everyone gets the original sequence. + ] + protein['target_feat'] = torch.cat(target_feat, dim=-1) + return protein + + +def make_msa_feat(protein): + """Create and concatenate MSA features.""" + msa_1hot = one_hot(protein['msa'], 23) + has_deletion = torch.clip(protein['deletion_matrix'], 0.0, 1.0) + deletion_value = torch.atan( + protein['deletion_matrix'] / 3.0) * (2.0 / np.pi) + msa_feat = [ + msa_1hot, + torch.unsqueeze(has_deletion, dim=-1), + torch.unsqueeze(deletion_value, dim=-1), + ] + if 'cluster_profile' in protein: + deletion_mean_value = torch.atan( + protein['cluster_deletion_mean'] / 3.0) * (2.0 / np.pi) + msa_feat.extend([ + protein['cluster_profile'], + torch.unsqueeze(deletion_mean_value, dim=-1), + ]) + + if 'extra_deletion_matrix' in protein: + protein['extra_msa_has_deletion'] = torch.clip( + protein['extra_deletion_matrix'], 0.0, 1.0) + protein['extra_msa_deletion_value'] = torch.atan( + protein['extra_deletion_matrix'] / 3.0) * (2.0 / np.pi) + + protein['msa_feat'] = torch.cat(msa_feat, dim=-1) + return protein + + +def make_msa_feat_v2(batch): + """Create and concatenate MSA features.""" + msa_1hot = one_hot(batch['msa'], 23) + deletion_matrix = batch['deletion_matrix'] + has_deletion = torch.clip(deletion_matrix, 0.0, 1.0)[..., None] + deletion_value = (torch.atan(deletion_matrix / 3.0) * (2.0 / np.pi))[..., + None] + + deletion_mean_value = ( + torch.arctan(batch['cluster_deletion_mean'] / 3.0) * # noqa W504 + (2.0 / np.pi))[..., None] + + msa_feat = [ + msa_1hot, + has_deletion, + deletion_value, + batch['cluster_profile'], + deletion_mean_value, + ] + batch['msa_feat'] = torch.concat(msa_feat, dim=-1) + return batch + + +@curry1 +def make_extra_msa_feat(batch, num_extra_msa): + # 23 = 20 amino acids + 'X' for unknown + gap + bert mask + extra_msa = batch['extra_msa'][:num_extra_msa] + deletion_matrix = batch['extra_deletion_matrix'][:num_extra_msa] + has_deletion = torch.clip(deletion_matrix, 0.0, 1.0) + deletion_value = torch.atan(deletion_matrix / 3.0) * (2.0 / np.pi) + extra_msa_mask = batch['extra_msa_mask'][:num_extra_msa] + batch['extra_msa'] = extra_msa + batch['extra_msa_mask'] = extra_msa_mask + batch['extra_msa_has_deletion'] = has_deletion + batch['extra_msa_deletion_value'] = deletion_value + return batch + + +@curry1 +def select_feat(protein, feature_list): + return {k: v for k, v in protein.items() if k in feature_list} + + +def make_atom14_masks(protein): + """Construct denser atom positions (14 dimensions instead of 37).""" + + if 'atom14_atom_exists' in protein: # lazy move + return protein + + restype_atom14_to_atom37 = torch.tensor( + rc.restype_atom14_to_atom37, + dtype=torch.int64, + device=protein['aatype'].device, + ) + restype_atom37_to_atom14 = torch.tensor( + rc.restype_atom37_to_atom14, + dtype=torch.int64, + device=protein['aatype'].device, + ) + restype_atom14_mask = torch.tensor( + rc.restype_atom14_mask, + dtype=torch.float32, + device=protein['aatype'].device, + ) + restype_atom37_mask = torch.tensor( + rc.restype_atom37_mask, + dtype=torch.float32, + device=protein['aatype'].device) + + protein_aatype = protein['aatype'].long() + protein['residx_atom14_to_atom37'] = restype_atom14_to_atom37[ + protein_aatype].long() + protein['residx_atom37_to_atom14'] = restype_atom37_to_atom14[ + protein_aatype].long() + protein['atom14_atom_exists'] = restype_atom14_mask[protein_aatype] + protein['atom37_atom_exists'] = restype_atom37_mask[protein_aatype] + + return protein + + +def make_atom14_masks_np(batch): + batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray) + out = make_atom14_masks(batch) + out = tensor_tree_map(lambda t: np.array(t), out) + return out + + +def make_atom14_positions(protein): + """Constructs denser atom positions (14 dimensions instead of 37).""" + protein['aatype'] = protein['aatype'].long() + protein['all_atom_mask'] = protein['all_atom_mask'].float() + protein['all_atom_positions'] = protein['all_atom_positions'].float() + residx_atom14_mask = protein['atom14_atom_exists'] + residx_atom14_to_atom37 = protein['residx_atom14_to_atom37'] + + # Create a mask for known ground truth positions. + residx_atom14_gt_mask = residx_atom14_mask * batched_gather( + protein['all_atom_mask'], + residx_atom14_to_atom37, + dim=-1, + num_batch_dims=len(protein['all_atom_mask'].shape[:-1]), + ) + + # Gather the ground truth positions. + residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * ( + batched_gather( + protein['all_atom_positions'], + residx_atom14_to_atom37, + dim=-2, + num_batch_dims=len(protein['all_atom_positions'].shape[:-2]), + )) + + protein['atom14_atom_exists'] = residx_atom14_mask + protein['atom14_gt_exists'] = residx_atom14_gt_mask + protein['atom14_gt_positions'] = residx_atom14_gt_positions + + renaming_matrices = torch.tensor( + rc.renaming_matrices, + dtype=protein['all_atom_mask'].dtype, + device=protein['all_atom_mask'].device, + ) + + # Pick the transformation matrices for the given residue sequence + # shape (num_res, 14, 14). + renaming_transform = renaming_matrices[protein['aatype']] + + # Apply it to the ground truth positions. shape (num_res, 14, 3). + alternative_gt_positions = torch.einsum('...rac,...rab->...rbc', + residx_atom14_gt_positions, + renaming_transform) + protein['atom14_alt_gt_positions'] = alternative_gt_positions + + # Create the mask for the alternative ground truth (differs from the + # ground truth mask, if only one of the atoms in an ambiguous pair has a + # ground truth position). + alternative_gt_mask = torch.einsum('...ra,...rab->...rb', + residx_atom14_gt_mask, + renaming_transform) + protein['atom14_alt_gt_exists'] = alternative_gt_mask + + restype_atom14_is_ambiguous = torch.tensor( + rc.restype_atom14_is_ambiguous, + dtype=protein['all_atom_mask'].dtype, + device=protein['all_atom_mask'].device, + ) + # From this create an ambiguous_mask for the given sequence. + protein['atom14_atom_is_ambiguous'] = restype_atom14_is_ambiguous[ + protein['aatype']] + + return protein + + +def atom37_to_frames(protein, eps=1e-8): + # TODO: extract common part and put them into residue constants. + aatype = protein['aatype'] + all_atom_positions = protein['all_atom_positions'] + all_atom_mask = protein['all_atom_mask'] + + batch_dims = len(aatype.shape[:-1]) + + restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) + restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] + restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] + + for restype, restype_letter in enumerate(rc.restypes): + resname = rc.restype_1to3[restype_letter] + for chi_idx in range(4): + if rc.chi_angles_mask[restype][chi_idx]: + names = rc.chi_angles_atoms[resname][chi_idx] + restype_rigidgroup_base_atom_names[restype, + chi_idx + 4, :] = names[1:] + + restype_rigidgroup_mask = all_atom_mask.new_zeros( + (*aatype.shape[:-1], 21, 8), ) + restype_rigidgroup_mask[..., 0] = 1 + restype_rigidgroup_mask[..., 3] = 1 + restype_rigidgroup_mask[..., :20, + 4:] = all_atom_mask.new_tensor(rc.chi_angles_mask) + + lookuptable = rc.atom_order.copy() + lookuptable[''] = 0 + lookup = np.vectorize(lambda x: lookuptable[x]) + restype_rigidgroup_base_atom37_idx = lookup( + restype_rigidgroup_base_atom_names, ) + restype_rigidgroup_base_atom37_idx = aatype.new_tensor( + restype_rigidgroup_base_atom37_idx, ) + restype_rigidgroup_base_atom37_idx = restype_rigidgroup_base_atom37_idx.view( + *((1, ) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape) + + residx_rigidgroup_base_atom37_idx = batched_gather( + restype_rigidgroup_base_atom37_idx, + aatype, + dim=-3, + num_batch_dims=batch_dims, + ) + + base_atom_pos = batched_gather( + all_atom_positions, + residx_rigidgroup_base_atom37_idx, + dim=-2, + num_batch_dims=len(all_atom_positions.shape[:-2]), + ) + + gt_frames = Frame.from_3_points( + p_neg_x_axis=base_atom_pos[..., 0, :], + origin=base_atom_pos[..., 1, :], + p_xy_plane=base_atom_pos[..., 2, :], + eps=eps, + ) + + group_exists = batched_gather( + restype_rigidgroup_mask, + aatype, + dim=-2, + num_batch_dims=batch_dims, + ) + + gt_atoms_exist = batched_gather( + all_atom_mask, + residx_rigidgroup_base_atom37_idx, + dim=-1, + num_batch_dims=len(all_atom_mask.shape[:-1]), + ) + gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists + + rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device) + rots = torch.tile(rots, (*((1, ) * batch_dims), 8, 1, 1)) + rots[..., 0, 0, 0] = -1 + rots[..., 0, 2, 2] = -1 + rots = Rotation(mat=rots) + + gt_frames = gt_frames.compose(Frame(rots, None)) + + restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( + *((1, ) * batch_dims), 21, 8) + restype_rigidgroup_rots = torch.eye( + 3, dtype=all_atom_mask.dtype, device=aatype.device) + restype_rigidgroup_rots = torch.tile( + restype_rigidgroup_rots, + (*((1, ) * batch_dims), 21, 8, 1, 1), + ) + + for resname, _ in rc.residue_atom_renaming_swaps.items(): + restype = rc.restype_order[rc.restype_3to1[resname]] + chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) + restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 + restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 + + residx_rigidgroup_is_ambiguous = batched_gather( + restype_rigidgroup_is_ambiguous, + aatype, + dim=-2, + num_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = batched_gather( + restype_rigidgroup_rots, + aatype, + dim=-4, + num_batch_dims=batch_dims, + ) + + residx_rigidgroup_ambiguity_rot = Rotation( + mat=residx_rigidgroup_ambiguity_rot) + alt_gt_frames = gt_frames.compose( + Frame(residx_rigidgroup_ambiguity_rot, None)) + + gt_frames_tensor = gt_frames.to_tensor_4x4() + alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4() + + protein['rigidgroups_gt_frames'] = gt_frames_tensor + protein['rigidgroups_gt_exists'] = gt_exists + protein['rigidgroups_group_exists'] = group_exists + protein['rigidgroups_group_is_ambiguous'] = residx_rigidgroup_is_ambiguous + protein['rigidgroups_alt_gt_frames'] = alt_gt_frames_tensor + + return protein + + +@curry1 +def atom37_to_torsion_angles( + protein, + prefix='', +): + aatype = protein[prefix + 'aatype'] + all_atom_positions = protein[prefix + 'all_atom_positions'] + all_atom_mask = protein[prefix + 'all_atom_mask'] + if aatype.shape[-1] == 0: + base_shape = aatype.shape + protein[prefix + + 'torsion_angles_sin_cos'] = all_atom_positions.new_zeros( + *base_shape, 7, 2) + protein[prefix + + 'alt_torsion_angles_sin_cos'] = all_atom_positions.new_zeros( + *base_shape, 7, 2) + protein[prefix + 'torsion_angles_mask'] = all_atom_positions.new_zeros( + *base_shape, 7) + return protein + + aatype = torch.clamp(aatype, max=20) + + pad = all_atom_positions.new_zeros( + [*all_atom_positions.shape[:-3], 1, 37, 3]) + prev_all_atom_positions = torch.cat( + [pad, all_atom_positions[..., :-1, :, :]], dim=-3) + + pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37]) + prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) + + pre_omega_atom_pos = torch.cat( + [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]], + dim=-2, + ) + phi_atom_pos = torch.cat( + [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]], + dim=-2, + ) + psi_atom_pos = torch.cat( + [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]], + dim=-2, + ) + + pre_omega_mask = torch.prod( + prev_all_atom_mask[..., 1:3], dim=-1) * torch.prod( + all_atom_mask[..., :2], dim=-1) + phi_mask = prev_all_atom_mask[..., 2] * torch.prod( + all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + psi_mask = ( + torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) + * all_atom_mask[..., 4]) + + chi_atom_indices = torch.as_tensor( + rc.chi_atom_indices, device=aatype.device) + + atom_indices = chi_atom_indices[..., aatype, :, :] + chis_atom_pos = batched_gather(all_atom_positions, atom_indices, -2, + len(atom_indices.shape[:-2])) + + chi_angles_mask = list(rc.chi_angles_mask) + chi_angles_mask.append([0.0, 0.0, 0.0, 0.0]) + chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) + + chis_mask = chi_angles_mask[aatype, :] + + chi_angle_atoms_mask = batched_gather( + all_atom_mask, + atom_indices, + dim=-1, + num_batch_dims=len(atom_indices.shape[:-2]), + ) + chi_angle_atoms_mask = torch.prod( + chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype) + chis_mask = chis_mask * chi_angle_atoms_mask + + torsions_atom_pos = torch.cat( + [ + pre_omega_atom_pos[..., None, :, :], + phi_atom_pos[..., None, :, :], + psi_atom_pos[..., None, :, :], + chis_atom_pos, + ], + dim=-3, + ) + + torsion_angles_mask = torch.cat( + [ + pre_omega_mask[..., None], + phi_mask[..., None], + psi_mask[..., None], + chis_mask, + ], + dim=-1, + ) + + torsion_frames = Frame.from_3_points( + torsions_atom_pos[..., 1, :], + torsions_atom_pos[..., 2, :], + torsions_atom_pos[..., 0, :], + eps=1e-8, + ) + + fourth_atom_rel_pos = torsion_frames.invert().apply( + torsions_atom_pos[..., 3, :]) + + torsion_angles_sin_cos = torch.stack( + [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1) + + denom = torch.sqrt( + torch.sum( + torch.square(torsion_angles_sin_cos), + dim=-1, + dtype=torsion_angles_sin_cos.dtype, + keepdims=True, + ) + 1e-8) + torsion_angles_sin_cos = torsion_angles_sin_cos / denom + + torsion_angles_sin_cos = ( + torsion_angles_sin_cos + * all_atom_mask.new_tensor([1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0], )[ + ((None, ) * len(torsion_angles_sin_cos.shape[:-2])) + + (slice(None), None)]) + + chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( + rc.chi_pi_periodic, )[aatype, ...] + + mirror_torsion_angles = torch.cat( + [ + all_atom_mask.new_ones(*aatype.shape, 3), + 1.0 - 2.0 * chi_is_ambiguous, + ], + dim=-1, + ) + + alt_torsion_angles_sin_cos = ( + torsion_angles_sin_cos * mirror_torsion_angles[..., None]) + + if prefix == '': + # consistent to uni-fold. use [1, 0] placeholder + placeholder_torsions = torch.stack( + [ + torch.ones(torsion_angles_sin_cos.shape[:-1]), + torch.zeros(torsion_angles_sin_cos.shape[:-1]), + ], + dim=-1, + ) + torsion_angles_sin_cos = torsion_angles_sin_cos * torsion_angles_mask[ + ..., + None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + alt_torsion_angles_sin_cos = alt_torsion_angles_sin_cos * torsion_angles_mask[ + ..., + None] + placeholder_torsions * (1 - torsion_angles_mask[..., None]) + + protein[prefix + 'torsion_angles_sin_cos'] = torsion_angles_sin_cos + protein[prefix + 'alt_torsion_angles_sin_cos'] = alt_torsion_angles_sin_cos + protein[prefix + 'torsion_angles_mask'] = torsion_angles_mask + + return protein + + +def get_backbone_frames(protein): + protein['true_frame_tensor'] = protein['rigidgroups_gt_frames'][..., + 0, :, :] + protein['frame_mask'] = protein['rigidgroups_gt_exists'][..., 0] + + return protein + + +def get_chi_angles(protein): + dtype = protein['all_atom_mask'].dtype + protein['chi_angles_sin_cos'] = ( + protein['torsion_angles_sin_cos'][..., 3:, :]).to(dtype) + protein['chi_mask'] = protein['torsion_angles_mask'][..., 3:].to(dtype) + + return protein + + +@curry1 +def crop_templates( + protein, + max_templates, + subsample_templates=False, +): + if 'template_mask' in protein: + num_templates = protein['template_mask'].shape[-1] + else: + num_templates = 0 + + # don't sample when there are no templates + if num_templates > 0: + if subsample_templates: + # af2's sampling, min(4, uniform[0, n]) + max_templates = min(max_templates, + np.random.randint(0, num_templates + 1)) + template_idx = torch.tensor( + np.random.choice(num_templates, max_templates, replace=False), + dtype=torch.int64, + ) + else: + # use top templates + template_idx = torch.arange( + min(num_templates, max_templates), dtype=torch.int64) + for k, v in protein.items(): + if k.startswith('template'): + try: + v = v[template_idx] + except Exception as ex: + print(ex.__class__, ex) + print('num_templates', num_templates) + print(k, v.shape) + print('protein:', protein) + print( + 'protein_shape:', + { + k: v.shape + for k, v in protein.items() if 'shape' in dir(v) + }, + ) + protein[k] = v + + return protein + + +@curry1 +def crop_to_size_single(protein, crop_size, shape_schema, seed): + """crop to size.""" + num_res = ( + protein['aatype'].shape[0] + if 'aatype' in protein else protein['msa_mask'].shape[1]) + crop_idx = get_single_crop_idx(num_res, crop_size, seed) + protein = apply_crop_idx(protein, shape_schema, crop_idx) + return protein + + +@curry1 +def crop_to_size_multimer(protein, crop_size, shape_schema, seed, + spatial_crop_prob, ca_ca_threshold): + """crop to size.""" + with data_utils.numpy_seed(seed, key='multimer_crop'): + use_spatial_crop = np.random.rand() < spatial_crop_prob + is_distillation = 'is_distillation' in protein and protein[ + 'is_distillation'] == 1 + if is_distillation: + return crop_to_size_single( + crop_size=crop_size, shape_schema=shape_schema, seed=seed)( + protein) + elif use_spatial_crop: + crop_idx = get_spatial_crop_idx(protein, crop_size, seed, + ca_ca_threshold) + else: + crop_idx = get_contiguous_crop_idx(protein, crop_size, seed) + return apply_crop_idx(protein, shape_schema, crop_idx) + + +def get_single_crop_idx(num_res: NumpyDict, crop_size: int, + random_seed: Optional[int]) -> torch.Tensor: + + if num_res < crop_size: + return torch.arange(num_res) + with data_utils.numpy_seed(random_seed): + crop_start = int(np.random.randint(0, num_res - crop_size + 1)) + return torch.arange(crop_start, crop_start + crop_size) + + +def get_crop_sizes_each_chain( + asym_len: torch.Tensor, + crop_size: int, + random_seed: Optional[int] = None, + use_multinomial: bool = False, +) -> torch.Tensor: + """get crop sizes for contiguous crop""" + if not use_multinomial: + with data_utils.numpy_seed( + random_seed, key='multimer_contiguous_perm'): + shuffle_idx = np.random.permutation(len(asym_len)) + num_left = asym_len.sum() + num_budget = torch.tensor(crop_size) + crop_sizes = [0 for _ in asym_len] + for j, idx in enumerate(shuffle_idx): + this_len = asym_len[idx] + num_left -= this_len + # num res at most we can keep in this ent + max_size = min(num_budget, this_len) + # num res at least we shall keep in this ent + min_size = min(this_len, max(0, num_budget - num_left)) + with data_utils.numpy_seed( + random_seed, j, key='multimer_contiguous_crop_size'): + this_crop_size = int( + np.random.randint( + low=int(min_size), high=int(max_size) + 1)) + num_budget -= this_crop_size + crop_sizes[idx] = this_crop_size + crop_sizes = torch.tensor(crop_sizes) + else: # use multinomial + # TODO: better multimer + entity_probs = asym_len / torch.sum(asym_len) + crop_sizes = torch.from_numpy( + np.random.multinomial(crop_size, pvals=entity_probs)) + crop_sizes = torch.min(crop_sizes, asym_len) + return crop_sizes + + +def get_contiguous_crop_idx( + protein: NumpyDict, + crop_size: int, + random_seed: Optional[int] = None, + use_multinomial: bool = False, +) -> torch.Tensor: + + num_res = protein['aatype'].shape[0] + if num_res <= crop_size: + return torch.arange(num_res) + + assert 'asym_len' in protein + asym_len = protein['asym_len'] + + crop_sizes = get_crop_sizes_each_chain(asym_len, crop_size, random_seed, + use_multinomial) + crop_idxs = [] + asym_offset = torch.tensor(0, dtype=torch.int64) + with data_utils.numpy_seed( + random_seed, key='multimer_contiguous_crop_start_idx'): + for ll, csz in zip(asym_len, crop_sizes): + this_start = np.random.randint(0, int(ll - csz) + 1) + crop_idxs.append( + torch.arange(asym_offset + this_start, + asym_offset + this_start + csz)) + asym_offset += ll + + return torch.concat(crop_idxs) + + +def get_spatial_crop_idx( + protein: NumpyDict, + crop_size: int, + random_seed: int, + ca_ca_threshold: float, + inf: float = 3e4, +) -> List[int]: + + ca_idx = rc.atom_order['CA'] + ca_coords = protein['all_atom_positions'][..., ca_idx, :] + ca_mask = protein['all_atom_mask'][..., ca_idx].bool() + # if there are not enough atoms to construct interface, use contiguous crop + if (ca_mask.sum(dim=-1) <= 1).all(): + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + pair_mask = ca_mask[..., None] * ca_mask[..., None, :] + ca_distances = get_pairwise_distances(ca_coords) + + interface_candidates = get_interface_candidates(ca_distances, + protein['asym_id'], + pair_mask, ca_ca_threshold) + + if torch.any(interface_candidates): + with data_utils.numpy_seed(random_seed, key='multimer_spatial_crop'): + target_res = int(np.random.choice(interface_candidates)) + else: + return get_contiguous_crop_idx(protein, crop_size, random_seed) + + to_target_distances = ca_distances[target_res] + # set inf to non-position residues + to_target_distances[~ca_mask] = inf + break_tie = ( + torch.arange( + 0, + to_target_distances.shape[-1], + device=to_target_distances.device).float() * 1e-3) + to_target_distances += break_tie + ret = torch.argsort(to_target_distances)[:crop_size] + return ret.sort().values + + +def get_pairwise_distances(coords: torch.Tensor) -> torch.Tensor: + coord_diff = coords.unsqueeze(-2) - coords.unsqueeze(-3) + return torch.sqrt(torch.sum(coord_diff**2, dim=-1)) + + +def get_interface_candidates( + ca_distances: torch.Tensor, + asym_id: torch.Tensor, + pair_mask: torch.Tensor, + ca_ca_threshold, +) -> torch.Tensor: + + in_same_asym = asym_id[..., None] == asym_id[..., None, :] + # set distance in the same entity to zero + ca_distances = ca_distances * (1.0 - in_same_asym.float()) * pair_mask + cnt_interfaces = torch.sum( + (ca_distances > 0) & (ca_distances < ca_ca_threshold), dim=-1) + interface_candidates = cnt_interfaces.nonzero(as_tuple=True)[0] + return interface_candidates + + +def apply_crop_idx(protein, shape_schema, crop_idx): + cropped_protein = {} + for k, v in protein.items(): + if k not in shape_schema: # skip items with unknown shape schema + continue + for i, dim_size in enumerate(shape_schema[k]): + if dim_size == N_RES: + v = torch.index_select(v, i, crop_idx) + cropped_protein[k] = v + return cropped_protein diff --git a/modelscope/models/science/unifold/data/msa_pairing.py b/modelscope/models/science/unifold/data/msa_pairing.py new file mode 100644 index 00000000..cc65962c --- /dev/null +++ b/modelscope/models/science/unifold/data/msa_pairing.py @@ -0,0 +1,526 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pairing logic for multimer data """ + +import collections +from typing import Dict, Iterable, List, Sequence + +import numpy as np +import pandas as pd +import scipy.linalg + +from .data_ops import NumpyDict +from .residue_constants import restypes_with_x_and_gap + +MSA_GAP_IDX = restypes_with_x_and_gap.index('-') +SEQUENCE_GAP_CUTOFF = 0.5 +SEQUENCE_SIMILARITY_CUTOFF = 0.9 + +MSA_PAD_VALUES = { + 'msa_all_seq': MSA_GAP_IDX, + 'msa_mask_all_seq': 1, + 'deletion_matrix_all_seq': 0, + 'deletion_matrix_int_all_seq': 0, + 'msa': MSA_GAP_IDX, + 'msa_mask': 1, + 'deletion_matrix': 0, + 'deletion_matrix_int': 0, +} + +MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') +SEQ_FEATURES = ( + 'residue_index', + 'aatype', + 'all_atom_positions', + 'all_atom_mask', + 'seq_mask', + 'between_segment_residues', + 'has_alt_locations', + 'has_hetatoms', + 'asym_id', + 'entity_id', + 'sym_id', + 'entity_mask', + 'deletion_mean', + 'prediction_atom_mask', + 'literature_positions', + 'atom_indices_to_group_indices', + 'rigid_group_default_frame', + # zy + 'num_sym', +) +TEMPLATE_FEATURES = ( + 'template_aatype', + 'template_all_atom_positions', + 'template_all_atom_mask', +) +CHAIN_FEATURES = ('num_alignments', 'seq_length') + + +def create_paired_features(chains: Iterable[NumpyDict], ) -> List[NumpyDict]: + """Returns the original chains with paired NUM_SEQ features. + + Args: + chains: A list of feature dictionaries for each chain. + + Returns: + A list of feature dictionaries with sequence features including only + rows to be paired. + """ + chains = list(chains) + chain_keys = chains[0].keys() + + if len(chains) < 2: + return chains + else: + updated_chains = [] + paired_chains_to_paired_row_indices = pair_sequences(chains) + paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices) + + for chain_num, chain in enumerate(chains): + new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} + for feature_name in chain_keys: + if feature_name.endswith('_all_seq'): + feats_padded = pad_features(chain[feature_name], + feature_name) + new_chain[feature_name] = feats_padded[ + paired_rows[:, chain_num]] + new_chain['num_alignments_all_seq'] = np.asarray( + len(paired_rows[:, chain_num])) + updated_chains.append(new_chain) + return updated_chains + + +def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: + """Add a 'padding' row at the end of the features list. + + The padding row will be selected as a 'paired' row in the case of partial + alignment - for the chain that doesn't have paired alignment. + + Args: + feature: The feature to be padded. + feature_name: The name of the feature to be padded. + + Returns: + The feature with an additional padding row. + """ + assert feature.dtype != np.dtype(np.string_) + if feature_name in ( + 'msa_all_seq', + 'msa_mask_all_seq', + 'deletion_matrix_all_seq', + 'deletion_matrix_int_all_seq', + ): + num_res = feature.shape[1] + padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], + feature.dtype) + elif feature_name == 'msa_species_identifiers_all_seq': + padding = [b''] + else: + return feature + feats_padded = np.concatenate([feature, padding], axis=0) + return feats_padded + + +def _make_msa_df(chain_features: NumpyDict) -> pd.DataFrame: + """Makes dataframe with msa features needed for msa pairing.""" + chain_msa = chain_features['msa_all_seq'] + query_seq = chain_msa[0] + per_seq_similarity = np.sum( + query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) + per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) + msa_df = pd.DataFrame({ + 'msa_species_identifiers': + chain_features['msa_species_identifiers_all_seq'], + 'msa_row': + np.arange(len(chain_features['msa_species_identifiers_all_seq'])), + 'msa_similarity': + per_seq_similarity, + 'gap': + per_seq_gap, + }) + return msa_df + + +def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: + """Creates mapping from species to msa dataframe of that species.""" + species_lookup = {} + for species, species_df in msa_df.groupby('msa_species_identifiers'): + species_lookup[species] = species_df + return species_lookup + + +def _match_rows_by_sequence_similarity( + this_species_msa_dfs: List[pd.DataFrame], ) -> List[List[int]]: # noqa + """Finds MSA sequence pairings across chains based on sequence similarity. + + Each chain's MSA sequences are first sorted by their sequence similarity to + their respective target sequence. The sequences are then paired, starting + from the sequences most similar to their target sequence. + + Args: + this_species_msa_dfs: a list of dataframes containing MSA features for + sequences for a specific species. + + Returns: + A list of lists, each containing M indices corresponding to paired MSA rows, + where M is the number of chains. + """ + all_paired_msa_rows = [] + + num_seqs = [ + len(species_df) for species_df in this_species_msa_dfs + if species_df is not None + ] + take_num_seqs = np.min(num_seqs) + + # sort_by_similarity = lambda x: x.sort_values( + # 'msa_similarity', axis=0, ascending=False) + + def sort_by_similarity(x): + return x.sort_values('msa_similarity', axis=0, ascending=False) + + for species_df in this_species_msa_dfs: + if species_df is not None: + species_df_sorted = sort_by_similarity(species_df) + msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values + else: + msa_rows = [-1] * take_num_seqs # take the last 'padding' row + all_paired_msa_rows.append(msa_rows) + all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) + return all_paired_msa_rows + + +def pair_sequences(examples: List[NumpyDict]) -> Dict[int, np.ndarray]: + """Returns indices for paired MSA sequences across chains.""" + + num_examples = len(examples) + + all_chain_species_dict = [] + common_species = set() + for chain_features in examples: + msa_df = _make_msa_df(chain_features) + species_dict = _create_species_dict(msa_df) + all_chain_species_dict.append(species_dict) + common_species.update(set(species_dict)) + + common_species = sorted(common_species) + common_species.remove(b'') # Remove target sequence species. + + all_paired_msa_rows = [np.zeros(len(examples), int)] + all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} + all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] + + for species in common_species: + if not species: + continue + this_species_msa_dfs = [] + species_dfs_present = 0 + for species_dict in all_chain_species_dict: + if species in species_dict: + this_species_msa_dfs.append(species_dict[species]) + species_dfs_present += 1 + else: + this_species_msa_dfs.append(None) + + # Skip species that are present in only one chain. + if species_dfs_present <= 1: + continue + + if np.any( + np.array([ + len(species_df) for species_df in this_species_msa_dfs + if isinstance(species_df, pd.DataFrame) + ]) > 600): + continue + + paired_msa_rows = _match_rows_by_sequence_similarity( + this_species_msa_dfs) + all_paired_msa_rows.extend(paired_msa_rows) + all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) + all_paired_msa_rows_dict = { + num_examples: np.array(paired_msa_rows) + for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() + } + return all_paired_msa_rows_dict + + +def reorder_paired_rows( + all_paired_msa_rows_dict: Dict[int, np.ndarray]) -> np.ndarray: + """Creates a list of indices of paired MSA rows across chains. + + Args: + all_paired_msa_rows_dict: a mapping from the number of paired chains to the + paired indices. + + Returns: + a list of lists, each containing indices of paired MSA rows across chains. + The paired-index lists are ordered by: + 1) the number of chains in the paired alignment, i.e, all-chain pairings + will come first. + 2) e-values + """ + all_paired_msa_rows = [] + + for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): + paired_rows = all_paired_msa_rows_dict[num_pairings] + paired_rows_product = np.abs( + np.array( + [np.prod(rows.astype(np.float64)) for rows in paired_rows])) + paired_rows_sort_index = np.argsort(paired_rows_product) + all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) + + return np.array(all_paired_msa_rows) + + +def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: + """Like scipy.linalg.block_diag but with an optional padding value.""" + ones_arrs = [np.ones_like(x) for x in arrs] + off_diag_mask = 1 - scipy.linalg.block_diag(*ones_arrs) + diag = scipy.linalg.block_diag(*arrs) + diag += (off_diag_mask * pad_value).astype(diag.dtype) + return diag + + +def _correct_post_merged_feats(np_example: NumpyDict, + np_chains_list: Sequence[NumpyDict], + pair_msa_sequences: bool) -> NumpyDict: + """Adds features that need to be computed/recomputed post merging.""" + + np_example['seq_length'] = np.asarray( + np_example['aatype'].shape[0], dtype=np.int32) + np_example['num_alignments'] = np.asarray( + np_example['msa'].shape[0], dtype=np.int32) + + if not pair_msa_sequences: + # Generate a bias that is 1 for the first row of every block in the + # block diagonal MSA - i.e. make sure the cluster stack always includes + # the query sequences for each chain (since the first row is the query + # sequence). + cluster_bias_masks = [] + for chain in np_chains_list: + mask = np.zeros(chain['msa'].shape[0]) + mask[0] = 1 + cluster_bias_masks.append(mask) + np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [ + np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list + ] + + np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0) + else: + np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) + np_example['cluster_bias_mask'][0] = 1 + + # Initialize Bert mask with masked out off diagonals. + msa_masks = [ + np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list + ] + msa_masks_all_seq = [ + np.ones(x['msa_all_seq'].shape, dtype=np.int8) + for x in np_chains_list + ] + + msa_mask_block_diag = block_diag(*msa_masks, pad_value=0) + msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) + np_example['bert_mask'] = np.concatenate( + [msa_mask_all_seq, msa_mask_block_diag], axis=0) + return np_example + + +def _pad_templates(chains: Sequence[NumpyDict], + max_templates: int) -> Sequence[NumpyDict]: + """For each chain pad the number of templates to a fixed size. + + Args: + chains: A list of protein chains. + max_templates: Each chain will be padded to have this many templates. + + Returns: + The list of chains, updated to have template features padded to + max_templates. + """ + for chain in chains: + for k, v in chain.items(): + if k in TEMPLATE_FEATURES: + padding = np.zeros_like(v.shape) + padding[0] = max_templates - v.shape[0] + padding = [(0, p) for p in padding] + chain[k] = np.pad(v, padding, mode='constant') + return chains + + +def _merge_features_from_multiple_chains( + chains: Sequence[NumpyDict], pair_msa_sequences: bool) -> NumpyDict: + """Merge features from multiple chains. + + Args: + chains: A list of feature dictionaries that we want to merge. + pair_msa_sequences: Whether to concatenate MSA features along the + num_res dimension (if True), or to block diagonalize them (if False). + + Returns: + A feature dictionary for the merged example. + """ + merged_example = {} + for feature_name in chains[0]: + feats = [x[feature_name] for x in chains] + feature_name_split = feature_name.split('_all_seq')[0] + if feature_name_split in MSA_FEATURES: + if pair_msa_sequences or '_all_seq' in feature_name: + merged_example[feature_name] = np.concatenate(feats, axis=1) + if feature_name_split == 'msa': + merged_example['msa_chains_all_seq'] = np.ones( + merged_example[feature_name].shape[0]).reshape(-1, 1) + else: + merged_example[feature_name] = block_diag( + *feats, pad_value=MSA_PAD_VALUES[feature_name]) + if feature_name_split == 'msa': + msa_chains = [] + for i, feat in enumerate(feats): + cur_shape = feat.shape[0] + vals = np.ones(cur_shape) * (i + 2) + msa_chains.append(vals) + merged_example['msa_chains'] = np.concatenate( + msa_chains).reshape(-1, 1) + elif feature_name_split in SEQ_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=0) + elif feature_name_split in TEMPLATE_FEATURES: + merged_example[feature_name] = np.concatenate(feats, axis=1) + elif feature_name_split in CHAIN_FEATURES: + merged_example[feature_name] = np.sum(feats).astype(np.int32) + else: + merged_example[feature_name] = feats[0] + return merged_example + + +def _merge_homomers_dense_msa( + chains: Iterable[NumpyDict]) -> Sequence[NumpyDict]: + """Merge all identical chains, making the resulting MSA dense. + + Args: + chains: An iterable of features for each chain. + + Returns: + A list of feature dictionaries. All features with the same entity_id + will be merged - MSA features will be concatenated along the num_res + dimension - making them dense. + """ + entity_chains = collections.defaultdict(list) + for chain in chains: + entity_id = chain['entity_id'][0] + entity_chains[entity_id].append(chain) + + grouped_chains = [] + for entity_id in sorted(entity_chains): + chains = entity_chains[entity_id] + grouped_chains.append(chains) + chains = [ + _merge_features_from_multiple_chains(chains, pair_msa_sequences=True) + for chains in grouped_chains + ] + return chains + + +def _concatenate_paired_and_unpaired_features(example: NumpyDict) -> NumpyDict: + """Merges paired and block-diagonalised features.""" + features = MSA_FEATURES + ('msa_chains', ) + for feature_name in features: + if feature_name in example: + feat = example[feature_name] + feat_all_seq = example[feature_name + '_all_seq'] + try: + merged_feat = np.concatenate([feat_all_seq, feat], axis=0) + except Exception as ex: + raise Exception( + 'concat failed.', + feature_name, + feat_all_seq.shape, + feat.shape, + ex.__class__, + ex, + ) + example[feature_name] = merged_feat + example['num_alignments'] = np.array( + example['msa'].shape[0], dtype=np.int32) + return example + + +def merge_chain_features(np_chains_list: List[NumpyDict], + pair_msa_sequences: bool, + max_templates: int) -> NumpyDict: + """Merges features for multiple chains to single FeatureDict. + + Args: + np_chains_list: List of FeatureDicts for each chain. + pair_msa_sequences: Whether to merge paired MSAs. + max_templates: The maximum number of templates to include. + + Returns: + Single FeatureDict for entire complex. + """ + np_chains_list = _pad_templates( + np_chains_list, max_templates=max_templates) + np_chains_list = _merge_homomers_dense_msa(np_chains_list) + # Unpaired MSA features will be always block-diagonalised; paired MSA + # features will be concatenated. + np_example = _merge_features_from_multiple_chains( + np_chains_list, pair_msa_sequences=False) + if pair_msa_sequences: + np_example = _concatenate_paired_and_unpaired_features(np_example) + np_example = _correct_post_merged_feats( + np_example=np_example, + np_chains_list=np_chains_list, + pair_msa_sequences=pair_msa_sequences, + ) + + return np_example + + +def deduplicate_unpaired_sequences( + np_chains: List[NumpyDict]) -> List[NumpyDict]: + """Removes unpaired sequences which duplicate a paired sequence.""" + + feature_names = np_chains[0].keys() + msa_features = MSA_FEATURES + cache_msa_features = {} + for chain in np_chains: + entity_id = int(chain['entity_id'][0]) + if entity_id not in cache_msa_features: + sequence_set = set(s.tobytes() for s in chain['msa_all_seq']) + keep_rows = [] + # Go through unpaired MSA seqs and remove any rows that correspond to the + # sequences that are already present in the paired MSA. + for row_num, seq in enumerate(chain['msa']): + if seq.tobytes() not in sequence_set: + keep_rows.append(row_num) + new_msa_features = {} + for feature_name in feature_names: + if feature_name in msa_features: + if keep_rows: + new_msa_features[feature_name] = chain[feature_name][ + keep_rows] + else: + new_shape = list(chain[feature_name].shape) + new_shape[0] = 0 + new_msa_features[feature_name] = np.zeros( + new_shape, dtype=chain[feature_name].dtype) + cache_msa_features[entity_id] = new_msa_features + for feature_name in cache_msa_features[entity_id]: + chain[feature_name] = cache_msa_features[entity_id][feature_name] + chain['num_alignments'] = np.array( + chain['msa'].shape[0], dtype=np.int32) + return np_chains diff --git a/modelscope/models/science/unifold/data/process.py b/modelscope/models/science/unifold/data/process.py new file mode 100644 index 00000000..3987cb1c --- /dev/null +++ b/modelscope/models/science/unifold/data/process.py @@ -0,0 +1,264 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Optional + +import numpy as np +import torch + +from modelscope.models.science.unifold.data import data_ops + + +def nonensembled_fns(common_cfg, mode_cfg): + """Input pipeline data transformers that are not ensembled.""" + v2_feature = common_cfg.v2_feature + operators = [] + if mode_cfg.random_delete_msa: + operators.append( + data_ops.random_delete_msa(common_cfg.random_delete_msa)) + operators.extend([ + data_ops.cast_to_64bit_ints, + data_ops.correct_msa_restypes, + data_ops.squeeze_features, + data_ops.randomly_replace_msa_with_unknown(0.0), + data_ops.make_seq_mask, + data_ops.make_msa_mask, + ]) + operators.append(data_ops.make_hhblits_profile_v2 + if v2_feature else data_ops.make_hhblits_profile) + if common_cfg.use_templates: + operators.extend([ + data_ops.make_template_mask, + data_ops.make_pseudo_beta('template_'), + ]) + operators.append( + data_ops.crop_templates( + max_templates=mode_cfg.max_templates, + subsample_templates=mode_cfg.subsample_templates, + )) + + if common_cfg.use_template_torsion_angles: + operators.extend([ + data_ops.atom37_to_torsion_angles('template_'), + ]) + + operators.append(data_ops.make_atom14_masks) + operators.append(data_ops.make_target_feat) + + return operators + + +def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed): + operators = [] + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates + else: + pad_msa_clusters = mode_cfg.max_msa_clusters + crop_feats = dict(common_cfg.features) + if mode_cfg.fixed_size: + if mode_cfg.crop: + if common_cfg.is_multimer: + crop_fn = data_ops.crop_to_size_multimer( + crop_size=mode_cfg.crop_size, + shape_schema=crop_feats, + seed=crop_and_fix_size_seed, + spatial_crop_prob=mode_cfg.spatial_crop_prob, + ca_ca_threshold=mode_cfg.ca_ca_threshold, + ) + else: + crop_fn = data_ops.crop_to_size_single( + crop_size=mode_cfg.crop_size, + shape_schema=crop_feats, + seed=crop_and_fix_size_seed, + ) + operators.append(crop_fn) + + operators.append(data_ops.select_feat(crop_feats)) + + operators.append( + data_ops.make_fixed_size( + crop_feats, + pad_msa_clusters, + common_cfg.max_extra_msa, + mode_cfg.crop_size, + mode_cfg.max_templates, + )) + return operators + + +def ensembled_fns(common_cfg, mode_cfg): + """Input pipeline data transformers that can be ensembled and averaged.""" + operators = [] + multimer_mode = common_cfg.is_multimer + v2_feature = common_cfg.v2_feature + # multimer don't use block delete msa + if mode_cfg.block_delete_msa and not multimer_mode: + operators.append( + data_ops.block_delete_msa(common_cfg.block_delete_msa)) + if 'max_distillation_msa_clusters' in mode_cfg: + operators.append( + data_ops.sample_msa_distillation( + mode_cfg.max_distillation_msa_clusters)) + + if common_cfg.reduce_msa_clusters_by_max_templates: + pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates + else: + pad_msa_clusters = mode_cfg.max_msa_clusters + + max_msa_clusters = pad_msa_clusters + max_extra_msa = common_cfg.max_extra_msa + + assert common_cfg.resample_msa_in_recycling + gumbel_sample = common_cfg.gumbel_sample + operators.append( + data_ops.sample_msa( + max_msa_clusters, + keep_extra=True, + gumbel_sample=gumbel_sample, + biased_msa_by_chain=mode_cfg.biased_msa_by_chain, + )) + + if 'masked_msa' in common_cfg: + # Masked MSA should come *before* MSA clustering so that + # the clustering and full MSA profile do not leak information about + # the masked locations and secret corrupted locations. + operators.append( + data_ops.make_masked_msa( + common_cfg.masked_msa, + mode_cfg.masked_msa_replace_fraction, + gumbel_sample=gumbel_sample, + share_mask=mode_cfg.share_mask, + )) + + if common_cfg.msa_cluster_features: + if v2_feature: + operators.append(data_ops.nearest_neighbor_clusters_v2()) + else: + operators.append(data_ops.nearest_neighbor_clusters()) + operators.append(data_ops.summarize_clusters) + + if v2_feature: + operators.append(data_ops.make_msa_feat_v2) + else: + operators.append(data_ops.make_msa_feat) + # Crop after creating the cluster profiles. + if max_extra_msa: + if v2_feature: + operators.append(data_ops.make_extra_msa_feat(max_extra_msa)) + else: + operators.append(data_ops.crop_extra_msa(max_extra_msa)) + else: + operators.append(data_ops.delete_extra_msa) + # operators.append(data_operators.select_feat(common_cfg.recycling_features)) + return operators + + +def process_features(tensors, common_cfg, mode_cfg): + """Based on the config, apply filters and transformations to the data.""" + is_distillation = bool(tensors.get('is_distillation', 0)) + multimer_mode = common_cfg.is_multimer + crop_and_fix_size_seed = int(tensors['crop_and_fix_size_seed']) + crop_fn = crop_and_fix_size_fns( + common_cfg, + mode_cfg, + crop_and_fix_size_seed, + ) + + def wrap_ensemble_fn(data, i): + """Function to be mapped over the ensemble dimension.""" + d = data.copy() + fns = ensembled_fns( + common_cfg, + mode_cfg, + ) + new_d = compose(fns)(d) + if not multimer_mode or is_distillation: + new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d) + return compose(crop_fn)(new_d) + else: # select after crop for spatial cropping + d = compose(crop_fn)(d) + d = data_ops.select_feat(common_cfg.recycling_features)(d) + return d + + nonensembled = nonensembled_fns(common_cfg, mode_cfg) + + if mode_cfg.supervised and (not multimer_mode or is_distillation): + nonensembled.extend(label_transform_fn()) + + tensors = compose(nonensembled)(tensors) + + num_recycling = int(tensors['num_recycling_iters']) + 1 + num_ensembles = mode_cfg.num_ensembles + + ensemble_tensors = map_fn( + lambda x: wrap_ensemble_fn(tensors, x), + torch.arange(num_recycling * num_ensembles), + ) + tensors = compose(crop_fn)(tensors) + # add a dummy dim to align with recycling features + tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors} + tensors.update(ensemble_tensors) + return tensors + + +@data_ops.curry1 +def compose(x, fs): + for f in fs: + x = f(x) + return x + + +def pad_then_stack(values, ): + if len(values[0].shape) >= 1: + size = max(v.shape[0] for v in values) + new_values = [] + for v in values: + if v.shape[0] < size: + res = values[0].new_zeros(size, *v.shape[1:]) + res[:v.shape[0], ...] = v + else: + res = v + new_values.append(res) + else: + new_values = values + return torch.stack(new_values, dim=0) + + +def map_fn(fun, x): + ensembles = [fun(elem) for elem in x] + features = ensembles[0].keys() + ensembled_dict = {} + for feat in features: + ensembled_dict[feat] = pad_then_stack( + [dict_i[feat] for dict_i in ensembles]) + return ensembled_dict + + +def process_single_label(label: dict, + num_ensemble: Optional[int] = None) -> dict: + assert 'aatype' in label + assert 'all_atom_positions' in label + assert 'all_atom_mask' in label + label = compose(label_transform_fn())(label) + if num_ensemble is not None: + label = { + k: torch.stack([v for _ in range(num_ensemble)]) + for k, v in label.items() + } + return label + + +def process_labels(labels_list, num_ensemble: Optional[int] = None): + return [process_single_label(ll, num_ensemble) for ll in labels_list] + + +def label_transform_fn(): + return [ + data_ops.make_atom14_masks, + data_ops.make_atom14_positions, + data_ops.atom37_to_frames, + data_ops.atom37_to_torsion_angles(''), + data_ops.make_pseudo_beta(''), + data_ops.get_backbone_frames, + data_ops.get_chi_angles, + ] diff --git a/modelscope/models/science/unifold/data/process_multimer.py b/modelscope/models/science/unifold/data/process_multimer.py new file mode 100644 index 00000000..04572d2d --- /dev/null +++ b/modelscope/models/science/unifold/data/process_multimer.py @@ -0,0 +1,417 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Feature processing logic for multimer data """ + +import collections +from typing import Iterable, List, MutableMapping + +import numpy as np + +from modelscope.models.science.unifold.data import (msa_pairing, + residue_constants) +from .utils import correct_template_restypes + +FeatureDict = MutableMapping[str, np.ndarray] + +REQUIRED_FEATURES = frozenset({ + 'aatype', + 'all_atom_mask', + 'all_atom_positions', + 'all_chains_entity_ids', + 'all_crops_all_chains_mask', + 'all_crops_all_chains_positions', + 'all_crops_all_chains_residue_ids', + 'assembly_num_chains', + 'asym_id', + 'bert_mask', + 'cluster_bias_mask', + 'deletion_matrix', + 'deletion_mean', + 'entity_id', + 'entity_mask', + 'mem_peak', + 'msa', + 'msa_mask', + 'num_alignments', + 'num_templates', + 'queue_size', + 'residue_index', + 'resolution', + 'seq_length', + 'seq_mask', + 'sym_id', + 'template_aatype', + 'template_all_atom_mask', + 'template_all_atom_positions', + # zy added: + 'asym_len', + 'template_sum_probs', + 'num_sym', + 'msa_chains', +}) + +MAX_TEMPLATES = 4 +MSA_CROP_SIZE = 2048 + + +def _is_homomer_or_monomer(chains: Iterable[FeatureDict]) -> bool: + """Checks if a list of chains represents a homomer/monomer example.""" + # Note that an entity_id of 0 indicates padding. + num_unique_chains = len( + np.unique( + np.concatenate([ + np.unique(chain['entity_id'][chain['entity_id'] > 0]) + for chain in chains + ]))) + return num_unique_chains == 1 + + +def pair_and_merge( + all_chain_features: MutableMapping[str, FeatureDict]) -> FeatureDict: + """Runs processing on features to augment, pair and merge. + + Args: + all_chain_features: A MutableMap of dictionaries of features for each chain. + + Returns: + A dictionary of features. + """ + + process_unmerged_features(all_chain_features) + + np_chains_list = all_chain_features + + pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list) + + if pair_msa_sequences: + np_chains_list = msa_pairing.create_paired_features( + chains=np_chains_list) + np_chains_list = msa_pairing.deduplicate_unpaired_sequences( + np_chains_list) + np_chains_list = crop_chains( + np_chains_list, + msa_crop_size=MSA_CROP_SIZE, + pair_msa_sequences=pair_msa_sequences, + max_templates=MAX_TEMPLATES, + ) + np_example = msa_pairing.merge_chain_features( + np_chains_list=np_chains_list, + pair_msa_sequences=pair_msa_sequences, + max_templates=MAX_TEMPLATES, + ) + np_example = process_final(np_example) + return np_example + + +def crop_chains( + chains_list: List[FeatureDict], + msa_crop_size: int, + pair_msa_sequences: bool, + max_templates: int, +) -> List[FeatureDict]: + """Crops the MSAs for a set of chains. + + Args: + chains_list: A list of chains to be cropped. + msa_crop_size: The total number of sequences to crop from the MSA. + pair_msa_sequences: Whether we are operating in sequence-pairing mode. + max_templates: The maximum templates to use per chain. + + Returns: + The chains cropped. + """ + + # Apply the cropping. + cropped_chains = [] + for chain in chains_list: + cropped_chain = _crop_single_chain( + chain, + msa_crop_size=msa_crop_size, + pair_msa_sequences=pair_msa_sequences, + max_templates=max_templates, + ) + cropped_chains.append(cropped_chain) + + return cropped_chains + + +def _crop_single_chain(chain: FeatureDict, msa_crop_size: int, + pair_msa_sequences: bool, + max_templates: int) -> FeatureDict: + """Crops msa sequences to `msa_crop_size`.""" + msa_size = chain['num_alignments'] + + if pair_msa_sequences: + msa_size_all_seq = chain['num_alignments_all_seq'] + msa_crop_size_all_seq = np.minimum(msa_size_all_seq, + msa_crop_size // 2) + + # We reduce the number of un-paired sequences, by the number of times a + # sequence from this chain's MSA is included in the paired MSA. This keeps + # the MSA size for each chain roughly constant. + msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :] + num_non_gapped_pairs = np.sum( + np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)) + num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, + msa_crop_size_all_seq) + + # Restrict the unpaired crop size so that paired+unpaired sequences do not + # exceed msa_seqs_per_chain for each chain. + max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) + msa_crop_size = np.minimum(msa_size, max_msa_crop_size) + else: + msa_crop_size = np.minimum(msa_size, msa_crop_size) + + include_templates = 'template_aatype' in chain and max_templates + if include_templates: + num_templates = chain['template_aatype'].shape[0] + templates_crop_size = np.minimum(num_templates, max_templates) + + for k in chain: + k_split = k.split('_all_seq')[0] + if k_split in msa_pairing.TEMPLATE_FEATURES: + chain[k] = chain[k][:templates_crop_size, :] + elif k_split in msa_pairing.MSA_FEATURES: + if '_all_seq' in k and pair_msa_sequences: + chain[k] = chain[k][:msa_crop_size_all_seq, :] + else: + chain[k] = chain[k][:msa_crop_size, :] + + chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32) + if include_templates: + chain['num_templates'] = np.asarray( + templates_crop_size, dtype=np.int32) + if pair_msa_sequences: + chain['num_alignments_all_seq'] = np.asarray( + msa_crop_size_all_seq, dtype=np.int32) + return chain + + +def process_final(np_example: FeatureDict) -> FeatureDict: + """Final processing steps in data pipeline, after merging and pairing.""" + np_example = _make_seq_mask(np_example) + np_example = _make_msa_mask(np_example) + np_example = _filter_features(np_example) + return np_example + + +def _make_seq_mask(np_example): + np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32) + return np_example + + +def _make_msa_mask(np_example): + """Mask features are all ones, but will later be zero-padded.""" + + np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.int8) + + seq_mask = (np_example['entity_id'] > 0).astype(np.int8) + np_example['msa_mask'] *= seq_mask[None] + + return np_example + + +def _filter_features(np_example: FeatureDict) -> FeatureDict: + """Filters features of example to only those requested.""" + return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES} + + +def process_unmerged_features(all_chain_features: MutableMapping[str, + FeatureDict]): + """Postprocessing stage for per-chain features before merging.""" + num_chains = len(all_chain_features) + for chain_features in all_chain_features: + # Convert deletion matrices to float. + if 'deletion_matrix_int' in chain_features: + chain_features['deletion_matrix'] = np.asarray( + chain_features.pop('deletion_matrix_int'), dtype=np.float32) + if 'deletion_matrix_int_all_seq' in chain_features: + chain_features['deletion_matrix_all_seq'] = np.asarray( + chain_features.pop('deletion_matrix_int_all_seq'), + dtype=np.float32) + + chain_features['deletion_mean'] = np.mean( + chain_features['deletion_matrix'], axis=0) + + if 'all_atom_positions' not in chain_features: + # Add all_atom_mask and dummy all_atom_positions based on aatype. + all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ + chain_features['aatype']] + chain_features['all_atom_mask'] = all_atom_mask + chain_features['all_atom_positions'] = np.zeros( + list(all_atom_mask.shape) + [3]) + + # Add assembly_num_chains. + chain_features['assembly_num_chains'] = np.asarray(num_chains) + + # Add entity_mask. + for chain_features in all_chain_features: + chain_features['entity_mask'] = ( + chain_features['entity_id'] != # noqa W504 + 0).astype(np.int32) + + +def empty_template_feats(n_res): + return { + 'template_aatype': + np.zeros((0, n_res)).astype(np.int64), + 'template_all_atom_positions': + np.zeros((0, n_res, 37, 3)).astype(np.float32), + 'template_sum_probs': + np.zeros((0, 1)).astype(np.float32), + 'template_all_atom_mask': + np.zeros((0, n_res, 37)).astype(np.float32), + } + + +def convert_monomer_features(monomer_features: FeatureDict) -> FeatureDict: + """Reshapes and modifies monomer features for multimer models.""" + if monomer_features['template_aatype'].shape[0] == 0: + monomer_features.update( + empty_template_feats(monomer_features['aatype'].shape[0])) + converted = {} + unnecessary_leading_dim_feats = { + 'sequence', + 'domain_name', + 'num_alignments', + 'seq_length', + } + for feature_name, feature in monomer_features.items(): + if feature_name in unnecessary_leading_dim_feats: + # asarray ensures it's a np.ndarray. + feature = np.asarray(feature[0], dtype=feature.dtype) + elif feature_name == 'aatype': + # The multimer model performs the one-hot operation itself. + feature = np.argmax(feature, axis=-1).astype(np.int32) + elif feature_name == 'template_aatype': + if feature.shape[0] > 0: + feature = correct_template_restypes(feature) + elif feature_name == 'template_all_atom_masks': + feature_name = 'template_all_atom_mask' + elif feature_name == 'msa': + feature = feature.astype(np.uint8) + + if feature_name.endswith('_mask'): + feature = feature.astype(np.float32) + + converted[feature_name] = feature + + if 'deletion_matrix_int' in monomer_features: + monomer_features['deletion_matrix'] = monomer_features.pop( + 'deletion_matrix_int').astype(np.float32) + + converted.pop( + 'template_sum_probs' + ) # zy: this input is checked to be dirty in shape. TODO: figure out why and make it right. + return converted + + +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +def add_assembly_features(all_chain_features, ): + """Add features to distinguish between chains. + + Args: + all_chain_features: A dictionary which maps chain_id to a dictionary of + features for each chain. + + Returns: + all_chain_features: A dictionary which maps strings of the form + `_` to the corresponding chain features. E.g. two + chains from a homodimer would have keys A_1 and A_2. Two chains from a + heterodimer would have keys A_1 and B_1. + """ + # Group the chains by sequence + seq_to_entity_id = {} + grouped_chains = collections.defaultdict(list) + for chain_features in all_chain_features: + assert 'sequence' in chain_features + seq = str(chain_features['sequence']) + if seq not in seq_to_entity_id: + seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 + grouped_chains[seq_to_entity_id[seq]].append(chain_features) + + new_all_chain_features = [] + chain_id = 1 + for entity_id, group_chain_features in grouped_chains.items(): + num_sym = len(group_chain_features) # zy + for sym_id, chain_features in enumerate(group_chain_features, start=1): + seq_length = chain_features['seq_length'] + chain_features['asym_id'] = chain_id * np.ones(seq_length) + chain_features['sym_id'] = sym_id * np.ones(seq_length) + chain_features['entity_id'] = entity_id * np.ones(seq_length) + chain_features['num_sym'] = num_sym * np.ones(seq_length) + chain_id += 1 + new_all_chain_features.append(chain_features) + + return new_all_chain_features + + +def pad_msa(np_example, min_num_seq): + np_example = dict(np_example) + num_seq = np_example['msa'].shape[0] + if num_seq < min_num_seq: + for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask', + 'msa_chains'): + np_example[feat] = np.pad(np_example[feat], + ((0, min_num_seq - num_seq), (0, 0))) + np_example['cluster_bias_mask'] = np.pad( + np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq), )) + return np_example + + +def post_process(np_example): + np_example = pad_msa(np_example, 512) + no_dim_keys = [ + 'num_alignments', + 'assembly_num_chains', + 'num_templates', + 'seq_length', + 'resolution', + ] + for k in no_dim_keys: + if k in np_example: + np_example[k] = np_example[k].reshape(-1) + return np_example + + +def merge_msas(msa, del_mat, new_msa, new_del_mat): + cur_msa_set = set([tuple(m) for m in msa]) + new_rows = [] + for i, s in enumerate(new_msa): + if tuple(s) not in cur_msa_set: + new_rows.append(i) + ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0) + ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0) + return ret_msa, ret_del_mat diff --git a/modelscope/models/science/unifold/data/protein.py b/modelscope/models/science/unifold/data/protein.py new file mode 100644 index 00000000..42308d04 --- /dev/null +++ b/modelscope/models/science/unifold/data/protein.py @@ -0,0 +1,322 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Protein data type.""" +import dataclasses +import io +from typing import Any, Mapping, Optional + +import numpy as np +from Bio.PDB import PDBParser + +from modelscope.models.science.unifold.data import residue_constants + +FeatureDict = Mapping[str, np.ndarray] +ModelOutput = Mapping[str, Any] # Is a nested dict. + +# Complete sequence of chain IDs supported by the PDB format. +PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' +PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. + + +@dataclasses.dataclass(frozen=True) +class Protein: + """Protein structure representation.""" + + # Cartesian coordinates of atoms in angstroms. The atom types correspond to + # residue_constants.atom_types, i.e. the first three are N, CA, CB. + atom_positions: np.ndarray # [num_res, num_atom_type, 3] + + # Amino-acid type for each residue represented as an integer between 0 and + # 20, where 20 is 'X'. + aatype: np.ndarray # [num_res] + + # Binary float mask to indicate presence of a particular atom. 1.0 if an atom + # is present and 0.0 if not. This should be used for loss masking. + atom_mask: np.ndarray # [num_res, num_atom_type] + + # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. + residue_index: np.ndarray # [num_res] + + # 0-indexed number corresponding to the chain in the protein that this residue + # belongs to. + chain_index: np.ndarray # [num_res] + + # B-factors, or temperature factors, of each residue (in sq. angstroms units), + # representing the displacement of the residue from its ground truth mean + # value. + b_factors: np.ndarray # [num_res, num_atom_type] + + def __post_init__(self): + if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: + raise ValueError( + f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' + 'because these cannot be written to PDB format.') + + +def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: + """Takes a PDB string and constructs a Protein object. + + WARNING: All non-standard residue types will be converted into UNK. All + non-standard atoms will be ignored. + + Args: + pdb_str: The contents of the pdb file + chain_id: If chain_id is specified (e.g. A), then only that chain + is parsed. Otherwise all chains are parsed. + + Returns: + A new `Protein` parsed from the pdb contents. + """ + pdb_fh = io.StringIO(pdb_str) + parser = PDBParser(QUIET=True) + structure = parser.get_structure('none', pdb_fh) + models = list(structure.get_models()) + if len(models) != 1: + raise ValueError( + f'Only single model PDBs are supported. Found {len(models)} models.' + ) + model = models[0] + + atom_positions = [] + aatype = [] + atom_mask = [] + residue_index = [] + chain_ids = [] + b_factors = [] + + for chain in model: + if chain_id is not None and chain.id != chain_id: + continue + for res in chain: + if res.id[2] != ' ': + raise ValueError( + f'PDB contains an insertion code at chain {chain.id} and residue ' + f'index {res.id[1]}. These are not supported.') + res_shortname = residue_constants.restype_3to1.get( + res.resname, 'X') + restype_idx = residue_constants.restype_order.get( + res_shortname, residue_constants.restype_num) + pos = np.zeros((residue_constants.atom_type_num, 3)) + mask = np.zeros((residue_constants.atom_type_num, )) + res_b_factors = np.zeros((residue_constants.atom_type_num, )) + for atom in res: + if atom.name not in residue_constants.atom_types: + continue + pos[residue_constants.atom_order[atom.name]] = atom.coord + mask[residue_constants.atom_order[atom.name]] = 1.0 + res_b_factors[residue_constants.atom_order[ + atom.name]] = atom.bfactor + if np.sum(mask) < 0.5: + # If no known atom positions are reported for the residue then skip it. + continue + aatype.append(restype_idx) + atom_positions.append(pos) + atom_mask.append(mask) + residue_index.append(res.id[1]) + chain_ids.append(chain.id) + b_factors.append(res_b_factors) + + # Chain IDs are usually characters so map these to ints. + unique_chain_ids = np.unique(chain_ids) + chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} + chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) + + return Protein( + atom_positions=np.array(atom_positions), + atom_mask=np.array(atom_mask), + aatype=np.array(aatype), + residue_index=np.array(residue_index), + chain_index=chain_index, + b_factors=np.array(b_factors), + ) + + +def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: + chain_end = 'TER' + return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' + f'{chain_name:>1}{residue_index:>4}') + + +def to_pdb(prot: Protein) -> str: + """Converts a `Protein` instance to a PDB string. + + Args: + prot: The protein to convert to PDB. + + Returns: + PDB string. + """ + restypes = residue_constants.restypes + ['X'] + + # res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') + def res_1to3(r): + return residue_constants.restype_1to3.get(restypes[r], 'UNK') + + atom_types = residue_constants.atom_types + + pdb_lines = [] + + atom_mask = prot.atom_mask + aatype = prot.aatype + atom_positions = prot.atom_positions + residue_index = prot.residue_index.astype(np.int32) + chain_index = prot.chain_index.astype(np.int32) + b_factors = prot.b_factors + + if np.any(aatype > residue_constants.restype_num): + raise ValueError('Invalid aatypes.') + + # Construct a mapping from chain integer indices to chain ID strings. + chain_ids = {} + for i in np.unique(chain_index): # np.unique gives sorted output. + if i >= PDB_MAX_CHAINS: + raise ValueError( + f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') + chain_ids[i] = PDB_CHAIN_IDS[i] + + pdb_lines.append('MODEL 1') + atom_index = 1 + last_chain_index = chain_index[0] + # Add all atom sites. + for i in range(aatype.shape[0]): + # Close the previous chain if in a multichain PDB. + if last_chain_index != chain_index[i]: + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[i - 1]), + chain_ids[chain_index[i - 1]], + residue_index[i - 1], + )) + last_chain_index = chain_index[i] + atom_index += 1 # Atom index increases at the TER symbol. + + res_name_3 = res_1to3(aatype[i]) + for atom_name, pos, mask, b_factor in zip(atom_types, + atom_positions[i], + atom_mask[i], b_factors[i]): + if mask < 0.5: + continue + + record_type = 'ATOM' + name = atom_name if len(atom_name) == 4 else f' {atom_name}' + alt_loc = '' + insertion_code = '' + occupancy = 1.00 + element = atom_name[ + 0] # Protein supports only C, N, O, S, this works. + charge = '' + # PDB is a columnar format, every space matters here! + atom_line = ( + f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' + f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' + f'{residue_index[i]:>4}{insertion_code:>1} ' + f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' + f'{occupancy:>6.2f}{b_factor:>6.2f} ' + f'{element:>2}{charge:>2}') + pdb_lines.append(atom_line) + atom_index += 1 + + # Close the final chain. + pdb_lines.append( + _chain_end( + atom_index, + res_1to3(aatype[-1]), + chain_ids[chain_index[-1]], + residue_index[-1], + )) + pdb_lines.append('ENDMDL') + pdb_lines.append('END') + + # Pad all lines to 80 characters. + pdb_lines = [line.ljust(80) for line in pdb_lines] + return '\n'.join(pdb_lines) + '\n' # Add terminating newline. + + +def ideal_atom_mask(prot: Protein) -> np.ndarray: + """Computes an ideal atom mask. + + `Protein.atom_mask` typically is defined according to the atoms that are + reported in the PDB. This function computes a mask according to heavy atoms + that should be present in the given sequence of amino acids. + + Args: + prot: `Protein` whose fields are `numpy.ndarray` objects. + + Returns: + An ideal atom mask. + """ + return residue_constants.STANDARD_ATOM_MASK[prot.aatype] + + +def from_prediction(features: FeatureDict, + result: ModelOutput, + b_factors: Optional[np.ndarray] = None) -> Protein: + """Assembles a protein from a prediction. + + Args: + features: Dictionary holding model inputs. + fold_output: Dictionary holding model outputs. + b_factors: (Optional) B-factors to use for the protein. + + Returns: + A protein instance. + """ + + if 'asym_id' in features: + chain_index = features['asym_id'] - 1 + else: + chain_index = np.zeros_like((features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(result['final_atom_mask']) + + return Protein( + aatype=features['aatype'], + atom_positions=result['final_atom_positions'], + atom_mask=result['final_atom_mask'], + residue_index=features['residue_index'] + 1, + chain_index=chain_index, + b_factors=b_factors, + ) + + +def from_feature(features: FeatureDict, + b_factors: Optional[np.ndarray] = None) -> Protein: + """Assembles a standard pdb from input atom positions & mask. + + Args: + features: Dictionary holding model inputs. + b_factors: (Optional) B-factors to use for the protein. + + Returns: + A protein instance. + """ + + if 'asym_id' in features: + chain_index = features['asym_id'] - 1 + else: + chain_index = np.zeros_like((features['aatype'])) + + if b_factors is None: + b_factors = np.zeros_like(features['all_atom_mask']) + + return Protein( + aatype=features['aatype'], + atom_positions=features['all_atom_positions'], + atom_mask=features['all_atom_mask'], + residue_index=features['residue_index'] + 1, + chain_index=chain_index, + b_factors=b_factors, + ) diff --git a/modelscope/models/science/unifold/data/residue_constants.py b/modelscope/models/science/unifold/data/residue_constants.py new file mode 100644 index 00000000..beebfe89 --- /dev/null +++ b/modelscope/models/science/unifold/data/residue_constants.py @@ -0,0 +1,1212 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Constants used in AlphaFold.""" + +import collections +import functools +import os +from typing import List, Mapping, Tuple + +import numpy as np +from unicore.utils import tree_map + +# Distance from one CA to next CA [trans configuration: omega = 180]. +ca_ca = 3.80209737096 + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +chi_angles_atoms = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'NE'], + ['CG', 'CD', 'NE', 'CZ'], + ], + 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']], + 'CYS': [['N', 'CA', 'CB', 'SG']], + 'GLN': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1'], + ], + 'GLU': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'OE1'], + ], + 'GLY': [], + 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']], + 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']], + 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'LYS': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'CD'], + ['CB', 'CG', 'CD', 'CE'], + ['CG', 'CD', 'CE', 'NZ'], + ], + 'MET': [ + ['N', 'CA', 'CB', 'CG'], + ['CA', 'CB', 'CG', 'SD'], + ['CB', 'CG', 'SD', 'CE'], + ], + 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']], + 'SER': [['N', 'CA', 'CB', 'OG']], + 'THR': [['N', 'CA', 'CB', 'OG1']], + 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']], + 'VAL': [['N', 'CA', 'CB', 'CG1']], +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +chi_angles_mask = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [1.0, 1.0, 1.0, 1.0], # ARG + [1.0, 1.0, 0.0, 0.0], # ASN + [1.0, 1.0, 0.0, 0.0], # ASP + [1.0, 0.0, 0.0, 0.0], # CYS + [1.0, 1.0, 1.0, 0.0], # GLN + [1.0, 1.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [1.0, 1.0, 0.0, 0.0], # HIS + [1.0, 1.0, 0.0, 0.0], # ILE + [1.0, 1.0, 0.0, 0.0], # LEU + [1.0, 1.0, 1.0, 1.0], # LYS + [1.0, 1.0, 1.0, 0.0], # MET + [1.0, 1.0, 0.0, 0.0], # PHE + [1.0, 1.0, 0.0, 0.0], # PRO + [1.0, 0.0, 0.0, 0.0], # SER + [1.0, 0.0, 0.0, 0.0], # THR + [1.0, 1.0, 0.0, 0.0], # TRP + [1.0, 1.0, 0.0, 0.0], # TYR + [1.0, 0.0, 0.0, 0.0], # VAL +] + +# The following chi angles are pi periodic: they can be rotated by a multiple +# of pi without affecting the structure. +chi_pi_periodic = [ + [0.0, 0.0, 0.0, 0.0], # ALA + [0.0, 0.0, 0.0, 0.0], # ARG + [0.0, 0.0, 0.0, 0.0], # ASN + [0.0, 1.0, 0.0, 0.0], # ASP + [0.0, 0.0, 0.0, 0.0], # CYS + [0.0, 0.0, 0.0, 0.0], # GLN + [0.0, 0.0, 1.0, 0.0], # GLU + [0.0, 0.0, 0.0, 0.0], # GLY + [0.0, 0.0, 0.0, 0.0], # HIS + [0.0, 0.0, 0.0, 0.0], # ILE + [0.0, 0.0, 0.0, 0.0], # LEU + [0.0, 0.0, 0.0, 0.0], # LYS + [0.0, 0.0, 0.0, 0.0], # MET + [0.0, 1.0, 0.0, 0.0], # PHE + [0.0, 0.0, 0.0, 0.0], # PRO + [0.0, 0.0, 0.0, 0.0], # SER + [0.0, 0.0, 0.0, 0.0], # THR + [0.0, 0.0, 0.0, 0.0], # TRP + [0.0, 1.0, 0.0, 0.0], # TYR + [0.0, 0.0, 0.0, 0.0], # VAL + [0.0, 0.0, 0.0, 0.0], # UNK +] + +# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi, +# psi and chi angles: +# 0: 'backbone group', +# 1: 'pre-omega-group', (empty) +# 2: 'phi-group', (currently empty, because it defines only hydrogens) +# 3: 'psi-group', +# 4,5,6,7: 'chi1,2,3,4-group' +# The atom positions are relative to the axis-end-atom of the corresponding +# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis +# is defined such that the dihedral-angle-definiting atom (the last entry in +# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate). +# format: [atomname, group_idx, rel_position] +rigid_group_atom_positions = { + 'ALA': [ + ['N', 0, (-0.525, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.529, -0.774, -1.205)], + ['O', 3, (0.627, 1.062, 0.000)], + ], + 'ARG': [ + ['N', 0, (-0.524, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.524, -0.778, -1.209)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.616, 1.390, -0.000)], + ['CD', 5, (0.564, 1.414, 0.000)], + ['NE', 6, (0.539, 1.357, -0.000)], + ['NH1', 7, (0.206, 2.301, 0.000)], + ['NH2', 7, (2.078, 0.978, -0.000)], + ['CZ', 7, (0.758, 1.093, -0.000)], + ], + 'ASN': [ + ['N', 0, (-0.536, 1.357, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.531, -0.787, -1.200)], + ['O', 3, (0.625, 1.062, 0.000)], + ['CG', 4, (0.584, 1.399, 0.000)], + ['ND2', 5, (0.593, -1.188, 0.001)], + ['OD1', 5, (0.633, 1.059, 0.000)], + ], + 'ASP': [ + ['N', 0, (-0.525, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, 0.000, -0.000)], + ['CB', 0, (-0.526, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.593, 1.398, -0.000)], + ['OD1', 5, (0.610, 1.091, 0.000)], + ['OD2', 5, (0.592, -1.101, -0.003)], + ], + 'CYS': [ + ['N', 0, (-0.522, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, 0.000)], + ['CB', 0, (-0.519, -0.773, -1.212)], + ['O', 3, (0.625, 1.062, -0.000)], + ['SG', 4, (0.728, 1.653, 0.000)], + ], + 'GLN': [ + ['N', 0, (-0.526, 1.361, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.779, -1.207)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.615, 1.393, 0.000)], + ['CD', 5, (0.587, 1.399, -0.000)], + ['NE2', 6, (0.593, -1.189, -0.001)], + ['OE1', 6, (0.634, 1.060, 0.000)], + ], + 'GLU': [ + ['N', 0, (-0.528, 1.361, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, -0.000, -0.000)], + ['CB', 0, (-0.526, -0.781, -1.207)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG', 4, (0.615, 1.392, 0.000)], + ['CD', 5, (0.600, 1.397, 0.000)], + ['OE1', 6, (0.607, 1.095, -0.000)], + ['OE2', 6, (0.589, -1.104, -0.001)], + ], + 'GLY': [ + ['N', 0, (-0.572, 1.337, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.517, -0.000, -0.000)], + ['O', 3, (0.626, 1.062, -0.000)], + ], + 'HIS': [ + ['N', 0, (-0.527, 1.360, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.525, -0.778, -1.208)], + ['O', 3, (0.625, 1.063, 0.000)], + ['CG', 4, (0.600, 1.370, -0.000)], + ['CD2', 5, (0.889, -1.021, 0.003)], + ['ND1', 5, (0.744, 1.160, -0.000)], + ['CE1', 5, (2.030, 0.851, 0.002)], + ['NE2', 5, (2.145, -0.466, 0.004)], + ], + 'ILE': [ + ['N', 0, (-0.493, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.536, -0.793, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.534, 1.437, -0.000)], + ['CG2', 4, (0.540, -0.785, -1.199)], + ['CD1', 5, (0.619, 1.391, 0.000)], + ], + 'LEU': [ + ['N', 0, (-0.520, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.773, -1.214)], + ['O', 3, (0.625, 1.063, -0.000)], + ['CG', 4, (0.678, 1.371, 0.000)], + ['CD1', 5, (0.530, 1.430, -0.000)], + ['CD2', 5, (0.535, -0.774, 1.200)], + ], + 'LYS': [ + ['N', 0, (-0.526, 1.362, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, 0.000)], + ['CB', 0, (-0.524, -0.778, -1.208)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.619, 1.390, 0.000)], + ['CD', 5, (0.559, 1.417, 0.000)], + ['CE', 6, (0.560, 1.416, 0.000)], + ['NZ', 7, (0.554, 1.387, 0.000)], + ], + 'MET': [ + ['N', 0, (-0.521, 1.364, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, 0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.210)], + ['O', 3, (0.625, 1.062, -0.000)], + ['CG', 4, (0.613, 1.391, -0.000)], + ['SD', 5, (0.703, 1.695, 0.000)], + ['CE', 6, (0.320, 1.786, -0.000)], + ], + 'PHE': [ + ['N', 0, (-0.518, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, 0.000, -0.000)], + ['CB', 0, (-0.525, -0.776, -1.212)], + ['O', 3, (0.626, 1.062, -0.000)], + ['CG', 4, (0.607, 1.377, 0.000)], + ['CD1', 5, (0.709, 1.195, -0.000)], + ['CD2', 5, (0.706, -1.196, 0.000)], + ['CE1', 5, (2.102, 1.198, -0.000)], + ['CE2', 5, (2.098, -1.201, -0.000)], + ['CZ', 5, (2.794, -0.003, -0.001)], + ], + 'PRO': [ + ['N', 0, (-0.566, 1.351, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, 0.000)], + ['CB', 0, (-0.546, -0.611, -1.293)], + ['O', 3, (0.621, 1.066, 0.000)], + ['CG', 4, (0.382, 1.445, 0.0)], + # ['CD', 5, (0.427, 1.440, 0.0)], + ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger + ], + 'SER': [ + ['N', 0, (-0.529, 1.360, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, -0.000)], + ['CB', 0, (-0.518, -0.777, -1.211)], + ['O', 3, (0.626, 1.062, -0.000)], + ['OG', 4, (0.503, 1.325, 0.000)], + ], + 'THR': [ + ['N', 0, (-0.517, 1.364, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.526, 0.000, -0.000)], + ['CB', 0, (-0.516, -0.793, -1.215)], + ['O', 3, (0.626, 1.062, 0.000)], + ['CG2', 4, (0.550, -0.718, -1.228)], + ['OG1', 4, (0.472, 1.353, 0.000)], + ], + 'TRP': [ + ['N', 0, (-0.521, 1.363, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.525, -0.000, 0.000)], + ['CB', 0, (-0.523, -0.776, -1.212)], + ['O', 3, (0.627, 1.062, 0.000)], + ['CG', 4, (0.609, 1.370, -0.000)], + ['CD1', 5, (0.824, 1.091, 0.000)], + ['CD2', 5, (0.854, -1.148, -0.005)], + ['CE2', 5, (2.186, -0.678, -0.007)], + ['CE3', 5, (0.622, -2.530, -0.007)], + ['NE1', 5, (2.140, 0.690, -0.004)], + ['CH2', 5, (3.028, -2.890, -0.013)], + ['CZ2', 5, (3.283, -1.543, -0.011)], + ['CZ3', 5, (1.715, -3.389, -0.011)], + ], + 'TYR': [ + ['N', 0, (-0.522, 1.362, 0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.524, -0.000, -0.000)], + ['CB', 0, (-0.522, -0.776, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG', 4, (0.607, 1.382, -0.000)], + ['CD1', 5, (0.716, 1.195, -0.000)], + ['CD2', 5, (0.713, -1.194, -0.001)], + ['CE1', 5, (2.107, 1.200, -0.002)], + ['CE2', 5, (2.104, -1.201, -0.003)], + ['OH', 5, (4.168, -0.002, -0.005)], + ['CZ', 5, (2.791, -0.001, -0.003)], + ], + 'VAL': [ + ['N', 0, (-0.494, 1.373, -0.000)], + ['CA', 0, (0.000, 0.000, 0.000)], + ['C', 0, (1.527, -0.000, -0.000)], + ['CB', 0, (-0.533, -0.795, -1.213)], + ['O', 3, (0.627, 1.062, -0.000)], + ['CG1', 4, (0.540, 1.429, -0.000)], + ['CG2', 4, (0.533, -0.776, 1.203)], + ], +} + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +residue_atoms = { + 'ALA': ['C', 'CA', 'CB', 'N', 'O'], + 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'], + 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'], + 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'], + 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'], + 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'], + 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'], + 'GLY': ['C', 'CA', 'N', 'O'], + 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'], + 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'], + 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'], + 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'], + 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'], + 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'], + 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'], + 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'], + 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'], + 'TRP': [ + 'C', + 'CA', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE2', + 'CE3', + 'CZ2', + 'CZ3', + 'CH2', + 'N', + 'NE1', + 'O', + ], + 'TYR': + ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O', 'OH'], + 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O'], +} + +# Naming swaps for ambiguous atom names. +# Due to symmetries in the amino acids the naming of atoms is ambiguous in +# 4 of the 20 amino acids. +# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities +# in LEU, VAL and ARG can be resolved by using the 3d constellations of +# the 'ambiguous' atoms and their neighbours) +residue_atom_renaming_swaps = { + 'ASP': { + 'OD1': 'OD2' + }, + 'GLU': { + 'OE1': 'OE2' + }, + 'PHE': { + 'CD1': 'CD2', + 'CE1': 'CE2' + }, + 'TYR': { + 'CD1': 'CD2', + 'CE1': 'CE2' + }, +} + +# Van der Waals radii [Angstroem] of the atoms (from Wikipedia) +van_der_waals_radius = { + 'C': 1.7, + 'N': 1.55, + 'O': 1.52, + 'S': 1.8, +} + +Bond = collections.namedtuple('Bond', + ['atom1_name', 'atom2_name', 'length', 'stddev']) +BondAngle = collections.namedtuple( + 'BondAngle', + ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev']) + + +@functools.lru_cache(maxsize=None) +# def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]], Mapping[ #noqa +# str, List[Bond]], Mapping[str, List[BondAngle]]]: +def load_stereo_chemical_props(): + """Load stereo_chemical_props.txt into a nice structure. + + Load literature values for bond lengths and bond angles and translate + bond angles into the length of the opposite edge of the triangle + ("residue_virtual_bonds"). + + Returns: + residue_bonds: Dict that maps resname -> list of Bond tuples. + residue_virtual_bonds: Dict that maps resname -> list of Bond tuples. + residue_bond_angles: Dict that maps resname -> list of BondAngle tuples. + """ + stereo_chemical_props_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + 'stereo_chemical_props.txt') + with open(stereo_chemical_props_path, 'rt') as f: + stereo_chemical_props = f.read() + lines_iter = iter(stereo_chemical_props.splitlines()) + # Load bond lengths. + residue_bonds = {} + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, length, stddev = line.split() + atom1, atom2 = bond.split('-') + if resname not in residue_bonds: + residue_bonds[resname] = [] + residue_bonds[resname].append( + Bond(atom1, atom2, float(length), float(stddev))) + residue_bonds['UNK'] = [] + + # Load bond angles. + residue_bond_angles = {} + next(lines_iter) # Skip empty line. + next(lines_iter) # Skip header line. + for line in lines_iter: + if line.strip() == '-': + break + bond, resname, angle_degree, stddev_degree = line.split() + atom1, atom2, atom3 = bond.split('-') + if resname not in residue_bond_angles: + residue_bond_angles[resname] = [] + residue_bond_angles[resname].append( + BondAngle( + atom1, + atom2, + atom3, + float(angle_degree) / 180.0 * np.pi, + float(stddev_degree) / 180.0 * np.pi, + )) + residue_bond_angles['UNK'] = [] + + def make_bond_key(atom1_name, atom2_name): + """Unique key to lookup bonds.""" + return '-'.join(sorted([atom1_name, atom2_name])) + + # Translate bond angles into distances ("virtual bonds"). + residue_virtual_bonds = {} + for resname, bond_angles in residue_bond_angles.items(): + # Create a fast lookup dict for bond lengths. + bond_cache = {} + for b in residue_bonds[resname]: + bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b + residue_virtual_bonds[resname] = [] + for ba in bond_angles: + bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)] + bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)] + + # Compute distance between atom1 and atom3 using the law of cosines + # c^2 = a^2 + b^2 - 2ab*cos(gamma). + gamma = ba.angle_rad + length = np.sqrt(bond1.length**2 + bond2.length**2 + - 2 * bond1.length * bond2.length * np.cos(gamma)) + + # Propagation of uncertainty assuming uncorrelated errors. + dl_outer = 0.5 / length + dl_dgamma = (2 * bond1.length * bond2.length + * np.sin(gamma)) * dl_outer + dl_db1 = (2 * bond1.length + - 2 * bond2.length * np.cos(gamma)) * dl_outer + dl_db2 = (2 * bond2.length + - 2 * bond1.length * np.cos(gamma)) * dl_outer + stddev = np.sqrt((dl_dgamma * ba.stddev)**2 + + (dl_db1 * bond1.stddev)**2 + + (dl_db2 * bond2.stddev)**2) + residue_virtual_bonds[resname].append( + Bond(ba.atom1_name, ba.atom3name, length, stddev)) + + return (residue_bonds, residue_virtual_bonds, residue_bond_angles) + + +# Between-residue bond lengths for general bonds (first element) and for Proline +# (second element). +between_res_bond_length_c_n = [1.329, 1.341] +between_res_bond_length_stddev_c_n = [0.014, 0.016] + +# Between-residue cos_angles. +between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315 +between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995 + +# This mapping is used when we need to store atom data in a format that requires +# fixed atom data size for every residue (e.g. a numpy array). +atom_types = [ + 'N', + 'CA', + 'C', + 'CB', + 'O', + 'CG', + 'CG1', + 'CG2', + 'OG', + 'OG1', + 'SG', + 'CD', + 'CD1', + 'CD2', + 'ND1', + 'ND2', + 'OD1', + 'OD2', + 'SD', + 'CE', + 'CE1', + 'CE2', + 'CE3', + 'NE', + 'NE1', + 'NE2', + 'OE1', + 'OE2', + 'CH2', + 'NH1', + 'NH2', + 'OH', + 'CZ', + 'CZ2', + 'CZ3', + 'NZ', + 'OXT', +] +atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)} +atom_type_num = len(atom_types) # := 37. + +# A compact atom encoding with 14 columns +# pylint: disable=line-too-long +# pylint: disable=bad-whitespace +restype_name_to_atom14_names = { + 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''], + 'ARG': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD', + 'NE', + 'CZ', + 'NH1', + 'NH2', + '', + '', + '', + ], + 'ASN': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''], + 'ASP': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''], + 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''], + 'GLN': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''], + 'GLU': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''], + 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''], + 'HIS': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'ND1', + 'CD2', + 'CE1', + 'NE2', + '', + '', + '', + '', + ], + 'ILE': + ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''], + 'LEU': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''], + 'LYS': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''], + 'MET': + ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''], + 'PHE': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE1', + 'CE2', + 'CZ', + '', + '', + '', + ], + 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''], + 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''], + 'THR': + ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''], + 'TRP': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'NE1', + 'CE2', + 'CE3', + 'CZ2', + 'CZ3', + 'CH2', + ], + 'TYR': [ + 'N', + 'CA', + 'C', + 'O', + 'CB', + 'CG', + 'CD1', + 'CD2', + 'CE1', + 'CE2', + 'CZ', + 'OH', + '', + '', + ], + 'VAL': + ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''], + 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''], +} +# pylint: enable=line-too-long +# pylint: enable=bad-whitespace + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +restypes = [ + 'A', + 'R', + 'N', + 'D', + 'C', + 'Q', + 'E', + 'G', + 'H', + 'I', + 'L', + 'K', + 'M', + 'F', + 'P', + 'S', + 'T', + 'W', + 'Y', + 'V', +] +restype_order = {restype: i for i, restype in enumerate(restypes)} +restype_num = len(restypes) # := 20. +unk_restype_index = restype_num # Catch-all index for unknown restypes. + +restypes_with_x = restypes + ['X'] +restype_order_with_x = { + restype: i + for i, restype in enumerate(restypes_with_x) +} + + +def sequence_to_onehot(sequence: str, + mapping: Mapping[str, int], + map_unknown_to_x: bool = False) -> np.ndarray: + """Maps the given sequence into a one-hot encoded matrix. + + Args: + sequence: An amino acid sequence. + mapping: A dictionary mapping amino acids to integers. + map_unknown_to_x: If True, any amino acid that is not in the mapping will be + mapped to the unknown amino acid 'X'. If the mapping doesn't contain + amino acid 'X', an error will be thrown. If False, any amino acid not in + the mapping will throw an error. + + Returns: + A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of + the sequence. + + Raises: + ValueError: If the mapping doesn't contain values from 0 to + num_unique_aas - 1 without any gaps. + """ + num_entries = max(mapping.values()) + 1 + + if sorted(set(mapping.values())) != list(range(num_entries)): + raise ValueError( + 'The mapping must have values from 0 to num_unique_aas-1 ' + 'without any gaps. Got: %s' % sorted(mapping.values())) + + one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32) + + for aa_index, aa_type in enumerate(sequence): + if map_unknown_to_x: + if aa_type.isalpha() and aa_type.isupper(): + aa_id = mapping.get(aa_type, mapping['X']) + else: + raise ValueError( + f'Invalid character in the sequence: {aa_type}') + else: + aa_id = mapping[aa_type] + one_hot_arr[aa_index, aa_id] = 1 + + return one_hot_arr + + +restype_1to3 = { + 'A': 'ALA', + 'R': 'ARG', + 'N': 'ASN', + 'D': 'ASP', + 'C': 'CYS', + 'Q': 'GLN', + 'E': 'GLU', + 'G': 'GLY', + 'H': 'HIS', + 'I': 'ILE', + 'L': 'LEU', + 'K': 'LYS', + 'M': 'MET', + 'F': 'PHE', + 'P': 'PRO', + 'S': 'SER', + 'T': 'THR', + 'W': 'TRP', + 'Y': 'TYR', + 'V': 'VAL', +} + +# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple +# 1-to-1 mapping of 3 letter names to one letter names. The latter contains +# many more, and less common, three letter names as keys and maps many of these +# to the same one letter name (including 'X' and 'U' which we don't use here). +restype_3to1 = {v: k for k, v in restype_1to3.items()} + +# Define a restype name for all unknown residues. +unk_restype = 'UNK' + +resnames = [restype_1to3[r] for r in restypes] + [unk_restype] +resname_to_idx = {resname: i for i, resname in enumerate(resnames)} + +# The mapping here uses hhblits convention, so that B is mapped to D, J and O +# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the +# remaining 20 amino acids are kept in alphabetical order. +# There are 2 non-amino acid codes, X (representing any amino acid) and +# "-" representing a missing amino acid in an alignment. The id for these +# codes is put at the end (20 and 21) so that they can easily be ignored if +# desired. +HHBLITS_AA_TO_ID = { + 'A': 0, + 'B': 2, + 'C': 1, + 'D': 2, + 'E': 3, + 'F': 4, + 'G': 5, + 'H': 6, + 'I': 7, + 'J': 20, + 'K': 8, + 'L': 9, + 'M': 10, + 'N': 11, + 'O': 20, + 'P': 12, + 'Q': 13, + 'R': 14, + 'S': 15, + 'T': 16, + 'U': 1, + 'V': 17, + 'W': 18, + 'X': 20, + 'Y': 19, + 'Z': 3, + '-': 21, +} + +# Partial inversion of HHBLITS_AA_TO_ID. +ID_TO_HHBLITS_AA = { + 0: 'A', + 1: 'C', # Also U. + 2: 'D', # Also B. + 3: 'E', # Also Z. + 4: 'F', + 5: 'G', + 6: 'H', + 7: 'I', + 8: 'K', + 9: 'L', + 10: 'M', + 11: 'N', + 12: 'P', + 13: 'Q', + 14: 'R', + 15: 'S', + 16: 'T', + 17: 'V', + 18: 'W', + 19: 'Y', + 20: 'X', # Includes J and O. + 21: '-', +} + +restypes_with_x_and_gap = restypes + ['X', '-'] +MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple( + restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i]) + for i in range(len(restypes_with_x_and_gap))) + + +def _make_standard_atom_mask() -> np.ndarray: + """Returns [num_res_types, num_atom_types] mask array.""" + # +1 to account for unknown (all 0s). + mask = np.zeros([restype_num + 1, atom_type_num], dtype=np.int32) + for restype, restype_letter in enumerate(restypes): + restype_name = restype_1to3[restype_letter] + atom_names = residue_atoms[restype_name] + for atom_name in atom_names: + atom_type = atom_order[atom_name] + mask[restype, atom_type] = 1 + return mask + + +STANDARD_ATOM_MASK = _make_standard_atom_mask() + + +# A one hot representation for the first and second atoms defining the axis +# of rotation for each chi-angle in each residue. +def chi_angle_atom(atom_index: int) -> np.ndarray: + """Define chi-angle rigid groups via one-hot representations.""" + chi_angles_index = {} + one_hots = [] + + for k, v in chi_angles_atoms.items(): + indices = [atom_types.index(s[atom_index]) for s in v] + indices.extend([-1] * (4 - len(indices))) + chi_angles_index[k] = indices + + for r in restypes: + res3 = restype_1to3[r] + one_hot = np.eye(atom_type_num)[chi_angles_index[res3]] + one_hots.append(one_hot) + + one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`. + one_hot = np.stack(one_hots, axis=0) + one_hot = np.transpose(one_hot, [0, 2, 1]) + + return one_hot + + +chi_atom_1_one_hot = chi_angle_atom(1) +chi_atom_2_one_hot = chi_angle_atom(2) + +# An array like chi_angles_atoms but using indices rather than names. +chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes] +chi_angles_atom_indices = tree_map( + lambda n: atom_order[n], chi_angles_atom_indices, leaf_type=str) +chi_angles_atom_indices = np.array([ + chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms))) + for chi_atoms in chi_angles_atom_indices +]) + +# Mapping from (res_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +chi_groups_for_atom = collections.defaultdict(list) +for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i)) +chi_groups_for_atom = dict(chi_groups_for_atom) + + +def _make_rigid_transformation_4x4(ex, ey, translation): + """Create a rigid 4x4 transformation matrix from two axes and transl.""" + # Normalize ex. + ex_normalized = ex / np.linalg.norm(ex) + + # make ey perpendicular to ex + ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized + ey_normalized /= np.linalg.norm(ey_normalized) + + # compute ez as cross product + eznorm = np.cross(ex_normalized, ey_normalized) + m = np.stack([ex_normalized, ey_normalized, eznorm, + translation]).transpose() + m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0) + return m + + +# create an array with (restype, atomtype) --> rigid_group_idx +# and an array with (restype, atomtype, coord) for the atom positions +# and compute affine transformation matrices (4,4) from one rigid group to the +# previous group +restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=np.int_) +restype_atom37_mask = np.zeros([21, 37], dtype=np.float32) +restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32) +restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=np.int_) +restype_atom14_mask = np.zeros([21, 14], dtype=np.float32) +restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32) +restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32) + + +def _make_rigid_group_constants(): + """Fill the arrays above.""" + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + for atomname, group_idx, atom_position in rigid_group_atom_positions[ + resname]: + atomtype = atom_order[atomname] + restype_atom37_to_rigid_group[restype, atomtype] = group_idx + restype_atom37_mask[restype, atomtype] = 1 + restype_atom37_rigid_group_positions[restype, + atomtype, :] = atom_position + + atom14idx = restype_name_to_atom14_names[resname].index(atomname) + restype_atom14_to_rigid_group[restype, atom14idx] = group_idx + restype_atom14_mask[restype, atom14idx] = 1 + restype_atom14_rigid_group_positions[restype, + atom14idx, :] = atom_position + + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_positions = { + name: np.array(pos) + for name, _, pos in rigid_group_atom_positions[resname] + } + + # backbone to backbone is the identity transform + restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4) + + # pre-omega-frame to backbone (currently dummy identity matrix) + restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4) + + # phi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['N'] - atom_positions['CA'], + ey=np.array([1.0, 0.0, 0.0]), + translation=atom_positions['N'], + ) + restype_rigid_group_default_frame[restype, 2, :, :] = mat + + # psi-frame to backbone + mat = _make_rigid_transformation_4x4( + ex=atom_positions['C'] - atom_positions['CA'], + ey=atom_positions['CA'] - atom_positions['N'], + translation=atom_positions['C'], + ) + restype_rigid_group_default_frame[restype, 3, :, :] = mat + + # chi1-frame to backbone + if chi_angles_mask[restype][0]: + base_atom_names = chi_angles_atoms[resname][0] + base_atom_positions = [ + atom_positions[name] for name in base_atom_names + ] + mat = _make_rigid_transformation_4x4( + ex=base_atom_positions[2] - base_atom_positions[1], + ey=base_atom_positions[0] - base_atom_positions[1], + translation=base_atom_positions[2], + ) + restype_rigid_group_default_frame[restype, 4, :, :] = mat + + # chi2-frame to chi1-frame + # chi3-frame to chi2-frame + # chi4-frame to chi3-frame + # luckily all rotation axes for the next frame start at (0,0,0) of the + # previous frame + for chi_idx in range(1, 4): + if chi_angles_mask[restype][chi_idx]: + axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2] + axis_end_atom_position = atom_positions[axis_end_atom_name] + mat = _make_rigid_transformation_4x4( + ex=axis_end_atom_position, + ey=np.array([-1.0, 0.0, 0.0]), + translation=axis_end_atom_position, + ) + restype_rigid_group_default_frame[restype, + 4 + chi_idx, :, :] = mat + + +_make_rigid_group_constants() + + +def make_atom14_dists_bounds(overlap_tolerance=1.5, + bond_length_tolerance_factor=15): + """compute upper and lower bounds for bonds to assess violations.""" + restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32) + restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32) + residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props() + for restype, restype_letter in enumerate(restypes): + resname = restype_1to3[restype_letter] + atom_list = restype_name_to_atom14_names[resname] + + # create lower and upper bounds for clashes + for atom1_idx, atom1_name in enumerate(atom_list): + if not atom1_name: + continue + atom1_radius = van_der_waals_radius[atom1_name[0]] + for atom2_idx, atom2_name in enumerate(atom_list): + if (not atom2_name) or atom1_idx == atom2_idx: + continue + atom2_radius = van_der_waals_radius[atom2_name[0]] + lower = atom1_radius + atom2_radius - overlap_tolerance + upper = 1e10 + restype_atom14_bond_lower_bound[restype, atom1_idx, + atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, + atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, + atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, + atom1_idx] = upper + + # overwrite lower and upper bounds for bonds and angles + for b in residue_bonds[resname] + residue_virtual_bonds[resname]: + atom1_idx = atom_list.index(b.atom1_name) + atom2_idx = atom_list.index(b.atom2_name) + lower = b.length - bond_length_tolerance_factor * b.stddev + upper = b.length + bond_length_tolerance_factor * b.stddev + restype_atom14_bond_lower_bound[restype, atom1_idx, + atom2_idx] = lower + restype_atom14_bond_lower_bound[restype, atom2_idx, + atom1_idx] = lower + restype_atom14_bond_upper_bound[restype, atom1_idx, + atom2_idx] = upper + restype_atom14_bond_upper_bound[restype, atom2_idx, + atom1_idx] = upper + restype_atom14_bond_stddev[restype, atom1_idx, + atom2_idx] = b.stddev + restype_atom14_bond_stddev[restype, atom2_idx, + atom1_idx] = b.stddev + return { + 'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14) + 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14) + 'stddev': restype_atom14_bond_stddev, # shape (21,14,14) + } + + +def _make_atom14_and_atom37_constants(): + restype_atom14_to_atom37 = [] + restype_atom37_to_atom14 = [] + restype_atom14_mask = [] + + for rt in restypes: + atom_names = restype_name_to_atom14_names[restype_1to3[rt]] + restype_atom14_to_atom37.append([(atom_order[name] if name else 0) + for name in atom_names]) + atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} + restype_atom37_to_atom14.append([ + (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) + for name in atom_types + ]) + + restype_atom14_mask.append([(1.0 if name else 0.0) + for name in atom_names]) + + # Add dummy mapping for restype 'UNK' + restype_atom14_to_atom37.append([0] * 14) + restype_atom37_to_atom14.append([0] * 37) + restype_atom14_mask.append([0.0] * 14) + + restype_atom14_to_atom37 = np.array( + restype_atom14_to_atom37, dtype=np.int32) + restype_atom37_to_atom14 = np.array( + restype_atom37_to_atom14, dtype=np.int32) + restype_atom14_mask = np.array(restype_atom14_mask, dtype=np.float32) + + return restype_atom14_to_atom37, restype_atom37_to_atom14, restype_atom14_mask + + +( + restype_atom14_to_atom37, + restype_atom37_to_atom14, + restype_atom14_mask, +) = _make_atom14_and_atom37_constants() + + +def _make_renaming_matrices(): + # As the atom naming is ambiguous for 7 of the 20 amino acids, provide + # alternative ground truth coordinates where the naming is swapped + restype_3 = [restype_1to3[res] for res in restypes] + restype_3 += ['UNK'] + + # Matrices for renaming ambiguous atoms. + all_matrices = {res: np.eye(14) for res in restype_3} + for resname, swap in residue_atom_renaming_swaps.items(): + correspondences = np.arange(14) + for source_atom_swap, target_atom_swap in swap.items(): + source_index = restype_name_to_atom14_names[resname].index( + source_atom_swap) + target_index = restype_name_to_atom14_names[resname].index( + target_atom_swap) + correspondences[source_index] = target_index + correspondences[target_index] = source_index + renaming_matrix = np.zeros((14, 14)) + for index, correspondence in enumerate(correspondences): + renaming_matrix[index, correspondence] = 1.0 + all_matrices[resname] = renaming_matrix + renaming_matrices = np.stack( + [all_matrices[restype] for restype in restype_3]) + return renaming_matrices + + +renaming_matrices = _make_renaming_matrices() + + +def _make_atom14_is_ambiguous(): + # Create an ambiguous atoms mask. shape: (21, 14). + restype_atom14_is_ambiguous = np.zeros((21, 14)) + for resname, swap in residue_atom_renaming_swaps.items(): + for atom_name1, atom_name2 in swap.items(): + restype = restype_order[restype_3to1[resname]] + atom_idx1 = restype_name_to_atom14_names[resname].index(atom_name1) + atom_idx2 = restype_name_to_atom14_names[resname].index(atom_name2) + restype_atom14_is_ambiguous[restype, atom_idx1] = 1 + restype_atom14_is_ambiguous[restype, atom_idx2] = 1 + return restype_atom14_is_ambiguous + + +restype_atom14_is_ambiguous = _make_atom14_is_ambiguous() + + +def get_chi_atom_indices(): + """Returns atom indices needed to compute chi angles for all residue types. + + Returns: + A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are + in the order specified in restypes + unknown residue type + at the end. For chi angles which are not defined on the residue, the + positions indices are by default set to 0. + """ + chi_atom_indices = [] + for residue_name in restypes: + residue_name = restype_1to3[residue_name] + residue_chi_angles = chi_angles_atoms[residue_name] + atom_indices = [] + for chi_angle in residue_chi_angles: + atom_indices.append([atom_order[atom] for atom in chi_angle]) + for _ in range(4 - len(atom_indices)): + atom_indices.append([0, 0, 0, + 0]) # For chi angles not defined on the AA. + chi_atom_indices.append(atom_indices) + + chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. + + return chi_atom_indices + + +chi_atom_indices = get_chi_atom_indices() diff --git a/modelscope/models/science/unifold/data/stereo_chemical_props.txt b/modelscope/models/science/unifold/data/stereo_chemical_props.txt new file mode 100644 index 00000000..25262efd --- /dev/null +++ b/modelscope/models/science/unifold/data/stereo_chemical_props.txt @@ -0,0 +1,345 @@ +Bond Residue Mean StdDev +CA-CB ALA 1.520 0.021 +N-CA ALA 1.459 0.020 +CA-C ALA 1.525 0.026 +C-O ALA 1.229 0.019 +CA-CB ARG 1.535 0.022 +CB-CG ARG 1.521 0.027 +CG-CD ARG 1.515 0.025 +CD-NE ARG 1.460 0.017 +NE-CZ ARG 1.326 0.013 +CZ-NH1 ARG 1.326 0.013 +CZ-NH2 ARG 1.326 0.013 +N-CA ARG 1.459 0.020 +CA-C ARG 1.525 0.026 +C-O ARG 1.229 0.019 +CA-CB ASN 1.527 0.026 +CB-CG ASN 1.506 0.023 +CG-OD1 ASN 1.235 0.022 +CG-ND2 ASN 1.324 0.025 +N-CA ASN 1.459 0.020 +CA-C ASN 1.525 0.026 +C-O ASN 1.229 0.019 +CA-CB ASP 1.535 0.022 +CB-CG ASP 1.513 0.021 +CG-OD1 ASP 1.249 0.023 +CG-OD2 ASP 1.249 0.023 +N-CA ASP 1.459 0.020 +CA-C ASP 1.525 0.026 +C-O ASP 1.229 0.019 +CA-CB CYS 1.526 0.013 +CB-SG CYS 1.812 0.016 +N-CA CYS 1.459 0.020 +CA-C CYS 1.525 0.026 +C-O CYS 1.229 0.019 +CA-CB GLU 1.535 0.022 +CB-CG GLU 1.517 0.019 +CG-CD GLU 1.515 0.015 +CD-OE1 GLU 1.252 0.011 +CD-OE2 GLU 1.252 0.011 +N-CA GLU 1.459 0.020 +CA-C GLU 1.525 0.026 +C-O GLU 1.229 0.019 +CA-CB GLN 1.535 0.022 +CB-CG GLN 1.521 0.027 +CG-CD GLN 1.506 0.023 +CD-OE1 GLN 1.235 0.022 +CD-NE2 GLN 1.324 0.025 +N-CA GLN 1.459 0.020 +CA-C GLN 1.525 0.026 +C-O GLN 1.229 0.019 +N-CA GLY 1.456 0.015 +CA-C GLY 1.514 0.016 +C-O GLY 1.232 0.016 +CA-CB HIS 1.535 0.022 +CB-CG HIS 1.492 0.016 +CG-ND1 HIS 1.369 0.015 +CG-CD2 HIS 1.353 0.017 +ND1-CE1 HIS 1.343 0.025 +CD2-NE2 HIS 1.415 0.021 +CE1-NE2 HIS 1.322 0.023 +N-CA HIS 1.459 0.020 +CA-C HIS 1.525 0.026 +C-O HIS 1.229 0.019 +CA-CB ILE 1.544 0.023 +CB-CG1 ILE 1.536 0.028 +CB-CG2 ILE 1.524 0.031 +CG1-CD1 ILE 1.500 0.069 +N-CA ILE 1.459 0.020 +CA-C ILE 1.525 0.026 +C-O ILE 1.229 0.019 +CA-CB LEU 1.533 0.023 +CB-CG LEU 1.521 0.029 +CG-CD1 LEU 1.514 0.037 +CG-CD2 LEU 1.514 0.037 +N-CA LEU 1.459 0.020 +CA-C LEU 1.525 0.026 +C-O LEU 1.229 0.019 +CA-CB LYS 1.535 0.022 +CB-CG LYS 1.521 0.027 +CG-CD LYS 1.520 0.034 +CD-CE LYS 1.508 0.025 +CE-NZ LYS 1.486 0.025 +N-CA LYS 1.459 0.020 +CA-C LYS 1.525 0.026 +C-O LYS 1.229 0.019 +CA-CB MET 1.535 0.022 +CB-CG MET 1.509 0.032 +CG-SD MET 1.807 0.026 +SD-CE MET 1.774 0.056 +N-CA MET 1.459 0.020 +CA-C MET 1.525 0.026 +C-O MET 1.229 0.019 +CA-CB PHE 1.535 0.022 +CB-CG PHE 1.509 0.017 +CG-CD1 PHE 1.383 0.015 +CG-CD2 PHE 1.383 0.015 +CD1-CE1 PHE 1.388 0.020 +CD2-CE2 PHE 1.388 0.020 +CE1-CZ PHE 1.369 0.019 +CE2-CZ PHE 1.369 0.019 +N-CA PHE 1.459 0.020 +CA-C PHE 1.525 0.026 +C-O PHE 1.229 0.019 +CA-CB PRO 1.531 0.020 +CB-CG PRO 1.495 0.050 +CG-CD PRO 1.502 0.033 +CD-N PRO 1.474 0.014 +N-CA PRO 1.468 0.017 +CA-C PRO 1.524 0.020 +C-O PRO 1.228 0.020 +CA-CB SER 1.525 0.015 +CB-OG SER 1.418 0.013 +N-CA SER 1.459 0.020 +CA-C SER 1.525 0.026 +C-O SER 1.229 0.019 +CA-CB THR 1.529 0.026 +CB-OG1 THR 1.428 0.020 +CB-CG2 THR 1.519 0.033 +N-CA THR 1.459 0.020 +CA-C THR 1.525 0.026 +C-O THR 1.229 0.019 +CA-CB TRP 1.535 0.022 +CB-CG TRP 1.498 0.018 +CG-CD1 TRP 1.363 0.014 +CG-CD2 TRP 1.432 0.017 +CD1-NE1 TRP 1.375 0.017 +NE1-CE2 TRP 1.371 0.013 +CD2-CE2 TRP 1.409 0.012 +CD2-CE3 TRP 1.399 0.015 +CE2-CZ2 TRP 1.393 0.017 +CE3-CZ3 TRP 1.380 0.017 +CZ2-CH2 TRP 1.369 0.019 +CZ3-CH2 TRP 1.396 0.016 +N-CA TRP 1.459 0.020 +CA-C TRP 1.525 0.026 +C-O TRP 1.229 0.019 +CA-CB TYR 1.535 0.022 +CB-CG TYR 1.512 0.015 +CG-CD1 TYR 1.387 0.013 +CG-CD2 TYR 1.387 0.013 +CD1-CE1 TYR 1.389 0.015 +CD2-CE2 TYR 1.389 0.015 +CE1-CZ TYR 1.381 0.013 +CE2-CZ TYR 1.381 0.013 +CZ-OH TYR 1.374 0.017 +N-CA TYR 1.459 0.020 +CA-C TYR 1.525 0.026 +C-O TYR 1.229 0.019 +CA-CB VAL 1.543 0.021 +CB-CG1 VAL 1.524 0.021 +CB-CG2 VAL 1.524 0.021 +N-CA VAL 1.459 0.020 +CA-C VAL 1.525 0.026 +C-O VAL 1.229 0.019 +- + +Angle Residue Mean StdDev +N-CA-CB ALA 110.1 1.4 +CB-CA-C ALA 110.1 1.5 +N-CA-C ALA 111.0 2.7 +CA-C-O ALA 120.1 2.1 +N-CA-CB ARG 110.6 1.8 +CB-CA-C ARG 110.4 2.0 +CA-CB-CG ARG 113.4 2.2 +CB-CG-CD ARG 111.6 2.6 +CG-CD-NE ARG 111.8 2.1 +CD-NE-CZ ARG 123.6 1.4 +NE-CZ-NH1 ARG 120.3 0.5 +NE-CZ-NH2 ARG 120.3 0.5 +NH1-CZ-NH2 ARG 119.4 1.1 +N-CA-C ARG 111.0 2.7 +CA-C-O ARG 120.1 2.1 +N-CA-CB ASN 110.6 1.8 +CB-CA-C ASN 110.4 2.0 +CA-CB-CG ASN 113.4 2.2 +CB-CG-ND2 ASN 116.7 2.4 +CB-CG-OD1 ASN 121.6 2.0 +ND2-CG-OD1 ASN 121.9 2.3 +N-CA-C ASN 111.0 2.7 +CA-C-O ASN 120.1 2.1 +N-CA-CB ASP 110.6 1.8 +CB-CA-C ASP 110.4 2.0 +CA-CB-CG ASP 113.4 2.2 +CB-CG-OD1 ASP 118.3 0.9 +CB-CG-OD2 ASP 118.3 0.9 +OD1-CG-OD2 ASP 123.3 1.9 +N-CA-C ASP 111.0 2.7 +CA-C-O ASP 120.1 2.1 +N-CA-CB CYS 110.8 1.5 +CB-CA-C CYS 111.5 1.2 +CA-CB-SG CYS 114.2 1.1 +N-CA-C CYS 111.0 2.7 +CA-C-O CYS 120.1 2.1 +N-CA-CB GLU 110.6 1.8 +CB-CA-C GLU 110.4 2.0 +CA-CB-CG GLU 113.4 2.2 +CB-CG-CD GLU 114.2 2.7 +CG-CD-OE1 GLU 118.3 2.0 +CG-CD-OE2 GLU 118.3 2.0 +OE1-CD-OE2 GLU 123.3 1.2 +N-CA-C GLU 111.0 2.7 +CA-C-O GLU 120.1 2.1 +N-CA-CB GLN 110.6 1.8 +CB-CA-C GLN 110.4 2.0 +CA-CB-CG GLN 113.4 2.2 +CB-CG-CD GLN 111.6 2.6 +CG-CD-OE1 GLN 121.6 2.0 +CG-CD-NE2 GLN 116.7 2.4 +OE1-CD-NE2 GLN 121.9 2.3 +N-CA-C GLN 111.0 2.7 +CA-C-O GLN 120.1 2.1 +N-CA-C GLY 113.1 2.5 +CA-C-O GLY 120.6 1.8 +N-CA-CB HIS 110.6 1.8 +CB-CA-C HIS 110.4 2.0 +CA-CB-CG HIS 113.6 1.7 +CB-CG-ND1 HIS 123.2 2.5 +CB-CG-CD2 HIS 130.8 3.1 +CG-ND1-CE1 HIS 108.2 1.4 +ND1-CE1-NE2 HIS 109.9 2.2 +CE1-NE2-CD2 HIS 106.6 2.5 +NE2-CD2-CG HIS 109.2 1.9 +CD2-CG-ND1 HIS 106.0 1.4 +N-CA-C HIS 111.0 2.7 +CA-C-O HIS 120.1 2.1 +N-CA-CB ILE 110.8 2.3 +CB-CA-C ILE 111.6 2.0 +CA-CB-CG1 ILE 111.0 1.9 +CB-CG1-CD1 ILE 113.9 2.8 +CA-CB-CG2 ILE 110.9 2.0 +CG1-CB-CG2 ILE 111.4 2.2 +N-CA-C ILE 111.0 2.7 +CA-C-O ILE 120.1 2.1 +N-CA-CB LEU 110.4 2.0 +CB-CA-C LEU 110.2 1.9 +CA-CB-CG LEU 115.3 2.3 +CB-CG-CD1 LEU 111.0 1.7 +CB-CG-CD2 LEU 111.0 1.7 +CD1-CG-CD2 LEU 110.5 3.0 +N-CA-C LEU 111.0 2.7 +CA-C-O LEU 120.1 2.1 +N-CA-CB LYS 110.6 1.8 +CB-CA-C LYS 110.4 2.0 +CA-CB-CG LYS 113.4 2.2 +CB-CG-CD LYS 111.6 2.6 +CG-CD-CE LYS 111.9 3.0 +CD-CE-NZ LYS 111.7 2.3 +N-CA-C LYS 111.0 2.7 +CA-C-O LYS 120.1 2.1 +N-CA-CB MET 110.6 1.8 +CB-CA-C MET 110.4 2.0 +CA-CB-CG MET 113.3 1.7 +CB-CG-SD MET 112.4 3.0 +CG-SD-CE MET 100.2 1.6 +N-CA-C MET 111.0 2.7 +CA-C-O MET 120.1 2.1 +N-CA-CB PHE 110.6 1.8 +CB-CA-C PHE 110.4 2.0 +CA-CB-CG PHE 113.9 2.4 +CB-CG-CD1 PHE 120.8 0.7 +CB-CG-CD2 PHE 120.8 0.7 +CD1-CG-CD2 PHE 118.3 1.3 +CG-CD1-CE1 PHE 120.8 1.1 +CG-CD2-CE2 PHE 120.8 1.1 +CD1-CE1-CZ PHE 120.1 1.2 +CD2-CE2-CZ PHE 120.1 1.2 +CE1-CZ-CE2 PHE 120.0 1.8 +N-CA-C PHE 111.0 2.7 +CA-C-O PHE 120.1 2.1 +N-CA-CB PRO 103.3 1.2 +CB-CA-C PRO 111.7 2.1 +CA-CB-CG PRO 104.8 1.9 +CB-CG-CD PRO 106.5 3.9 +CG-CD-N PRO 103.2 1.5 +CA-N-CD PRO 111.7 1.4 +N-CA-C PRO 112.1 2.6 +CA-C-O PRO 120.2 2.4 +N-CA-CB SER 110.5 1.5 +CB-CA-C SER 110.1 1.9 +CA-CB-OG SER 111.2 2.7 +N-CA-C SER 111.0 2.7 +CA-C-O SER 120.1 2.1 +N-CA-CB THR 110.3 1.9 +CB-CA-C THR 111.6 2.7 +CA-CB-OG1 THR 109.0 2.1 +CA-CB-CG2 THR 112.4 1.4 +OG1-CB-CG2 THR 110.0 2.3 +N-CA-C THR 111.0 2.7 +CA-C-O THR 120.1 2.1 +N-CA-CB TRP 110.6 1.8 +CB-CA-C TRP 110.4 2.0 +CA-CB-CG TRP 113.7 1.9 +CB-CG-CD1 TRP 127.0 1.3 +CB-CG-CD2 TRP 126.6 1.3 +CD1-CG-CD2 TRP 106.3 0.8 +CG-CD1-NE1 TRP 110.1 1.0 +CD1-NE1-CE2 TRP 109.0 0.9 +NE1-CE2-CD2 TRP 107.3 1.0 +CE2-CD2-CG TRP 107.3 0.8 +CG-CD2-CE3 TRP 133.9 0.9 +NE1-CE2-CZ2 TRP 130.4 1.1 +CE3-CD2-CE2 TRP 118.7 1.2 +CD2-CE2-CZ2 TRP 122.3 1.2 +CE2-CZ2-CH2 TRP 117.4 1.0 +CZ2-CH2-CZ3 TRP 121.6 1.2 +CH2-CZ3-CE3 TRP 121.2 1.1 +CZ3-CE3-CD2 TRP 118.8 1.3 +N-CA-C TRP 111.0 2.7 +CA-C-O TRP 120.1 2.1 +N-CA-CB TYR 110.6 1.8 +CB-CA-C TYR 110.4 2.0 +CA-CB-CG TYR 113.4 1.9 +CB-CG-CD1 TYR 121.0 0.6 +CB-CG-CD2 TYR 121.0 0.6 +CD1-CG-CD2 TYR 117.9 1.1 +CG-CD1-CE1 TYR 121.3 0.8 +CG-CD2-CE2 TYR 121.3 0.8 +CD1-CE1-CZ TYR 119.8 0.9 +CD2-CE2-CZ TYR 119.8 0.9 +CE1-CZ-CE2 TYR 119.8 1.6 +CE1-CZ-OH TYR 120.1 2.7 +CE2-CZ-OH TYR 120.1 2.7 +N-CA-C TYR 111.0 2.7 +CA-C-O TYR 120.1 2.1 +N-CA-CB VAL 111.5 2.2 +CB-CA-C VAL 111.4 1.9 +CA-CB-CG1 VAL 110.9 1.5 +CA-CB-CG2 VAL 110.9 1.5 +CG1-CB-CG2 VAL 110.9 1.6 +N-CA-C VAL 111.0 2.7 +CA-C-O VAL 120.1 2.1 +- + +Non-bonded distance Minimum Dist Tolerance +C-C 3.4 1.5 +C-N 3.25 1.5 +C-S 3.5 1.5 +C-O 3.22 1.5 +N-N 3.1 1.5 +N-S 3.35 1.5 +N-O 3.07 1.5 +O-S 3.32 1.5 +O-O 3.04 1.5 +S-S 2.03 1.0 +- diff --git a/modelscope/models/science/unifold/data/utils.py b/modelscope/models/science/unifold/data/utils.py new file mode 100644 index 00000000..2be91ef0 --- /dev/null +++ b/modelscope/models/science/unifold/data/utils.py @@ -0,0 +1,161 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import copy as copy_lib +import functools +import gzip +import pickle +from typing import Any, Dict + +import json +import numpy as np +from scipy import sparse as sp + +from . import residue_constants as rc +from .data_ops import NumpyDict + +# from typing import * + + +def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False): + if deepcopy: + + def decorator(f): + cached_func = functools.lru_cache(maxsize, typed)(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return copy_lib.deepcopy(cached_func(*args, **kwargs)) + + return wrapper + + elif copy: + + def decorator(f): + cached_func = functools.lru_cache(maxsize, typed)(f) + + @functools.wraps(f) + def wrapper(*args, **kwargs): + return copy_lib.copy(cached_func(*args, **kwargs)) + + return wrapper + + else: + decorator = functools.lru_cache(maxsize, typed) + return decorator + + +@lru_cache(maxsize=8, deepcopy=True) +def load_pickle_safe(path: str) -> Dict[str, Any]: + + def load(path): + assert path.endswith('.pkl') or path.endswith( + '.pkl.gz'), f'bad suffix in {path} as pickle file.' + open_fn = gzip.open if path.endswith('.gz') else open + with open_fn(path, 'rb') as f: + return pickle.load(f) + + ret = load(path) + ret = uncompress_features(ret) + return ret + + +@lru_cache(maxsize=8, copy=True) +def load_pickle(path: str) -> Dict[str, Any]: + + def load(path): + assert path.endswith('.pkl') or path.endswith( + '.pkl.gz'), f'bad suffix in {path} as pickle file.' + open_fn = gzip.open if path.endswith('.gz') else open + with open_fn(path, 'rb') as f: + return pickle.load(f) + + ret = load(path) + ret = uncompress_features(ret) + return ret + + +def correct_template_restypes(feature): + """Correct template restype to have the same order as residue_constants.""" + feature = np.argmax(feature, axis=-1).astype(np.int32) + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + feature = np.take(new_order_list, feature.astype(np.int32), axis=0) + return feature + + +def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict: + feature['msa'] = feature['msa'].astype(np.uint8) + if 'num_alignments' in feature: + feature.pop('num_alignments') + # make_all_seq_key = lambda k: f'{k}_all_seq' if not k.endswith('_all_seq') else k + + def make_all_seq_key(k): + if not k.endswith('_all_seq'): + return f'{k}_all_seq' + return k + + return {make_all_seq_key(k): v for k, v in feature.items()} + + +def to_dense_matrix(spmat_dict: NumpyDict): + spmat = sp.coo_matrix( + (spmat_dict['data'], (spmat_dict['row'], spmat_dict['col'])), + shape=spmat_dict['shape'], + dtype=np.float32, + ) + return spmat.toarray() + + +FEATS_DTYPE = {'msa': np.int32} + + +def uncompress_features(feats: NumpyDict) -> NumpyDict: + if 'sparse_deletion_matrix_int' in feats: + v = feats.pop('sparse_deletion_matrix_int') + v = to_dense_matrix(v) + feats['deletion_matrix'] = v + return feats + + +def filter(feature: NumpyDict, **kwargs) -> NumpyDict: + assert len(kwargs) == 1, f'wrong usage of filter with kwargs: {kwargs}' + if 'desired_keys' in kwargs: + feature = { + k: v + for k, v in feature.items() if k in kwargs['desired_keys'] + } + elif 'required_keys' in kwargs: + for k in kwargs['required_keys']: + assert k in feature, f'cannot find required key {k}.' + elif 'ignored_keys' in kwargs: + feature = { + k: v + for k, v in feature.items() if k not in kwargs['ignored_keys'] + } + else: + raise AssertionError(f'wrong usage of filter with kwargs: {kwargs}') + return feature + + +def compress_features(features: NumpyDict): + change_dtype = { + 'msa': np.uint8, + } + sparse_keys = ['deletion_matrix_int'] + + compressed_features = {} + for k, v in features.items(): + if k in change_dtype: + v = v.astype(change_dtype[k]) + if k in sparse_keys: + v = sp.coo_matrix(v, dtype=v.dtype) + sp_v = { + 'shape': v.shape, + 'row': v.row, + 'col': v.col, + 'data': v.data + } + k = f'sparse_{k}' + v = sp_v + compressed_features[k] = v + return compressed_features diff --git a/modelscope/models/science/unifold/dataset.py b/modelscope/models/science/unifold/dataset.py new file mode 100644 index 00000000..05803f2c --- /dev/null +++ b/modelscope/models/science/unifold/dataset.py @@ -0,0 +1,514 @@ +import copy +import logging +import os +# from typing import * +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import json +import ml_collections as mlc +import numpy as np +import torch +from unicore.data import UnicoreDataset, data_utils +from unicore.distributed import utils as distributed_utils + +from .data import utils +from .data.data_ops import NumpyDict, TorchDict +from .data.process import process_features, process_labels +from .data.process_multimer import (add_assembly_features, + convert_monomer_features, merge_msas, + pair_and_merge, post_process) + +Rotation = Iterable[Iterable] +Translation = Iterable +Operation = Union[str, Tuple[Rotation, Translation]] +NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]] +TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]] + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + + +def make_data_config( + config: mlc.ConfigDict, + mode: str, + num_res: int, +) -> Tuple[mlc.ConfigDict, List[str]]: + cfg = copy.deepcopy(config) + mode_cfg = cfg[mode] + with cfg.unlocked(): + if mode_cfg.crop_size is None: + mode_cfg.crop_size = num_res + feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features + if cfg.common.use_templates: + feature_names += cfg.common.template_features + if cfg.common.is_multimer: + feature_names += cfg.common.multimer_features + if cfg[mode].supervised: + feature_names += cfg.supervised.supervised_features + + return cfg, feature_names + + +def process_label(all_atom_positions: np.ndarray, + operation: Operation) -> np.ndarray: + if operation == 'I': + return all_atom_positions + rot, trans = operation + rot = np.array(rot).reshape(3, 3) + trans = np.array(trans).reshape(3) + return all_atom_positions @ rot.T + trans + + +@utils.lru_cache(maxsize=8, copy=True) +def load_single_feature( + sequence_id: str, + monomer_feature_dir: str, + uniprot_msa_dir: Optional[str] = None, + is_monomer: bool = False, +) -> NumpyDict: + + monomer_feature = utils.load_pickle( + os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz')) + monomer_feature = convert_monomer_features(monomer_feature) + chain_feature = {**monomer_feature} + + if uniprot_msa_dir is not None: + all_seq_feature = utils.load_pickle( + os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz')) + if is_monomer: + chain_feature['msa'], chain_feature[ + 'deletion_matrix'] = merge_msas( + chain_feature['msa'], + chain_feature['deletion_matrix'], + all_seq_feature['msa'], + all_seq_feature['deletion_matrix'], + ) # noqa + else: + all_seq_feature = utils.convert_all_seq_feature(all_seq_feature) + for key in [ + 'msa_all_seq', + 'msa_species_identifiers_all_seq', + 'deletion_matrix_all_seq', + ]: + chain_feature[key] = all_seq_feature[key] + + return chain_feature + + +def load_single_label( + label_id: str, + label_dir: str, + symmetry_operation: Optional[Operation] = None, +) -> NumpyDict: + label = utils.load_pickle( + os.path.join(label_dir, f'{label_id}.label.pkl.gz')) + if symmetry_operation is not None: + label['all_atom_positions'] = process_label( + label['all_atom_positions'], symmetry_operation) + label = { + k: v + for k, v in label.items() if k in + ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] + } + return label + + +def load( + sequence_ids: List[str], + monomer_feature_dir: str, + uniprot_msa_dir: Optional[str] = None, + label_ids: Optional[List[str]] = None, + label_dir: Optional[str] = None, + symmetry_operations: Optional[List[Operation]] = None, + is_monomer: bool = False, +) -> NumpyExample: + + all_chain_features = [ + load_single_feature(s, monomer_feature_dir, uniprot_msa_dir, + is_monomer) for s in sequence_ids + ] + + if label_ids is not None: + # load labels + assert len(label_ids) == len(sequence_ids) + assert label_dir is not None + if symmetry_operations is None: + symmetry_operations = ['I' for _ in label_ids] + all_chain_labels = [ + load_single_label(ll, label_dir, o) + for ll, o in zip(label_ids, symmetry_operations) + ] + # update labels into features to calculate spatial cropping etc. + [f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)] + + all_chain_features = add_assembly_features(all_chain_features) + + # get labels back from features, as add_assembly_features may alter the order of inputs. + if label_ids is not None: + all_chain_labels = [{ + k: f[k] + for k in + ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] + } for f in all_chain_features] + else: + all_chain_labels = None + + asym_len = np.array([c['seq_length'] for c in all_chain_features], + dtype=np.int64) + if is_monomer: + all_chain_features = all_chain_features[0] + else: + all_chain_features = pair_and_merge(all_chain_features) + all_chain_features = post_process(all_chain_features) + all_chain_features['asym_len'] = asym_len + + return all_chain_features, all_chain_labels + + +def process( + config: mlc.ConfigDict, + mode: str, + features: NumpyDict, + labels: Optional[List[NumpyDict]] = None, + seed: int = 0, + batch_idx: Optional[int] = None, + data_idx: Optional[int] = None, + is_distillation: bool = False, +) -> TorchExample: + + if mode == 'train': + assert batch_idx is not None + with data_utils.numpy_seed(seed, batch_idx, key='recycling'): + num_iters = np.random.randint( + 0, config.common.max_recycling_iters + 1) + use_clamped_fape = np.random.rand( + ) < config[mode].use_clamped_fape_prob + else: + num_iters = config.common.max_recycling_iters + use_clamped_fape = 1 + + features['num_recycling_iters'] = int(num_iters) + features['use_clamped_fape'] = int(use_clamped_fape) + features['is_distillation'] = int(is_distillation) + if is_distillation and 'msa_chains' in features: + features.pop('msa_chains') + + num_res = int(features['seq_length']) + cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) + + if labels is not None: + features['resolution'] = labels[0]['resolution'].reshape(-1) + + with data_utils.numpy_seed(seed, data_idx, key='protein_feature'): + features['crop_and_fix_size_seed'] = np.random.randint(0, 63355) + features = utils.filter(features, desired_keys=feature_names) + features = {k: torch.tensor(v) for k, v in features.items()} + with torch.no_grad(): + features = process_features(features, cfg.common, cfg[mode]) + + if labels is not None: + labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels] + with torch.no_grad(): + labels = process_labels(labels) + + return features, labels + + +def load_and_process( + config: mlc.ConfigDict, + mode: str, + seed: int = 0, + batch_idx: Optional[int] = None, + data_idx: Optional[int] = None, + is_distillation: bool = False, + **load_kwargs, +): + is_monomer = ( + is_distillation + if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer')) + features, labels = load(**load_kwargs, is_monomer=is_monomer) + features, labels = process(config, mode, features, labels, seed, batch_idx, + data_idx, is_distillation) + return features, labels + + +class UnifoldDataset(UnicoreDataset): + + def __init__( + self, + args, + seed, + config, + data_path, + mode='train', + max_step=None, + disable_sd=False, + json_prefix='', + ): + self.path = data_path + + def load_json(filename): + return json.load(open(filename, 'r')) + + sample_weight = load_json( + os.path.join(self.path, + json_prefix + mode + '_sample_weight.json')) + self.multi_label = load_json( + os.path.join(self.path, json_prefix + mode + '_multi_label.json')) + self.inverse_multi_label = self._inverse_map(self.multi_label) + self.sample_weight = {} + for chain in self.inverse_multi_label: + entity = self.inverse_multi_label[chain] + self.sample_weight[chain] = sample_weight[entity] + self.seq_sample_weight = sample_weight + logger.info('load {} chains (unique {} sequences)'.format( + len(self.sample_weight), len(self.seq_sample_weight))) + self.feature_path = os.path.join(self.path, 'pdb_features') + self.label_path = os.path.join(self.path, 'pdb_labels') + sd_sample_weight_path = os.path.join( + self.path, json_prefix + 'sd_train_sample_weight.json') + if mode == 'train' and os.path.isfile( + sd_sample_weight_path) and not disable_sd: + self.sd_sample_weight = load_json(sd_sample_weight_path) + logger.info('load {} self-distillation samples.'.format( + len(self.sd_sample_weight))) + self.sd_feature_path = os.path.join(self.path, 'sd_features') + self.sd_label_path = os.path.join(self.path, 'sd_labels') + else: + self.sd_sample_weight = None + self.batch_size = ( + args.batch_size * distributed_utils.get_data_parallel_world_size() + * args.update_freq[0]) + self.data_len = ( + max_step * self.batch_size + if max_step is not None else len(self.sample_weight)) + self.mode = mode + self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight( + self.seq_sample_weight) + self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( + self.sample_weight) + if self.sd_sample_weight is not None: + ( + self.sd_num_chain, + self.sd_chain_keys, + self.sd_sample_prob, + ) = self.cal_sample_weight(self.sd_sample_weight) + self.config = config.data + self.seed = seed + self.sd_prob = args.sd_prob + + def cal_sample_weight(self, sample_weight): + prot_keys = list(sample_weight.keys()) + sum_weight = sum(sample_weight.values()) + sample_prob = [sample_weight[k] / sum_weight for k in prot_keys] + num_prot = len(prot_keys) + return num_prot, prot_keys, sample_prob + + def sample_chain(self, idx, sample_by_seq=False): + is_distillation = False + if self.mode == 'train': + with data_utils.numpy_seed(self.seed, idx, key='data_sample'): + is_distillation = ((np.random.rand(1)[0] < self.sd_prob) + if self.sd_sample_weight is not None else + False) + if is_distillation: + prot_idx = np.random.choice( + self.sd_num_chain, p=self.sd_sample_prob) + label_name = self.sd_chain_keys[prot_idx] + seq_name = label_name + else: + if not sample_by_seq: + prot_idx = np.random.choice( + self.num_chain, p=self.sample_prob) + label_name = self.chain_keys[prot_idx] + seq_name = self.inverse_multi_label[label_name] + else: + seq_idx = np.random.choice( + self.num_seq, p=self.seq_sample_prob) + seq_name = self.seq_keys[seq_idx] + label_name = np.random.choice( + self.multi_label[seq_name]) + else: + label_name = self.chain_keys[idx] + seq_name = self.inverse_multi_label[label_name] + return seq_name, label_name, is_distillation + + def __getitem__(self, idx): + sequence_id, label_id, is_distillation = self.sample_chain( + idx, sample_by_seq=True) + feature_dir, label_dir = ((self.feature_path, + self.label_path) if not is_distillation else + (self.sd_feature_path, self.sd_label_path)) + features, _ = load_and_process( + self.config, + self.mode, + self.seed, + batch_idx=(idx // self.batch_size), + data_idx=idx, + is_distillation=is_distillation, + sequence_ids=[sequence_id], + monomer_feature_dir=feature_dir, + uniprot_msa_dir=None, + label_ids=[label_id], + label_dir=label_dir, + symmetry_operations=None, + is_monomer=True, + ) + return features + + def __len__(self): + return self.data_len + + @staticmethod + def collater(samples): + # first dim is recyling. bsz is at the 2nd dim + return data_utils.collate_dict(samples, dim=1) + + @staticmethod + def _inverse_map(mapping: Dict[str, List[str]]): + inverse_mapping = {} + for ent, refs in mapping.items(): + for ref in refs: + if ref in inverse_mapping: # duplicated ent for this ref. + ent_2 = inverse_mapping[ref] + assert ( + ent == ent_2 + ), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.' + inverse_mapping[ref] = ent + return inverse_mapping + + +class UnifoldMultimerDataset(UnifoldDataset): + + def __init__( + self, + args: mlc.ConfigDict, + seed: int, + config: mlc.ConfigDict, + data_path: str, + mode: str = 'train', + max_step: Optional[int] = None, + disable_sd: bool = False, + json_prefix: str = '', + **kwargs, + ): + super().__init__(args, seed, config, data_path, mode, max_step, + disable_sd, json_prefix) + self.data_path = data_path + self.pdb_assembly = json.load( + open( + os.path.join(self.data_path, + json_prefix + 'pdb_assembly.json'))) + self.pdb_chains = self.get_chains(self.inverse_multi_label) + self.monomer_feature_path = os.path.join(self.data_path, + 'pdb_features') + self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots') + self.label_path = os.path.join(self.data_path, 'pdb_labels') + self.max_chains = args.max_chains + if self.mode == 'train': + self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains( + self.pdb_chains, self.pdb_assembly, self.sample_weight, + self.max_chains) + self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( + self.sample_weight) + + def __getitem__(self, idx): + seq_id, label_id, is_distillation = self.sample_chain(idx) + if is_distillation: + label_ids = [label_id] + sequence_ids = [seq_id] + monomer_feature_path, uniprot_msa_path, label_path = ( + self.sd_feature_path, + None, + self.sd_label_path, + ) + symmetry_operations = None + else: + pdb_id = self.get_pdb_name(label_id) + if pdb_id in self.pdb_assembly and self.mode == 'train': + label_ids = [ + pdb_id + '_' + id + for id in self.pdb_assembly[pdb_id]['chains'] + ] + symmetry_operations = [ + t for t in self.pdb_assembly[pdb_id]['opers'] + ] + else: + label_ids = self.pdb_chains[pdb_id] + symmetry_operations = None + sequence_ids = [ + self.inverse_multi_label[chain_id] for chain_id in label_ids + ] + monomer_feature_path, uniprot_msa_path, label_path = ( + self.monomer_feature_path, + self.uniprot_msa_path, + self.label_path, + ) + + return load_and_process( + self.config, + self.mode, + self.seed, + batch_idx=(idx // self.batch_size), + data_idx=idx, + is_distillation=is_distillation, + sequence_ids=sequence_ids, + monomer_feature_dir=monomer_feature_path, + uniprot_msa_dir=uniprot_msa_path, + label_ids=label_ids, + label_dir=label_path, + symmetry_operations=symmetry_operations, + is_monomer=False, + ) + + @staticmethod + def collater(samples): + # first dim is recyling. bsz is at the 2nd dim + if len(samples) <= 0: # tackle empty batch + return None + feats = [s[0] for s in samples] + labs = [s[1] for s in samples if s[1] is not None] + try: + feats = data_utils.collate_dict(feats, dim=1) + except BaseException: + raise ValueError('cannot collate features', feats) + if not labs: + labs = None + return feats, labs + + @staticmethod + def get_pdb_name(chain): + return chain.split('_')[0] + + @staticmethod + def get_chains(canon_chain_map): + pdb_chains = {} + for chain in canon_chain_map: + pdb = UnifoldMultimerDataset.get_pdb_name(chain) + if pdb not in pdb_chains: + pdb_chains[pdb] = [] + pdb_chains[pdb].append(chain) + return pdb_chains + + @staticmethod + def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight, + max_chains): + new_pdb_chains = {} + for chain in pdb_chains: + if chain in pdb_assembly: + size = len(pdb_assembly[chain]['chains']) + if size <= max_chains: + new_pdb_chains[chain] = pdb_chains[chain] + else: + size = len(pdb_chains[chain]) + if size == 1: + new_pdb_chains[chain] = pdb_chains[chain] + new_sample_weight = { + k: sample_weight[k] + for k in sample_weight + if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains + } + logger.info( + f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs ' + f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) ' + f'by max_chains {max_chains}') + return new_pdb_chains, new_sample_weight diff --git a/modelscope/models/science/unifold/model.py b/modelscope/models/science/unifold/model.py new file mode 100644 index 00000000..6632751a --- /dev/null +++ b/modelscope/models/science/unifold/model.py @@ -0,0 +1,75 @@ +import argparse +import os +from typing import Any + +import torch + +from modelscope.metainfo import Models +from modelscope.models import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .config import model_config +from .modules.alphafold import AlphaFold + +__all__ = ['UnifoldForProteinStructrue'] + + +@MODELS.register_module(Tasks.protein_structure, module_name=Models.unifold) +class UnifoldForProteinStructrue(TorchModel): + + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + '--model-name', + help='choose the model config', + ) + + def __init__(self, **kwargs): + super().__init__() + parser = argparse.ArgumentParser() + parse_comm = [] + for key in kwargs: + parser.add_argument(f'--{key}') + parse_comm.append(f'--{key}') + parse_comm.append(kwargs[key]) + args = parser.parse_args(parse_comm) + base_architecture(args) + self.args = args + config = model_config( + self.args.model_name, + train=True, + ) + self.model = AlphaFold(config) + self.config = config + + # load model state dict + param_path = os.path.join(kwargs['model_dir'], + ModelFile.TORCH_MODEL_BIN_FILE) + state_dict = torch.load(param_path)['ema']['params'] + state_dict = { + '.'.join(k.split('.')[1:]): v + for k, v in state_dict.items() + } + self.model.load_state_dict(state_dict) + + def half(self): + self.model = self.model.half() + return self + + def bfloat16(self): + self.model = self.model.bfloat16() + return self + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + return cls(args) + + def forward(self, batch, **kwargs): + outputs = self.model.forward(batch) + return outputs, self.config.loss + + +def base_architecture(args): + args.model_name = getattr(args, 'model_name', 'model_2') diff --git a/modelscope/models/science/unifold/modules/alphafold.py b/modelscope/models/science/unifold/modules/alphafold.py new file mode 100644 index 00000000..71a1b310 --- /dev/null +++ b/modelscope/models/science/unifold/modules/alphafold.py @@ -0,0 +1,450 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import torch +import torch.nn as nn +from unicore.utils import tensor_tree_map + +from ..data import residue_constants +from .attentions import gen_msa_attn_mask, gen_tri_attn_mask +from .auxillary_heads import AuxiliaryHeads +from .common import residual +from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder, + TemplateAngleEmbedder, TemplatePairEmbedder) +from .evoformer import EvoformerStack, ExtraMSAStack +from .featurization import (atom14_to_atom37, build_extra_msa_feat, + build_template_angle_feat, + build_template_pair_feat, + build_template_pair_feat_v2, pseudo_beta_fn) +from .structure_module import StructureModule +from .template import (TemplatePairStack, TemplatePointwiseAttention, + TemplateProjection) + + +class AlphaFold(nn.Module): + + def __init__(self, config): + super(AlphaFold, self).__init__() + + self.globals = config.globals + config = config.model + template_config = config.template + extra_msa_config = config.extra_msa + + self.input_embedder = InputEmbedder( + **config['input_embedder'], + use_chain_relative=config.is_multimer, + ) + self.recycling_embedder = RecyclingEmbedder( + **config['recycling_embedder'], ) + if config.template.enabled: + self.template_angle_embedder = TemplateAngleEmbedder( + **template_config['template_angle_embedder'], ) + self.template_pair_embedder = TemplatePairEmbedder( + **template_config['template_pair_embedder'], ) + self.template_pair_stack = TemplatePairStack( + **template_config['template_pair_stack'], ) + else: + self.template_pair_stack = None + self.enable_template_pointwise_attention = template_config[ + 'template_pointwise_attention'].enabled + if self.enable_template_pointwise_attention: + self.template_pointwise_att = TemplatePointwiseAttention( + **template_config['template_pointwise_attention'], ) + else: + self.template_proj = TemplateProjection( + **template_config['template_pointwise_attention'], ) + self.extra_msa_embedder = ExtraMSAEmbedder( + **extra_msa_config['extra_msa_embedder'], ) + self.extra_msa_stack = ExtraMSAStack( + **extra_msa_config['extra_msa_stack'], ) + self.evoformer = EvoformerStack(**config['evoformer_stack'], ) + self.structure_module = StructureModule(**config['structure_module'], ) + + self.aux_heads = AuxiliaryHeads(config['heads'], ) + + self.config = config + self.dtype = torch.float + self.inf = self.globals.inf + if self.globals.alphafold_original_mode: + self.alphafold_original_mode() + + def __make_input_float__(self): + self.input_embedder = self.input_embedder.float() + self.recycling_embedder = self.recycling_embedder.float() + + def half(self): + super().half() + if (not getattr(self, 'inference', False)): + self.__make_input_float__() + self.dtype = torch.half + return self + + def bfloat16(self): + super().bfloat16() + if (not getattr(self, 'inference', False)): + self.__make_input_float__() + self.dtype = torch.bfloat16 + return self + + def alphafold_original_mode(self): + + def set_alphafold_original_mode(module): + if hasattr(module, 'apply_alphafold_original_mode'): + module.apply_alphafold_original_mode() + if hasattr(module, 'act'): + module.act = nn.ReLU() + + self.apply(set_alphafold_original_mode) + + def inference_mode(self): + + def set_inference_mode(module): + setattr(module, 'inference', True) + + self.apply(set_inference_mode) + + def __convert_input_dtype__(self, batch): + for key in batch: + # only convert features with mask + if batch[key].dtype != self.dtype and 'mask' in key: + batch[key] = batch[key].type(self.dtype) + return batch + + def embed_templates_pair_core(self, batch, z, pair_mask, + tri_start_attn_mask, tri_end_attn_mask, + templ_dim, multichain_mask_2d): + if self.config.template.template_pair_embedder.v2_feature: + t = build_template_pair_feat_v2( + batch, + inf=self.config.template.inf, + eps=self.config.template.eps, + multichain_mask_2d=multichain_mask_2d, + **self.config.template.distogram, + ) + num_template = t[0].shape[-4] + single_templates = [ + self.template_pair_embedder([x[..., ti, :, :, :] + for x in t], z) + for ti in range(num_template) + ] + else: + t = build_template_pair_feat( + batch, + inf=self.config.template.inf, + eps=self.config.template.eps, + **self.config.template.distogram, + ) + single_templates = [ + self.template_pair_embedder(x, z) + for x in torch.unbind(t, dim=templ_dim) + ] + + t = self.template_pair_stack( + single_templates, + pair_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + templ_dim=templ_dim, + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + return_mean=not self.enable_template_pointwise_attention, + ) + return t + + def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask, + tri_end_attn_mask, templ_dim): + if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch: + multichain_mask_2d = ( + batch['asym_id'][..., :, None] == batch['asym_id'][..., + None, :]) + multichain_mask_2d = multichain_mask_2d.unsqueeze(0) + else: + multichain_mask_2d = None + + if self.training or self.enable_template_pointwise_attention: + t = self.embed_templates_pair_core(batch, z, pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, templ_dim, + multichain_mask_2d) + if self.enable_template_pointwise_attention: + t = self.template_pointwise_att( + t, + z, + template_mask=batch['template_mask'], + chunk_size=self.globals.chunk_size, + ) + t_mask = torch.sum( + batch['template_mask'], dim=-1, keepdims=True) > 0 + t_mask = t_mask[..., None, None].type(t.dtype) + t *= t_mask + else: + t = self.template_proj(t, z) + else: + template_aatype_shape = batch['template_aatype'].shape + # template_aatype is either [n_template, n_res] or [1, n_template_, n_res] + batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0 + n_templ = batch['template_aatype'].shape[batch_templ_dim] + + if n_templ <= 0: + t = None + else: + template_batch = { + k: v + for k, v in batch.items() if k.startswith('template_') + } + + def embed_one_template(i): + + def slice_template_tensor(t): + s = [slice(None) for _ in t.shape] + s[batch_templ_dim] = slice(i, i + 1) + return t[s] + + template_feats = tensor_tree_map( + slice_template_tensor, + template_batch, + ) + t = self.embed_templates_pair_core( + template_feats, z, pair_mask, tri_start_attn_mask, + tri_end_attn_mask, templ_dim, multichain_mask_2d) + return t + + t = embed_one_template(0) + # iterate templates one by one + for i in range(1, n_templ): + t += embed_one_template(i) + t /= n_templ + t = self.template_proj(t, z) + return t + + def embed_templates_angle(self, batch): + template_angle_feat, template_angle_mask = build_template_angle_feat( + batch, + v2_feature=self.config.template.template_pair_embedder.v2_feature) + t = self.template_angle_embedder(template_angle_feat) + return t, template_angle_mask + + def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev): + batch_dims = feats['target_feat'].shape[:-2] + n = feats['target_feat'].shape[-2] + seq_mask = feats['seq_mask'] + pair_mask = seq_mask[..., None] * seq_mask[..., None, :] + msa_mask = feats['msa_mask'] + + m, z = self.input_embedder( + feats['target_feat'], + feats['msa_feat'], + ) + + if m_1_prev is None: + m_1_prev = m.new_zeros( + (*batch_dims, n, self.config.input_embedder.d_msa), + requires_grad=False, + ) + if z_prev is None: + z_prev = z.new_zeros( + (*batch_dims, n, n, self.config.input_embedder.d_pair), + requires_grad=False, + ) + if x_prev is None: + x_prev = z.new_zeros( + (*batch_dims, n, residue_constants.atom_type_num, 3), + requires_grad=False, + ) + x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None) + + z += self.recycling_embedder.recyle_pos(x_prev) + + m_1_prev_emb, z_prev_emb = self.recycling_embedder( + m_1_prev, + z_prev, + ) + + m[..., 0, :, :] += m_1_prev_emb + + z += z_prev_emb + + z += self.input_embedder.relpos_emb( + feats['residue_index'].long(), + feats.get('sym_id', None), + feats.get('asym_id', None), + feats.get('entity_id', None), + feats.get('num_sym', None), + ) + + m = m.type(self.dtype) + z = z.type(self.dtype) + tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask( + pair_mask, self.inf) + + if self.config.template.enabled: + template_mask = feats['template_mask'] + if torch.any(template_mask): + z = residual( + z, + self.embed_templates_pair( + feats, + z, + pair_mask, + tri_start_attn_mask, + tri_end_attn_mask, + templ_dim=-4, + ), + self.training, + ) + + if self.config.extra_msa.enabled: + a = self.extra_msa_embedder(build_extra_msa_feat(feats)) + extra_msa_row_mask = gen_msa_attn_mask( + feats['extra_msa_mask'], + inf=self.inf, + gen_col_mask=False, + ) + z = self.extra_msa_stack( + a, + z, + msa_mask=feats['extra_msa_mask'], + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + pair_mask=pair_mask, + msa_row_attn_mask=extra_msa_row_mask, + msa_col_attn_mask=None, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + ) + + if self.config.template.embed_angles: + template_1d_feat, template_1d_mask = self.embed_templates_angle( + feats) + m = torch.cat([m, template_1d_feat], dim=-3) + msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2) + + msa_row_mask, msa_col_mask = gen_msa_attn_mask( + msa_mask, + inf=self.inf, + ) + + m, z, s = self.evoformer( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_mask, + msa_col_attn_mask=msa_col_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=self.globals.chunk_size, + block_size=self.globals.block_size, + ) + return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb + + def iteration_evoformer_structure_module(self, + batch, + m_1_prev, + z_prev, + x_prev, + cycle_no, + num_recycling, + num_ensembles=1): + z, s = 0, 0 + n_seq = batch['msa_feat'].shape[-3] + assert num_ensembles >= 1 + for ensemble_no in range(num_ensembles): + idx = cycle_no * num_ensembles + ensemble_no + + # fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...] + def fetch_cur_batch(t): + return t[min(t.shape[0] - 1, idx), ...] + + feats = tensor_tree_map(fetch_cur_batch, batch) + m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer( + feats, m_1_prev, z_prev, x_prev) + z += z0 + s += s0 + del z0, s0 + if num_ensembles > 1: + z /= float(num_ensembles) + s /= float(num_ensembles) + + outputs = {} + + outputs['msa'] = m[..., :n_seq, :, :] + outputs['pair'] = z + outputs['single'] = s + + # norm loss + if (not getattr(self, 'inference', + False)) and num_recycling == (cycle_no + 1): + delta_msa = m + delta_msa[..., + 0, :, :] = delta_msa[..., + 0, :, :] - m_1_prev_emb.detach() + delta_pair = z - z_prev_emb.detach() + outputs['delta_msa'] = delta_msa + outputs['delta_pair'] = delta_pair + outputs['msa_norm_mask'] = msa_mask + + outputs['sm'] = self.structure_module( + s, + z, + feats['aatype'], + mask=feats['seq_mask'], + ) + outputs['final_atom_positions'] = atom14_to_atom37( + outputs['sm']['positions'], feats) + outputs['final_atom_mask'] = feats['atom37_atom_exists'] + outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1] + + # use float32 for numerical stability + if (not getattr(self, 'inference', False)): + m_1_prev = m[..., 0, :, :].float() + z_prev = z.float() + x_prev = outputs['final_atom_positions'].float() + else: + m_1_prev = m[..., 0, :, :] + z_prev = z + x_prev = outputs['final_atom_positions'] + + return outputs, m_1_prev, z_prev, x_prev + + def forward(self, batch): + + m_1_prev = batch.get('m_1_prev', None) + z_prev = batch.get('z_prev', None) + x_prev = batch.get('x_prev', None) + + is_grad_enabled = torch.is_grad_enabled() + + num_iters = int(batch['num_recycling_iters']) + 1 + num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters + if self.training: + # don't use ensemble during training + assert num_ensembles == 1 + + # convert dtypes in batch + batch = self.__convert_input_dtype__(batch) + for cycle_no in range(num_iters): + is_final_iter = cycle_no == (num_iters - 1) + with torch.set_grad_enabled(is_grad_enabled and is_final_iter): + ( + outputs, + m_1_prev, + z_prev, + x_prev, + ) = self.iteration_evoformer_structure_module( + batch, + m_1_prev, + z_prev, + x_prev, + cycle_no=cycle_no, + num_recycling=num_iters, + num_ensembles=num_ensembles, + ) + if not is_final_iter: + del outputs + + if 'asym_id' in batch: + outputs['asym_id'] = batch['asym_id'][0, ...] + outputs.update(self.aux_heads(outputs)) + return outputs diff --git a/modelscope/models/science/unifold/modules/attentions.py b/modelscope/models/science/unifold/modules/attentions.py new file mode 100644 index 00000000..d2319079 --- /dev/null +++ b/modelscope/models/science/unifold/modules/attentions.py @@ -0,0 +1,430 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partialmethod +from typing import List, Optional + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm, softmax_dropout +from unicore.utils import permute_final_dims + +from .common import Linear, chunk_layer + + +def gen_attn_mask(mask, neg_inf): + assert neg_inf < -1e4 + attn_mask = torch.zeros_like(mask) + attn_mask[mask == 0] = neg_inf + return attn_mask + + +class Attention(nn.Module): + + def __init__( + self, + q_dim: int, + k_dim: int, + v_dim: int, + head_dim: int, + num_heads: int, + gating: bool = True, + ): + super(Attention, self).__init__() + + self.num_heads = num_heads + total_dim = head_dim * self.num_heads + self.gating = gating + self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot') + self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot') + self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot') + self.linear_o = Linear(total_dim, q_dim, init='final') + self.linear_g = None + if self.gating: + self.linear_g = Linear(q_dim, total_dim, init='gating') + # precompute the 1/sqrt(head_dim) + self.norm = head_dim**-0.5 + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: torch.Tensor = None, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + g = None + if self.linear_g is not None: + # gating, use raw query input + g = self.linear_g(q) + + q = self.linear_q(q) + q *= self.norm + k = self.linear_k(k) + v = self.linear_v(v) + + q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose( + -2, -3).contiguous() + k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose( + -2, -3).contiguous() + v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) + + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + + attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) + o = torch.matmul(attn, v) + del attn, v + + o = o.transpose(-2, -3).contiguous() + o = o.view(*o.shape[:-2], -1) + + if g is not None: + o = torch.sigmoid(g) * o + + # merge heads + o = nn.functional.linear(o, self.linear_o.weight) + return o + + def get_output_bias(self): + return self.linear_o.bias + + +class GlobalAttention(nn.Module): + + def __init__(self, input_dim, head_dim, num_heads, inf, eps): + super(GlobalAttention, self).__init__() + + self.num_heads = num_heads + self.inf = inf + self.eps = eps + self.linear_q = Linear( + input_dim, head_dim * num_heads, bias=False, init='glorot') + self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot') + self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot') + self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating') + self.linear_o = Linear(head_dim * num_heads, input_dim, init='final') + self.sigmoid = nn.Sigmoid() + # precompute the 1/sqrt(head_dim) + self.norm = head_dim**-0.5 + + def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + + # gating + g = self.sigmoid(self.linear_g(x)) + + k = self.linear_k(x) + v = self.linear_v(x) + + q = torch.sum( + x * mask.unsqueeze(-1), dim=-2) / ( + torch.sum(mask, dim=-1, keepdims=True) + self.eps) + q = self.linear_q(q) + q *= self.norm + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + attn = torch.matmul(q, k.transpose(-1, -2)) + del q, k + + attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :] + attn = softmax_dropout(attn, 0, self.training, mask=attn_mask) + + o = torch.matmul( + attn, + v, + ) + del attn, v + + g = g.view(g.shape[:-1] + (self.num_heads, -1)) + o = o.unsqueeze(-3) * g + del g + + # merge heads + o = o.reshape(o.shape[:-2] + (-1, )) + return self.linear_o(o) + + +def gen_msa_attn_mask(mask, inf, gen_col_mask=True): + row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] + if gen_col_mask: + col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, + None, :] + return row_mask, col_mask + else: + return row_mask + + +class MSAAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + pair_bias=False, + d_pair=None, + ): + super(MSAAttention, self).__init__() + + self.pair_bias = pair_bias + self.layer_norm_m = LayerNorm(d_in) + self.layer_norm_z = None + self.linear_z = None + if self.pair_bias: + self.layer_norm_z = LayerNorm(d_pair) + self.linear_z = Linear( + d_pair, num_heads, bias=False, init='normal') + + self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: int = None, + ) -> torch.Tensor: + + return chunk_layer( + self._attn_forward, + { + 'm': m, + 'mask': mask, + 'bias': bias + }, + chunk_size=chunk_size, + num_batch_dims=len(m.shape[:-2]), + ) + + @torch.jit.ignore + def _attn_chunk_forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = 2560, + ) -> torch.Tensor: + m = self.layer_norm_m(m) + num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size + outputs = [] + for i in range(num_chunk): + chunk_start = i * chunk_size + chunk_end = min(m.shape[-3], chunk_start + chunk_size) + cur_m = m[..., chunk_start:chunk_end, :, :] + cur_mask = ( + mask[..., chunk_start:chunk_end, :, :, :] + if mask is not None else None) + outputs.append( + self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) + return torch.concat(outputs, dim=-3) + + def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): + m = self.layer_norm_m(m) + return self.mha(q=m, k=m, v=m, mask=mask, bias=bias) + + def forward( + self, + m: torch.Tensor, + z: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + bias = None + if self.pair_bias: + z = self.layer_norm_z(z) + bias = ( + permute_final_dims(self.linear_z(z), + (2, 0, 1)).unsqueeze(-4).contiguous()) + + if chunk_size is not None: + m = self._chunk(m, attn_mask, bias, chunk_size) + else: + attn_chunk_size = 2560 + if m.shape[-3] <= attn_chunk_size: + m = self._attn_forward(m, attn_mask, bias) + else: + # reduce the peak memory cost in extra_msa_stack + return self._attn_chunk_forward( + m, attn_mask, bias, chunk_size=attn_chunk_size) + + return m + + def get_output_bias(self): + return self.mha.get_output_bias() + + +class MSARowAttentionWithPairBias(MSAAttention): + + def __init__(self, d_msa, d_pair, d_hid, num_heads): + super(MSARowAttentionWithPairBias, self).__init__( + d_msa, + d_hid, + num_heads, + pair_bias=True, + d_pair=d_pair, + ) + + +class MSAColumnAttention(MSAAttention): + + def __init__(self, d_msa, d_hid, num_heads): + super(MSAColumnAttention, self).__init__( + d_in=d_msa, + d_hid=d_hid, + num_heads=num_heads, + pair_bias=False, + d_pair=None, + ) + + def forward( + self, + m: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + m = m.transpose(-2, -3) + m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size) + m = m.transpose(-2, -3) + + return m + + +class MSAColumnGlobalAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + inf=1e9, + eps=1e-10, + ): + super(MSAColumnGlobalAttention, self).__init__() + + self.layer_norm_m = LayerNorm(d_in) + self.global_attention = GlobalAttention( + d_in, + d_hid, + num_heads, + inf=inf, + eps=eps, + ) + + @torch.jit.ignore + def _chunk( + self, + m: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._attn_forward, + { + 'm': m, + 'mask': mask + }, + chunk_size=chunk_size, + num_batch_dims=len(m.shape[:-2]), + ) + + def _attn_forward(self, m, mask): + m = self.layer_norm_m(m) + return self.global_attention(m, mask=mask) + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + m = m.transpose(-2, -3) + mask = mask.transpose(-1, -2) + + if chunk_size is not None: + m = self._chunk(m, mask, chunk_size) + else: + m = self._attn_forward(m, mask=mask) + + m = m.transpose(-2, -3) + return m + + +def gen_tri_attn_mask(mask, inf): + start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] + end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, + None, :] + return start_mask, end_mask + + +class TriangleAttention(nn.Module): + + def __init__( + self, + d_in, + d_hid, + num_heads, + starting, + ): + super(TriangleAttention, self).__init__() + self.starting = starting + self.layer_norm = LayerNorm(d_in) + self.linear = Linear(d_in, num_heads, bias=False, init='normal') + self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + chunk_size: int = None, + ) -> torch.Tensor: + return chunk_layer( + self.mha, + { + 'q': x, + 'k': x, + 'v': x, + 'mask': mask, + 'bias': bias + }, + chunk_size=chunk_size, + num_batch_dims=len(x.shape[:-2]), + ) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + if not self.starting: + x = x.transpose(-2, -3) + + x = self.layer_norm(x) + triangle_bias = ( + permute_final_dims(self.linear(x), + (2, 0, 1)).unsqueeze(-4).contiguous()) + + if chunk_size is not None: + x = self._chunk(x, attn_mask, triangle_bias, chunk_size) + else: + x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias) + + if not self.starting: + x = x.transpose(-2, -3) + return x + + def get_output_bias(self): + return self.mha.get_output_bias() + + +class TriangleAttentionStarting(TriangleAttention): + __init__ = partialmethod(TriangleAttention.__init__, starting=True) + + +class TriangleAttentionEnding(TriangleAttention): + __init__ = partialmethod(TriangleAttention.__init__, starting=False) diff --git a/modelscope/models/science/unifold/modules/auxillary_heads.py b/modelscope/models/science/unifold/modules/auxillary_heads.py new file mode 100644 index 00000000..2daf5d55 --- /dev/null +++ b/modelscope/models/science/unifold/modules/auxillary_heads.py @@ -0,0 +1,171 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict + +import torch.nn as nn +from unicore.modules import LayerNorm + +from .common import Linear +from .confidence import (predicted_aligned_error, predicted_lddt, + predicted_tm_score) + + +class AuxiliaryHeads(nn.Module): + + def __init__(self, config): + super(AuxiliaryHeads, self).__init__() + + self.plddt = PredictedLDDTHead(**config['plddt'], ) + + self.distogram = DistogramHead(**config['distogram'], ) + + self.masked_msa = MaskedMSAHead(**config['masked_msa'], ) + + if config.experimentally_resolved.enabled: + self.experimentally_resolved = ExperimentallyResolvedHead( + **config['experimentally_resolved'], ) + + if config.pae.enabled: + self.pae = PredictedAlignedErrorHead(**config.pae, ) + + self.config = config + + def forward(self, outputs): + aux_out = {} + plddt_logits = self.plddt(outputs['sm']['single']) + aux_out['plddt_logits'] = plddt_logits + + aux_out['plddt'] = predicted_lddt(plddt_logits.detach()) + + distogram_logits = self.distogram(outputs['pair']) + aux_out['distogram_logits'] = distogram_logits + + masked_msa_logits = self.masked_msa(outputs['msa']) + aux_out['masked_msa_logits'] = masked_msa_logits + + if self.config.experimentally_resolved.enabled: + exp_res_logits = self.experimentally_resolved(outputs['single']) + aux_out['experimentally_resolved_logits'] = exp_res_logits + + if self.config.pae.enabled: + pae_logits = self.pae(outputs['pair']) + aux_out['pae_logits'] = pae_logits + pae_logits = pae_logits.detach() + aux_out.update( + predicted_aligned_error( + pae_logits, + **self.config.pae, + )) + aux_out['ptm'] = predicted_tm_score( + pae_logits, interface=False, **self.config.pae) + + iptm_weight = self.config.pae.get('iptm_weight', 0.0) + if iptm_weight > 0.0: + aux_out['iptm'] = predicted_tm_score( + pae_logits, + interface=True, + asym_id=outputs['asym_id'], + **self.config.pae, + ) + aux_out['iptm+ptm'] = ( + iptm_weight * aux_out['iptm'] + # noqa W504 + (1.0 - iptm_weight) * aux_out['ptm']) + + return aux_out + + +class PredictedLDDTHead(nn.Module): + + def __init__(self, num_bins, d_in, d_hid): + super(PredictedLDDTHead, self).__init__() + + self.num_bins = num_bins + self.d_in = d_in + self.d_hid = d_hid + + self.layer_norm = LayerNorm(self.d_in) + + self.linear_1 = Linear(self.d_in, self.d_hid, init='relu') + self.linear_2 = Linear(self.d_hid, self.d_hid, init='relu') + self.act = nn.GELU() + self.linear_3 = Linear(self.d_hid, self.num_bins, init='final') + + def forward(self, s): + s = self.layer_norm(s) + s = self.linear_1(s) + s = self.act(s) + s = self.linear_2(s) + s = self.act(s) + s = self.linear_3(s) + return s + + +class EnhancedHeadBase(nn.Module): + + def __init__(self, d_in, d_out, disable_enhance_head): + super(EnhancedHeadBase, self).__init__() + if disable_enhance_head: + self.layer_norm = None + self.linear_in = None + else: + self.layer_norm = LayerNorm(d_in) + self.linear_in = Linear(d_in, d_in, init='relu') + self.act = nn.GELU() + self.linear = Linear(d_in, d_out, init='final') + + def apply_alphafold_original_mode(self): + self.layer_norm = None + self.linear_in = None + + def forward(self, x): + if self.layer_norm is not None: + x = self.layer_norm(x) + x = self.act(self.linear_in(x)) + logits = self.linear(x) + return logits + + +class DistogramHead(EnhancedHeadBase): + + def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): + super(DistogramHead, self).__init__( + d_in=d_pair, + d_out=num_bins, + disable_enhance_head=disable_enhance_head, + ) + + def forward(self, x): + logits = super().forward(x) + logits = logits + logits.transpose(-2, -3) + return logits + + +class PredictedAlignedErrorHead(EnhancedHeadBase): + + def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): + super(PredictedAlignedErrorHead, self).__init__( + d_in=d_pair, + d_out=num_bins, + disable_enhance_head=disable_enhance_head, + ) + + +class MaskedMSAHead(EnhancedHeadBase): + + def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs): + super(MaskedMSAHead, self).__init__( + d_in=d_msa, + d_out=d_out, + disable_enhance_head=disable_enhance_head, + ) + + +class ExperimentallyResolvedHead(EnhancedHeadBase): + + def __init__(self, d_single, d_out, disable_enhance_head, **kwargs): + super(ExperimentallyResolvedHead, self).__init__( + d_in=d_single, + d_out=d_out, + disable_enhance_head=disable_enhance_head, + ) diff --git a/modelscope/models/science/unifold/modules/common.py b/modelscope/models/science/unifold/modules/common.py new file mode 100644 index 00000000..186f2567 --- /dev/null +++ b/modelscope/models/science/unifold/modules/common.py @@ -0,0 +1,387 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partial +from typing import Any, Callable, Dict, Iterable, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from unicore.modules import LayerNorm +from unicore.utils import tensor_tree_map + + +class Linear(nn.Linear): + + def __init__( + self, + d_in: int, + d_out: int, + bias: bool = True, + init: str = 'default', + ): + super(Linear, self).__init__(d_in, d_out, bias=bias) + + self.use_bias = bias + + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(0) + + if init == 'default': + self._trunc_normal_init(1.0) + elif init == 'relu': + self._trunc_normal_init(2.0) + elif init == 'glorot': + self._glorot_uniform_init() + elif init == 'gating': + self._zero_init(self.use_bias) + elif init == 'normal': + self._normal_init() + elif init == 'final': + self._zero_init(False) + else: + raise ValueError('Invalid init method.') + + def _trunc_normal_init(self, scale=1.0): + # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) + TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 + _, fan_in = self.weight.shape + scale = scale / max(1, fan_in) + std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR + nn.init.trunc_normal_(self.weight, mean=0.0, std=std) + + def _glorot_uniform_init(self): + nn.init.xavier_uniform_(self.weight, gain=1) + + def _zero_init(self, use_bias=True): + with torch.no_grad(): + self.weight.fill_(0.0) + if use_bias: + with torch.no_grad(): + self.bias.fill_(1.0) + + def _normal_init(self): + torch.nn.init.kaiming_normal_(self.weight, nonlinearity='linear') + + +class Transition(nn.Module): + + def __init__(self, d_in, n): + + super(Transition, self).__init__() + + self.d_in = d_in + self.n = n + + self.layer_norm = LayerNorm(self.d_in) + self.linear_1 = Linear(self.d_in, self.n * self.d_in, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.n * self.d_in, d_in, init='final') + + def _transition(self, x): + x = self.layer_norm(x) + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + return x + + @torch.jit.ignore + def _chunk( + self, + x: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + return chunk_layer( + self._transition, + {'x': x}, + chunk_size=chunk_size, + num_batch_dims=len(x.shape[:-2]), + ) + + def forward( + self, + x: torch.Tensor, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + if chunk_size is not None: + x = self._chunk(x, chunk_size) + else: + x = self._transition(x=x) + + return x + + +class OuterProductMean(nn.Module): + + def __init__(self, d_msa, d_pair, d_hid, eps=1e-3): + super(OuterProductMean, self).__init__() + + self.d_msa = d_msa + self.d_pair = d_pair + self.d_hid = d_hid + self.eps = eps + + self.layer_norm = LayerNorm(d_msa) + self.linear_1 = Linear(d_msa, d_hid) + self.linear_2 = Linear(d_msa, d_hid) + self.linear_out = Linear(d_hid**2, d_pair, init='relu') + self.act = nn.GELU() + self.linear_z = Linear(self.d_pair, self.d_pair, init='final') + self.layer_norm_out = LayerNorm(self.d_pair) + + def _opm(self, a, b): + outer = torch.einsum('...bac,...dae->...bdce', a, b) + outer = outer.reshape(outer.shape[:-2] + (-1, )) + outer = self.linear_out(outer) + return outer + + @torch.jit.ignore + def _chunk(self, a: torch.Tensor, b: torch.Tensor, + chunk_size: int) -> torch.Tensor: + a = a.reshape((-1, ) + a.shape[-3:]) + b = b.reshape((-1, ) + b.shape[-3:]) + out = [] + # TODO: optimize this + for a_prime, b_prime in zip(a, b): + outer = chunk_layer( + partial(self._opm, b=b_prime), + {'a': a_prime}, + chunk_size=chunk_size, + num_batch_dims=1, + ) + out.append(outer) + if len(out) == 1: + outer = out[0].unsqueeze(0) + else: + outer = torch.stack(out, dim=0) + outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) + + return outer + + def apply_alphafold_original_mode(self): + self.linear_z = None + self.layer_norm_out = None + + def forward( + self, + m: torch.Tensor, + mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + + m = self.layer_norm(m) + mask = mask.unsqueeze(-1) + if self.layer_norm_out is not None: + # for numerical stability + mask = mask * (mask.size(-2)**-0.5) + a = self.linear_1(m) + b = self.linear_2(m) + if self.training: + a = a * mask + b = b * mask + else: + a *= mask + b *= mask + + a = a.transpose(-2, -3) + b = b.transpose(-2, -3) + + if chunk_size is not None: + z = self._chunk(a, b, chunk_size) + else: + z = self._opm(a, b) + + norm = torch.einsum('...abc,...adc->...bdc', mask, mask) + z /= self.eps + norm + if self.layer_norm_out is not None: + z = self.act(z) + z = self.layer_norm_out(z) + z = self.linear_z(z) + return z + + +def residual(residual, x, training): + if training: + return x + residual + else: + residual += x + return residual + + +@torch.jit.script +def fused_bias_dropout_add( + x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + dropmask: torch.Tensor, + prob: float, +) -> torch.Tensor: + return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual + + +@torch.jit.script +def fused_bias_dropout_add_inference( + x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, +) -> torch.Tensor: + residual += bias + x + return residual + + +def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob, + training): + bias = module.get_output_bias() + if training: + shape = list(x.shape) + shape[dropout_shared_dim] = 1 + with torch.no_grad(): + mask = x.new_ones(shape) + return fused_bias_dropout_add(x, bias, residual, mask, prob) + else: + return fused_bias_dropout_add_inference(x, bias, residual) + + +@torch.jit.script +def fused_bias_gated_dropout_add( + x: torch.Tensor, + bias: torch.Tensor, + g: torch.Tensor, + g_bias: torch.Tensor, + residual: torch.Tensor, + dropout_mask: torch.Tensor, + prob: float, +) -> torch.Tensor: + return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout( + dropout_mask, + p=prob, + training=True, + ) + residual + + +def tri_mul_residual( + module, + residual, + outputs, + dropout_shared_dim, + prob, + training, + block_size, +): + if training: + x, g = outputs + bias, g_bias = module.get_output_bias() + shape = list(x.shape) + shape[dropout_shared_dim] = 1 + with torch.no_grad(): + mask = x.new_ones(shape) + return fused_bias_gated_dropout_add( + x, + bias, + g, + g_bias, + residual, + mask, + prob, + ) + elif block_size is None: + x, g = outputs + bias, g_bias = module.get_output_bias() + residual += (torch.sigmoid(g + g_bias) * (x + bias)) + return residual + else: + # gated is not used here + residual += outputs + return residual + + +class SimpleModuleList(nn.ModuleList): + + def __repr__(self): + return str(len(self)) + ' X ...\n' + self[0].__repr__() + + +def chunk_layer( + layer: Callable, + inputs: Dict[str, Any], + chunk_size: int, + num_batch_dims: int, +) -> Any: + # TODO: support inplace add to output + if not (len(inputs) > 0): + raise ValueError('Must provide at least one input') + + def _dict_get_shapes(input): + shapes = [] + if type(input) is torch.Tensor: + shapes.append(input.shape) + elif type(input) is dict: + for v in input.values(): + shapes.extend(_dict_get_shapes(v)) + elif isinstance(input, Iterable): + for v in input: + shapes.extend(_dict_get_shapes(v)) + else: + raise ValueError('Not supported') + + return shapes + + inputs = {k: v for k, v in inputs.items() if v is not None} + initial_dims = [ + shape[:num_batch_dims] for shape in _dict_get_shapes(inputs) + ] + orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) + + flat_batch_dim = 1 + for d in orig_batch_dims: + flat_batch_dim *= d + num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size + + def _flat_inputs(t): + t = t.view(-1, *t.shape[num_batch_dims:]) + assert ( + t.shape[0] == flat_batch_dim or t.shape[0] == 1 + ), 'batch dimension must be 1 or equal to the flat batch dimension' + return t + + flat_inputs = tensor_tree_map(_flat_inputs, inputs) + + out = None + for i in range(num_chunks): + chunk_start = i * chunk_size + chunk_end = min((i + 1) * chunk_size, flat_batch_dim) + + def select_chunk(t): + if t.shape[0] == 1: + return t[0:1] + else: + return t[chunk_start:chunk_end] + + chunkes = tensor_tree_map(select_chunk, flat_inputs) + + output_chunk = layer(**chunkes) + + if out is None: + out = tensor_tree_map( + lambda t: t.new_zeros((flat_batch_dim, ) + t.shape[1:]), + output_chunk) + + out_type = type(output_chunk) + if out_type is tuple: + for x, y in zip(out, output_chunk): + x[chunk_start:chunk_end] = y + elif out_type is torch.Tensor: + out[chunk_start:chunk_end] = output_chunk + else: + raise ValueError('Not supported') + + # reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) + def reshape(t): + return t.view(orig_batch_dims + t.shape[1:]) + + out = tensor_tree_map(reshape, out) + + return out diff --git a/modelscope/models/science/unifold/modules/confidence.py b/modelscope/models/science/unifold/modules/confidence.py new file mode 100644 index 00000000..7574689c --- /dev/null +++ b/modelscope/models/science/unifold/modules/confidence.py @@ -0,0 +1,159 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict, Optional, Tuple + +import torch + + +def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor: + """Computes per-residue pLDDT from logits. + Args: + logits: [num_res, num_bins] output from the PredictedLDDTHead. + Returns: + plddt: [num_res] per-residue pLDDT. + """ + num_bins = plddt_logits.shape[-1] + bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1) + bin_width = 1.0 / num_bins + bounds = torch.arange( + start=0.5 * bin_width, + end=1.0, + step=bin_width, + device=plddt_logits.device) + plddt = torch.sum( + bin_probs + * bounds.view(*((1, ) * len(bin_probs.shape[:-1])), *bounds.shape), + dim=-1, + ) + return plddt + + +def compute_bin_values(breaks: torch.Tensor): + """Gets the bin centers from the bin edges. + Args: + breaks: [num_bins - 1] the error bin edges. + Returns: + bin_centers: [num_bins] the error bin centers. + """ + step = breaks[1] - breaks[0] + bin_values = breaks + step / 2 + bin_values = torch.cat([bin_values, (bin_values[-1] + step).unsqueeze(-1)], + dim=0) + return bin_values + + +def compute_predicted_aligned_error( + bin_edges: torch.Tensor, + bin_probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculates expected aligned distance errors for every pair of residues. + Args: + alignment_confidence_breaks: [num_bins - 1] the error bin edges. + aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted + probs for each error bin, for each pair of residues. + Returns: + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_values = compute_bin_values(bin_edges) + return torch.sum(bin_probs * bin_values, dim=-1) + + +def predicted_aligned_error( + pae_logits: torch.Tensor, + max_bin: int = 31, + num_bins: int = 64, + **kwargs, +) -> Dict[str, torch.Tensor]: + """Computes aligned confidence metrics from logits. + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins - 1] the error bin edges. + Returns: + aligned_confidence_probs: [num_res, num_res, num_bins] the predicted + aligned error probabilities over bins for each residue pair. + predicted_aligned_error: [num_res, num_res] the expected aligned distance + error for each pair of residues. + max_predicted_aligned_error: The maximum predicted error possible. + """ + bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1) + bin_edges = torch.linspace( + 0, max_bin, steps=(num_bins - 1), device=pae_logits.device) + + predicted_aligned_error = compute_predicted_aligned_error( + bin_edges=bin_edges, + bin_probs=bin_probs, + ) + + return { + 'aligned_error_probs_per_bin': bin_probs, + 'predicted_aligned_error': predicted_aligned_error, + } + + +def predicted_tm_score( + pae_logits: torch.Tensor, + residue_weights: Optional[torch.Tensor] = None, + max_bin: int = 31, + num_bins: int = 64, + eps: float = 1e-8, + asym_id: Optional[torch.Tensor] = None, + interface: bool = False, + **kwargs, +) -> torch.Tensor: + """Computes predicted TM alignment or predicted interface TM alignment score. + Args: + logits: [num_res, num_res, num_bins] the logits output from + PredictedAlignedErrorHead. + breaks: [num_bins] the error bins. + residue_weights: [num_res] the per residue weights to use for the + expectation. + asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for + ipTM calculation, i.e. when interface=True. + interface: If True, interface predicted TM score is computed. + Returns: + ptm_score: The predicted TM alignment or the predicted iTM score. + """ + pae_logits = pae_logits.float() + if residue_weights is None: + residue_weights = pae_logits.new_ones(pae_logits.shape[:-2]) + + breaks = torch.linspace( + 0, max_bin, steps=(num_bins - 1), device=pae_logits.device) + + def tm_kernal(nres): + clipped_n = max(nres, 19) + d0 = 1.24 * (clipped_n - 15)**(1.0 / 3.0) - 1.8 + return lambda x: 1.0 / (1.0 + (x / d0)**2) + + def rmsd_kernal(eps): # leave for compute pRMS + return lambda x: 1. / (x + eps) + + bin_centers = compute_bin_values(breaks) + probs = torch.nn.functional.softmax(pae_logits, dim=-1) + tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers) + # tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) + # rmsd_per_bin = rmsd_kernal()(bin_centers) + predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) + + pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape) + if interface: + assert asym_id is not None, 'must provide asym_id for iptm calculation.' + pair_mask *= asym_id[..., :, None] != asym_id[..., None, :] + + predicted_tm_term *= pair_mask + + pair_residue_weights = pair_mask * ( + residue_weights[None, :] * residue_weights[:, None]) + normed_residue_mask = pair_residue_weights / ( + eps + pair_residue_weights.sum(dim=-1, keepdim=True)) + + per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) + weighted = per_alignment * residue_weights + ret = per_alignment.gather( + dim=-1, index=weighted.max(dim=-1, + keepdim=True).indices).squeeze(dim=-1) + return ret diff --git a/modelscope/models/science/unifold/modules/embedders.py b/modelscope/models/science/unifold/modules/embedders.py new file mode 100644 index 00000000..84e87e2d --- /dev/null +++ b/modelscope/models/science/unifold/modules/embedders.py @@ -0,0 +1,290 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import one_hot + +from .common import Linear, SimpleModuleList, residual + + +class InputEmbedder(nn.Module): + + def __init__( + self, + tf_dim: int, + msa_dim: int, + d_pair: int, + d_msa: int, + relpos_k: int, + use_chain_relative: bool = False, + max_relative_chain: Optional[int] = None, + **kwargs, + ): + super(InputEmbedder, self).__init__() + + self.tf_dim = tf_dim + self.msa_dim = msa_dim + + self.d_pair = d_pair + self.d_msa = d_msa + + self.linear_tf_z_i = Linear(tf_dim, d_pair) + self.linear_tf_z_j = Linear(tf_dim, d_pair) + self.linear_tf_m = Linear(tf_dim, d_msa) + self.linear_msa_m = Linear(msa_dim, d_msa) + + # RPE stuff + self.relpos_k = relpos_k + self.use_chain_relative = use_chain_relative + self.max_relative_chain = max_relative_chain + if not self.use_chain_relative: + self.num_bins = 2 * self.relpos_k + 1 + else: + self.num_bins = 2 * self.relpos_k + 2 + self.num_bins += 1 # entity id + self.num_bins += 2 * max_relative_chain + 2 + + self.linear_relpos = Linear(self.num_bins, d_pair) + + def _relpos_indices( + self, + res_id: torch.Tensor, + sym_id: Optional[torch.Tensor] = None, + asym_id: Optional[torch.Tensor] = None, + entity_id: Optional[torch.Tensor] = None, + ): + + max_rel_res = self.relpos_k + rp = res_id[..., None] - res_id[..., None, :] + rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res + if not self.use_chain_relative: + return rp + else: + asym_id_same = asym_id[..., :, None] == asym_id[..., None, :] + rp[~asym_id_same] = 2 * max_rel_res + 1 + entity_id_same = entity_id[..., :, None] == entity_id[..., None, :] + rp_entity_id = entity_id_same.type(rp.dtype)[..., None] + + rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :] + + max_rel_chain = self.max_relative_chain + + clipped_rel_chain = torch.clamp( + rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain) + + clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1 + return rp, rp_entity_id, clipped_rel_chain + + def relpos_emb( + self, + res_id: torch.Tensor, + sym_id: Optional[torch.Tensor] = None, + asym_id: Optional[torch.Tensor] = None, + entity_id: Optional[torch.Tensor] = None, + num_sym: Optional[torch.Tensor] = None, + ): + + dtype = self.linear_relpos.weight.dtype + if not self.use_chain_relative: + rp = self._relpos_indices(res_id=res_id) + return self.linear_relpos( + one_hot(rp, num_classes=self.num_bins, dtype=dtype)) + else: + rp, rp_entity_id, rp_rel_chain = self._relpos_indices( + res_id=res_id, + sym_id=sym_id, + asym_id=asym_id, + entity_id=entity_id) + rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype) + rp_entity_id = rp_entity_id.type(dtype) + rp_rel_chain = one_hot( + rp_rel_chain, + num_classes=(2 * self.max_relative_chain + 2), + dtype=dtype) + return self.linear_relpos( + torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1)) + + def forward( + self, + tf: torch.Tensor, + msa: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # [*, N_res, d_pair] + if self.tf_dim == 21: + # multimer use 21 target dim + tf = tf[..., 1:] + # convert type if necessary + tf = tf.type(self.linear_tf_z_i.weight.dtype) + msa = msa.type(self.linear_tf_z_i.weight.dtype) + n_clust = msa.shape[-3] + + msa_emb = self.linear_msa_m(msa) + # target_feat (aatype) into msa representation + tf_m = ( + self.linear_tf_m(tf).unsqueeze(-3).expand( + ((-1, ) * len(tf.shape[:-2]) + # noqa W504 + (n_clust, -1, -1)))) + msa_emb += tf_m + + tf_emb_i = self.linear_tf_z_i(tf) + tf_emb_j = self.linear_tf_z_j(tf) + pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] + + return msa_emb, pair_emb + + +class RecyclingEmbedder(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + min_bin: float, + max_bin: float, + num_bins: int, + inf: float = 1e8, + **kwargs, + ): + super(RecyclingEmbedder, self).__init__() + + self.d_msa = d_msa + self.d_pair = d_pair + self.min_bin = min_bin + self.max_bin = max_bin + self.num_bins = num_bins + self.inf = inf + + self.squared_bins = None + + self.linear = Linear(self.num_bins, self.d_pair) + self.layer_norm_m = LayerNorm(self.d_msa) + self.layer_norm_z = LayerNorm(self.d_pair) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + m_update = self.layer_norm_m(m) + z_update = self.layer_norm_z(z) + + return m_update, z_update + + def recyle_pos( + self, + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.squared_bins is None: + bins = torch.linspace( + self.min_bin, + self.max_bin, + self.num_bins, + dtype=torch.float if self.training else x.dtype, + device=x.device, + requires_grad=False, + ) + self.squared_bins = bins**2 + upper = torch.cat( + [self.squared_bins[1:], + self.squared_bins.new_tensor([self.inf])], + dim=-1) + if self.training: + x = x.float() + d = torch.sum( + (x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True) + d = ((d > self.squared_bins) * # noqa W504 + (d < upper)).type(self.linear.weight.dtype) + d = self.linear(d) + return d + + +class TemplateAngleEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + d_out: int, + **kwargs, + ): + super(TemplateAngleEmbedder, self).__init__() + + self.d_out = d_out + self.d_in = d_in + + self.linear_1 = Linear(self.d_in, self.d_out, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.d_out, self.d_out, init='relu') + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x.type(self.linear_1.weight.dtype)) + x = self.act(x) + x = self.linear_2(x) + return x + + +class TemplatePairEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + v2_d_in: list, + d_out: int, + d_pair: int, + v2_feature: bool = False, + **kwargs, + ): + super(TemplatePairEmbedder, self).__init__() + + self.d_out = d_out + self.v2_feature = v2_feature + if self.v2_feature: + self.d_in = v2_d_in + self.linear = SimpleModuleList() + for d_in in self.d_in: + self.linear.append(Linear(d_in, self.d_out, init='relu')) + self.z_layer_norm = LayerNorm(d_pair) + self.z_linear = Linear(d_pair, self.d_out, init='relu') + else: + self.d_in = d_in + self.linear = Linear(self.d_in, self.d_out, init='relu') + + def forward( + self, + x, + z, + ) -> torch.Tensor: + if not self.v2_feature: + x = self.linear(x.type(self.linear.weight.dtype)) + return x + else: + dtype = self.z_linear.weight.dtype + t = self.linear[0](x[0].type(dtype)) + for i, s in enumerate(x[1:]): + t = residual(t, self.linear[i + 1](s.type(dtype)), + self.training) + t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training) + return t + + +class ExtraMSAEmbedder(nn.Module): + + def __init__( + self, + d_in: int, + d_out: int, + **kwargs, + ): + super(ExtraMSAEmbedder, self).__init__() + + self.d_in = d_in + self.d_out = d_out + self.linear = Linear(self.d_in, self.d_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x.type(self.linear.weight.dtype)) diff --git a/modelscope/models/science/unifold/modules/evoformer.py b/modelscope/models/science/unifold/modules/evoformer.py new file mode 100644 index 00000000..b0834986 --- /dev/null +++ b/modelscope/models/science/unifold/modules/evoformer.py @@ -0,0 +1,362 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partial +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from unicore.utils import checkpoint_sequential + +from .attentions import (MSAColumnAttention, MSAColumnGlobalAttention, + MSARowAttentionWithPairBias, TriangleAttentionEnding, + TriangleAttentionStarting) +from .common import (Linear, OuterProductMean, SimpleModuleList, Transition, + bias_dropout_residual, residual, tri_mul_residual) +from .triangle_multiplication import (TriangleMultiplicationIncoming, + TriangleMultiplicationOutgoing) + + +class EvoformerIteration(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + num_heads_msa: int, + num_heads_pair: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + ): + super(EvoformerIteration, self).__init__() + + self._is_extra_msa_stack = _is_extra_msa_stack + self.outer_product_mean_first = outer_product_mean_first + + self.msa_att_row = MSARowAttentionWithPairBias( + d_msa=d_msa, + d_pair=d_pair, + d_hid=d_hid_msa_att, + num_heads=num_heads_msa, + ) + + if _is_extra_msa_stack: + self.msa_att_col = MSAColumnGlobalAttention( + d_in=d_msa, + d_hid=d_hid_msa_att, + num_heads=num_heads_msa, + inf=inf, + eps=eps, + ) + else: + self.msa_att_col = MSAColumnAttention( + d_msa, + d_hid_msa_att, + num_heads_msa, + ) + + self.msa_transition = Transition( + d_in=d_msa, + n=transition_n, + ) + + self.outer_product_mean = OuterProductMean( + d_msa, + d_pair, + d_hid_opm, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + d_pair, + d_hid_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + d_pair, + d_hid_mul, + ) + + self.tri_att_start = TriangleAttentionStarting( + d_pair, + d_hid_pair_att, + num_heads_pair, + ) + self.tri_att_end = TriangleAttentionEnding( + d_pair, + d_hid_pair_att, + num_heads_pair, + ) + + self.pair_transition = Transition( + d_in=d_pair, + n=transition_n, + ) + + self.row_dropout_share_dim = -3 + self.col_dropout_share_dim = -2 + self.msa_dropout = msa_dropout + self.pair_dropout = pair_dropout + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + msa_row_attn_mask: torch.Tensor, + msa_col_attn_mask: Optional[torch.Tensor], + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: Optional[int] = None, + block_size: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if self.outer_product_mean_first: + z = residual( + z, + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size), self.training) + + m = bias_dropout_residual( + self.msa_att_row, + m, + self.msa_att_row( + m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.msa_dropout, + self.training, + ) + if self._is_extra_msa_stack: + m = residual( + m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size), + self.training) + else: + m = bias_dropout_residual( + self.msa_att_col, + m, + self.msa_att_col( + m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.msa_dropout, + self.training, + ) + m = residual(m, self.msa_transition(m, chunk_size=chunk_size), + self.training) + if not self.outer_product_mean_first: + z = residual( + z, + self.outer_product_mean( + m, mask=msa_mask, chunk_size=chunk_size), self.training) + + z = tri_mul_residual( + self.tri_mul_out, + z, + self.tri_mul_out(z, mask=pair_mask, block_size=block_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + block_size=block_size, + ) + + z = tri_mul_residual( + self.tri_mul_in, + z, + self.tri_mul_in(z, mask=pair_mask, block_size=block_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + block_size=block_size, + ) + + z = bias_dropout_residual( + self.tri_att_start, + z, + self.tri_att_start( + z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.pair_dropout, + self.training, + ) + + z = bias_dropout_residual( + self.tri_att_end, + z, + self.tri_att_end( + z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.pair_dropout, + self.training, + ) + z = residual(z, self.pair_transition(z, chunk_size=chunk_size), + self.training) + return m, z + + +class EvoformerStack(nn.Module): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + d_single: int, + num_heads_msa: int, + num_heads_pair: int, + num_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + _is_extra_msa_stack: bool = False, + **kwargs, + ): + super(EvoformerStack, self).__init__() + + self._is_extra_msa_stack = _is_extra_msa_stack + + self.blocks = SimpleModuleList() + + for _ in range(num_blocks): + self.blocks.append( + EvoformerIteration( + d_msa=d_msa, + d_pair=d_pair, + d_hid_msa_att=d_hid_msa_att, + d_hid_opm=d_hid_opm, + d_hid_mul=d_hid_mul, + d_hid_pair_att=d_hid_pair_att, + num_heads_msa=num_heads_msa, + num_heads_pair=num_heads_pair, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + outer_product_mean_first=outer_product_mean_first, + inf=inf, + eps=eps, + _is_extra_msa_stack=_is_extra_msa_stack, + )) + if not self._is_extra_msa_stack: + self.linear = Linear(d_msa, d_single) + else: + self.linear = None + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: torch.Tensor, + pair_mask: torch.Tensor, + msa_row_attn_mask: torch.Tensor, + msa_col_attn_mask: torch.Tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: int, + block_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + blocks = [ + partial( + b, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_attn_mask, + msa_col_attn_mask=msa_col_attn_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size) for b in self.blocks + ] + + m, z = checkpoint_sequential( + blocks, + input=(m, z), + ) + + s = None + if not self._is_extra_msa_stack: + seq_dim = -3 + index = torch.tensor([0], device=m.device) + s = self.linear(torch.index_select(m, dim=seq_dim, index=index)) + s = s.squeeze(seq_dim) + + return m, z, s + + +class ExtraMSAStack(EvoformerStack): + + def __init__( + self, + d_msa: int, + d_pair: int, + d_hid_msa_att: int, + d_hid_opm: int, + d_hid_mul: int, + d_hid_pair_att: int, + num_heads_msa: int, + num_heads_pair: int, + num_blocks: int, + transition_n: int, + msa_dropout: float, + pair_dropout: float, + outer_product_mean_first: bool, + inf: float, + eps: float, + **kwargs, + ): + super(ExtraMSAStack, self).__init__( + d_msa=d_msa, + d_pair=d_pair, + d_hid_msa_att=d_hid_msa_att, + d_hid_opm=d_hid_opm, + d_hid_mul=d_hid_mul, + d_hid_pair_att=d_hid_pair_att, + d_single=None, + num_heads_msa=num_heads_msa, + num_heads_pair=num_heads_pair, + num_blocks=num_blocks, + transition_n=transition_n, + msa_dropout=msa_dropout, + pair_dropout=pair_dropout, + outer_product_mean_first=outer_product_mean_first, + inf=inf, + eps=eps, + _is_extra_msa_stack=True, + ) + + def forward( + self, + m: torch.Tensor, + z: torch.Tensor, + msa_mask: Optional[torch.Tensor] = None, + pair_mask: Optional[torch.Tensor] = None, + msa_row_attn_mask: torch.Tensor = None, + msa_col_attn_mask: torch.Tensor = None, + tri_start_attn_mask: torch.Tensor = None, + tri_end_attn_mask: torch.Tensor = None, + chunk_size: int = None, + block_size: int = None, + ) -> torch.Tensor: + _, z, _ = super().forward( + m, + z, + msa_mask=msa_mask, + pair_mask=pair_mask, + msa_row_attn_mask=msa_row_attn_mask, + msa_col_attn_mask=msa_col_attn_mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size) + return z diff --git a/modelscope/models/science/unifold/modules/featurization.py b/modelscope/models/science/unifold/modules/featurization.py new file mode 100644 index 00000000..b62adc9d --- /dev/null +++ b/modelscope/models/science/unifold/modules/featurization.py @@ -0,0 +1,195 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from typing import Dict + +import torch +import torch.nn as nn +from unicore.utils import batched_gather, one_hot + +from modelscope.models.science.unifold.data import residue_constants as rc +from .frame import Frame + + +def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): + is_gly = aatype == rc.restype_order['G'] + ca_idx = rc.atom_order['CA'] + cb_idx = rc.atom_order['CB'] + pseudo_beta = torch.where( + is_gly[..., None].expand(*((-1, ) * len(is_gly.shape)), 3), + all_atom_positions[..., ca_idx, :], + all_atom_positions[..., cb_idx, :], + ) + + if all_atom_masks is not None: + pseudo_beta_mask = torch.where( + is_gly, + all_atom_masks[..., ca_idx], + all_atom_masks[..., cb_idx], + ) + return pseudo_beta, pseudo_beta_mask + else: + return pseudo_beta + + +def atom14_to_atom37(atom14, batch): + atom37_data = batched_gather( + atom14, + batch['residx_atom37_to_atom14'], + dim=-2, + num_batch_dims=len(atom14.shape[:-2]), + ) + + atom37_data = atom37_data * batch['atom37_atom_exists'][..., None] + + return atom37_data + + +def build_template_angle_feat(template_feats, v2_feature=False): + template_aatype = template_feats['template_aatype'] + torsion_angles_sin_cos = template_feats['template_torsion_angles_sin_cos'] + torsion_angles_mask = template_feats['template_torsion_angles_mask'] + if not v2_feature: + alt_torsion_angles_sin_cos = template_feats[ + 'template_alt_torsion_angles_sin_cos'] + template_angle_feat = torch.cat( + [ + one_hot(template_aatype, 22), + torsion_angles_sin_cos.reshape( + *torsion_angles_sin_cos.shape[:-2], 14), + alt_torsion_angles_sin_cos.reshape( + *alt_torsion_angles_sin_cos.shape[:-2], 14), + torsion_angles_mask, + ], + dim=-1, + ) + template_angle_mask = torsion_angles_mask[..., 2] + else: + chi_mask = torsion_angles_mask[..., 3:] + chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask + chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask + template_angle_feat = torch.cat( + [ + one_hot(template_aatype, 22), + chi_angles_sin, + chi_angles_cos, + chi_mask, + ], + dim=-1, + ) + template_angle_mask = chi_mask[..., 0] + return template_angle_feat, template_angle_mask + + +def build_template_pair_feat( + batch, + min_bin, + max_bin, + num_bins, + eps=1e-20, + inf=1e8, +): + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + + tpb = batch['template_pseudo_beta'] + dgram = torch.sum( + (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = nn.functional.one_hot( + batch['template_aatype'], + rc.restype_num + 2, + ) + + n_res = batch['template_aatype'].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand( + *aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., + None, :].expand(*aatype_one_hot.shape[:-2], + -1, n_res, -1)) + + to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3)) + to_concat.append(template_mask_2d[..., None]) + + act = torch.cat(to_concat, dim=-1) + act = act * template_mask_2d[..., None] + + return act + + +def build_template_pair_feat_v2( + batch, + min_bin, + max_bin, + num_bins, + multichain_mask_2d=None, + eps=1e-20, + inf=1e8, +): + template_mask = batch['template_pseudo_beta_mask'] + template_mask_2d = template_mask[..., None] * template_mask[..., None, :] + if multichain_mask_2d is not None: + template_mask_2d *= multichain_mask_2d + + tpb = batch['template_pseudo_beta'] + dgram = torch.sum( + (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) + lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 + upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) + dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) + dgram *= template_mask_2d[..., None] + to_concat = [dgram, template_mask_2d[..., None]] + + aatype_one_hot = one_hot( + batch['template_aatype'], + rc.restype_num + 2, + ) + + n_res = batch['template_aatype'].shape[-1] + to_concat.append(aatype_one_hot[..., None, :, :].expand( + *aatype_one_hot.shape[:-2], n_res, -1, -1)) + to_concat.append(aatype_one_hot[..., + None, :].expand(*aatype_one_hot.shape[:-2], + -1, n_res, -1)) + + n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']] + rigids = Frame.make_transform_from_reference( + n_xyz=batch['template_all_atom_positions'][..., n, :], + ca_xyz=batch['template_all_atom_positions'][..., ca, :], + c_xyz=batch['template_all_atom_positions'][..., c, :], + eps=eps, + ) + points = rigids.get_trans()[..., None, :, :] + rigid_vec = rigids[..., None].invert_apply(points) + + inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) + + t_aa_masks = batch['template_all_atom_mask'] + backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., + c] + backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[..., + None, :] + if multichain_mask_2d is not None: + backbone_mask_2d *= multichain_mask_2d + + inv_distance_scalar = inv_distance_scalar * backbone_mask_2d + unit_vector_data = rigid_vec * inv_distance_scalar[..., None] + to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1)) + to_concat.append(backbone_mask_2d[..., None]) + + return to_concat + + +def build_extra_msa_feat(batch): + msa_1hot = one_hot(batch['extra_msa'], 23) + msa_feat = [ + msa_1hot, + batch['extra_msa_has_deletion'].unsqueeze(-1), + batch['extra_msa_deletion_value'].unsqueeze(-1), + ] + return torch.cat(msa_feat, dim=-1) diff --git a/modelscope/models/science/unifold/modules/frame.py b/modelscope/models/science/unifold/modules/frame.py new file mode 100644 index 00000000..5a0e4d6a --- /dev/null +++ b/modelscope/models/science/unifold/modules/frame.py @@ -0,0 +1,562 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from __future__ import annotations # noqa +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple + +import numpy as np +import torch + + +def zero_translation( + batch_dims: Tuple[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, +) -> torch.Tensor: + trans = torch.zeros((*batch_dims, 3), + dtype=dtype, + device=device, + requires_grad=requires_grad) + return trans + + +# pylint: disable=bad-whitespace +_QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) + +_QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr +_QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii +_QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj +_QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk + +_QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij +_QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik +_QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk + +_QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir +_QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr +_QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr + +_QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9) +_QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT) + +_QUAT_MULTIPLY = np.zeros((4, 4, 4)) +_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], + [0, 0, 0, -1]] + +_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], + [0, 0, -1, 0]] + +_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], + [0, 1, 0, 0]] + +_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], + [1, 0, 0, 0]] + +_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] +_QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC) + + +class Rotation: + + def __init__( + self, + mat: torch.Tensor, + ): + if mat.shape[-2:] != (3, 3): + raise ValueError(f'incorrect rotation shape: {mat.shape}') + self._mat = mat + + @staticmethod + def identity( + shape, + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Rotation: + mat = torch.eye( + 3, dtype=dtype, device=device, requires_grad=requires_grad) + mat = mat.view(*((1, ) * len(shape)), 3, 3) + mat = mat.expand(*shape, -1, -1) + return Rotation(mat) + + @staticmethod + def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return (a.float() @ b.float()).type(a.dtype) + + @staticmethod + def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype) + + def __getitem__(self, index: Any) -> Rotation: + if not isinstance(index, tuple): + index = (index, ) + return Rotation(mat=self._mat[index + (slice(None), slice(None))]) + + def __mul__(self, right: Any) -> Rotation: + if isinstance(right, (int, float)): + return Rotation(mat=self._mat * right) + elif isinstance(right, torch.Tensor): + return Rotation(mat=self._mat * right[..., None, None]) + else: + raise TypeError( + f'multiplicand must be a tensor or a number, got {type(right)}.' + ) + + def __rmul__(self, left: Any) -> Rotation: + return self.__mul__(left) + + def __matmul__(self, other: Rotation) -> Rotation: + new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat) + return Rotation(mat=new_mat) + + @property + def _inv_mat(self): + return self._mat.transpose(-1, -2) + + @property + def rot_mat(self) -> torch.Tensor: + return self._mat + + def invert(self) -> Rotation: + return Rotation(mat=self._inv_mat) + + def apply(self, pts: torch.Tensor) -> torch.Tensor: + return Rotation.mat_mul_vec(self._mat, pts) + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + return Rotation.mat_mul_vec(self._inv_mat, pts) + + # inherit tensor behaviors + @property + def shape(self) -> torch.Size: + s = self._mat.shape[:-2] + return s + + @property + def dtype(self) -> torch.dtype: + return self._mat.dtype + + @property + def device(self) -> torch.device: + return self._mat.device + + @property + def requires_grad(self) -> bool: + return self._mat.requires_grad + + def unsqueeze(self, dim: int) -> Rotation: + if dim >= len(self.shape): + raise ValueError('Invalid dimension') + + rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2) + return Rotation(mat=rot_mats) + + def map_tensor_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Rotation: + mat = self._mat.view(self._mat.shape[:-2] + (9, )) + mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1) + mat = mat.view(mat.shape[:-1] + (3, 3)) + return Rotation(mat=mat) + + @staticmethod + def cat(rs: Sequence[Rotation], dim: int) -> Rotation: + rot_mats = [r.rot_mat for r in rs] + rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) + + return Rotation(mat=rot_mats) + + def cuda(self) -> Rotation: + return Rotation(mat=self._mat.cuda()) + + def to(self, device: Optional[torch.device], + dtype: Optional[torch.dtype]) -> Rotation: + return Rotation(mat=self._mat.to(device=device, dtype=dtype)) + + def type(self, dtype: Optional[torch.dtype]) -> Rotation: + return Rotation(mat=self._mat.type(dtype)) + + def detach(self) -> Rotation: + return Rotation(mat=self._mat.detach()) + + +class Frame: + + def __init__( + self, + rotation: Optional[Rotation], + translation: Optional[torch.Tensor], + ): + if rotation is None and translation is None: + rotation = Rotation.identity((0, )) + translation = zero_translation((0, )) + elif translation is None: + translation = zero_translation(rotation.shape, rotation.dtype, + rotation.device, + rotation.requires_grad) + + elif rotation is None: + rotation = Rotation.identity( + translation.shape[:-1], + translation.dtype, + translation.device, + translation.requires_grad, + ) + + if (rotation.shape != translation.shape[:-1]) or (rotation.device + != # noqa W504 + translation.device): + raise ValueError('RotationMatrix and translation incompatible') + + self._r = rotation + self._t = translation + + @staticmethod + def identity( + shape: Iterable[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Frame: + return Frame( + Rotation.identity(shape, dtype, device, requires_grad), + zero_translation(shape, dtype, device, requires_grad), + ) + + def __getitem__( + self, + index: Any, + ) -> Frame: + if type(index) != tuple: + index = (index, ) + + return Frame( + self._r[index], + self._t[index + (slice(None), )], + ) + + def __mul__( + self, + right: torch.Tensor, + ) -> Frame: + if not (isinstance(right, torch.Tensor)): + raise TypeError('The other multiplicand must be a Tensor') + + new_rots = self._r * right + new_trans = self._t * right[..., None] + + return Frame(new_rots, new_trans) + + def __rmul__( + self, + left: torch.Tensor, + ) -> Frame: + return self.__mul__(left) + + @property + def shape(self) -> torch.Size: + s = self._t.shape[:-1] + return s + + @property + def device(self) -> torch.device: + return self._t.device + + def get_rots(self) -> Rotation: + return self._r + + def get_trans(self) -> torch.Tensor: + return self._t + + def compose( + self, + other: Frame, + ) -> Frame: + new_rot = self._r @ other._r + new_trans = self._r.apply(other._t) + self._t + return Frame(new_rot, new_trans) + + def apply( + self, + pts: torch.Tensor, + ) -> torch.Tensor: + rotated = self._r.apply(pts) + return rotated + self._t + + def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: + pts = pts - self._t + return self._r.invert_apply(pts) + + def invert(self) -> Frame: + rot_inv = self._r.invert() + trn_inv = rot_inv.apply(self._t) + + return Frame(rot_inv, -1 * trn_inv) + + def map_tensor_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Frame: + new_rots = self._r.map_tensor_fn(fn) + new_trans = torch.stack( + list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1) + + return Frame(new_rots, new_trans) + + def to_tensor_4x4(self) -> torch.Tensor: + tensor = self._t.new_zeros((*self.shape, 4, 4)) + tensor[..., :3, :3] = self._r.rot_mat + tensor[..., :3, 3] = self._t + tensor[..., 3, 3] = 1 + return tensor + + @staticmethod + def from_tensor_4x4(t: torch.Tensor) -> Frame: + if t.shape[-2:] != (4, 4): + raise ValueError('Incorrectly shaped input tensor') + + rots = Rotation(mat=t[..., :3, :3]) + trans = t[..., :3, 3] + + return Frame(rots, trans) + + @staticmethod + def from_3_points( + p_neg_x_axis: torch.Tensor, + origin: torch.Tensor, + p_xy_plane: torch.Tensor, + eps: float = 1e-8, + ) -> Frame: + p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) + origin = torch.unbind(origin, dim=-1) + p_xy_plane = torch.unbind(p_xy_plane, dim=-1) + + e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] + e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] + + denom = torch.sqrt(sum((c * c for c in e0)) + eps) + e0 = [c / denom for c in e0] + dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) + e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] + denom = torch.sqrt(sum((c * c for c in e1)) + eps) + e1 = [c / denom for c in e1] + e2 = [ + e0[1] * e1[2] - e0[2] * e1[1], + e0[2] * e1[0] - e0[0] * e1[2], + e0[0] * e1[1] - e0[1] * e1[0], + ] + + rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) + rots = rots.reshape(rots.shape[:-1] + (3, 3)) + + rot_obj = Rotation(mat=rots) + + return Frame(rot_obj, torch.stack(origin, dim=-1)) + + def unsqueeze( + self, + dim: int, + ) -> Frame: + if dim >= len(self.shape): + raise ValueError('Invalid dimension') + rots = self._r.unsqueeze(dim) + trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1) + + return Frame(rots, trans) + + @staticmethod + def cat( + Ts: Sequence[Frame], + dim: int, + ) -> Frame: + rots = Rotation.cat([T._r for T in Ts], dim) + trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1) + + return Frame(rots, trans) + + def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame: + return Frame(fn(self._r), self._t) + + def apply_trans_fn(self, fn: Callable[[torch.Tensor], + torch.Tensor]) -> Frame: + return Frame(self._r, fn(self._t)) + + def scale_translation(self, trans_scale_factor: float) -> Frame: + # fn = lambda t: t * trans_scale_factor + def fn(t): + return t * trans_scale_factor + + return self.apply_trans_fn(fn) + + def stop_rot_gradient(self) -> Frame: + # fn = lambda r: r.detach() + def fn(r): + return r.detach() + + return self.apply_rot_fn(fn) + + @staticmethod + def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): + input_dtype = ca_xyz.dtype + n_xyz = n_xyz.float() + ca_xyz = ca_xyz.float() + c_xyz = c_xyz.float() + n_xyz = n_xyz - ca_xyz + c_xyz = c_xyz - ca_xyz + + c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + c_x**2 + c_y**2) + sin_c1 = -c_y / norm + cos_c1 = c_x / norm + + c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) + c1_rots[..., 0, 0] = cos_c1 + c1_rots[..., 0, 1] = -1 * sin_c1 + c1_rots[..., 1, 0] = sin_c1 + c1_rots[..., 1, 1] = cos_c1 + c1_rots[..., 2, 2] = 1 + + norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2) + sin_c2 = d_pair / norm + cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm + + c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + c2_rots[..., 0, 0] = cos_c2 + c2_rots[..., 0, 2] = sin_c2 + c2_rots[..., 1, 1] = 1 + c2_rots[..., 2, 0] = -1 * sin_c2 + c2_rots[..., 2, 2] = cos_c2 + + c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots) + n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz) + + _, n_y, n_z = [n_xyz[..., i] for i in range(3)] + norm = torch.sqrt(eps + n_y**2 + n_z**2) + sin_n = -n_z / norm + cos_n = n_y / norm + + n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) + n_rots[..., 0, 0] = 1 + n_rots[..., 1, 1] = cos_n + n_rots[..., 1, 2] = -1 * sin_n + n_rots[..., 2, 1] = sin_n + n_rots[..., 2, 2] = cos_n + + rots = Rotation.mat_mul_mat(n_rots, c_rots) + + rots = rots.transpose(-1, -2) + rot_obj = Rotation(mat=rots.type(input_dtype)) + + return Frame(rot_obj, ca_xyz.type(input_dtype)) + + def cuda(self) -> Frame: + return Frame(self._r.cuda(), self._t.cuda()) + + @property + def dtype(self) -> torch.dtype: + assert self._r.dtype == self._t.dtype + return self._r.dtype + + def type(self, dtype) -> Frame: + return Frame(self._r.type(dtype), self._t.type(dtype)) + + +class Quaternion: + + def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor): + if quaternion.shape[-1] != 4: + raise ValueError(f'incorrect quaternion shape: {quaternion.shape}') + self._q = quaternion + self._t = translation + + @staticmethod + def identity( + shape: Iterable[int], + dtype: Optional[torch.dtype] = torch.float, + device: Optional[torch.device] = torch.device('cpu'), + requires_grad: bool = False, + ) -> Quaternion: + trans = zero_translation(shape, dtype, device, requires_grad) + quats = torch.zeros((*shape, 4), + dtype=dtype, + device=device, + requires_grad=requires_grad) + with torch.no_grad(): + quats[..., 0] = 1 + return Quaternion(quats, trans) + + def get_quats(self): + return self._q + + def get_trans(self): + return self._t + + def get_rot_mats(self): + quats = self.get_quats() + rot_mats = Quaternion.quat_to_rot(quats) + return rot_mats + + @staticmethod + def quat_to_rot(normalized_quat): + global _QUAT_TO_ROT_tensor + dtype = normalized_quat.dtype + normalized_quat = normalized_quat.float() + if _QUAT_TO_ROT_tensor.device != normalized_quat.device: + _QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to( + normalized_quat.device) + rot_tensor = torch.sum( + _QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None] + * normalized_quat[..., None, :, None], + dim=(-3, -2), + ) + rot_tensor = rot_tensor.type(dtype) + rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3) + return rot_tensor + + @staticmethod + def normalize_quat(quats): + dtype = quats.dtype + quats = quats.float() + quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) + quats = quats.type(dtype) + return quats + + @staticmethod + def quat_multiply_by_vec(quat, vec): + dtype = quat.dtype + quat = quat.float() + vec = vec.float() + global _QUAT_MULTIPLY_BY_VEC_tensor + if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device: + _QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to( + quat.device) + mat = _QUAT_MULTIPLY_BY_VEC_tensor + reshaped_mat = mat.view((1, ) * len(quat.shape[:-1]) + mat.shape) + return torch.sum( + reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], + dim=(-3, -2), + ).type(dtype) + + def compose_q_update_vec(self, + q_update_vec: torch.Tensor, + normalize_quats: bool = True) -> torch.Tensor: + quats = self.get_quats() + new_quats = quats + Quaternion.quat_multiply_by_vec( + quats, q_update_vec) + if normalize_quats: + new_quats = Quaternion.normalize_quat(new_quats) + return new_quats + + def compose_update_vec( + self, + update_vec: torch.Tensor, + pre_rot_mat: Rotation, + ) -> Quaternion: + q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:] + new_quats = self.compose_q_update_vec(q_vec) + + trans_update = pre_rot_mat.apply(t_vec) + new_trans = self._t + trans_update + + return Quaternion(new_quats, new_trans) + + def stop_rot_gradient(self) -> Quaternion: + return Quaternion(self._q.detach(), self._t) diff --git a/modelscope/models/science/unifold/modules/structure_module.py b/modelscope/models/science/unifold/modules/structure_module.py new file mode 100644 index 00000000..5d4da30b --- /dev/null +++ b/modelscope/models/science/unifold/modules/structure_module.py @@ -0,0 +1,592 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm, softmax_dropout +from unicore.utils import dict_multimap, one_hot, permute_final_dims + +from modelscope.models.science.unifold.data.residue_constants import ( + restype_atom14_mask, restype_atom14_rigid_group_positions, + restype_atom14_to_rigid_group, restype_rigid_group_default_frame) +from .attentions import gen_attn_mask +from .common import Linear, SimpleModuleList, residual +from .frame import Frame, Quaternion, Rotation + + +def ipa_point_weights_init_(weights): + with torch.no_grad(): + softplus_inverse_1 = 0.541324854612918 + weights.fill_(softplus_inverse_1) + + +def torsion_angles_to_frames( + frame: Frame, + alpha: torch.Tensor, + aatype: torch.Tensor, + default_frames: torch.Tensor, +): + default_frame = Frame.from_tensor_4x4(default_frames[aatype, ...]) + + bb_rot = alpha.new_zeros((*((1, ) * len(alpha.shape[:-1])), 2)) + bb_rot[..., 1] = 1 + + alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], + dim=-2) + + all_rots = alpha.new_zeros(default_frame.get_rots().rot_mat.shape) + all_rots[..., 0, 0] = 1 + all_rots[..., 1, 1] = alpha[..., 1] + all_rots[..., 1, 2] = -alpha[..., 0] + all_rots[..., 2, 1:] = alpha + + all_rots = Frame(Rotation(mat=all_rots), None) + + all_frames = default_frame.compose(all_rots) + + chi2_frame_to_frame = all_frames[..., 5] + chi3_frame_to_frame = all_frames[..., 6] + chi4_frame_to_frame = all_frames[..., 7] + + chi1_frame_to_bb = all_frames[..., 4] + chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) + chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) + chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) + + all_frames_to_bb = Frame.cat( + [ + all_frames[..., :5], + chi2_frame_to_bb.unsqueeze(-1), + chi3_frame_to_bb.unsqueeze(-1), + chi4_frame_to_bb.unsqueeze(-1), + ], + dim=-1, + ) + + all_frames_to_global = frame[..., None].compose(all_frames_to_bb) + + return all_frames_to_global + + +def frames_and_literature_positions_to_atom14_pos( + frame: Frame, + aatype: torch.Tensor, + default_frames, + group_idx, + atom_mask, + lit_positions, +): + group_mask = group_idx[aatype, ...] + group_mask = one_hot( + group_mask, + num_classes=default_frames.shape[-3], + ) + + t_atoms_to_global = frame[..., None, :] * group_mask + t_atoms_to_global = t_atoms_to_global.map_tensor_fn( + lambda x: torch.sum(x, dim=-1)) + + atom_mask = atom_mask[aatype, ...].unsqueeze(-1) + + lit_positions = lit_positions[aatype, ...] + pred_positions = t_atoms_to_global.apply(lit_positions) + pred_positions = pred_positions * atom_mask + + return pred_positions + + +class SideChainAngleResnetIteration(nn.Module): + + def __init__(self, d_hid): + super(SideChainAngleResnetIteration, self).__init__() + + self.d_hid = d_hid + + self.linear_1 = Linear(self.d_hid, self.d_hid, init='relu') + self.act = nn.GELU() + self.linear_2 = Linear(self.d_hid, self.d_hid, init='final') + + def forward(self, s: torch.Tensor) -> torch.Tensor: + + x = self.act(s) + x = self.linear_1(x) + x = self.act(x) + x = self.linear_2(x) + + return residual(s, x, self.training) + + +class SidechainAngleResnet(nn.Module): + + def __init__(self, d_in, d_hid, num_blocks, num_angles): + super(SidechainAngleResnet, self).__init__() + + self.linear_in = Linear(d_in, d_hid) + self.act = nn.GELU() + self.linear_initial = Linear(d_in, d_hid) + + self.layers = SimpleModuleList() + for _ in range(num_blocks): + self.layers.append(SideChainAngleResnetIteration(d_hid=d_hid)) + + self.linear_out = Linear(d_hid, num_angles * 2) + + def forward(self, s: torch.Tensor, + initial_s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + initial_s = self.linear_initial(self.act(initial_s)) + s = self.linear_in(self.act(s)) + + s = s + initial_s + + for layer in self.layers: + s = layer(s) + + s = self.linear_out(self.act(s)) + + s = s.view(s.shape[:-1] + (-1, 2)) + + unnormalized_s = s + norm_denom = torch.sqrt( + torch.clamp( + torch.sum(s.float()**2, dim=-1, keepdim=True), + min=1e-12, + )) + s = s.float() / norm_denom + + return unnormalized_s, s.type(unnormalized_s.dtype) + + +class InvariantPointAttention(nn.Module): + + def __init__( + self, + d_single: int, + d_pair: int, + d_hid: int, + num_heads: int, + num_qk_points: int, + num_v_points: int, + separate_kv: bool = False, + bias: bool = True, + eps: float = 1e-8, + ): + super(InvariantPointAttention, self).__init__() + + self.d_hid = d_hid + self.num_heads = num_heads + self.num_qk_points = num_qk_points + self.num_v_points = num_v_points + self.eps = eps + + hc = self.d_hid * self.num_heads + self.linear_q = Linear(d_single, hc, bias=bias) + self.separate_kv = separate_kv + if self.separate_kv: + self.linear_k = Linear(d_single, hc, bias=bias) + self.linear_v = Linear(d_single, hc, bias=bias) + else: + self.linear_kv = Linear(d_single, 2 * hc, bias=bias) + + hpq = self.num_heads * self.num_qk_points * 3 + self.linear_q_points = Linear(d_single, hpq) + hpk = self.num_heads * self.num_qk_points * 3 + hpv = self.num_heads * self.num_v_points * 3 + if self.separate_kv: + self.linear_k_points = Linear(d_single, hpk) + self.linear_v_points = Linear(d_single, hpv) + else: + hpkv = hpk + hpv + self.linear_kv_points = Linear(d_single, hpkv) + + self.linear_b = Linear(d_pair, self.num_heads) + + self.head_weights = nn.Parameter(torch.zeros((num_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = self.num_heads * ( + d_pair + self.d_hid + self.num_v_points * 4) + self.linear_out = Linear(concat_out_dim, d_single, init='final') + + self.softplus = nn.Softplus() + + def forward( + self, + s: torch.Tensor, + z: torch.Tensor, + f: Frame, + square_mask: torch.Tensor, + ) -> torch.Tensor: + q = self.linear_q(s) + + q = q.view(q.shape[:-1] + (self.num_heads, -1)) + + if self.separate_kv: + k = self.linear_k(s) + v = self.linear_v(s) + k = k.view(k.shape[:-1] + (self.num_heads, -1)) + v = v.view(v.shape[:-1] + (self.num_heads, -1)) + else: + kv = self.linear_kv(s) + kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) + k, v = torch.split(kv, self.d_hid, dim=-1) + + q_pts = self.linear_q_points(s) + + def process_points(pts, no_points): + shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3) + if self.separate_kv: + # alphafold-multimer uses different layout + pts = pts.view(pts.shape[:-1] + + (self.num_heads, no_points * 3)) + pts = torch.split(pts, pts.shape[-1] // 3, dim=-1) + pts = torch.stack(pts, dim=-1).view(*shape) + pts = f[..., None].apply(pts) + + pts = pts.view(pts.shape[:-2] + (self.num_heads, no_points, 3)) + return pts + + q_pts = process_points(q_pts, self.num_qk_points) + + if self.separate_kv: + k_pts = self.linear_k_points(s) + v_pts = self.linear_v_points(s) + k_pts = process_points(k_pts, self.num_qk_points) + v_pts = process_points(v_pts, self.num_v_points) + else: + kv_pts = self.linear_kv_points(s) + + kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) + kv_pts = torch.stack(kv_pts, dim=-1) + kv_pts = f[..., None].apply(kv_pts) + + kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) + + k_pts, v_pts = torch.split( + kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) + + bias = self.linear_b(z) + + attn = torch.matmul( + permute_final_dims(q, (1, 0, 2)), + permute_final_dims(k, (1, 2, 0)), + ) + + if self.training: + attn = attn * math.sqrt(1.0 / (3 * self.d_hid)) + attn = attn + ( + math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att = pt_att.float()**2 + else: + attn *= math.sqrt(1.0 / (3 * self.d_hid)) + attn += (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) + pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) + pt_att *= pt_att + + pt_att = pt_att.sum(dim=-1) + head_weights = self.softplus(self.head_weights).view( + *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) + head_weights = head_weights * math.sqrt( + 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) + pt_att *= head_weights * (-0.5) + + pt_att = torch.sum(pt_att, dim=-1) + + pt_att = permute_final_dims(pt_att, (2, 0, 1)) + attn += square_mask + attn = softmax_dropout( + attn, 0, self.training, bias=pt_att.type(attn.dtype)) + del pt_att, q_pts, k_pts, bias + o = torch.matmul(attn, v.transpose(-2, -3)).transpose(-2, -3) + o = o.contiguous().view(*o.shape[:-2], -1) + del q, k, v + o_pts = torch.sum( + (attn[..., None, :, :, None] + * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), + dim=-2, + ) + + o_pts = permute_final_dims(o_pts, (2, 0, 3, 1)) + o_pts = f[..., None, None].invert_apply(o_pts) + if self.training: + o_pts_norm = torch.sqrt( + torch.sum(o_pts.float()**2, dim=-1) + self.eps).type( + o_pts.dtype) + else: + o_pts_norm = torch.sqrt(torch.sum(o_pts**2, dim=-1) + + self.eps).type(o_pts.dtype) + + o_pts_norm = o_pts_norm.view(*o_pts_norm.shape[:-2], -1) + + o_pts = o_pts.view(*o_pts.shape[:-3], -1, 3) + + o_pair = torch.matmul(attn.transpose(-2, -3), z) + + o_pair = o_pair.view(*o_pair.shape[:-2], -1) + + s = self.linear_out( + torch.cat((o, *torch.unbind(o_pts, dim=-1), o_pts_norm, o_pair), + dim=-1)) + + return s + + +class BackboneUpdate(nn.Module): + + def __init__(self, d_single): + super(BackboneUpdate, self).__init__() + self.linear = Linear(d_single, 6, init='final') + + def forward(self, s: torch.Tensor): + return self.linear(s) + + +class StructureModuleTransitionLayer(nn.Module): + + def __init__(self, c): + super(StructureModuleTransitionLayer, self).__init__() + + self.linear_1 = Linear(c, c, init='relu') + self.linear_2 = Linear(c, c, init='relu') + self.act = nn.GELU() + self.linear_3 = Linear(c, c, init='final') + + def forward(self, s): + s_old = s + s = self.linear_1(s) + s = self.act(s) + s = self.linear_2(s) + s = self.act(s) + s = self.linear_3(s) + + s = residual(s_old, s, self.training) + + return s + + +class StructureModuleTransition(nn.Module): + + def __init__(self, c, num_layers, dropout_rate): + super(StructureModuleTransition, self).__init__() + + self.num_layers = num_layers + self.dropout_rate = dropout_rate + + self.layers = SimpleModuleList() + for _ in range(self.num_layers): + self.layers.append(StructureModuleTransitionLayer(c)) + + self.dropout = nn.Dropout(self.dropout_rate) + self.layer_norm = LayerNorm(c) + + def forward(self, s): + for layer in self.layers: + s = layer(s) + + s = self.dropout(s) + s = self.layer_norm(s) + + return s + + +class StructureModule(nn.Module): + + def __init__( + self, + d_single, + d_pair, + d_ipa, + d_angle, + num_heads_ipa, + num_qk_points, + num_v_points, + dropout_rate, + num_blocks, + no_transition_layers, + num_resnet_blocks, + num_angles, + trans_scale_factor, + separate_kv, + ipa_bias, + epsilon, + inf, + **kwargs, + ): + super(StructureModule, self).__init__() + + self.num_blocks = num_blocks + self.trans_scale_factor = trans_scale_factor + self.default_frames = None + self.group_idx = None + self.atom_mask = None + self.lit_positions = None + self.inf = inf + + self.layer_norm_s = LayerNorm(d_single) + self.layer_norm_z = LayerNorm(d_pair) + + self.linear_in = Linear(d_single, d_single) + + self.ipa = InvariantPointAttention( + d_single, + d_pair, + d_ipa, + num_heads_ipa, + num_qk_points, + num_v_points, + separate_kv=separate_kv, + bias=ipa_bias, + eps=epsilon, + ) + + self.ipa_dropout = nn.Dropout(dropout_rate) + self.layer_norm_ipa = LayerNorm(d_single) + + self.transition = StructureModuleTransition( + d_single, + no_transition_layers, + dropout_rate, + ) + + self.bb_update = BackboneUpdate(d_single) + + self.angle_resnet = SidechainAngleResnet( + d_single, + d_angle, + num_resnet_blocks, + num_angles, + ) + + def forward( + self, + s, + z, + aatype, + mask=None, + ): + if mask is None: + mask = s.new_ones(s.shape[:-1]) + + # generate square mask + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3) + s = self.layer_norm_s(s) + z = self.layer_norm_z(z) + initial_s = s + s = self.linear_in(s) + + quat_encoder = Quaternion.identity( + s.shape[:-1], + s.dtype, + s.device, + requires_grad=False, + ) + backb_to_global = Frame( + Rotation(mat=quat_encoder.get_rot_mats(), ), + quat_encoder.get_trans(), + ) + outputs = [] + for i in range(self.num_blocks): + s = residual(s, self.ipa(s, z, backb_to_global, square_mask), + self.training) + s = self.ipa_dropout(s) + s = self.layer_norm_ipa(s) + s = self.transition(s) + + # update quaternion encoder + # use backb_to_global to avoid quat-to-rot conversion + quat_encoder = quat_encoder.compose_update_vec( + self.bb_update(s), pre_rot_mat=backb_to_global.get_rots()) + + # initial_s is always used to update the backbone + unnormalized_angles, angles = self.angle_resnet(s, initial_s) + + # convert quaternion to rotation matrix + backb_to_global = Frame( + Rotation(mat=quat_encoder.get_rot_mats(), ), + quat_encoder.get_trans(), + ) + if i == self.num_blocks - 1: + all_frames_to_global = self.torsion_angles_to_frames( + backb_to_global.scale_translation(self.trans_scale_factor), + angles, + aatype, + ) + + pred_positions = self.frames_and_literature_positions_to_atom14_pos( + all_frames_to_global, + aatype, + ) + + preds = { + 'frames': + backb_to_global.scale_translation( + self.trans_scale_factor).to_tensor_4x4(), + 'unnormalized_angles': + unnormalized_angles, + 'angles': + angles, + } + + outputs.append(preds) + if i < (self.num_blocks - 1): + # stop gradient in iteration + quat_encoder = quat_encoder.stop_rot_gradient() + backb_to_global = backb_to_global.stop_rot_gradient() + + outputs = dict_multimap(torch.stack, outputs) + outputs['sidechain_frames'] = all_frames_to_global.to_tensor_4x4() + outputs['positions'] = pred_positions + outputs['single'] = s + + return outputs + + def _init_residue_constants(self, float_dtype, device): + if self.default_frames is None: + self.default_frames = torch.tensor( + restype_rigid_group_default_frame, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.group_idx is None: + self.group_idx = torch.tensor( + restype_atom14_to_rigid_group, + device=device, + requires_grad=False, + ) + if self.atom_mask is None: + self.atom_mask = torch.tensor( + restype_atom14_mask, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + if self.lit_positions is None: + self.lit_positions = torch.tensor( + restype_atom14_rigid_group_positions, + dtype=float_dtype, + device=device, + requires_grad=False, + ) + + def torsion_angles_to_frames(self, frame, alpha, aatype): + self._init_residue_constants(alpha.dtype, alpha.device) + return torsion_angles_to_frames(frame, alpha, aatype, + self.default_frames) + + def frames_and_literature_positions_to_atom14_pos(self, frame, aatype): + self._init_residue_constants(frame.get_rots().dtype, + frame.get_rots().device) + return frames_and_literature_positions_to_atom14_pos( + frame, + aatype, + self.default_frames, + self.group_idx, + self.atom_mask, + self.lit_positions, + ) diff --git a/modelscope/models/science/unifold/modules/template.py b/modelscope/models/science/unifold/modules/template.py new file mode 100644 index 00000000..49e5bec0 --- /dev/null +++ b/modelscope/models/science/unifold/modules/template.py @@ -0,0 +1,330 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import math +from functools import partial +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import (checkpoint_sequential, permute_final_dims, + tensor_tree_map) + +from .attentions import (Attention, TriangleAttentionEnding, + TriangleAttentionStarting, gen_attn_mask) +from .common import (Linear, SimpleModuleList, Transition, + bias_dropout_residual, chunk_layer, residual, + tri_mul_residual) +from .featurization import build_template_pair_feat_v2 +from .triangle_multiplication import (TriangleMultiplicationIncoming, + TriangleMultiplicationOutgoing) + + +class TemplatePointwiseAttention(nn.Module): + + def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs): + super(TemplatePointwiseAttention, self).__init__() + + self.inf = inf + + self.mha = Attention( + d_pair, + d_template, + d_template, + d_hid, + num_heads, + gating=False, + ) + + def _chunk( + self, + z: torch.Tensor, + t: torch.Tensor, + mask: torch.Tensor, + chunk_size: int, + ) -> torch.Tensor: + mha_inputs = { + 'q': z, + 'k': t, + 'v': t, + 'mask': mask, + } + return chunk_layer( + self.mha, + mha_inputs, + chunk_size=chunk_size, + num_batch_dims=len(z.shape[:-2]), + ) + + def forward( + self, + t: torch.Tensor, + z: torch.Tensor, + template_mask: Optional[torch.Tensor] = None, + chunk_size: Optional[int] = None, + ) -> torch.Tensor: + if template_mask is None: + template_mask = t.new_ones(t.shape[:-3]) + + mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None, + None, :] + z = z.unsqueeze(-2) + + t = permute_final_dims(t, (1, 2, 0, 3)) + + if chunk_size is not None: + z = self._chunk(z, t, mask, chunk_size) + else: + z = self.mha(z, t, t, mask=mask) + + z = z.squeeze(-2) + + return z + + +class TemplateProjection(nn.Module): + + def __init__(self, d_template, d_pair, **kwargs): + super(TemplateProjection, self).__init__() + + self.d_pair = d_pair + self.act = nn.ReLU() + self.output_linear = Linear(d_template, d_pair, init='relu') + + def forward(self, t, z) -> torch.Tensor: + if t is None: + # handle for non-template case + shape = z.shape + shape[-1] = self.d_pair + t = torch.zeros(shape, dtype=z.dtype, device=z.device) + t = self.act(t) + z_t = self.output_linear(t) + return z_t + + +class TemplatePairStackBlock(nn.Module): + + def __init__( + self, + d_template: int, + d_hid_tri_att: int, + d_hid_tri_mul: int, + num_heads: int, + pair_transition_n: int, + dropout_rate: float, + tri_attn_first: bool, + inf: float, + **kwargs, + ): + super(TemplatePairStackBlock, self).__init__() + + self.tri_att_start = TriangleAttentionStarting( + d_template, + d_hid_tri_att, + num_heads, + ) + self.tri_att_end = TriangleAttentionEnding( + d_template, + d_hid_tri_att, + num_heads, + ) + + self.tri_mul_out = TriangleMultiplicationOutgoing( + d_template, + d_hid_tri_mul, + ) + self.tri_mul_in = TriangleMultiplicationIncoming( + d_template, + d_hid_tri_mul, + ) + + self.pair_transition = Transition( + d_template, + pair_transition_n, + ) + self.tri_attn_first = tri_attn_first + self.dropout = dropout_rate + self.row_dropout_share_dim = -3 + self.col_dropout_share_dim = -2 + + def forward( + self, + s: torch.Tensor, + mask: torch.Tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + chunk_size: Optional[int] = None, + block_size: Optional[int] = None, + ): + if self.tri_attn_first: + s = bias_dropout_residual( + self.tri_att_start, + s, + self.tri_att_start( + s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + ) + + s = bias_dropout_residual( + self.tri_att_end, + s, + self.tri_att_end( + s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.dropout, + self.training, + ) + s = tri_mul_residual( + self.tri_mul_out, + s, + self.tri_mul_out(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = tri_mul_residual( + self.tri_mul_in, + s, + self.tri_mul_in(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + else: + s = tri_mul_residual( + self.tri_mul_out, + s, + self.tri_mul_out(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = tri_mul_residual( + self.tri_mul_in, + s, + self.tri_mul_in(s, mask=mask, block_size=block_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + block_size=block_size, + ) + + s = bias_dropout_residual( + self.tri_att_start, + s, + self.tri_att_start( + s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), + self.row_dropout_share_dim, + self.dropout, + self.training, + ) + + s = bias_dropout_residual( + self.tri_att_end, + s, + self.tri_att_end( + s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), + self.col_dropout_share_dim, + self.dropout, + self.training, + ) + s = residual(s, self.pair_transition( + s, + chunk_size=chunk_size, + ), self.training) + return s + + +class TemplatePairStack(nn.Module): + + def __init__( + self, + d_template, + d_hid_tri_att, + d_hid_tri_mul, + num_blocks, + num_heads, + pair_transition_n, + dropout_rate, + tri_attn_first, + inf=1e9, + **kwargs, + ): + super(TemplatePairStack, self).__init__() + + self.blocks = SimpleModuleList() + for _ in range(num_blocks): + self.blocks.append( + TemplatePairStackBlock( + d_template=d_template, + d_hid_tri_att=d_hid_tri_att, + d_hid_tri_mul=d_hid_tri_mul, + num_heads=num_heads, + pair_transition_n=pair_transition_n, + dropout_rate=dropout_rate, + inf=inf, + tri_attn_first=tri_attn_first, + )) + + self.layer_norm = LayerNorm(d_template) + + def forward( + self, + single_templates: Tuple[torch.Tensor], + mask: torch.tensor, + tri_start_attn_mask: torch.Tensor, + tri_end_attn_mask: torch.Tensor, + templ_dim: int, + chunk_size: int, + block_size: int, + return_mean: bool, + ): + + def one_template(i): + (s, ) = checkpoint_sequential( + functions=[ + partial( + b, + mask=mask, + tri_start_attn_mask=tri_start_attn_mask, + tri_end_attn_mask=tri_end_attn_mask, + chunk_size=chunk_size, + block_size=block_size, + ) for b in self.blocks + ], + input=(single_templates[i], ), + ) + return s + + n_templ = len(single_templates) + if n_templ > 0: + new_single_templates = [one_template(0)] + if return_mean: + t = self.layer_norm(new_single_templates[0]) + for i in range(1, n_templ): + s = one_template(i) + if return_mean: + t = residual(t, self.layer_norm(s), self.training) + else: + new_single_templates.append(s) + + if return_mean: + if n_templ > 0: + t /= n_templ + else: + t = None + else: + t = torch.cat( + [s.unsqueeze(templ_dim) for s in new_single_templates], + dim=templ_dim) + t = self.layer_norm(t) + + return t diff --git a/modelscope/models/science/unifold/modules/triangle_multiplication.py b/modelscope/models/science/unifold/modules/triangle_multiplication.py new file mode 100644 index 00000000..c4094cd2 --- /dev/null +++ b/modelscope/models/science/unifold/modules/triangle_multiplication.py @@ -0,0 +1,158 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +from functools import partialmethod +from typing import List, Optional + +import torch +import torch.nn as nn +from unicore.modules import LayerNorm +from unicore.utils import permute_final_dims + +from .common import Linear + + +class TriangleMultiplication(nn.Module): + + def __init__(self, d_pair, d_hid, outgoing=True): + super(TriangleMultiplication, self).__init__() + self.outgoing = outgoing + + self.linear_ab_p = Linear(d_pair, d_hid * 2) + self.linear_ab_g = Linear(d_pair, d_hid * 2, init='gating') + + self.linear_g = Linear(d_pair, d_pair, init='gating') + self.linear_z = Linear(d_hid, d_pair, init='final') + + self.layer_norm_in = LayerNorm(d_pair) + self.layer_norm_out = LayerNorm(d_hid) + + self._alphafold_original_mode = False + + def _chunk_2d( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + block_size: int = None, + ) -> torch.Tensor: + + # avoid too small chunk size + # block_size = max(block_size, 256) + new_z = z.new_zeros(z.shape) + dim1 = z.shape[-3] + + def _slice_linear(z, linear: Linear, a=True): + d_hid = linear.bias.shape[0] // 2 + index = 0 if a else d_hid + p = ( + nn.functional.linear(z, linear.weight[index:index + d_hid]) + + linear.bias[index:index + d_hid]) + return p + + def _chunk_projection(z, mask, a=True): + p = _slice_linear(z, self.linear_ab_p, a) * mask + p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a)) + return p + + num_chunk = (dim1 + block_size - 1) // block_size + for i in range(num_chunk): + chunk_start = i * block_size + chunk_end = min(chunk_start + block_size, dim1) + if self.outgoing: + a_chunk = _chunk_projection( + z[..., chunk_start:chunk_end, :, :], + mask[..., chunk_start:chunk_end, :, :], + a=True, + ) + a_chunk = permute_final_dims(a_chunk, (2, 0, 1)) + else: + a_chunk = _chunk_projection( + z[..., :, chunk_start:chunk_end, :], + mask[..., :, chunk_start:chunk_end, :], + a=True, + ) + a_chunk = a_chunk.transpose(-1, -3) + + for j in range(num_chunk): + j_chunk_start = j * block_size + j_chunk_end = min(j_chunk_start + block_size, dim1) + if self.outgoing: + b_chunk = _chunk_projection( + z[..., j_chunk_start:j_chunk_end, :, :], + mask[..., j_chunk_start:j_chunk_end, :, :], + a=False, + ) + b_chunk = b_chunk.transpose(-1, -3) + else: + b_chunk = _chunk_projection( + z[..., :, j_chunk_start:j_chunk_end, :], + mask[..., :, j_chunk_start:j_chunk_end, :], + a=False, + ) + b_chunk = permute_final_dims(b_chunk, (2, 0, 1)) + x_chunk = torch.matmul(a_chunk, b_chunk) + del b_chunk + x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) + x_chunk = self.layer_norm_out(x_chunk) + x_chunk = self.linear_z(x_chunk) + x_chunk *= torch.sigmoid( + self.linear_g(z[..., chunk_start:chunk_end, + j_chunk_start:j_chunk_end, :])) + new_z[..., chunk_start:chunk_end, + j_chunk_start:j_chunk_end, :] = x_chunk + del x_chunk + del a_chunk + return new_z + + def forward( + self, + z: torch.Tensor, + mask: Optional[torch.Tensor] = None, + block_size=None, + ) -> torch.Tensor: + + mask = mask.unsqueeze(-1) + if not self._alphafold_original_mode: + # divided by 1/sqrt(dim) for numerical stability + mask = mask * (mask.shape[-2]**-0.5) + + z = self.layer_norm_in(z) + if not self.training and block_size is not None: + return self._chunk_2d(z, mask, block_size=block_size) + + g = nn.functional.linear(z, self.linear_g.weight) + if self.training: + ab = self.linear_ab_p(z) * mask * torch.sigmoid( + self.linear_ab_g(z)) + else: + ab = self.linear_ab_p(z) + ab *= mask + ab *= torch.sigmoid(self.linear_ab_g(z)) + a, b = torch.chunk(ab, 2, dim=-1) + del z, ab + + if self.outgoing: + a = permute_final_dims(a, (2, 0, 1)) + b = b.transpose(-1, -3) + else: + b = permute_final_dims(b, (2, 0, 1)) + a = a.transpose(-1, -3) + x = torch.matmul(a, b) + del a, b + + x = permute_final_dims(x, (1, 2, 0)) + + x = self.layer_norm_out(x) + x = nn.functional.linear(x, self.linear_z.weight) + return x, g + + def get_output_bias(self): + return self.linear_z.bias, self.linear_g.bias + + +class TriangleMultiplicationOutgoing(TriangleMultiplication): + __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True) + + +class TriangleMultiplicationIncoming(TriangleMultiplication): + __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False) diff --git a/modelscope/models/science/unifold/msa/__init__.py b/modelscope/models/science/unifold/msa/__init__.py new file mode 100644 index 00000000..2121062c --- /dev/null +++ b/modelscope/models/science/unifold/msa/__init__.py @@ -0,0 +1 @@ +""" Scripts for MSA & template searching. """ diff --git a/modelscope/models/science/unifold/msa/mmcif.py b/modelscope/models/science/unifold/msa/mmcif.py new file mode 100644 index 00000000..cf67239f --- /dev/null +++ b/modelscope/models/science/unifold/msa/mmcif.py @@ -0,0 +1,483 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Parses the mmCIF file format.""" +import collections +import dataclasses +import functools +import io +from typing import Any, Mapping, Optional, Sequence, Tuple + +from absl import logging +from Bio import PDB +from Bio.Data import SCOPData +from Bio.PDB.MMCIFParser import MMCIFParser + +# Type aliases: +ChainId = str +PdbHeader = Mapping[str, Any] +PdbStructure = PDB.Structure.Structure +SeqRes = str +MmCIFDict = Mapping[str, Sequence[str]] + + +@dataclasses.dataclass(frozen=True) +class Monomer: + id: str + num: int + + +# Note - mmCIF format provides no guarantees on the type of author-assigned +# sequence numbers. They need not be integers. +@dataclasses.dataclass(frozen=True) +class AtomSite: + residue_name: str + author_chain_id: str + mmcif_chain_id: str + author_seq_num: str + mmcif_seq_num: int + insertion_code: str + hetatm_atom: str + model_num: int + + +# Used to map SEQRES index to a residue in the structure. +@dataclasses.dataclass(frozen=True) +class ResiduePosition: + chain_id: str + residue_number: int + insertion_code: str + + +@dataclasses.dataclass(frozen=True) +class ResidueAtPosition: + position: Optional[ResiduePosition] + name: str + is_missing: bool + hetflag: str + + +@dataclasses.dataclass(frozen=True) +class MmcifObject: + """Representation of a parsed mmCIF file. + + Contains: + file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all + files being processed. + header: Biopython header. + structure: Biopython structure. + chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. + {'A': 'ABCDEFG'} + seqres_to_structure: Dict; for each chain_id contains a mapping between + SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, 1: ResidueAtPosition, ...}} + raw_string: The raw string used to construct the MmcifObject. + """ + + file_id: str + header: PdbHeader + structure: PdbStructure + chain_to_seqres: Mapping[ChainId, SeqRes] + seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] + raw_string: Any + mmcif_to_author_chain_id: Mapping[ChainId, ChainId] + valid_chains: Mapping[ChainId, str] + + +@dataclasses.dataclass(frozen=True) +class ParsingResult: + """Returned by the parse function. + + Contains: + mmcif_object: A MmcifObject, may be None if no chain could be successfully + parsed. + errors: A dict mapping (file_id, chain_id) to any exception generated. + """ + + mmcif_object: Optional[MmcifObject] + errors: Mapping[Tuple[str, str], Any] + + +class ParseError(Exception): + """An error indicating that an mmCIF file could not be parsed.""" + + +def mmcif_loop_to_list(prefix: str, + parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + cols = [] + data = [] + for key, value in parsed_info.items(): + if key.startswith(prefix): + cols.append(key) + data.append(value) + + assert all([ + len(xs) == len(data[0]) for xs in data + ]), ('mmCIF error: Not all loops are the same length: %s' % cols) + + return [dict(zip(cols, xs)) for xs in zip(*data)] + + +def mmcif_loop_to_dict( + prefix: str, + index: str, + parsed_info: MmCIFDict, +) -> Mapping[str, Mapping[str, str]]: + """Extracts loop associated with a prefix from mmCIF data as a dictionary. + + Args: + prefix: Prefix shared by each of the data items in the loop. + e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + _entity_poly_seq.mon_id. Should include the trailing period. + index: Which item of loop data should serve as the key. + parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython + parser. + + Returns: + Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + entries = mmcif_loop_to_list(prefix, parsed_info) + return {entry[index]: entry for entry in entries} + + +@functools.lru_cache(16, typed=False) +def fast_parse(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = MMCIFParser(QUIET=True) + # handle = io.StringIO(mmcif_string) + # full_structure = parser.get_structure('', handle) + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ''): 'No protein chains found in this file.'}) + + mmcif_to_author_chain_id = {} + # seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + mmcif_to_author_chain_id[ + atom.mmcif_chain_id] = atom.author_chain_id + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=None, + chain_to_seqres=None, + seqres_to_structure=None, + raw_string=parsed_info, + mmcif_to_author_chain_id=mmcif_to_author_chain_id, + valid_chains=valid_chains, + ) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +@functools.lru_cache(16, typed=False) +def parse(*, + file_id: str, + mmcif_string: str, + catch_all_errors: bool = True) -> ParsingResult: + """Entry point, parses an mmcif_string. + + Args: + file_id: A string identifier for this file. Should be unique within the + collection of files being processed. + mmcif_string: Contents of an mmCIF file. + catch_all_errors: If True, all exceptions are caught and error messages are + returned as part of the ParsingResult. If False exceptions will be allowed + to propagate. + + Returns: + A ParsingResult. + """ + errors = {} + try: + parser = PDB.MMCIFParser(QUIET=True) + handle = io.StringIO(mmcif_string) + full_structure = parser.get_structure('', handle) + first_model_structure = _get_first_model(full_structure) + # Extract the _mmcif_dict from the parser, which contains useful fields not + # reflected in the Biopython structure. + parsed_info = parser._mmcif_dict # pylint:disable=protected-access + + # Ensure all values are lists, even if singletons. + for key, value in parsed_info.items(): + if not isinstance(value, list): + parsed_info[key] = [value] + + header = _get_header(parsed_info) + + # Determine the protein chains, and their start numbers according to the + # internal mmCIF numbering scheme (likely but not guaranteed to be 1). + valid_chains = _get_protein_chains(parsed_info=parsed_info) + if not valid_chains: + return ParsingResult( + None, {(file_id, ''): 'No protein chains found in this file.'}) + seq_start_num = { + chain_id: min([monomer.num for monomer in seq]) + for chain_id, seq in valid_chains.items() + } + + # Loop over the atoms for which we have coordinates. Populate two mappings: + # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used + # the authors / Biopython). + # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). + mmcif_to_author_chain_id = {} + seq_to_structure_mappings = {} + for atom in _get_atom_site_list(parsed_info): + if atom.model_num != '1': + # We only process the first model at the moment. + continue + + mmcif_to_author_chain_id[ + atom.mmcif_chain_id] = atom.author_chain_id + + if atom.mmcif_chain_id in valid_chains: + hetflag = ' ' + if atom.hetatm_atom == 'HETATM': + # Water atoms are assigned a special hetflag of W in Biopython. We + # need to do the same, so that this hetflag can be used to fetch + # a residue from the Biopython structure by id. + if atom.residue_name in ('HOH', 'WAT'): + hetflag = 'W' + else: + hetflag = 'H_' + atom.residue_name + insertion_code = atom.insertion_code + if not _is_set(atom.insertion_code): + insertion_code = ' ' + position = ResiduePosition( + chain_id=atom.author_chain_id, + residue_number=int(atom.author_seq_num), + insertion_code=insertion_code, + ) + seq_idx = int( + atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] + current = seq_to_structure_mappings.get( + atom.author_chain_id, {}) + current[seq_idx] = ResidueAtPosition( + position=position, + name=atom.residue_name, + is_missing=False, + hetflag=hetflag, + ) + seq_to_structure_mappings[atom.author_chain_id] = current + + # Add missing residue information to seq_to_structure_mappings. + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + current_mapping = seq_to_structure_mappings[author_chain] + for idx, monomer in enumerate(seq_info): + if idx not in current_mapping: + current_mapping[idx] = ResidueAtPosition( + position=None, + name=monomer.id, + is_missing=True, + hetflag=' ') + + author_chain_to_sequence = {} + for chain_id, seq_info in valid_chains.items(): + author_chain = mmcif_to_author_chain_id[chain_id] + seq = [] + for monomer in seq_info: + code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') + seq.append(code if len(code) == 1 else 'X') + seq = ''.join(seq) + author_chain_to_sequence[author_chain] = seq + + mmcif_object = MmcifObject( + file_id=file_id, + header=header, + structure=first_model_structure, + chain_to_seqres=author_chain_to_sequence, + seqres_to_structure=seq_to_structure_mappings, + raw_string=parsed_info, + mmcif_to_author_chain_id=mmcif_to_author_chain_id, + valid_chains=valid_chains, + ) + + return ParsingResult(mmcif_object=mmcif_object, errors=errors) + except Exception as e: # pylint:disable=broad-except + errors[(file_id, '')] = e + if not catch_all_errors: + raise + return ParsingResult(mmcif_object=None, errors=errors) + + +def _get_first_model(structure: PdbStructure) -> PdbStructure: + """Returns the first model in a Biopython structure.""" + return next(structure.get_models()) + + +_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 + + +def get_release_date(parsed_info: MmCIFDict) -> str: + """Returns the oldest revision date.""" + revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] + return min(revision_dates) + + +def _get_header(parsed_info: MmCIFDict) -> PdbHeader: + """Returns a basic header containing method, release date and resolution.""" + header = {} + + experiments = mmcif_loop_to_list('_exptl.', parsed_info) + header['structure_method'] = ','.join( + [experiment['_exptl.method'].lower() for experiment in experiments]) + + # Note: The release_date here corresponds to the oldest revision. We prefer to + # use this for dataset filtering over the deposition_date. + if '_pdbx_audit_revision_history.revision_date' in parsed_info: + header['release_date'] = get_release_date(parsed_info) + else: + logging.warning('Could not determine release_date: %s', + parsed_info['_entry.id']) + + header['resolution'] = 0.00 + for res_key in ( + '_refine.ls_d_res_high', + '_em_3d_reconstruction.resolution', + '_reflns.d_resolution_high', + ): + if res_key in parsed_info: + try: + raw_resolution = parsed_info[res_key][0] + header['resolution'] = float(raw_resolution) + except ValueError: + logging.debug('Invalid resolution format: %s', + parsed_info[res_key]) + + return header + + +def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: + """Returns list of atom sites; contains data not present in the structure.""" + return [ + AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension + parsed_info['_atom_site.label_comp_id'], + parsed_info['_atom_site.auth_asym_id'], + parsed_info['_atom_site.label_asym_id'], + parsed_info['_atom_site.auth_seq_id'], + parsed_info['_atom_site.label_seq_id'], + parsed_info['_atom_site.pdbx_PDB_ins_code'], + parsed_info['_atom_site.group_PDB'], + parsed_info['_atom_site.pdbx_PDB_model_num'], + ) + ] + + +def _get_protein_chains( + *, parsed_info: Mapping[str, + Any]) -> Mapping[ChainId, Sequence[Monomer]]: + """Extracts polymer information for protein chains only. + + Args: + parsed_info: _mmcif_dict produced by the Biopython parser. + + Returns: + A dict mapping mmcif chain id to a list of Monomers. + """ + # Get polymer information for each entity in the structure. + entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) + + polymers = collections.defaultdict(list) + for entity_poly_seq in entity_poly_seqs: + polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( + Monomer( + id=entity_poly_seq['_entity_poly_seq.mon_id'], + num=int(entity_poly_seq['_entity_poly_seq.num']), + )) + + # Get chemical compositions. Will allow us to identify which of these polymers + # are proteins. + chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', + parsed_info) + + # Get chains information for each entity. Necessary so that we can return a + # dict keyed on chain id rather than entity. + struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) + + entity_to_mmcif_chains = collections.defaultdict(list) + for struct_asym in struct_asyms: + chain_id = struct_asym['_struct_asym.id'] + entity_id = struct_asym['_struct_asym.entity_id'] + entity_to_mmcif_chains[entity_id].append(chain_id) + + # Identify and return the valid protein chains. + valid_chains = {} + for entity_id, seq_info in polymers.items(): + chain_ids = entity_to_mmcif_chains[entity_id] + + # Reject polymers without any peptide-like components, such as DNA/RNA. + if any([ + 'peptide' in chem_comps[monomer.id]['_chem_comp.type'] + for monomer in seq_info + ]): + for chain_id in chain_ids: + valid_chains[chain_id] = seq_info + return valid_chains + + +def _is_set(data: str) -> bool: + """Returns False if data is a special mmCIF character indicating 'unset'.""" + return data not in ('.', '?') diff --git a/modelscope/models/science/unifold/msa/msa_identifiers.py b/modelscope/models/science/unifold/msa/msa_identifiers.py new file mode 100644 index 00000000..366239db --- /dev/null +++ b/modelscope/models/science/unifold/msa/msa_identifiers.py @@ -0,0 +1,88 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities for extracting identifiers from MSA sequence descriptions.""" + +import dataclasses +import re +from typing import Optional + +# Sequences coming from UniProtKB database come in the +# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` +# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). +_UNIPROT_PATTERN = re.compile( + r""" + ^ + # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot + (?:tr|sp) + \| + # A primary accession number of the UniProtKB entry. + (?P[A-Za-z0-9]{6,10}) + # Occasionally there is a _0 or _1 isoform suffix, which we ignore. + (?:_\d)? + \| + # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic + # protein ID code. + (?:[A-Za-z0-9]+) + _ + # A mnemonic species identification code. + (?P([A-Za-z0-9]){1,5}) + # Small BFD uses a final value after an underscore, which we ignore. + (?:_\d+)? + $ + """, + re.VERBOSE, +) + + +@dataclasses.dataclass(frozen=True) +class Identifiers: + species_id: str = '' + + +def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: + """Gets accession id and species from an msa sequence identifier. + + The sequence identifier has the format specified by + _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. + An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` + + Args: + msa_sequence_identifier: a sequence identifier. + + Returns: + An `Identifiers` instance with a species_id. These + can be empty in the case where no identifier was found. + """ + matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) + if matches: + return Identifiers(species_id=matches.group('SpeciesIdentifier')) + return Identifiers() + + +def _extract_sequence_identifier(description: str) -> Optional[str]: + """Extracts sequence identifier from description. Returns None if no match.""" + split_description = description.split() + if split_description: + return split_description[0].partition('/')[0] + else: + return None + + +def get_identifiers(description: str) -> Identifiers: + """Computes extra MSA features from the description.""" + sequence_identifier = _extract_sequence_identifier(description) + if sequence_identifier is None: + return Identifiers() + else: + return _parse_sequence_identifier(sequence_identifier) diff --git a/modelscope/models/science/unifold/msa/parsers.py b/modelscope/models/science/unifold/msa/parsers.py new file mode 100644 index 00000000..bf36c816 --- /dev/null +++ b/modelscope/models/science/unifold/msa/parsers.py @@ -0,0 +1,627 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for parsing various file formats.""" +import collections +import dataclasses +import itertools +import re +import string +from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple + +DeletionMatrix = Sequence[Sequence[int]] + + +@dataclasses.dataclass(frozen=True) +class Msa: + """Class representing a parsed MSA file.""" + + sequences: Sequence[str] + deletion_matrix: DeletionMatrix + descriptions: Sequence[str] + + def __post_init__(self): + if not (len(self.sequences) == len(self.deletion_matrix) == len( + self.descriptions)): + raise ValueError( + 'All fields for an MSA must have the same length. ' + f'Got {len(self.sequences)} sequences, ' + f'{len(self.deletion_matrix)} rows in the deletion matrix and ' + f'{len(self.descriptions)} descriptions.') + + def __len__(self): + return len(self.sequences) + + def truncate(self, max_seqs: int): + return Msa( + sequences=self.sequences[:max_seqs], + deletion_matrix=self.deletion_matrix[:max_seqs], + descriptions=self.descriptions[:max_seqs], + ) + + +@dataclasses.dataclass(frozen=True) +class TemplateHit: + """Class representing a template hit.""" + + index: int + name: str + aligned_cols: int + sum_probs: Optional[float] + query: str + hit_sequence: str + indices_query: List[int] + indices_hit: List[int] + + +def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + sequences = [] + descriptions = [] + index = -1 + for line in fasta_string.splitlines(): + line = line.strip() + if line.startswith('>'): + index += 1 + descriptions.append(line[1:]) # Remove the '>' at the beginning. + sequences.append('') + continue + elif not line: + continue # Skip blank lines. + sequences[index] += line + + return sequences, descriptions + + +def parse_stockholm(stockholm_string: str) -> Msa: + """Parses sequences and deletion matrix from stockholm format alignment. + + Args: + stockholm_string: The string contents of a stockholm file. The first + sequence in the file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * The names of the targets matched, including the jackhmmer subsequence + suffix. + """ + name_to_sequence = collections.OrderedDict() + for line in stockholm_string.splitlines(): + line = line.strip() + if not line or line.startswith(('#', '//')): + continue + name, sequence = line.split() + if name not in name_to_sequence: + name_to_sequence[name] = '' + name_to_sequence[name] += sequence + + msa = [] + deletion_matrix = [] + + query = '' + keep_columns = [] + for seq_index, sequence in enumerate(name_to_sequence.values()): + if seq_index == 0: + # Gather the columns with gaps from the query + query = sequence + keep_columns = [i for i, res in enumerate(query) if res != '-'] + + # Remove the columns with gaps in the query from all sequences. + aligned_sequence = ''.join([sequence[c] for c in keep_columns]) + + msa.append(aligned_sequence) + + # Count the number of deletions w.r.t. query. + deletion_vec = [] + deletion_count = 0 + for seq_res, query_res in zip(sequence, query): + if seq_res != '-' or query_res != '-': + if query_res == '-': + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + return Msa( + sequences=msa, + deletion_matrix=deletion_matrix, + descriptions=list(name_to_sequence.keys()), + ) + + +def parse_a3m(a3m_string: str) -> Msa: + """Parses sequences and deletion matrix from a3m format alignment. + + Args: + a3m_string: The string contents of a a3m file. The first sequence in the + file should be the query sequence. + + Returns: + A tuple of: + * A list of sequences that have been aligned to the query. These + might contain duplicates. + * The deletion matrix for the alignment as a list of lists. The element + at `deletion_matrix[i][j]` is the number of residues deleted from + the aligned sequence i at residue position j. + * A list of descriptions, one per sequence, from the a3m file. + """ + sequences, descriptions = parse_fasta(a3m_string) + deletion_matrix = [] + for msa_sequence in sequences: + deletion_vec = [] + deletion_count = 0 + for j in msa_sequence: + if j.islower(): + deletion_count += 1 + else: + deletion_vec.append(deletion_count) + deletion_count = 0 + deletion_matrix.append(deletion_vec) + + # Make the MSA matrix out of aligned (deletion-free) sequences. + deletion_table = str.maketrans('', '', string.ascii_lowercase) + aligned_sequences = [s.translate(deletion_table) for s in sequences] + return Msa( + sequences=aligned_sequences, + deletion_matrix=deletion_matrix, + descriptions=descriptions, + ) + + +def _convert_sto_seq_to_a3m(query_non_gaps: Sequence[bool], + sto_seq: str) -> Iterable[str]: + for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): + if is_query_res_non_gap: + yield sequence_res + elif sequence_res != '-': + yield sequence_res.lower() + + +def convert_stockholm_to_a3m( + stockholm_format: str, + max_sequences: Optional[int] = None, + remove_first_row_gaps: bool = True, +) -> str: + """Converts MSA in Stockholm format to the A3M format.""" + descriptions = {} + sequences = {} + reached_max_sequences = False + + for line in stockholm_format.splitlines(): + reached_max_sequences = max_sequences and len( + sequences) >= max_sequences + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname, aligned_seq = line.split(maxsplit=1) + if seqname not in sequences: + if reached_max_sequences: + continue + sequences[seqname] = '' + sequences[seqname] += aligned_seq + + for line in stockholm_format.splitlines(): + if line[:4] == '#=GS': + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else '' + if feature != 'DE': + continue + if reached_max_sequences and seqname not in sequences: + continue + descriptions[seqname] = value + if len(descriptions) == len(sequences): + break + + # Convert sto format to a3m line by line + a3m_sequences = {} + if remove_first_row_gaps: + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + query_non_gaps = [res != '-' for res in query_sequence] + for seqname, sto_sequence in sequences.items(): + # Dots are optional in a3m format and are commonly removed. + out_sequence = sto_sequence.replace('.', '') + if remove_first_row_gaps: + out_sequence = ''.join( + _convert_sto_seq_to_a3m(query_non_gaps, out_sequence)) + a3m_sequences[seqname] = out_sequence + + fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" + for k in a3m_sequences) + return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. + + +def _keep_line(line: str, seqnames: Set[str]) -> bool: + """Function to decide which lines to keep.""" + if not line.strip(): + return True + if line.strip() == '//': # End tag + return True + if line.startswith('# STOCKHOLM'): # Start tag + return True + if line.startswith('#=GC RF'): # Reference Annotation Line + return True + if line[:4] == '#=GS': # Description lines - keep if sequence in list. + _, seqname, _ = line.split(maxsplit=2) + return seqname in seqnames + elif line.startswith('#'): # Other markup - filter out + return False + else: # Alignment data - keep if sequence in list. + seqname = line.partition(' ')[0] + return seqname in seqnames + + +def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str: + """Truncates a stockholm file to a maximum number of sequences.""" + seqnames = set() + filtered_lines = [] + for line in stockholm_msa.splitlines(): + if line.strip() and not line.startswith(('#', '//')): + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + seqname = line.partition(' ')[0] + seqnames.add(seqname) + if len(seqnames) >= max_sequences: + break + + for line in stockholm_msa.splitlines(): + if _keep_line(line, seqnames): + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + '\n' + + +def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str: + """Removes empty columns (dashes-only) from a Stockholm MSA.""" + processed_lines = {} + unprocessed_lines = {} + for i, line in enumerate(stockholm_msa.splitlines()): + if line.startswith('#=GC RF'): + reference_annotation_i = i + reference_annotation_line = line + # Reached the end of this chunk of the alignment. Process chunk. + _, _, first_alignment = line.rpartition(' ') + mask = [] + for j in range(len(first_alignment)): + for _, unprocessed_line in unprocessed_lines.items(): + prefix, _, alignment = unprocessed_line.rpartition(' ') + if alignment[j] != '-': + mask.append(True) + break + else: # Every row contained a hyphen - empty column. + mask.append(False) + # Add reference annotation for processing with mask. + unprocessed_lines[ + reference_annotation_i] = reference_annotation_line + + if not any( + mask + ): # All columns were empty. Output empty lines for chunk. + for line_index in unprocessed_lines: + processed_lines[line_index] = '' + else: + for line_index, unprocessed_line in unprocessed_lines.items(): + prefix, _, alignment = unprocessed_line.rpartition(' ') + masked_alignment = ''.join( + itertools.compress(alignment, mask)) + processed_lines[ + line_index] = f'{prefix} {masked_alignment}' + + # Clear raw_alignments. + unprocessed_lines = {} + elif line.strip() and not line.startswith(('#', '//')): + unprocessed_lines[i] = line + else: + processed_lines[i] = line + return '\n'.join((processed_lines[i] for i in range(len(processed_lines)))) + + +def deduplicate_stockholm_msa(stockholm_msa: str) -> str: + """Remove duplicate sequences (ignoring insertions wrt query).""" + sequence_dict = collections.defaultdict(str) + + # First we must extract all sequences from the MSA. + for line in stockholm_msa.splitlines(): + # Only consider the alignments - ignore reference annotation, empty lines, + # descriptions or markup. + if line.strip() and not line.startswith(('#', '//')): + line = line.strip() + seqname, alignment = line.split() + sequence_dict[seqname] += alignment + + seen_sequences = set() + seqnames = set() + # First alignment is the query. + query_align = next(iter(sequence_dict.values())) + mask = [c != '-' for c in query_align] # Mask is False for insertions. + for seqname, alignment in sequence_dict.items(): + # Apply mask to remove all insertions from the string. + masked_alignment = ''.join(itertools.compress(alignment, mask)) + if masked_alignment in seen_sequences: + continue + else: + seen_sequences.add(masked_alignment) + seqnames.add(seqname) + + filtered_lines = [] + for line in stockholm_msa.splitlines(): + if _keep_line(line, seqnames): + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + '\n' + + +def _get_hhr_line_regex_groups(regex_pattern: str, + line: str) -> Sequence[Optional[str]]: + match = re.match(regex_pattern, line) + if match is None: + raise RuntimeError(f'Could not parse query line {line}') + return match.groups() + + +def _update_hhr_residue_indices_list(sequence: str, start_index: int, + indices_list: List[int]): + """Computes the relative indices for each residue with respect to the original sequence.""" + counter = start_index + for symbol in sequence: + if symbol == '-': + indices_list.append(-1) + else: + indices_list.append(counter) + counter += 1 + + +def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: + """Parses the detailed HMM HMM comparison section for a single Hit. + + This works on .hhr files generated from both HHBlits and HHSearch. + + Args: + detailed_lines: A list of lines from a single comparison section between 2 + sequences (which each have their own HMM's) + + Returns: + A dictionary with the information from that detailed comparison section + + Raises: + RuntimeError: If a certain line cannot be processed + """ + # Parse first 2 lines. + number_of_hit = int(detailed_lines[0].split()[-1]) + name_hit = detailed_lines[1][1:] + + # Parse the summary line. + pattern = ( + 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' + ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' + ']*Template_Neff=(.*)') + match = re.match(pattern, detailed_lines[2]) + if match is None: + raise RuntimeError( + 'Could not parse section: %s. Expected this: \n%s to contain summary.' + % (detailed_lines, detailed_lines[2])) + (_, _, _, aligned_cols, _, _, sum_probs, + _) = [float(x) for x in match.groups()] + + # The next section reads the detailed comparisons. These are in a 'human + # readable' format which has a fixed length. The strategy employed is to + # assume that each block starts with the query sequence line, and to parse + # that with a regexp in order to deduce the fixed length used for that block. + query = '' + hit_sequence = '' + indices_query = [] + indices_hit = [] + length_block = None + + for line in detailed_lines[3:]: + # Parse the query sequence line + if (line.startswith('Q ') and not line.startswith('Q ss_dssp') + and not line.startswith('Q ss_pred') + and not line.startswith('Q Consensus')): + # Thus the first 17 characters must be 'Q ', and we can parse + # everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + + # Get the length of the parsed block using the start and finish indices, + # and ensure it is the same as the actual block length. + start = int(groups[0]) - 1 # Make index zero based. + delta_query = groups[1] + end = int(groups[2]) + num_insertions = len([x for x in delta_query if x == '-']) + length_block = end - start + num_insertions + assert length_block == len(delta_query) + + # Update the query sequence and indices list. + query += delta_query + _update_hhr_residue_indices_list(delta_query, start, indices_query) + + elif line.startswith('T '): + # Parse the hit sequence. + if (not line.startswith('T ss_dssp') + and not line.startswith('T ss_pred') + and not line.startswith('T Consensus')): + # Thus the first 17 characters must be 'T ', and we can + # parse everything after that. + # start sequence end total_sequence_length + patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' + groups = _get_hhr_line_regex_groups(patt, line[17:]) + start = int(groups[0]) - 1 # Make index zero based. + delta_hit_sequence = groups[1] + assert length_block == len(delta_hit_sequence) + + # Update the hit sequence and indices list. + hit_sequence += delta_hit_sequence + _update_hhr_residue_indices_list(delta_hit_sequence, start, + indices_hit) + + return TemplateHit( + index=number_of_hit, + name=name_hit, + aligned_cols=int(aligned_cols), + sum_probs=sum_probs, + query=query, + hit_sequence=hit_sequence, + indices_query=indices_query, + indices_hit=indices_hit, + ) + + +def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: + """Parses the content of an entire HHR file.""" + lines = hhr_string.splitlines() + + # Each .hhr file starts with a results table, then has a sequence of hit + # "paragraphs", each paragraph starting with a line 'No '. We + # iterate through each paragraph to parse each hit. + + block_starts = [ + i for i, line in enumerate(lines) if line.startswith('No ') + ] + + hits = [] + if block_starts: + block_starts.append(len(lines)) # Add the end of the final block. + for i in range(len(block_starts) - 1): + hits.append( + _parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) + return hits + + +def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: + """Parse target to e-value mapping parsed from Jackhmmer tblout string.""" + e_values = {'query': 0} + lines = [line for line in tblout.splitlines() if line[0] != '#'] + # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are + # space-delimited. Relevant fields are (1) target name: and + # (5) E-value (full sequence) (numbering from 1). + for line in lines: + fields = line.split() + e_value = fields[4] + target_name = fields[0] + e_values[target_name] = float(e_value) + return e_values + + +def _get_indices(sequence: str, start: int) -> List[int]: + """Returns indices for non-gap/insert residues starting at the given index.""" + indices = [] + counter = start + for symbol in sequence: + # Skip gaps but add a placeholder so that the alignment is preserved. + if symbol == '-': + indices.append(-1) + # Skip deleted residues, but increase the counter. + elif symbol.islower(): + counter += 1 + # Normal aligned residue. Increase the counter and append to indices. + else: + indices.append(counter) + counter += 1 + return indices + + +@dataclasses.dataclass(frozen=True) +class HitMetadata: + pdb_id: str + chain: str + start: int + end: int + length: int + text: str + + +def _parse_hmmsearch_description(description: str) -> HitMetadata: + """Parses the hmmsearch A3M sequence description line.""" + # Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text + # Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352 + match = re.match( + r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$', + description.strip(), + ) + + if not match: + raise ValueError(f'Could not parse description: "{description}".') + + return HitMetadata( + pdb_id=match[1], + chain=match[2], + start=int(match[3]), + end=int(match[4]), + length=int(match[5]), + text=match[6], + ) + + +def parse_hmmsearch_a3m(query_sequence: str, + a3m_string: str, + skip_first: bool = True) -> Sequence[TemplateHit]: + """Parses an a3m string produced by hmmsearch. + + Args: + query_sequence: The query sequence. + a3m_string: The a3m string produced by hmmsearch. + skip_first: Whether to skip the first sequence in the a3m string. + + Returns: + A sequence of `TemplateHit` results. + """ + # Zip the descriptions and MSAs together, skip the first query sequence. + parsed_a3m = list(zip(*parse_fasta(a3m_string))) + if skip_first: + parsed_a3m = parsed_a3m[1:] + + indices_query = _get_indices(query_sequence, start=0) + + hits = [] + for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1): + if 'mol:protein' not in hit_description: + continue # Skip non-protein chains. + metadata = _parse_hmmsearch_description(hit_description) + # Aligned columns are only the match states. + aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence]) + indices_hit = _get_indices(hit_sequence, start=metadata.start - 1) + + hit = TemplateHit( + index=i, + name=f'{metadata.pdb_id}_{metadata.chain}', + aligned_cols=aligned_cols, + sum_probs=None, + query=query_sequence, + hit_sequence=hit_sequence.upper(), + indices_query=indices_query, + indices_hit=indices_hit, + ) + hits.append(hit) + + return hits diff --git a/modelscope/models/science/unifold/msa/pipeline.py b/modelscope/models/science/unifold/msa/pipeline.py new file mode 100644 index 00000000..b7889bff --- /dev/null +++ b/modelscope/models/science/unifold/msa/pipeline.py @@ -0,0 +1,282 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for building the input features for the unifold model.""" + +import os +from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union + +import numpy as np +from absl import logging + +from modelscope.models.science.unifold.data import residue_constants +from modelscope.models.science.unifold.msa import (msa_identifiers, parsers, + templates) +from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch, + hmmsearch, jackhmmer) + +FeatureDict = MutableMapping[str, np.ndarray] +TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] + + +def make_sequence_features(sequence: str, description: str, + num_res: int) -> FeatureDict: + """Constructs a feature dict of sequence features.""" + features = {} + features['aatype'] = residue_constants.sequence_to_onehot( + sequence=sequence, + mapping=residue_constants.restype_order_with_x, + map_unknown_to_x=True, + ) + features['between_segment_residues'] = np.zeros((num_res, ), + dtype=np.int32) + features['domain_name'] = np.array([description.encode('utf-8')], + dtype=np.object_) + features['residue_index'] = np.array(range(num_res), dtype=np.int32) + features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) + features['sequence'] = np.array([sequence.encode('utf-8')], + dtype=np.object_) + return features + + +def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: + """Constructs a feature dict of MSA features.""" + if not msas: + raise ValueError('At least one MSA must be provided.') + + int_msa = [] + deletion_matrix = [] + species_ids = [] + seen_sequences = set() + for msa_index, msa in enumerate(msas): + if not msa: + raise ValueError( + f'MSA {msa_index} must contain at least one sequence.') + for sequence_index, sequence in enumerate(msa.sequences): + if sequence in seen_sequences: + continue + seen_sequences.add(sequence) + int_msa.append( + [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) + deletion_matrix.append(msa.deletion_matrix[sequence_index]) + identifiers = msa_identifiers.get_identifiers( + msa.descriptions[sequence_index]) + species_ids.append(identifiers.species_id.encode('utf-8')) + + num_res = len(msas[0].sequences[0]) + num_alignments = len(int_msa) + features = {} + features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) + features['msa'] = np.array(int_msa, dtype=np.int32) + features['num_alignments'] = np.array( + [num_alignments] * num_res, dtype=np.int32) + features['msa_species_identifiers'] = np.array( + species_ids, dtype=np.object_) + return features + + +def run_msa_tool( + msa_runner, + input_fasta_path: str, + msa_out_path: str, + msa_format: str, + use_precomputed_msas: bool, +) -> Mapping[str, Any]: + """Runs an MSA tool, checking if output already exists first.""" + if not use_precomputed_msas or not os.path.exists(msa_out_path): + result = msa_runner.query(input_fasta_path)[0] + with open(msa_out_path, 'w') as f: + f.write(result[msa_format]) + else: + logging.warning('Reading MSA from file %s', msa_out_path) + with open(msa_out_path, 'r') as f: + result = {msa_format: f.read()} + return result + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__( + self, + jackhmmer_binary_path: str, + hhblits_binary_path: str, + uniref90_database_path: str, + mgnify_database_path: str, + bfd_database_path: Optional[str], + uniclust30_database_path: Optional[str], + small_bfd_database_path: Optional[str], + uniprot_database_path: Optional[str], + template_searcher: TemplateSearcher, + template_featurizer: templates.TemplateHitFeaturizer, + use_small_bfd: bool, + mgnify_max_hits: int = 501, + uniref_max_hits: int = 10000, + use_precomputed_msas: bool = False, + ): + """Initializes the data pipeline.""" + self._use_small_bfd = use_small_bfd + self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniref90_database_path) + if use_small_bfd: + self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=small_bfd_database_path) + else: + self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( + binary_path=hhblits_binary_path, + databases=[bfd_database_path, uniclust30_database_path], + ) + self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=mgnify_database_path) + self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer( + binary_path=jackhmmer_binary_path, + database_path=uniprot_database_path) + self.template_searcher = template_searcher + self.template_featurizer = template_featurizer + self.mgnify_max_hits = mgnify_max_hits + self.uniref_max_hits = uniref_max_hits + self.use_precomputed_msas = use_precomputed_msas + + def process(self, input_fasta_path: str, + msa_output_dir: str) -> FeatureDict: + """Runs alignment tools on the input sequence and creates features.""" + with open(input_fasta_path) as f: + input_fasta_str = f.read() + input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) + if len(input_seqs) != 1: + raise ValueError( + f'More than one input sequence found in {input_fasta_path}.') + input_sequence = input_seqs[0] + input_description = input_descs[0] + num_res = len(input_sequence) + + uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') + jackhmmer_uniref90_result = run_msa_tool( + self.jackhmmer_uniref90_runner, + input_fasta_path, + uniref90_out_path, + 'sto', + self.use_precomputed_msas, + ) + mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') + jackhmmer_mgnify_result = run_msa_tool( + self.jackhmmer_mgnify_runner, + input_fasta_path, + mgnify_out_path, + 'sto', + self.use_precomputed_msas, + ) + + msa_for_templates = jackhmmer_uniref90_result['sto'] + msa_for_templates = parsers.truncate_stockholm_msa( + msa_for_templates, max_sequences=self.uniref_max_hits) + msa_for_templates = parsers.deduplicate_stockholm_msa( + msa_for_templates) + msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( + msa_for_templates) + + if self.template_searcher.input_format == 'sto': + pdb_templates_result = self.template_searcher.query( + msa_for_templates) + elif self.template_searcher.input_format == 'a3m': + uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( + msa_for_templates) + pdb_templates_result = self.template_searcher.query( + uniref90_msa_as_a3m) + else: + raise ValueError('Unrecognized template input format: ' + f'{self.template_searcher.input_format}') + + pdb_hits_out_path = os.path.join( + msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}') + with open(pdb_hits_out_path, 'w') as f: + f.write(pdb_templates_result) + + uniref90_msa = parsers.parse_stockholm( + jackhmmer_uniref90_result['sto']) + uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) + mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) + mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) + + pdb_template_hits = self.template_searcher.get_template_hits( + output_string=pdb_templates_result, input_sequence=input_sequence) + + if self._use_small_bfd: + bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') + jackhmmer_small_bfd_result = run_msa_tool( + self.jackhmmer_small_bfd_runner, + input_fasta_path, + bfd_out_path, + 'sto', + self.use_precomputed_msas, + ) + bfd_msa = parsers.parse_stockholm( + jackhmmer_small_bfd_result['sto']) + else: + bfd_out_path = os.path.join(msa_output_dir, + 'bfd_uniclust_hits.a3m') + hhblits_bfd_uniclust_result = run_msa_tool( + self.hhblits_bfd_uniclust_runner, + input_fasta_path, + bfd_out_path, + 'a3m', + self.use_precomputed_msas, + ) + bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) + + templates_result = self.template_featurizer.get_templates( + query_sequence=input_sequence, hits=pdb_template_hits) + + sequence_features = make_sequence_features( + sequence=input_sequence, + description=input_description, + num_res=num_res) + + msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa)) + + logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) + logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) + logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) + logging.info( + 'Final (deduplicated) MSA size: %d sequences.', + msa_features['num_alignments'][0], + ) + logging.info( + 'Total number of templates (NB: this can include bad ' + 'templates and is later filtered to top 4): %d.', + templates_result.features['template_domain_names'].shape[0], + ) + + return { + **sequence_features, + **msa_features, + **templates_result.features + } + + def process_uniprot(self, input_fasta_path: str, + msa_output_dir: str) -> FeatureDict: + uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') + uniprot_result = run_msa_tool( + self.jackhmmer_uniprot_runner, + input_fasta_path, + uniprot_path, + 'sto', + self.use_precomputed_msas, + ) + msa = parsers.parse_stockholm(uniprot_result['sto']) + msa = msa.truncate(max_seqs=50000) + all_seq_dict = make_msa_features([msa]) + return all_seq_dict diff --git a/modelscope/models/science/unifold/msa/templates.py b/modelscope/models/science/unifold/msa/templates.py new file mode 100644 index 00000000..fe3bcef9 --- /dev/null +++ b/modelscope/models/science/unifold/msa/templates.py @@ -0,0 +1,1110 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for getting templates and calculating template features.""" +import abc +import dataclasses +import datetime +import functools +import glob +import os +import re +from typing import Any, Dict, Mapping, Optional, Sequence, Tuple + +import numpy as np +from absl import logging + +from modelscope.models.science.unifold.data import residue_constants +from modelscope.models.science.unifold.msa import mmcif, parsers +from modelscope.models.science.unifold.msa.tools import kalign + + +class Error(Exception): + """Base class for exceptions.""" + + +class NoChainsError(Error): + """An error indicating that template mmCIF didn't have any chains.""" + + +class SequenceNotInTemplateError(Error): + """An error indicating that template mmCIF didn't contain the sequence.""" + + +class NoAtomDataInTemplateError(Error): + """An error indicating that template mmCIF didn't contain atom positions.""" + + +class TemplateAtomMaskAllZerosError(Error): + """An error indicating that template mmCIF had all atom positions masked.""" + + +class QueryToTemplateAlignError(Error): + """An error indicating that the query can't be aligned to the template.""" + + +class CaDistanceError(Error): + """An error indicating that a CA atom distance exceeds a threshold.""" + + +class MultipleChainsError(Error): + """An error indicating that multiple chains were found for a given ID.""" + + +# Prefilter exceptions. +class PrefilterError(Exception): + """A base class for template prefilter exceptions.""" + + +class DateError(PrefilterError): + """An error indicating that the hit date was after the max allowed date.""" + + +class AlignRatioError(PrefilterError): + """An error indicating that the hit align ratio to the query was too small.""" + + +class DuplicateError(PrefilterError): + """An error indicating that the hit was an exact subsequence of the query.""" + + +class LengthError(PrefilterError): + """An error indicating that the hit was too short.""" + + +TEMPLATE_FEATURES = { + 'template_aatype': np.float32, + 'template_all_atom_mask': np.float32, + 'template_all_atom_positions': np.float32, + 'template_domain_names': np.object_, + 'template_sequence': np.object_, + 'template_sum_probs': np.float32, +} + + +def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: + """Returns PDB id and chain id for an HHSearch Hit.""" + # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. + id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name) + if not id_match: + raise ValueError( + f'hit.name did not start with PDBID_chain: {hit.name}') + pdb_id, chain_id = id_match.group(0).split('_') + return pdb_id.lower(), chain_id + + +def _is_after_cutoff( + pdb_id: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: Optional[datetime.datetime], +) -> bool: + """Checks if the template date is after the release date cutoff. + + Args: + pdb_id: 4 letter pdb code. + release_dates: Dictionary mapping PDB ids to their structure release dates. + release_date_cutoff: Max release date that is valid for this query. + + Returns: + True if the template release date is after the cutoff, False otherwise. + """ + if release_date_cutoff is None: + raise ValueError('The release_date_cutoff must not be None.') + if pdb_id in release_dates: + return release_dates[pdb_id] > release_date_cutoff + else: + # Since this is just a quick prefilter to reduce the number of mmCIF files + # we need to parse, we don't have to worry about returning True here. + return False + + +def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, Optional[str]]: + """Parses the data file from PDB that lists which pdb_ids are obsolete.""" + with open(obsolete_file_path) as f: + result = {} + for line in f: + line = line.strip() + # Format: Date From To + # 'OBSLTE 06-NOV-19 6G9Y' - Removed, rare + # 'OBSLTE 31-JUL-94 116L 216L' - Replaced, common + # 'OBSLTE 26-SEP-06 2H33 2JM5 2OWI' - Replaced by multiple, rare + if line.startswith('OBSLTE'): + if len(line) > 30: + # Replaced by at least one structure. + from_id = line[20:24].lower() + to_id = line[29:33].lower() + result[from_id] = to_id + elif len(line) == 24: + # Removed. + from_id = line[20:24].lower() + result[from_id] = None + return result + + +def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: + """Parses release dates file, returns a mapping from PDBs to release dates.""" + if path.endswith('txt'): + release_dates = {} + with open(path, 'r') as f: + for line in f: + pdb_id, date = line.split(':') + date = date.strip() + # Python 3.6 doesn't have datetime.date.fromisoformat() which is about + # 90x faster than strptime. However, splitting the string manually is + # about 10x faster than strptime. + release_dates[pdb_id.strip()] = datetime.datetime( + year=int(date[:4]), + month=int(date[5:7]), + day=int(date[8:10])) + return release_dates + else: + raise ValueError('Invalid format of the release date file %s.' % path) + + +def _assess_hhsearch_hit( + hit: parsers.TemplateHit, + hit_pdb_code: str, + query_sequence: str, + release_dates: Mapping[str, datetime.datetime], + release_date_cutoff: datetime.datetime, + max_subsequence_ratio: float = 0.95, + min_align_ratio: float = 0.1, +) -> bool: + """Determines if template is valid (without parsing the template mmcif file). + + Args: + hit: HhrHit for the template. + hit_pdb_code: The 4 letter pdb code of the template hit. This might be + different from the value in the actual hit since the original pdb might + have become obsolete. + query_sequence: Amino acid sequence of the query. + release_dates: Dictionary mapping pdb codes to their structure release + dates. + release_date_cutoff: Max release date that is valid for this query. + max_subsequence_ratio: Exclude any exact matches with this much overlap. + min_align_ratio: Minimum overlap between the template and query. + + Returns: + True if the hit passed the prefilter. Raises an exception otherwise. + + Raises: + DateError: If the hit date was after the max allowed date. + AlignRatioError: If the hit align ratio to the query was too small. + DuplicateError: If the hit was an exact subsequence of the query. + LengthError: If the hit was too short. + """ + aligned_cols = hit.aligned_cols + align_ratio = aligned_cols / len(query_sequence) + + template_sequence = hit.hit_sequence.replace('-', '') + length_ratio = float(len(template_sequence)) / len(query_sequence) + + # Check whether the template is a large subsequence or duplicate of original + # query. This can happen due to duplicate entries in the PDB database. + duplicate = ( + template_sequence in query_sequence + and length_ratio > max_subsequence_ratio) + + if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): + raise DateError( + f'Date ({release_dates[hit_pdb_code]}) > max template date ' + f'({release_date_cutoff}).') + + if align_ratio <= min_align_ratio: + raise AlignRatioError( + 'Proportion of residues aligned to query too small. ' + f'Align ratio: {align_ratio}.') + + if duplicate: + raise DuplicateError( + 'Template is an exact subsequence of query with large ' + f'coverage. Length ratio: {length_ratio}.') + + if len(template_sequence) < 10: + raise LengthError( + f'Template too short. Length: {len(template_sequence)}.') + + return True + + +def _find_template_in_pdb( + template_chain_id: str, template_sequence: str, + mmcif_object: mmcif.MmcifObject) -> Tuple[str, str, int]: + """Tries to find the template chain in the given pdb file. + + This method tries the three following things in order: + 1. Tries if there is an exact match in both the chain ID and the sequence. + If yes, the chain sequence is returned. Otherwise: + 2. Tries if there is an exact match only in the sequence. + If yes, the chain sequence is returned. Otherwise: + 3. Tries if there is a fuzzy match (X = wildcard) in the sequence. + If yes, the chain sequence is returned. + If none of these succeed, a SequenceNotInTemplateError is thrown. + + Args: + template_chain_id: The template chain ID. + template_sequence: The template chain sequence. + mmcif_object: The PDB object to search for the template in. + + Returns: + A tuple with: + * The chain sequence that was found to match the template in the PDB object. + * The ID of the chain that is being returned. + * The offset where the template sequence starts in the chain sequence. + + Raises: + SequenceNotInTemplateError: If no match is found after the steps described + above. + """ + # Try if there is an exact match in both the chain ID and the (sub)sequence. + pdb_id = mmcif_object.file_id + chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id) + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found an exact template match %s_%s.', pdb_id, + template_chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, template_chain_id, mapping_offset + + # Try if there is an exact match in the (sub)sequence only. + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + if chain_sequence and (template_sequence in chain_sequence): + logging.info('Found a sequence-only match %s_%s.', pdb_id, + chain_id) + mapping_offset = chain_sequence.find(template_sequence) + return chain_sequence, chain_id, mapping_offset + + # Return a chain sequence that fuzzy matches (X = wildcard) the template. + # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit. + regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence] + regex = re.compile(''.join(regex)) + for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): + match = re.search(regex, chain_sequence) + if match: + logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, + chain_id) + mapping_offset = match.start() + return chain_sequence, chain_id, mapping_offset + + # No hits, raise an error. + raise SequenceNotInTemplateError( + 'Could not find the template sequence in %s_%s. Template sequence: %s, ' + 'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence, + mmcif_object.chain_to_seqres)) + + +def _realign_pdb_template_to_query( + old_template_sequence: str, + template_chain_id: str, + mmcif_object: mmcif.MmcifObject, + old_mapping: Mapping[int, int], + kalign_binary_path: str, +) -> Tuple[str, Mapping[int, int]]: + """Aligns template from the mmcif_object to the query. + + In case PDB70 contains a different version of the template sequence, we need + to perform a realignment to the actual sequence that is in the mmCIF file. + This method performs such realignment, but returns the new sequence and + mapping only if the sequence in the mmCIF file is 90% identical to the old + sequence. + + Note that the old_template_sequence comes from the hit, and contains only that + part of the chain that matches with the query while the new_template_sequence + is the full chain. + + Args: + old_template_sequence: The template sequence that was returned by the PDB + template search (typically done using HHSearch). + template_chain_id: The template chain id was returned by the PDB template + search (typically done using HHSearch). This is used to find the right + chain in the mmcif_object chain_to_seqres mapping. + mmcif_object: A mmcif_object which holds the actual template data. + old_mapping: A mapping from the query sequence to the template sequence. + This mapping will be used to compute the new mapping from the query + sequence to the actual mmcif_object template sequence by aligning the + old_template_sequence and the actual template sequence. + kalign_binary_path: The path to a kalign executable. + + Returns: + A tuple (new_template_sequence, new_query_to_template_mapping) where: + * new_template_sequence is the actual template sequence that was found in + the mmcif_object. + * new_query_to_template_mapping is the new mapping from the query to the + actual template found in the mmcif_object. + + Raises: + QueryToTemplateAlignError: + * If there was an error thrown by the alignment tool. + * Or if the actual template sequence differs by more than 10% from the + old_template_sequence. + """ + aligner = kalign.Kalign(binary_path=kalign_binary_path) + new_template_sequence = mmcif_object.chain_to_seqres.get( + template_chain_id, '') + + # Sometimes the template chain id is unknown. But if there is only a single + # sequence within the mmcif_object, it is safe to assume it is that one. + if not new_template_sequence: + if len(mmcif_object.chain_to_seqres) == 1: + logging.info( + 'Could not find %s in %s, but there is only 1 sequence, so ' + 'using that one.', + template_chain_id, + mmcif_object.file_id, + ) + new_template_sequence = list( + mmcif_object.chain_to_seqres.values())[0] + else: + raise QueryToTemplateAlignError( + f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. ' + 'If there are no mmCIF parsing errors, it is possible it was not a ' + 'protein chain.') + + try: + parsed_a3m = parsers.parse_a3m( + aligner.align([old_template_sequence, new_template_sequence])) + old_aligned_template, new_aligned_template = parsed_a3m.sequences + except Exception as e: + raise QueryToTemplateAlignError( + 'Could not align old template %s to template %s (%s_%s). Error: %s' + % ( + old_template_sequence, + new_template_sequence, + mmcif_object.file_id, + template_chain_id, + str(e), + )) + + logging.info( + 'Old aligned template: %s\nNew aligned template: %s', + old_aligned_template, + new_aligned_template, + ) + + old_to_new_template_mapping = {} + old_template_index = -1 + new_template_index = -1 + num_same = 0 + for old_template_aa, new_template_aa in zip(old_aligned_template, + new_aligned_template): + if old_template_aa != '-': + old_template_index += 1 + if new_template_aa != '-': + new_template_index += 1 + if old_template_aa != '-' and new_template_aa != '-': + old_to_new_template_mapping[ + old_template_index] = new_template_index + if old_template_aa == new_template_aa: + num_same += 1 + + # Require at least 90 % sequence identity wrt to the shorter of the sequences. + if (float(num_same) + / min(len(old_template_sequence), len(new_template_sequence)) + < # noqa W504 + 0.9): + raise QueryToTemplateAlignError( + 'Insufficient similarity of the sequence in the database: %s to the ' + 'actual sequence in the mmCIF file %s_%s: %s. We require at least ' + '90 %% similarity wrt to the shorter of the sequences. This is not a ' + 'problem unless you think this is a template that should be included.' + % ( + old_template_sequence, + mmcif_object.file_id, + template_chain_id, + new_template_sequence, + )) + + new_query_to_template_mapping = {} + for query_index, old_template_index in old_mapping.items(): + new_query_to_template_mapping[ + query_index] = old_to_new_template_mapping.get( + old_template_index, -1) + + new_template_sequence = new_template_sequence.replace('-', '') + + return new_template_sequence, new_query_to_template_mapping + + +def _check_residue_distances(all_positions: np.ndarray, + all_positions_mask: np.ndarray, + max_ca_ca_distance: float): + """Checks if the distance between unmasked neighbor residues is ok.""" + ca_position = residue_constants.atom_order['CA'] + prev_is_unmasked = False + prev_calpha = None + for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)): + this_is_unmasked = bool(mask[ca_position]) + if this_is_unmasked: + this_calpha = coords[ca_position] + if prev_is_unmasked: + distance = np.linalg.norm(this_calpha - prev_calpha) + if distance > max_ca_ca_distance: + raise CaDistanceError( + 'The distance between residues %d and %d is %f > limit %f.' + % (i, i + 1, distance, max_ca_ca_distance)) + prev_calpha = this_calpha + prev_is_unmasked = this_is_unmasked + + +def _get_atom_positions( + mmcif_object: mmcif.MmcifObject, auth_chain_id: str, + max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]: + """Gets atom positions and mask from a list of Biopython Residues.""" + num_res = len(mmcif_object.chain_to_seqres[auth_chain_id]) + + relevant_chains = [ + c for c in mmcif_object.structure.get_chains() if c.id == auth_chain_id + ] + if len(relevant_chains) != 1: + raise MultipleChainsError( + f'Expected exactly one chain in structure with id {auth_chain_id}.' + ) + chain = relevant_chains[0] + + all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3]) + all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num], + dtype=np.int64) + for res_index in range(num_res): + pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32) + mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) + res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][ + res_index] + if not res_at_position.is_missing: + res = chain[( + res_at_position.hetflag, + res_at_position.position.residue_number, + res_at_position.position.insertion_code, + )] + for atom in res.get_atoms(): + atom_name = atom.get_name() + x, y, z = atom.get_coord() + if atom_name in residue_constants.atom_order.keys(): + pos[residue_constants.atom_order[atom_name]] = [x, y, z] + mask[residue_constants.atom_order[atom_name]] = 1.0 + elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE': + # Put the coordinates of the selenium atom in the sulphur column. + pos[residue_constants.atom_order['SD']] = [x, y, z] + mask[residue_constants.atom_order['SD']] = 1.0 + + # Fix naming errors in arginine residues where NH2 is incorrectly + # assigned to be closer to CD than NH1. + cd = residue_constants.atom_order['CD'] + nh1 = residue_constants.atom_order['NH1'] + nh2 = residue_constants.atom_order['NH2'] + if (res.get_resname() == 'ARG' + and all(mask[atom_index] for atom_index in (cd, nh1, nh2)) + and (np.linalg.norm(pos[nh1] - pos[cd]) > # noqa W504 + np.linalg.norm(pos[nh2] - pos[cd]))): + pos[nh1], pos[nh2] = pos[nh2].copy(), pos[nh1].copy() + mask[nh1], mask[nh2] = mask[nh2].copy(), mask[nh1].copy() + + all_positions[res_index] = pos + all_positions_mask[res_index] = mask + _check_residue_distances(all_positions, all_positions_mask, + max_ca_ca_distance) + return all_positions, all_positions_mask + + +def _extract_template_features( + mmcif_object: mmcif.MmcifObject, + pdb_id: str, + mapping: Mapping[int, int], + template_sequence: str, + query_sequence: str, + template_chain_id: str, + kalign_binary_path: str, +) -> Tuple[Dict[str, Any], Optional[str]]: + """Parses atom positions in the target structure and aligns with the query. + + Atoms for each residue in the template structure are indexed to coincide + with their corresponding residue in the query sequence, according to the + alignment mapping provided. + + Args: + mmcif_object: mmcif_parsing.MmcifObject representing the template. + pdb_id: PDB code for the template. + mapping: Dictionary mapping indices in the query sequence to indices in + the template sequence. + template_sequence: String describing the amino acid sequence for the + template protein. + query_sequence: String describing the amino acid sequence for the query + protein. + template_chain_id: String ID describing which chain in the structure proto + should be used. + kalign_binary_path: The path to a kalign executable used for template + realignment. + + Returns: + A tuple with: + * A dictionary containing the extra features derived from the template + protein structure. + * A warning message if the hit was realigned to the actual mmCIF sequence. + Otherwise None. + + Raises: + NoChainsError: If the mmcif object doesn't contain any chains. + SequenceNotInTemplateError: If the given chain id / sequence can't + be found in the mmcif object. + QueryToTemplateAlignError: If the actual template in the mmCIF file + can't be aligned to the query. + NoAtomDataInTemplateError: If the mmcif object doesn't contain + atom positions. + TemplateAtomMaskAllZerosError: If the mmcif object doesn't have any + unmasked residues. + """ + if mmcif_object is None or not mmcif_object.chain_to_seqres: + raise NoChainsError('No chains in PDB: %s_%s' % + (pdb_id, template_chain_id)) + + warning = None + try: + seqres, chain_id, mapping_offset = _find_template_in_pdb( + template_chain_id=template_chain_id, + template_sequence=template_sequence, + mmcif_object=mmcif_object, + ) + except SequenceNotInTemplateError: + # If PDB70 contains a different version of the template, we use the sequence + # from the mmcif_object. + chain_id = template_chain_id + warning = ( + f'The exact sequence {template_sequence} was not found in ' + f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.' + ) + logging.warning(warning) + # This throws an exception if it fails to realign the hit. + seqres, mapping = _realign_pdb_template_to_query( + old_template_sequence=template_sequence, + template_chain_id=template_chain_id, + mmcif_object=mmcif_object, + old_mapping=mapping, + kalign_binary_path=kalign_binary_path, + ) + logging.info( + 'Sequence in %s_%s: %s successfully realigned to %s', + pdb_id, + chain_id, + template_sequence, + seqres, + ) + # The template sequence changed. + template_sequence = seqres + # No mapping offset, the query is aligned to the actual sequence. + mapping_offset = 0 + + try: + # Essentially set to infinity - we don't want to reject templates unless + # they're really really bad. + all_atom_positions, all_atom_mask = _get_atom_positions( + mmcif_object, chain_id, max_ca_ca_distance=150.0) + except (CaDistanceError, KeyError) as ex: + raise NoAtomDataInTemplateError('Could not get atom data (%s_%s): %s' % + (pdb_id, chain_id, str(ex))) from ex + + all_atom_positions = np.split(all_atom_positions, + all_atom_positions.shape[0]) + all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) + + output_templates_sequence = [] + templates_all_atom_positions = [] + templates_all_atom_masks = [] + + for _ in query_sequence: + # Residues in the query_sequence that are not in the template_sequence: + templates_all_atom_positions.append( + np.zeros((residue_constants.atom_type_num, 3))) + templates_all_atom_masks.append( + np.zeros(residue_constants.atom_type_num)) + output_templates_sequence.append('-') + + for k, v in mapping.items(): + template_index = v + mapping_offset + templates_all_atom_positions[k] = all_atom_positions[template_index][0] + templates_all_atom_masks[k] = all_atom_masks[template_index][0] + output_templates_sequence[k] = template_sequence[v] + + # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O). + if np.sum(templates_all_atom_masks) < 5: + raise TemplateAtomMaskAllZerosError( + 'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' + % ( + pdb_id, + chain_id, + min(mapping.values()) + mapping_offset, + max(mapping.values()) + mapping_offset, + )) + + output_templates_sequence = ''.join(output_templates_sequence) + + templates_aatype = residue_constants.sequence_to_onehot( + output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) + + return ( + { + 'template_all_atom_positions': + np.array(templates_all_atom_positions), + 'template_all_atom_mask': np.array(templates_all_atom_masks), + 'template_sequence': output_templates_sequence.encode(), + 'template_aatype': np.array(templates_aatype), + 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), + }, + warning, + ) + + +def _build_query_to_hit_index_mapping( + hit_query_sequence: str, + hit_sequence: str, + indices_hit: Sequence[int], + indices_query: Sequence[int], + original_query_sequence: str, +) -> Mapping[int, int]: + """Gets mapping from indices in original query sequence to indices in the hit. + + hit_query_sequence and hit_sequence are two aligned sequences containing gap + characters. hit_query_sequence contains only the part of the original query + sequence that matched the hit. When interpreting the indices from the .hhr, we + need to correct for this to recover a mapping from original query sequence to + the hit sequence. + + Args: + hit_query_sequence: The portion of the query sequence that is in the .hhr + hit + hit_sequence: The portion of the hit sequence that is in the .hhr + indices_hit: The indices for each aminoacid relative to the hit sequence + indices_query: The indices for each aminoacid relative to the original query + sequence + original_query_sequence: String describing the original query sequence. + + Returns: + Dictionary with indices in the original query sequence as keys and indices + in the hit sequence as values. + """ + # If the hit is empty (no aligned residues), return empty mapping + if not hit_query_sequence: + return {} + + # Remove gaps and find the offset of hit.query relative to original query. + hhsearch_query_sequence = hit_query_sequence.replace('-', '') + hit_sequence = hit_sequence.replace('-', '') + hhsearch_query_offset = original_query_sequence.find( + hhsearch_query_sequence) + + # Index of -1 used for gap characters. Subtract the min index ignoring gaps. + min_idx = min(x for x in indices_hit if x > -1) + fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit] + + min_idx = min(x for x in indices_query if x > -1) + fixed_indices_query = [ + x - min_idx if x > -1 else -1 for x in indices_query + ] + + # Zip the corrected indices, ignore case where both seqs have gap characters. + mapping = {} + for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit): + if q_t != -1 and q_i != -1: + if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len( + original_query_sequence): + continue + mapping[q_i + hhsearch_query_offset] = q_t + + return mapping + + +@dataclasses.dataclass(frozen=True) +class SingleHitResult: + features: Optional[Mapping[str, Any]] + error: Optional[str] + warning: Optional[str] + + +@functools.lru_cache(16, typed=False) +def _read_file(path): + with open(path, 'r') as f: + file_data = f.read() + return file_data + + +def _process_single_hit( + query_sequence: str, + hit: parsers.TemplateHit, + mmcif_dir: str, + max_template_date: datetime.datetime, + release_dates: Mapping[str, datetime.datetime], + obsolete_pdbs: Mapping[str, Optional[str]], + kalign_binary_path: str, + strict_error_check: bool = False, +) -> SingleHitResult: + """Tries to extract template features from a single HHSearch hit.""" + # Fail hard if we can't get the PDB ID and chain name from the hit. + hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) + + # This hit has been removed (obsoleted) from PDB, skip it. + if hit_pdb_code in obsolete_pdbs and obsolete_pdbs[hit_pdb_code] is None: + return SingleHitResult( + features=None, + error=None, + warning=f'Hit {hit_pdb_code} is obsolete.') + + if hit_pdb_code not in release_dates: + if hit_pdb_code in obsolete_pdbs: + hit_pdb_code = obsolete_pdbs[hit_pdb_code] + + # Pass hit_pdb_code since it might have changed due to the pdb being obsolete. + try: + _assess_hhsearch_hit( + hit=hit, + hit_pdb_code=hit_pdb_code, + query_sequence=query_sequence, + release_dates=release_dates, + release_date_cutoff=max_template_date, + ) + except PrefilterError as e: + msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' + logging.info(msg) + if strict_error_check and isinstance(e, (DateError, DuplicateError)): + # In strict mode we treat some prefilter cases as errors. + return SingleHitResult(features=None, error=msg, warning=None) + + return SingleHitResult(features=None, error=None, warning=None) + + mapping = _build_query_to_hit_index_mapping(hit.query, hit.hit_sequence, + hit.indices_hit, + hit.indices_query, + query_sequence) + + # The mapping is from the query to the actual hit sequence, so we need to + # remove gaps (which regardless have a missing confidence score). + template_sequence = hit.hit_sequence.replace('-', '') + + cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') + logging.debug( + 'Reading PDB entry from %s. Query: %s, template: %s', + cif_path, + query_sequence, + template_sequence, + ) + # Fail if we can't find the mmCIF file. + cif_string = _read_file(cif_path) + + parsing_result = mmcif.parse(file_id=hit_pdb_code, mmcif_string=cif_string) + + if parsing_result.mmcif_object is not None: + hit_release_date = datetime.datetime.strptime( + parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d') + if hit_release_date > max_template_date: + error = 'Template %s date (%s) > max template date (%s).' % ( + hit_pdb_code, + hit_release_date, + max_template_date, + ) + if strict_error_check: + return SingleHitResult( + features=None, error=error, warning=None) + else: + logging.debug(error) + return SingleHitResult(features=None, error=None, warning=None) + + try: + features, realign_warning = _extract_template_features( + mmcif_object=parsing_result.mmcif_object, + pdb_id=hit_pdb_code, + mapping=mapping, + template_sequence=template_sequence, + query_sequence=query_sequence, + template_chain_id=hit_chain_id, + kalign_binary_path=kalign_binary_path, + ) + if hit.sum_probs is None: + features['template_sum_probs'] = [0] + else: + features['template_sum_probs'] = [hit.sum_probs] + + # It is possible there were some errors when parsing the other chains in the + # mmCIF file, but the template features for the chain we want were still + # computed. In such case the mmCIF parsing errors are not relevant. + return SingleHitResult( + features=features, error=None, warning=realign_warning) + except ( + NoChainsError, + NoAtomDataInTemplateError, + TemplateAtomMaskAllZerosError, + ) as e: + # These 3 errors indicate missing mmCIF experimental data rather than a + # problem with the template search, so turn them into warnings. + warning = ( + '%s_%s (sum_probs: %s, rank: %s): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % ( + hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors, + )) + if strict_error_check: + return SingleHitResult(features=None, error=warning, warning=None) + else: + return SingleHitResult(features=None, error=None, warning=warning) + except Error as e: + error = ( + '%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' + '%s, mmCIF parsing errors: %s' % ( + hit_pdb_code, + hit_chain_id, + hit.sum_probs, + hit.index, + str(e), + parsing_result.errors, + )) + return SingleHitResult(features=None, error=error, warning=None) + + +@dataclasses.dataclass(frozen=True) +class TemplateSearchResult: + features: Mapping[str, Any] + errors: Sequence[str] + warnings: Sequence[str] + + +class TemplateHitFeaturizer(abc.ABC): + """An abstract base class for turning template hits to template features.""" + + def __init__( + self, + mmcif_dir: str, + max_template_date: str, + max_hits: int, + kalign_binary_path: str, + release_dates_path: Optional[str], + obsolete_pdbs_path: Optional[str], + strict_error_check: bool = False, + ): + """Initializes the Template Search. + + Args: + mmcif_dir: Path to a directory with mmCIF structures. Once a template ID + is found by HHSearch, this directory is used to retrieve the template + data. + max_template_date: The maximum date permitted for template structures. No + template with date higher than this date will be returned. In ISO8601 + date format, YYYY-MM-DD. + max_hits: The maximum number of templates that will be returned. + kalign_binary_path: The path to a kalign executable used for template + realignment. + release_dates_path: An optional path to a file with a mapping from PDB IDs + to their release dates. Thanks to this we don't have to redundantly + parse mmCIF files to get that information. + obsolete_pdbs_path: An optional path to a file containing a mapping from + obsolete PDB IDs to the PDB IDs of their replacements. + strict_error_check: If True, then the following will be treated as errors: + * If any template date is after the max_template_date. + * If any template has identical PDB ID to the query. + * If any template is a duplicate of the query. + * Any feature computation errors. + """ + self._mmcif_dir = mmcif_dir + if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')): + logging.error('Could not find CIFs in %s', self._mmcif_dir) + raise ValueError(f'Could not find CIFs in {self._mmcif_dir}') + + try: + self._max_template_date = datetime.datetime.strptime( + max_template_date, '%Y-%m-%d') + except ValueError: + raise ValueError( + 'max_template_date must be set and have format YYYY-MM-DD.') + self._max_hits = max_hits + self._kalign_binary_path = kalign_binary_path + self._strict_error_check = strict_error_check + + if release_dates_path: + logging.info('Using precomputed release dates %s.', + release_dates_path) + self._release_dates = _parse_release_dates(release_dates_path) + else: + self._release_dates = {} + + if obsolete_pdbs_path: + logging.info('Using precomputed obsolete pdbs %s.', + obsolete_pdbs_path) + self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path) + else: + self._obsolete_pdbs = {} + + @abc.abstractmethod + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence.""" + + +class HhsearchHitFeaturizer(TemplateHitFeaturizer): + """A class for turning a3m hits from hhsearch to template features.""" + + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_sequence) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + num_hits = 0 + errors = [] + warnings = [] + + for hit in sorted(hits, key=lambda x: x.sum_probs, reverse=True): + # We got all the templates we wanted, stop processing hits. + if num_hits >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=self._max_template_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path, + ) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.info( + 'Skipped invalid hit %s, error: %s, warning: %s', + hit.name, + result.error, + result.warning, + ) + else: + # Increment the hit counter, since we got features out of this hit. + num_hits += 1 + for k in template_features: + template_features[k].append(result.features[k]) + + for name in template_features: + if num_hits > 0: + template_features[name] = np.stack( + template_features[name], + axis=0).astype(TEMPLATE_FEATURES[name]) + else: + # Make sure the feature has correct dtype even if empty. + template_features[name] = np.array( + [], dtype=TEMPLATE_FEATURES[name]) + + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings) + + +class HmmsearchHitFeaturizer(TemplateHitFeaturizer): + """A class for turning a3m hits from hmmsearch to template features.""" + + def get_templates( + self, query_sequence: str, + hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: + """Computes the templates for given query sequence (more details above).""" + logging.info('Searching for template for: %s', query_sequence) + + template_features = {} + for template_feature_name in TEMPLATE_FEATURES: + template_features[template_feature_name] = [] + + already_seen = set() + errors = [] + warnings = [] + + if not hits or hits[0].sum_probs is None: + sorted_hits = hits + else: + sorted_hits = sorted(hits, key=lambda x: x.sum_probs, reverse=True) + + for hit in sorted_hits: + # We got all the templates we wanted, stop processing hits. + if len(already_seen) >= self._max_hits: + break + + result = _process_single_hit( + query_sequence=query_sequence, + hit=hit, + mmcif_dir=self._mmcif_dir, + max_template_date=self._max_template_date, + release_dates=self._release_dates, + obsolete_pdbs=self._obsolete_pdbs, + strict_error_check=self._strict_error_check, + kalign_binary_path=self._kalign_binary_path, + ) + + if result.error: + errors.append(result.error) + + # There could be an error even if there are some results, e.g. thrown by + # other unparsable chains in the same mmCIF file. + if result.warning: + warnings.append(result.warning) + + if result.features is None: + logging.debug( + 'Skipped invalid hit %s, error: %s, warning: %s', + hit.name, + result.error, + result.warning, + ) + else: + already_seen_key = result.features['template_sequence'] + if already_seen_key in already_seen: + continue + # Increment the hit counter, since we got features out of this hit. + already_seen.add(already_seen_key) + for k in template_features: + template_features[k].append(result.features[k]) + + if already_seen: + for name in template_features: + template_features[name] = np.stack( + template_features[name], + axis=0).astype(TEMPLATE_FEATURES[name]) + else: + num_res = len(query_sequence) + # Construct a default template with all zeros. + template_features = { + 'template_aatype': + np.zeros( + (1, num_res, len( + residue_constants.restypes_with_x_and_gap)), + np.float32, + ), + 'template_all_atom_mask': + np.zeros((1, num_res, residue_constants.atom_type_num), + np.float32), + 'template_all_atom_positions': + np.zeros((1, num_res, residue_constants.atom_type_num, 3), + np.float32), + 'template_domain_names': + np.array([''.encode()], dtype=np.object), + 'template_sequence': + np.array([''.encode()], dtype=np.object), + 'template_sum_probs': + np.array([0], dtype=np.float32), + } + return TemplateSearchResult( + features=template_features, errors=errors, warnings=warnings) diff --git a/modelscope/models/science/unifold/msa/tools/__init__.py b/modelscope/models/science/unifold/msa/tools/__init__.py new file mode 100644 index 00000000..903d0979 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python wrappers for third party tools.""" diff --git a/modelscope/models/science/unifold/msa/tools/hhblits.py b/modelscope/models/science/unifold/msa/tools/hhblits.py new file mode 100644 index 00000000..ee442e39 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hhblits.py @@ -0,0 +1,170 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run HHblits from Python.""" + +import glob +import os +import subprocess +from typing import Any, List, Mapping, Optional, Sequence + +from absl import logging + +from . import utils + +_HHBLITS_DEFAULT_P = 20 +_HHBLITS_DEFAULT_Z = 500 + + +class HHBlits: + """Python wrapper of the HHblits binary.""" + + def __init__( + self, + *, + binary_path: str, + databases: Sequence[str], + n_cpu: int = 4, + n_iter: int = 3, + e_value: float = 0.001, + maxseq: int = 1_000_000, + realign_max: int = 100_000, + maxfilt: int = 100_000, + min_prefilter_hits: int = 1000, + all_seqs: bool = False, + alt: Optional[int] = None, + p: int = _HHBLITS_DEFAULT_P, + z: int = _HHBLITS_DEFAULT_Z, + ): + """Initializes the Python HHblits wrapper. + + Args: + binary_path: The path to the HHblits executable. + databases: A sequence of HHblits database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + n_cpu: The number of CPUs to give HHblits. + n_iter: The number of HHblits iterations. + e_value: The E-value, see HHblits docs for more details. + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. + maxfilt: Max number of hits allowed to pass the 2nd prefilter. + HHblits default: 20000. + min_prefilter_hits: Min number of hits to pass prefilter. + HHblits default: 100. + all_seqs: Return all sequences in the MSA / Do not filter the result MSA. + HHblits default: False. + alt: Show up to this many alternative alignments. + p: Minimum Prob for a hit to be included in the output hhr file. + HHblits default: 20. + z: Hard cap on number of hits reported in the hhr file. + HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. + + Raises: + RuntimeError: If HHblits binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHBlits database %s', + database_path) + raise ValueError( + f'Could not find HHBlits database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.maxseq = maxseq + self.realign_max = realign_max + self.maxfilt = maxfilt + self.min_prefilter_hits = min_prefilter_hits + self.all_seqs = all_seqs + self.alt = alt + self.p = p + self.z = z + + def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: + """Queries the database using HHblits.""" + with utils.tmpdir_manager() as query_tmp_dir: + a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [ + self.binary_path, + '-i', + input_fasta_path, + '-cpu', + str(self.n_cpu), + '-oa3m', + a3m_path, + '-o', + '/dev/null', + '-n', + str(self.n_iter), + '-e', + str(self.e_value), + '-maxseq', + str(self.maxseq), + '-realign_max', + str(self.realign_max), + '-maxfilt', + str(self.maxfilt), + '-min_prefilter_hits', + str(self.min_prefilter_hits), + ] + if self.all_seqs: + cmd += ['-all'] + if self.alt: + cmd += ['-alt', str(self.alt)] + if self.p != _HHBLITS_DEFAULT_P: + cmd += ['-p', str(self.p)] + if self.z != _HHBLITS_DEFAULT_Z: + cmd += ['-Z', str(self.z)] + cmd += db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('HHblits query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Logs have a 15k character limit, so log HHblits error line by line. + logging.error('HHblits failed. HHblits stderr begin:') + for error_line in stderr.decode('utf-8').splitlines(): + if error_line.strip(): + logging.error(error_line.strip()) + logging.error('HHblits stderr end') + raise RuntimeError( + 'HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) + + with open(a3m_path) as f: + a3m = f.read() + + raw_output = dict( + a3m=a3m, + output=stdout, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value, + ) + return [raw_output] diff --git a/modelscope/models/science/unifold/msa/tools/hhsearch.py b/modelscope/models/science/unifold/msa/tools/hhsearch.py new file mode 100644 index 00000000..ac7f3b55 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hhsearch.py @@ -0,0 +1,111 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run HHsearch from Python.""" + +import glob +import os +import subprocess +from typing import Sequence + +from absl import logging + +from modelscope.models.science.unifold.msa import parsers +from . import utils + + +class HHSearch: + """Python wrapper of the HHsearch binary.""" + + def __init__(self, + *, + binary_path: str, + databases: Sequence[str], + maxseq: int = 1_000_000): + """Initializes the Python HHsearch wrapper. + + Args: + binary_path: The path to the HHsearch executable. + databases: A sequence of HHsearch database paths. This should be the + common prefix for the database files (i.e. up to but not including + _hhm.ffindex etc.) + maxseq: The maximum number of rows in an input alignment. Note that this + parameter is only supported in HHBlits version 3.1 and higher. + + Raises: + RuntimeError: If HHsearch binary not found within the path. + """ + self.binary_path = binary_path + self.databases = databases + self.maxseq = maxseq + + for database_path in self.databases: + if not glob.glob(database_path + '_*'): + logging.error('Could not find HHsearch database %s', + database_path) + raise ValueError( + f'Could not find HHsearch database {database_path}') + + @property + def output_format(self) -> str: + return 'hhr' + + @property + def input_format(self) -> str: + return 'a3m' + + def query(self, a3m: str) -> str: + """Queries the database using HHsearch using a given a3m.""" + with utils.tmpdir_manager() as query_tmp_dir: + input_path = os.path.join(query_tmp_dir, 'query.a3m') + hhr_path = os.path.join(query_tmp_dir, 'output.hhr') + with open(input_path, 'w') as f: + f.write(a3m) + + db_cmd = [] + for db_path in self.databases: + db_cmd.append('-d') + db_cmd.append(db_path) + cmd = [ + self.binary_path, + '-i', + input_path, + '-o', + hhr_path, + '-maxseq', + str(self.maxseq), + ] + db_cmd + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing('HHsearch query'): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + # Stderr is truncated to prevent proto size errors in Beam. + raise RuntimeError( + 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) + + with open(hhr_path) as f: + hhr = f.read() + return hhr + + def get_template_hits( + self, output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + del input_sequence # Used by hmmseach but not needed for hhsearch. + return parsers.parse_hhr(output_string) diff --git a/modelscope/models/science/unifold/msa/tools/hmmbuild.py b/modelscope/models/science/unifold/msa/tools/hmmbuild.py new file mode 100644 index 00000000..84f205d6 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hmmbuild.py @@ -0,0 +1,143 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" + +import os +import re +import subprocess + +from absl import logging + +from . import utils + + +class Hmmbuild(object): + """Python wrapper of the hmmbuild binary.""" + + def __init__(self, *, binary_path: str, singlemx: bool = False): + """Initializes the Python hmmbuild wrapper. + + Args: + binary_path: The path to the hmmbuild executable. + singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to + just use a common substitution score matrix. + + Raises: + RuntimeError: If hmmbuild binary not found within the path. + """ + self.binary_path = binary_path + self.singlemx = singlemx + + def build_profile_from_sto(self, + sto: str, + model_construction='fast') -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + sto: A string with the aligned sequences in the Stockholm format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + return self._build_profile(sto, model_construction=model_construction) + + def build_profile_from_a3m(self, a3m: str) -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + a3m: A string with the aligned sequences in the A3M format. + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + lines = [] + for line in a3m.splitlines(): + if not line.startswith('>'): + line = re.sub('[a-z]+', '', line) # Remove inserted residues. + lines.append(line + '\n') + msa = ''.join(lines) + return self._build_profile(msa, model_construction='fast') + + def _build_profile(self, + msa: str, + model_construction: str = 'fast') -> str: + """Builds a HMM for the aligned sequences given as an MSA string. + + Args: + msa: A string with the aligned sequences, in A3M or STO format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + ValueError: If unspecified arguments are provided. + """ + if model_construction not in {'hand', 'fast'}: + raise ValueError( + f'Invalid model_construction {model_construction} - only' + 'hand and fast supported.') + + with utils.tmpdir_manager() as query_tmp_dir: + input_query = os.path.join(query_tmp_dir, 'query.msa') + output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') + + with open(input_query, 'w') as f: + f.write(msa) + + cmd = [self.binary_path] + # If adding flags, we have to do so before the output and input: + + if model_construction == 'hand': + cmd.append(f'--{model_construction}') + if self.singlemx: + cmd.append('--singlemx') + cmd.extend([ + '--amino', + output_hmm_path, + input_query, + ]) + + logging.info('Launching subprocess %s', cmd) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('hmmbuild query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info( + 'hmmbuild stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), + stderr.decode('utf-8'), + ) + + if retcode: + raise RuntimeError( + 'hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_hmm_path, encoding='utf-8') as f: + hmm = f.read() + + return hmm diff --git a/modelscope/models/science/unifold/msa/tools/hmmsearch.py b/modelscope/models/science/unifold/msa/tools/hmmsearch.py new file mode 100644 index 00000000..445970ca --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/hmmsearch.py @@ -0,0 +1,146 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for hmmsearch - search profile against a sequence db.""" + +import os +import subprocess +from typing import Optional, Sequence + +from absl import logging + +from modelscope.models.science.unifold.msa import parsers +from . import hmmbuild, utils + + +class Hmmsearch(object): + """Python wrapper of the hmmsearch binary.""" + + def __init__( + self, + *, + binary_path: str, + hmmbuild_binary_path: str, + database_path: str, + flags: Optional[Sequence[str]] = None, + ): + """Initializes the Python hmmsearch wrapper. + + Args: + binary_path: The path to the hmmsearch executable. + hmmbuild_binary_path: The path to the hmmbuild executable. Used to build + an hmm from an input a3m. + database_path: The path to the hmmsearch database (FASTA format). + flags: List of flags to be used by hmmsearch. + + Raises: + RuntimeError: If hmmsearch binary not found within the path. + """ + self.binary_path = binary_path + self.hmmbuild_runner = hmmbuild.Hmmbuild( + binary_path=hmmbuild_binary_path) + self.database_path = database_path + if flags is None: + # Default hmmsearch run settings. + flags = [ + '--F1', + '0.1', + '--F2', + '0.1', + '--F3', + '0.1', + '--incE', + '100', + '-E', + '100', + '--domE', + '100', + '--incdomE', + '100', + ] + self.flags = flags + + if not os.path.exists(self.database_path): + logging.error('Could not find hmmsearch database %s', + database_path) + raise ValueError( + f'Could not find hmmsearch database {database_path}') + + @property + def output_format(self) -> str: + return 'sto' + + @property + def input_format(self) -> str: + return 'sto' + + def query(self, msa_sto: str) -> str: + """Queries the database using hmmsearch using a given stockholm msa.""" + hmm = self.hmmbuild_runner.build_profile_from_sto( + msa_sto, model_construction='hand') + return self.query_with_hmm(hmm) + + def query_with_hmm(self, hmm: str) -> str: + """Queries the database using hmmsearch using a given hmm.""" + with utils.tmpdir_manager() as query_tmp_dir: + hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') + out_path = os.path.join(query_tmp_dir, 'output.sto') + with open(hmm_input_path, 'w') as f: + f.write(hmm) + + cmd = [ + self.binary_path, + '--noali', # Don't include the alignment in stdout. + '--cpu', + '8', + ] + # If adding flags, we have to do so before the output and input: + if self.flags: + cmd.extend(self.flags) + cmd.extend([ + '-A', + out_path, + hmm_input_path, + self.database_path, + ]) + + logging.info('Launching sub-process %s', cmd) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'hmmsearch ({os.path.basename(self.database_path)}) query' + ): + stdout, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError( + 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(out_path) as f: + out_msa = f.read() + + return out_msa + + def get_template_hits( + self, output_string: str, + input_sequence: str) -> Sequence[parsers.TemplateHit]: + """Gets parsed template hits from the raw string output by the tool.""" + a3m_string = parsers.convert_stockholm_to_a3m( + output_string, remove_first_row_gaps=False) + template_hits = parsers.parse_hmmsearch_a3m( + query_sequence=input_sequence, + a3m_string=a3m_string, + skip_first=False) + return template_hits diff --git a/modelscope/models/science/unifold/msa/tools/jackhmmer.py b/modelscope/models/science/unifold/msa/tools/jackhmmer.py new file mode 100644 index 00000000..3e29eec9 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/jackhmmer.py @@ -0,0 +1,224 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Library to run Jackhmmer from Python.""" + +import glob +import os +import subprocess +from concurrent import futures +from typing import Any, Callable, Mapping, Optional, Sequence +from urllib import request + +from absl import logging + +from . import utils + + +class Jackhmmer: + """Python wrapper of the Jackhmmer binary.""" + + def __init__( + self, + *, + binary_path: str, + database_path: str, + n_cpu: int = 8, + n_iter: int = 1, + e_value: float = 0.0001, + z_value: Optional[int] = None, + get_tblout: bool = False, + filter_f1: float = 0.0005, + filter_f2: float = 0.00005, + filter_f3: float = 0.0000005, + incdom_e: Optional[float] = None, + dom_e: Optional[float] = None, + num_streamed_chunks: Optional[int] = None, + streaming_callback: Optional[Callable[[int], None]] = None, + ): + """Initializes the Python Jackhmmer wrapper. + + Args: + binary_path: The path to the jackhmmer executable. + database_path: The path to the jackhmmer database (FASTA format). + n_cpu: The number of CPUs to give Jackhmmer. + n_iter: The number of Jackhmmer iterations. + e_value: The E-value, see Jackhmmer docs for more details. + z_value: The Z-value, see Jackhmmer docs for more details. + get_tblout: Whether to save tblout string. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + incdom_e: Domain e-value criteria for inclusion of domains in MSA/next + round. + dom_e: Domain e-value criteria for inclusion in tblout. + num_streamed_chunks: Number of database chunks to stream over. + streaming_callback: Callback function run after each chunk iteration with + the iteration number as argument. + """ + self.binary_path = binary_path + self.database_path = database_path + self.num_streamed_chunks = num_streamed_chunks + + if not os.path.exists( + self.database_path) and num_streamed_chunks is None: + logging.error('Could not find Jackhmmer database %s', + database_path) + raise ValueError( + f'Could not find Jackhmmer database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.z_value = z_value + self.filter_f1 = filter_f1 + self.filter_f2 = filter_f2 + self.filter_f3 = filter_f3 + self.incdom_e = incdom_e + self.dom_e = dom_e + self.get_tblout = get_tblout + self.streaming_callback = streaming_callback + + def _query_chunk(self, input_fasta_path: str, + database_path: str) -> Mapping[str, Any]: + """Queries the database chunk using Jackhmmer.""" + with utils.tmpdir_manager() as query_tmp_dir: + sto_path = os.path.join(query_tmp_dir, 'output.sto') + + # The F1/F2/F3 are the expected proportion to pass each of the filtering + # stages (which get progressively more expensive), reducing these + # speeds up the pipeline at the expensive of sensitivity. They are + # currently set very low to make querying Mgnify run in a reasonable + # amount of time. + cmd_flags = [ + # Don't pollute stdout with Jackhmmer output. + '-o', + '/dev/null', + '-A', + sto_path, + '--noali', + '--F1', + str(self.filter_f1), + '--F2', + str(self.filter_f2), + '--F3', + str(self.filter_f3), + '--incE', + str(self.e_value), + # Report only sequences with E-values <= x in per-sequence output. + '-E', + str(self.e_value), + '--cpu', + str(self.n_cpu), + '-N', + str(self.n_iter), + ] + if self.get_tblout: + tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') + cmd_flags.extend(['--tblout', tblout_path]) + + if self.z_value: + cmd_flags.extend(['-Z', str(self.z_value)]) + + if self.dom_e is not None: + cmd_flags.extend(['--domE', str(self.dom_e)]) + + if self.incdom_e is not None: + cmd_flags.extend(['--incdomE', str(self.incdom_e)]) + + cmd = [self.binary_path + ] + cmd_flags + [input_fasta_path, database_path] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + with utils.timing( + f'Jackhmmer ({os.path.basename(database_path)}) query'): + _, stderr = process.communicate() + retcode = process.wait() + + if retcode: + raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' + % stderr.decode('utf-8')) + + # Get e-values for each target name + tbl = '' + if self.get_tblout: + with open(tblout_path) as f: + tbl = f.read() + + with open(sto_path) as f: + sto = f.read() + + raw_output = dict( + sto=sto, + tbl=tbl, + stderr=stderr, + n_iter=self.n_iter, + e_value=self.e_value) + + return raw_output + + def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: + """Queries the database using Jackhmmer.""" + if self.num_streamed_chunks is None: + return [self._query_chunk(input_fasta_path, self.database_path)] + + db_basename = os.path.basename(self.database_path) + + def db_remote_chunk(db_idx): + return f'{self.database_path}.{db_idx}' + + def db_local_chunk(db_idx): + return f'/tmp/ramdisk/{db_basename}.{db_idx}' + + # db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' + # db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' + + # Remove existing files to prevent OOM + for f in glob.glob(db_local_chunk('[0-9]*')): + try: + os.remove(f) + except OSError: + print(f'OSError while deleting {f}') + + # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk + with futures.ThreadPoolExecutor(max_workers=2) as executor: + chunked_output = [] + for i in range(1, self.num_streamed_chunks + 1): + # Copy the chunk locally + if i == 1: + future = executor.submit(request.urlretrieve, + db_remote_chunk(i), + db_local_chunk(i)) + if i < self.num_streamed_chunks: + next_future = executor.submit( + request.urlretrieve, + db_remote_chunk(i + 1), + db_local_chunk(i + 1), + ) + + # Run Jackhmmer with the chunk + future.result() + chunked_output.append( + self._query_chunk(input_fasta_path, db_local_chunk(i))) + + # Remove the local copy of the chunk + os.remove(db_local_chunk(i)) + # Do not set next_future for the last chunk so that this works even for + # databases with only 1 chunk. + if i < self.num_streamed_chunks: + future = next_future + if self.streaming_callback: + self.streaming_callback(i) + return chunked_output diff --git a/modelscope/models/science/unifold/msa/tools/kalign.py b/modelscope/models/science/unifold/msa/tools/kalign.py new file mode 100644 index 00000000..1ea997fa --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/kalign.py @@ -0,0 +1,110 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Python wrapper for Kalign.""" +import os +import subprocess +from typing import Sequence + +from absl import logging + +from . import utils + + +def _to_a3m(sequences: Sequence[str]) -> str: + """Converts sequences to an a3m file.""" + names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] + a3m = [] + for sequence, name in zip(sequences, names): + a3m.append('>' + name + '\n') + a3m.append(sequence + '\n') + return ''.join(a3m) + + +class Kalign: + """Python wrapper of the Kalign binary.""" + + def __init__(self, *, binary_path: str): + """Initializes the Python Kalign wrapper. + + Args: + binary_path: The path to the Kalign binary. + + Raises: + RuntimeError: If Kalign binary not found within the path. + """ + self.binary_path = binary_path + + def align(self, sequences: Sequence[str]) -> str: + """Aligns the sequences and returns the alignment in A3M string. + + Args: + sequences: A list of query sequence strings. The sequences have to be at + least 6 residues long (Kalign requires this). Note that the order in + which you give the sequences might alter the output slightly as + different alignment tree might get constructed. + + Returns: + A string with the alignment in a3m format. + + Raises: + RuntimeError: If Kalign fails. + ValueError: If any of the sequences is less than 6 residues long. + """ + logging.info('Aligning %d sequences', len(sequences)) + + for s in sequences: + if len(s) < 6: + raise ValueError( + 'Kalign requires all sequences to be at least 6 ' + 'residues long. Got %s (%d residues).' % (s, len(s))) + + with utils.tmpdir_manager() as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') + output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + with open(input_fasta_path, 'w') as f: + f.write(_to_a3m(sequences)) + + cmd = [ + self.binary_path, + '-i', + input_fasta_path, + '-o', + output_a3m_path, + '-format', + 'fasta', + ] + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + + with utils.timing('Kalign query'): + stdout, stderr = process.communicate() + retcode = process.wait() + logging.info( + 'Kalign stdout:\n%s\n\nstderr:\n%s\n', + stdout.decode('utf-8'), + stderr.decode('utf-8'), + ) + + if retcode: + raise RuntimeError( + 'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % + (stdout.decode('utf-8'), stderr.decode('utf-8'))) + + with open(output_a3m_path) as f: + a3m = f.read() + + return a3m diff --git a/modelscope/models/science/unifold/msa/tools/utils.py b/modelscope/models/science/unifold/msa/tools/utils.py new file mode 100644 index 00000000..1c2af936 --- /dev/null +++ b/modelscope/models/science/unifold/msa/tools/utils.py @@ -0,0 +1,40 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common utilities for data pipeline tools.""" +import contextlib +import shutil +import tempfile +import time +from typing import Optional + +from absl import logging + + +@contextlib.contextmanager +def tmpdir_manager(base_dir: Optional[str] = None): + """Context manager that deletes a temporary directory on exit.""" + tmpdir = tempfile.mkdtemp(dir=base_dir) + try: + yield tmpdir + finally: + shutil.rmtree(tmpdir, ignore_errors=True) + + +@contextlib.contextmanager +def timing(msg: str): + logging.info('Started %s', msg) + tic = time.time() + yield + toc = time.time() + logging.info('Finished %s in %.3f seconds', msg, toc - tic) diff --git a/modelscope/models/science/unifold/msa/utils.py b/modelscope/models/science/unifold/msa/utils.py new file mode 100644 index 00000000..50e380d4 --- /dev/null +++ b/modelscope/models/science/unifold/msa/utils.py @@ -0,0 +1,89 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import os +from typing import Mapping, Sequence + +import json +from absl import logging + +from modelscope.models.science.unifold.data import protein + + +def get_chain_id_map( + sequences: Sequence[str], + descriptions: Sequence[str], +): + """ + Makes a mapping from PDB-format chain ID to sequence and description, + and parses the order of multi-chains + """ + unique_seqs = [] + for seq in sequences: + if seq not in unique_seqs: + unique_seqs.append(seq) + + chain_id_map = { + chain_id: { + 'descriptions': [], + 'sequence': seq + } + for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs) + } + chain_order = [] + + for seq, des in zip(sequences, descriptions): + chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)] + chain_id_map[chain_id]['descriptions'].append(des) + chain_order.append(chain_id) + + return chain_id_map, chain_order + + +def divide_multi_chains( + fasta_name: str, + output_dir_base: str, + sequences: Sequence[str], + descriptions: Sequence[str], +): + """ + Divides the multi-chains fasta into several single fasta files and + records multi-chains mapping information. + """ + if len(sequences) != len(descriptions): + raise ValueError('sequences and descriptions must have equal length. ' + f'Got {len(sequences)} != {len(descriptions)}.') + if len(sequences) > protein.PDB_MAX_CHAINS: + raise ValueError( + 'Cannot process more chains than the PDB format supports. ' + f'Got {len(sequences)} chains.') + + chain_id_map, chain_order = get_chain_id_map(sequences, descriptions) + + output_dir = os.path.join(output_dir_base, fasta_name) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json') + with open(chain_id_map_path, 'w') as f: + json.dump(chain_id_map, f, indent=4, sort_keys=True) + + chain_order_path = os.path.join(output_dir, 'chains.txt') + with open(chain_order_path, 'w') as f: + f.write(' '.join(chain_order)) + + logging.info('Mapping multi-chains fasta with chain order: %s', + ' '.join(chain_order)) + + temp_names = [] + temp_paths = [] + for chain_id in chain_id_map.keys(): + temp_name = fasta_name + '_{}'.format(chain_id) + temp_path = os.path.join(output_dir, temp_name + '.fasta') + des = 'chain_{}'.format(chain_id) + seq = chain_id_map[chain_id]['sequence'] + with open(temp_path, 'w') as f: + f.write('>' + des + '\n' + seq) + temp_names.append(temp_name) + temp_paths.append(temp_path) + return temp_names, temp_paths diff --git a/modelscope/pipelines/science/__init__.py b/modelscope/pipelines/science/__init__.py new file mode 100644 index 00000000..1f81809b --- /dev/null +++ b/modelscope/pipelines/science/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .protein_structure_pipeline import ProteinStructurePipeline + +else: + _import_structure = { + 'protein_structure_pipeline': ['ProteinStructurePipeline'] + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/pipelines/science/protein_structure_pipeline.py b/modelscope/pipelines/science/protein_structure_pipeline.py new file mode 100644 index 00000000..3dc51c72 --- /dev/null +++ b/modelscope/pipelines/science/protein_structure_pipeline.py @@ -0,0 +1,215 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import time +from typing import Any, Dict, List, Optional, Union + +import json +import numpy as np +import torch +from unicore.utils import tensor_tree_map + +from modelscope.metainfo import Pipelines +from modelscope.models.base import Model +from modelscope.models.science.unifold.config import model_config +from modelscope.models.science.unifold.data import protein, residue_constants +from modelscope.models.science.unifold.dataset import (UnifoldDataset, + load_and_process) +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import Preprocessor, build_preprocessor +from modelscope.utils.constant import Fields, Frameworks, Tasks +from modelscope.utils.device import device_placement +from modelscope.utils.hub import read_config +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['ProteinStructurePipeline'] + + +def automatic_chunk_size(seq_len): + if seq_len < 512: + chunk_size = 256 + elif seq_len < 1024: + chunk_size = 128 + elif seq_len < 2048: + chunk_size = 32 + elif seq_len < 3072: + chunk_size = 16 + else: + chunk_size = 1 + return chunk_size + + +def load_feature_for_one_target( + config, + data_folder, + seed=0, + is_multimer=False, + use_uniprot=False, + symmetry_group=None, +): + if not is_multimer: + uniprot_msa_dir = None + sequence_ids = ['A'] + if use_uniprot: + uniprot_msa_dir = data_folder + + else: + uniprot_msa_dir = data_folder + sequence_ids = open(os.path.join(data_folder, + 'chains.txt')).readline().split() + + if symmetry_group is None: + batch, _ = load_and_process( + config=config.data, + mode='predict', + seed=seed, + batch_idx=None, + data_idx=0, + is_distillation=False, + sequence_ids=sequence_ids, + monomer_feature_dir=data_folder, + uniprot_msa_dir=uniprot_msa_dir, + ) + else: + raise NotImplementedError + batch = UnifoldDataset.collater([batch]) + return batch + + +@PIPELINES.register_module( + Tasks.protein_structure, module_name=Pipelines.protein_structure) +class ProteinStructurePipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """Use `model` and `preprocessor` to create a protein structure pipeline for prediction. + + Args: + model (str or Model): Supply either a local model dir which supported the protein structure task, + or a model id from the model hub, or a torch model instance. + preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for + the model if supplied. + + Example: + >>> from modelscope.pipelines import pipeline + >>> pipeline_ins = pipeline(task='protein-structure', + >>> model='DPTech/uni-fold-monomer') + >>> protein = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + >>> print(pipeline_ins(protein)) + + """ + import copy + model_path = copy.deepcopy(model) if isinstance(model, str) else None + cfg = read_config(model_path) # only model is str + self.cfg = cfg + self.config = model_config( + cfg['pipeline']['model_name']) # alphafold config + model = model if isinstance( + model, Model) else Model.from_pretrained(model_path) + self.postprocessor = cfg.pop('postprocessor', None) + if preprocessor is None: + preprocessor_cfg = cfg.preprocessor + preprocessor = build_preprocessor(preprocessor_cfg, Fields.science) + model.eval() + model.model.inference_mode() + model.model_dir = model_path + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + + def _sanitize_parameters(self, **pipeline_parameters): + return pipeline_parameters, pipeline_parameters, pipeline_parameters + + def _process_single(self, input, *args, **kwargs) -> Dict[str, Any]: + preprocess_params = kwargs.get('preprocess_params', {}) + forward_params = kwargs.get('forward_params', {}) + postprocess_params = kwargs.get('postprocess_params', {}) + out = self.preprocess(input, **preprocess_params) + with device_placement(self.framework, self.device_name): + with torch.no_grad(): + out = self.forward(out, **forward_params) + + out = self.postprocess(out, **postprocess_params) + return out + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + plddts = {} + ptms = {} + + output_dir = os.path.join(self.preprocessor.output_dir_base, + inputs['target_id']) + + pdbs = [] + for seed in range(self.cfg['pipeline']['times']): + cur_seed = hash((42, seed)) % 100000 + batch = load_feature_for_one_target( + self.config, + output_dir, + cur_seed, + is_multimer=inputs['is_multimer'], + use_uniprot=inputs['is_multimer'], + symmetry_group=self.preprocessor.symmetry_group, + ) + seq_len = batch['aatype'].shape[-1] + self.model.model.globals.chunk_size = automatic_chunk_size(seq_len) + + with torch.no_grad(): + batch = { + k: torch.as_tensor(v, device='cuda:0') + for k, v in batch.items() + } + out = self.model(batch) + + def to_float(x): + if x.dtype == torch.bfloat16 or x.dtype == torch.half: + return x.float() + else: + return x + + # Toss out the recycling dimensions --- we don't need them anymore + batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch) + batch = tensor_tree_map(to_float, batch) + out = tensor_tree_map(lambda t: t[0, ...], out[0]) + out = tensor_tree_map(to_float, out) + batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch) + out = tensor_tree_map(lambda x: np.array(x.cpu()), out) + + plddt = out['plddt'] + mean_plddt = np.mean(plddt) + plddt_b_factors = np.repeat( + plddt[..., None], residue_constants.atom_type_num, axis=-1) + # TODO: , may need to reorder chains, based on entity_ids + cur_protein = protein.from_prediction( + features=batch, result=out, b_factors=plddt_b_factors) + cur_save_name = (f'{cur_seed}') + plddts[cur_save_name] = str(mean_plddt) + if inputs[ + 'is_multimer'] and self.preprocessor.symmetry_group is None: + ptms[cur_save_name] = str(np.mean(out['iptm+ptm'])) + with open(os.path.join(output_dir, cur_save_name + '.pdb'), + 'w') as f: + f.write(protein.to_pdb(cur_protein)) + pdbs.append(protein.to_pdb(cur_protein)) + + logger.info('plddts:' + str(plddts)) + model_name = self.cfg['pipeline']['model_name'] + score_name = f'{model_name}' + plddt_fname = score_name + '_plddt.json' + + with open(os.path.join(output_dir, plddt_fname), 'w') as f: + json.dump(plddts, f, indent=4) + if ptms: + logger.info('ptms' + str(ptms)) + ptm_fname = score_name + '_ptm.json' + with open(os.path.join(output_dir, ptm_fname), 'w') as f: + json.dump(ptms, f, indent=4) + + return pdbs + + def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params): + return inputs diff --git a/modelscope/preprocessors/science/__init__.py b/modelscope/preprocessors/science/__init__.py new file mode 100644 index 00000000..54b24887 --- /dev/null +++ b/modelscope/preprocessors/science/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .unifold import (UniFoldPreprocessor) + +else: + _import_structure = {'unifold': ['UniFoldPreprocessor']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/preprocessors/science/uni_fold.py b/modelscope/preprocessors/science/uni_fold.py new file mode 100644 index 00000000..2a44c885 --- /dev/null +++ b/modelscope/preprocessors/science/uni_fold.py @@ -0,0 +1,569 @@ +# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, +# and is publicly available at https://github.com/dptech-corp/Uni-Fold. + +import gzip +import hashlib +import logging +import os +import pickle +import random +import re +import tarfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from unittest import result + +import json +import numpy as np +import requests +import torch +from tqdm import tqdm + +from modelscope.metainfo import Preprocessors +from modelscope.models.science.unifold.data import protein, residue_constants +from modelscope.models.science.unifold.data.protein import PDB_CHAIN_IDS +from modelscope.models.science.unifold.data.utils import compress_features +from modelscope.models.science.unifold.msa import parsers, pipeline, templates +from modelscope.models.science.unifold.msa.tools import hhsearch +from modelscope.models.science.unifold.msa.utils import divide_multi_chains +from modelscope.preprocessors.base import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.constant import Fields + +__all__ = [ + 'UniFoldPreprocessor', +] + +TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' +DEFAULT_API_SERVER = 'https://api.colabfold.com' + + +def run_mmseqs2( + x, + prefix, + use_env=True, + use_templates=False, + use_pairing=False, + host_url='https://api.colabfold.com') -> Tuple[List[str], List[str]]: + submission_endpoint = 'ticket/pair' if use_pairing else 'ticket/msa' + + def submit(seqs, mode, N=101): + n, query = N, '' + for seq in seqs: + query += f'>{n}\n{seq}\n' + n += 1 + + res = requests.post( + f'{host_url}/{submission_endpoint}', + data={ + 'q': query, + 'mode': mode + }) + try: + out = res.json() + except ValueError: + out = {'status': 'ERROR'} + return out + + def status(ID): + res = requests.get(f'{host_url}/ticket/{ID}') + try: + out = res.json() + except ValueError: + out = {'status': 'ERROR'} + return out + + def download(ID, path): + res = requests.get(f'{host_url}/result/download/{ID}') + with open(path, 'wb') as out: + out.write(res.content) + + # process input x + seqs = [x] if isinstance(x, str) else x + + mode = 'env' + if use_pairing: + mode = '' + use_templates = False + use_env = False + + # define path + path = f'{prefix}' + if not os.path.isdir(path): + os.mkdir(path) + + # call mmseqs2 api + tar_gz_file = f'{path}/out_{mode}.tar.gz' + N, REDO = 101, True + + # deduplicate and keep track of order + seqs_unique = [] + # TODO this might be slow for large sets + [seqs_unique.append(x) for x in seqs if x not in seqs_unique] + Ms = [N + seqs_unique.index(seq) for seq in seqs] + # lets do it! + if not os.path.isfile(tar_gz_file): + TIME_ESTIMATE = 150 * len(seqs_unique) + with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: + while REDO: + pbar.set_description('SUBMIT') + + # Resubmit job until it goes through + out = submit(seqs_unique, mode, N) + while out['status'] in ['UNKNOWN', 'RATELIMIT']: + sleep_time = 5 + random.randint(0, 5) + # logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") + # resubmit + time.sleep(sleep_time) + out = submit(seqs_unique, mode, N) + + if out['status'] == 'ERROR': + error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' + error = error + 'If error persists, please try again an hour later.' + raise Exception(error) + + if out['status'] == 'MAINTENANCE': + raise Exception( + 'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.' + ) + + # wait for job to finish + ID, TIME = out['id'], 0 + pbar.set_description(out['status']) + while out['status'] in ['UNKNOWN', 'RUNNING', 'PENDING']: + t = 5 + random.randint(0, 5) + # logger.error(f"Sleeping for {t}s. Reason: {out['status']}") + time.sleep(t) + out = status(ID) + pbar.set_description(out['status']) + if out['status'] == 'RUNNING': + TIME += t + pbar.update(n=t) + + if out['status'] == 'COMPLETE': + if TIME < TIME_ESTIMATE: + pbar.update(n=(TIME_ESTIMATE - TIME)) + REDO = False + + if out['status'] == 'ERROR': + REDO = False + error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' + error = error + 'If error persists, please try again an hour later.' + raise Exception(error) + + # Download results + download(ID, tar_gz_file) + + # prep list of a3m files + if use_pairing: + a3m_files = [f'{path}/pair.a3m'] + else: + a3m_files = [f'{path}/uniref.a3m'] + if use_env: + a3m_files.append(f'{path}/bfd.mgnify30.metaeuk30.smag30.a3m') + + # extract a3m files + if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): + with tarfile.open(tar_gz_file) as tar_gz: + tar_gz.extractall(path) + + # templates + if use_templates: + templates = {} + + with open(f'{path}/pdb70.m8', 'r') as f: + lines = f.readlines() + for line in lines: + p = line.rstrip().split() + M, pdb, _, _ = p[0], p[1], p[2], p[10] # qid, e_value + M = int(M) + if M not in templates: + templates[M] = [] + templates[M].append(pdb) + + template_paths = {} + for k, TMPL in templates.items(): + TMPL_PATH = f'{prefix}/templates_{k}' + if not os.path.isdir(TMPL_PATH): + os.mkdir(TMPL_PATH) + TMPL_LINE = ','.join(TMPL[:20]) + os.system( + f'curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/' + ) + os.system( + f'cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex' + ) + os.system(f'touch {TMPL_PATH}/pdb70_cs219.ffdata') + template_paths[k] = TMPL_PATH + + # gather a3m lines + a3m_lines = {} + for a3m_file in a3m_files: + update_M, M = True, None + with open(a3m_file, 'r') as f: + lines = f.readlines() + for line in lines: + if len(line) > 0: + if '\x00' in line: + line = line.replace('\x00', '') + update_M = True + if line.startswith('>') and update_M: + M = int(line[1:].rstrip()) + update_M = False + if M not in a3m_lines: + a3m_lines[M] = [] + a3m_lines[M].append(line) + + # return results + + a3m_lines = [''.join(a3m_lines[n]) for n in Ms] + + if use_templates: + template_paths_ = [] + for n in Ms: + if n not in template_paths: + template_paths_.append(None) + # print(f"{n-N}\tno_templates_found") + else: + template_paths_.append(template_paths[n]) + template_paths = template_paths_ + + return (a3m_lines, template_paths) if use_templates else a3m_lines + + +def get_null_template(query_sequence: Union[List[str], str], + num_temp: int = 1) -> Dict[str, Any]: + ln = ( + len(query_sequence) if isinstance(query_sequence, str) else sum( + len(s) for s in query_sequence)) + output_templates_sequence = 'A' * ln + # output_confidence_scores = np.full(ln, 1.0) + + templates_all_atom_positions = np.zeros( + (ln, templates.residue_constants.atom_type_num, 3)) + templates_all_atom_masks = np.zeros( + (ln, templates.residue_constants.atom_type_num)) + templates_aatype = templates.residue_constants.sequence_to_onehot( + output_templates_sequence, + templates.residue_constants.HHBLITS_AA_TO_ID) + template_features = { + 'template_all_atom_positions': + np.tile(templates_all_atom_positions[None], [num_temp, 1, 1, 1]), + 'template_all_atom_masks': + np.tile(templates_all_atom_masks[None], [num_temp, 1, 1]), + 'template_sequence': ['none'.encode()] * num_temp, + 'template_aatype': + np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]), + 'template_domain_names': ['none'.encode()] * num_temp, + 'template_sum_probs': + np.zeros([num_temp], dtype=np.float32), + } + return template_features + + +def get_template(a3m_lines: str, template_path: str, + query_sequence: str) -> Dict[str, Any]: + template_featurizer = templates.HhsearchHitFeaturizer( + mmcif_dir=template_path, + max_template_date='2100-01-01', + max_hits=20, + kalign_binary_path='kalign', + release_dates_path=None, + obsolete_pdbs_path=None, + ) + + hhsearch_pdb70_runner = hhsearch.HHSearch( + binary_path='hhsearch', databases=[f'{template_path}/pdb70']) + + hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines) + hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result) + templates_result = template_featurizer.get_templates( + query_sequence=query_sequence, hits=hhsearch_hits) + return dict(templates_result.features) + + +@PREPROCESSORS.register_module( + Fields.science, module_name=Preprocessors.unifold_preprocessor) +class UniFoldPreprocessor(Preprocessor): + + def __init__(self, **cfg): + self.symmetry_group = cfg['symmetry_group'] # "C1" + if not self.symmetry_group: + self.symmetry_group = None + self.MIN_SINGLE_SEQUENCE_LENGTH = 16 # TODO: change to cfg + self.MAX_SINGLE_SEQUENCE_LENGTH = 1000 + self.MAX_MULTIMER_LENGTH = 1000 + self.jobname = 'unifold' + self.output_dir_base = './unifold-predictions' + os.makedirs(self.output_dir_base, exist_ok=True) + + def clean_and_validate_sequence(self, input_sequence: str, min_length: int, + max_length: int) -> str: + clean_sequence = input_sequence.translate( + str.maketrans('', '', ' \n\t')).upper() + aatypes = set(residue_constants.restypes) # 20 standard aatypes. + if not set(clean_sequence).issubset(aatypes): + raise ValueError( + f'Input sequence contains non-amino acid letters: ' + f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard ' + 'amino acids as inputs.') + if len(clean_sequence) < min_length: + raise ValueError( + f'Input sequence is too short: {len(clean_sequence)} amino acids, ' + f'while the minimum is {min_length}') + if len(clean_sequence) > max_length: + raise ValueError( + f'Input sequence is too long: {len(clean_sequence)} amino acids, while ' + f'the maximum is {max_length}. You may be able to run it with the full ' + f'Uni-Fold system depending on your resources (system memory, ' + f'GPU memory).') + return clean_sequence + + def validate_input(self, input_sequences: Sequence[str], + symmetry_group: str, min_length: int, max_length: int, + max_multimer_length: int) -> Tuple[Sequence[str], bool]: + """Validates and cleans input sequences and determines which model to use.""" + sequences = [] + + for input_sequence in input_sequences: + if input_sequence.strip(): + input_sequence = self.clean_and_validate_sequence( + input_sequence=input_sequence, + min_length=min_length, + max_length=max_length) + sequences.append(input_sequence) + + if symmetry_group is not None and symmetry_group != 'C1': + if symmetry_group.startswith( + 'C') and symmetry_group[1:].isnumeric(): + print( + f'Using UF-Symmetry with group {symmetry_group}. If you do not ' + f'want to use UF-Symmetry, please use `C1` and copy the AU ' + f'sequences to the count in the assembly.') + is_multimer = (len(sequences) > 1) + return sequences, is_multimer, symmetry_group + else: + raise ValueError( + f'UF-Symmetry does not support symmetry group ' + f'{symmetry_group} currently. Cyclic groups (Cx) are ' + f'supported only.') + + elif len(sequences) == 1: + print('Using the single-chain model.') + return sequences, False, None + + elif len(sequences) > 1: + total_multimer_length = sum([len(seq) for seq in sequences]) + if total_multimer_length > max_multimer_length: + raise ValueError( + f'The total length of multimer sequences is too long: ' + f'{total_multimer_length}, while the maximum is ' + f'{max_multimer_length}. Please use the full AlphaFold ' + f'system for long multimers.') + print(f'Using the multimer model with {len(sequences)} sequences.') + return sequences, True, None + + else: + raise ValueError( + 'No input amino acid sequence provided, please provide at ' + 'least one sequence.') + + def add_hash(self, x, y): + return x + '_' + hashlib.sha1(y.encode()).hexdigest()[:5] + + def get_msa_and_templates( + self, + jobname: str, + query_seqs_unique: Union[str, List[str]], + result_dir: Path, + msa_mode: str, + use_templates: bool, + homooligomers_num: int = 1, + host_url: str = DEFAULT_API_SERVER, + ) -> Tuple[Optional[List[str]], Optional[List[str]], List[str], List[int], + List[Dict[str, Any]]]: + + use_env = msa_mode == 'MMseqs2' + + template_features = [] + if use_templates: + a3m_lines_mmseqs2, template_paths = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_templates=True, + host_url=host_url, + ) + if template_paths is None: + for index in range(0, len(query_seqs_unique)): + template_feature = get_null_template( + query_seqs_unique[index]) + template_features.append(template_feature) + else: + for index in range(0, len(query_seqs_unique)): + if template_paths[index] is not None: + template_feature = get_template( + a3m_lines_mmseqs2[index], + template_paths[index], + query_seqs_unique[index], + ) + if len(template_feature['template_domain_names']) == 0: + template_feature = get_null_template( + query_seqs_unique[index]) + else: + template_feature = get_null_template( + query_seqs_unique[index]) + template_features.append(template_feature) + else: + for index in range(0, len(query_seqs_unique)): + template_feature = get_null_template(query_seqs_unique[index]) + template_features.append(template_feature) + + if msa_mode == 'single_sequence': + a3m_lines = [] + num = 101 + for i, seq in enumerate(query_seqs_unique): + a3m_lines.append('>' + str(num + i) + '\n' + seq) + else: + # find normal a3ms + a3m_lines = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_pairing=False, + host_url=host_url, + ) + if len(query_seqs_unique) > 1: + # find paired a3m if not a homooligomers + paired_a3m_lines = run_mmseqs2( + query_seqs_unique, + str(result_dir.joinpath(jobname)), + use_env, + use_pairing=True, + host_url=host_url, + ) + else: + num = 101 + paired_a3m_lines = [] + for i in range(0, homooligomers_num): + paired_a3m_lines.append('>' + str(num + i) + '\n' + + query_seqs_unique[0] + '\n') + + return ( + a3m_lines, + paired_a3m_lines, + template_features, + ) + + def __call__(self, data: Union[str, Tuple]): + if isinstance(data, str): + data = [data, '', '', ''] + basejobname = ''.join(data) + basejobname = re.sub(r'\W+', '', basejobname) + target_id = self.add_hash(self.jobname, basejobname) + + sequences, is_multimer, _ = self.validate_input( + input_sequences=data, + symmetry_group=self.symmetry_group, + min_length=self.MIN_SINGLE_SEQUENCE_LENGTH, + max_length=self.MAX_SINGLE_SEQUENCE_LENGTH, + max_multimer_length=self.MAX_MULTIMER_LENGTH) + + descriptions = [ + '> ' + target_id + ' seq' + str(ii) + for ii in range(len(sequences)) + ] + + if is_multimer: + divide_multi_chains(target_id, self.output_dir_base, sequences, + descriptions) + + s = [] + for des, seq in zip(descriptions, sequences): + s += [des, seq] + + unique_sequences = [] + [ + unique_sequences.append(x) for x in sequences + if x not in unique_sequences + ] + + if len(unique_sequences) == 1: + homooligomers_num = len(sequences) + else: + homooligomers_num = 1 + + with open(f'{self.jobname}.fasta', 'w') as f: + f.write('\n'.join(s)) + + result_dir = Path(self.output_dir_base) + output_dir = os.path.join(self.output_dir_base, target_id) + + # msa_mode = 'single_sequence' + msa_mode = 'MMseqs2' + use_templates = True + + unpaired_msa, paired_msa, template_results = self.get_msa_and_templates( + target_id, + unique_sequences, + result_dir=result_dir, + msa_mode=msa_mode, + use_templates=use_templates, + homooligomers_num=homooligomers_num) + + features = [] + pair_features = [] + + for idx, seq in enumerate(unique_sequences): + chain_id = PDB_CHAIN_IDS[idx] + sequence_features = pipeline.make_sequence_features( + sequence=seq, + description=f'> {self.jobname} seq {chain_id}', + num_res=len(seq)) + monomer_msa = parsers.parse_a3m(unpaired_msa[idx]) + msa_features = pipeline.make_msa_features([monomer_msa]) + template_features = template_results[idx] + feature_dict = { + **sequence_features, + **msa_features, + **template_features + } + feature_dict = compress_features(feature_dict) + features_output_path = os.path.join( + output_dir, '{}.feature.pkl.gz'.format(chain_id)) + pickle.dump( + feature_dict, + gzip.GzipFile(features_output_path, 'wb'), + protocol=4) + features.append(feature_dict) + + if is_multimer: + multimer_msa = parsers.parse_a3m(paired_msa[idx]) + pair_features = pipeline.make_msa_features([multimer_msa]) + pair_feature_dict = compress_features(pair_features) + uniprot_output_path = os.path.join( + output_dir, '{}.uniprot.pkl.gz'.format(chain_id)) + pickle.dump( + pair_feature_dict, + gzip.GzipFile(uniprot_output_path, 'wb'), + protocol=4, + ) + pair_features.append(pair_feature_dict) + + # return features, pair_features, target_id + return { + 'features': features, + 'pair_features': pair_features, + 'target_id': target_id, + 'is_multimer': is_multimer, + } + + +if __name__ == '__main__': + proc = UniFoldPreprocessor() + protein_example = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + \ + 'TVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI' + features, pair_features = proc.__call__(protein_example) + import ipdb + ipdb.set_trace() diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 86b7bb7d..45bda324 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -9,6 +9,7 @@ class Fields(object): nlp = 'nlp' audio = 'audio' multi_modal = 'multi-modal' + science = 'science' class CVTasks(object): @@ -151,6 +152,10 @@ class MultiModalTasks(object): image_text_retrieval = 'image-text-retrieval' +class ScienceTasks(object): + protein_structure = 'protein-structure' + + class TasksIODescriptions(object): image_to_image = 'image_to_image', images_to_image = 'images_to_image', @@ -167,7 +172,7 @@ class TasksIODescriptions(object): generative_multi_modal_embedding = 'generative_multi_modal_embedding' -class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): +class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks): """ Names for tasks supported by modelscope. Holds the standard task name to use for identifying different tasks. @@ -196,6 +201,10 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): getattr(Tasks, attr) for attr in dir(MultiModalTasks) if not attr.startswith('__') ], + Fields.science: [ + getattr(Tasks, attr) for attr in dir(ScienceTasks) + if not attr.startswith('__') + ], } for field, tasks in field_dict.items(): diff --git a/requirements/science.txt b/requirements/science.txt new file mode 100644 index 00000000..72994f72 --- /dev/null +++ b/requirements/science.txt @@ -0,0 +1,6 @@ +iopath +lmdb +ml_collections +scipy +tensorboardX +tokenizers diff --git a/tests/pipelines/test_unifold.py b/tests/pipelines/test_unifold.py new file mode 100644 index 00000000..df35dc5e --- /dev/null +++ b/tests/pipelines/test_unifold.py @@ -0,0 +1,34 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.protein_structure + self.model_id = 'DPTech/uni-fold-monomer' + self.model_id_multimer = 'DPTech/uni-fold-multimer' + + self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD' + self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ + 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + model_dir = snapshot_download(self.model_id) + mono_pipeline_ins = pipeline(task=self.task, model=model_dir) + _ = mono_pipeline_ins(self.protein) + + model_dir1 = snapshot_download(self.model_id_multimer) + multi_pipeline_ins = pipeline(task=self.task, model=model_dir1) + _ = multi_pipeline_ins(self.protein_multimer) + + +if __name__ == '__main__': + unittest.main()