yuanzheng.yuanzhen yingda.chen 2 years ago
parent
commit
bab54bbce8
49 changed files with 14515 additions and 1 deletions
  1. +10
    -0
      modelscope/metainfo.py
  2. +21
    -0
      modelscope/models/science/__init__.py
  3. +1
    -0
      modelscope/models/science/unifold/__init__.py
  4. +636
    -0
      modelscope/models/science/unifold/config.py
  5. +14
    -0
      modelscope/models/science/unifold/data/__init__.py
  6. +1397
    -0
      modelscope/models/science/unifold/data/data_ops.py
  7. +526
    -0
      modelscope/models/science/unifold/data/msa_pairing.py
  8. +264
    -0
      modelscope/models/science/unifold/data/process.py
  9. +417
    -0
      modelscope/models/science/unifold/data/process_multimer.py
  10. +322
    -0
      modelscope/models/science/unifold/data/protein.py
  11. +1212
    -0
      modelscope/models/science/unifold/data/residue_constants.py
  12. +345
    -0
      modelscope/models/science/unifold/data/stereo_chemical_props.txt
  13. +161
    -0
      modelscope/models/science/unifold/data/utils.py
  14. +514
    -0
      modelscope/models/science/unifold/dataset.py
  15. +75
    -0
      modelscope/models/science/unifold/model.py
  16. +450
    -0
      modelscope/models/science/unifold/modules/alphafold.py
  17. +430
    -0
      modelscope/models/science/unifold/modules/attentions.py
  18. +171
    -0
      modelscope/models/science/unifold/modules/auxillary_heads.py
  19. +387
    -0
      modelscope/models/science/unifold/modules/common.py
  20. +159
    -0
      modelscope/models/science/unifold/modules/confidence.py
  21. +290
    -0
      modelscope/models/science/unifold/modules/embedders.py
  22. +362
    -0
      modelscope/models/science/unifold/modules/evoformer.py
  23. +195
    -0
      modelscope/models/science/unifold/modules/featurization.py
  24. +562
    -0
      modelscope/models/science/unifold/modules/frame.py
  25. +592
    -0
      modelscope/models/science/unifold/modules/structure_module.py
  26. +330
    -0
      modelscope/models/science/unifold/modules/template.py
  27. +158
    -0
      modelscope/models/science/unifold/modules/triangle_multiplication.py
  28. +1
    -0
      modelscope/models/science/unifold/msa/__init__.py
  29. +483
    -0
      modelscope/models/science/unifold/msa/mmcif.py
  30. +88
    -0
      modelscope/models/science/unifold/msa/msa_identifiers.py
  31. +627
    -0
      modelscope/models/science/unifold/msa/parsers.py
  32. +282
    -0
      modelscope/models/science/unifold/msa/pipeline.py
  33. +1110
    -0
      modelscope/models/science/unifold/msa/templates.py
  34. +14
    -0
      modelscope/models/science/unifold/msa/tools/__init__.py
  35. +170
    -0
      modelscope/models/science/unifold/msa/tools/hhblits.py
  36. +111
    -0
      modelscope/models/science/unifold/msa/tools/hhsearch.py
  37. +143
    -0
      modelscope/models/science/unifold/msa/tools/hmmbuild.py
  38. +146
    -0
      modelscope/models/science/unifold/msa/tools/hmmsearch.py
  39. +224
    -0
      modelscope/models/science/unifold/msa/tools/jackhmmer.py
  40. +110
    -0
      modelscope/models/science/unifold/msa/tools/kalign.py
  41. +40
    -0
      modelscope/models/science/unifold/msa/tools/utils.py
  42. +89
    -0
      modelscope/models/science/unifold/msa/utils.py
  43. +22
    -0
      modelscope/pipelines/science/__init__.py
  44. +215
    -0
      modelscope/pipelines/science/protein_structure_pipeline.py
  45. +20
    -0
      modelscope/preprocessors/science/__init__.py
  46. +569
    -0
      modelscope/preprocessors/science/uni_fold.py
  47. +10
    -1
      modelscope/utils/constant.py
  48. +6
    -0
      requirements/science.txt
  49. +34
    -0
      tests/pipelines/test_unifold.py

+ 10
- 0
modelscope/metainfo.py View File

@@ -99,6 +99,10 @@ class Models(object):
team = 'team-multi-modal-similarity'
video_clip = 'video-clip-multi-modal-embedding'

# science models
unifold = 'unifold'
unifold_symmetry = 'unifold-symmetry'


class TaskModels(object):
# nlp task
@@ -266,6 +270,9 @@ class Pipelines(object):
image_text_retrieval = 'image-text-retrieval'
ofa_ocr_recognition = 'ofa-ocr-recognition'

# science tasks
protein_structure = 'unifold-protein-structure'


class Trainers(object):
""" Names for different trainer.
@@ -368,6 +375,9 @@ class Preprocessors(object):
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor'
mplug_tasks_preprocessor = 'mplug-tasks-preprocessor'

# science preprocessor
unifold_preprocessor = 'unifold-preprocessor'


class Metrics(object):
""" Names for different metrics.


+ 21
- 0
modelscope/models/science/__init__.py View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:

from .unifold import UnifoldForProteinStructrue

else:
_import_structure = {'unifold': ['UnifoldForProteinStructrue']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 1
- 0
modelscope/models/science/unifold/__init__.py View File

@@ -0,0 +1 @@
from .model import UnifoldForProteinStructrue

+ 636
- 0
modelscope/models/science/unifold/config.py View File

@@ -0,0 +1,636 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import copy
from typing import Any

import ml_collections as mlc

N_RES = 'number of residues'
N_MSA = 'number of MSA sequences'
N_EXTRA_MSA = 'number of extra MSA sequences'
N_TPL = 'number of templates'

d_pair = mlc.FieldReference(128, field_type=int)
d_msa = mlc.FieldReference(256, field_type=int)
d_template = mlc.FieldReference(64, field_type=int)
d_extra_msa = mlc.FieldReference(64, field_type=int)
d_single = mlc.FieldReference(384, field_type=int)
max_recycling_iters = mlc.FieldReference(3, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)
inf = mlc.FieldReference(3e4, field_type=float)
use_templates = mlc.FieldReference(True, field_type=bool)
is_multimer = mlc.FieldReference(False, field_type=bool)


def base_config():
return mlc.ConfigDict({
'data': {
'common': {
'features': {
'aatype': [N_RES],
'all_atom_mask': [N_RES, None],
'all_atom_positions': [N_RES, None, None],
'alt_chi_angles': [N_RES, None],
'atom14_alt_gt_exists': [N_RES, None],
'atom14_alt_gt_positions': [N_RES, None, None],
'atom14_atom_exists': [N_RES, None],
'atom14_atom_is_ambiguous': [N_RES, None],
'atom14_gt_exists': [N_RES, None],
'atom14_gt_positions': [N_RES, None, None],
'atom37_atom_exists': [N_RES, None],
'frame_mask': [N_RES],
'true_frame_tensor': [N_RES, None, None],
'bert_mask': [N_MSA, N_RES],
'chi_angles_sin_cos': [N_RES, None, None],
'chi_mask': [N_RES, None],
'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES],
'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES],
'extra_msa': [N_EXTRA_MSA, N_RES],
'extra_msa_mask': [N_EXTRA_MSA, N_RES],
'extra_msa_row_mask': [N_EXTRA_MSA],
'is_distillation': [],
'msa_feat': [N_MSA, N_RES, None],
'msa_mask': [N_MSA, N_RES],
'msa_chains': [N_MSA, None],
'msa_row_mask': [N_MSA],
'num_recycling_iters': [],
'pseudo_beta': [N_RES, None],
'pseudo_beta_mask': [N_RES],
'residue_index': [N_RES],
'residx_atom14_to_atom37': [N_RES, None],
'residx_atom37_to_atom14': [N_RES, None],
'resolution': [],
'rigidgroups_alt_gt_frames': [N_RES, None, None, None],
'rigidgroups_group_exists': [N_RES, None],
'rigidgroups_group_is_ambiguous': [N_RES, None],
'rigidgroups_gt_exists': [N_RES, None],
'rigidgroups_gt_frames': [N_RES, None, None, None],
'seq_length': [],
'seq_mask': [N_RES],
'target_feat': [N_RES, None],
'template_aatype': [N_TPL, N_RES],
'template_all_atom_mask': [N_TPL, N_RES, None],
'template_all_atom_positions': [N_TPL, N_RES, None, None],
'template_alt_torsion_angles_sin_cos': [
N_TPL,
N_RES,
None,
None,
],
'template_frame_mask': [N_TPL, N_RES],
'template_frame_tensor': [N_TPL, N_RES, None, None],
'template_mask': [N_TPL],
'template_pseudo_beta': [N_TPL, N_RES, None],
'template_pseudo_beta_mask': [N_TPL, N_RES],
'template_sum_probs': [N_TPL, None],
'template_torsion_angles_mask': [N_TPL, N_RES, None],
'template_torsion_angles_sin_cos':
[N_TPL, N_RES, None, None],
'true_msa': [N_MSA, N_RES],
'use_clamped_fape': [],
'assembly_num_chains': [1],
'asym_id': [N_RES],
'sym_id': [N_RES],
'entity_id': [N_RES],
'num_sym': [N_RES],
'asym_len': [None],
'cluster_bias_mask': [N_MSA],
},
'masked_msa': {
'profile_prob': 0.1,
'same_prob': 0.1,
'uniform_prob': 0.1,
},
'block_delete_msa': {
'msa_fraction_per_block': 0.3,
'randomize_num_blocks': False,
'num_blocks': 5,
'min_num_msa': 16,
},
'random_delete_msa': {
'max_msa_entry': 1 << 25, # := 33554432
},
'v2_feature':
False,
'gumbel_sample':
False,
'max_extra_msa':
1024,
'msa_cluster_features':
True,
'reduce_msa_clusters_by_max_templates':
True,
'resample_msa_in_recycling':
True,
'template_features': [
'template_all_atom_positions',
'template_sum_probs',
'template_aatype',
'template_all_atom_mask',
],
'unsupervised_features': [
'aatype',
'residue_index',
'msa',
'msa_chains',
'num_alignments',
'seq_length',
'between_segment_residues',
'deletion_matrix',
'num_recycling_iters',
'crop_and_fix_size_seed',
],
'recycling_features': [
'msa_chains',
'msa_mask',
'msa_row_mask',
'bert_mask',
'true_msa',
'msa_feat',
'extra_msa_deletion_value',
'extra_msa_has_deletion',
'extra_msa',
'extra_msa_mask',
'extra_msa_row_mask',
'is_distillation',
],
'multimer_features': [
'assembly_num_chains',
'asym_id',
'sym_id',
'num_sym',
'entity_id',
'asym_len',
'cluster_bias_mask',
],
'use_templates':
use_templates,
'is_multimer':
is_multimer,
'use_template_torsion_angles':
use_templates,
'max_recycling_iters':
max_recycling_iters,
},
'supervised': {
'use_clamped_fape_prob':
1.0,
'supervised_features': [
'all_atom_mask',
'all_atom_positions',
'resolution',
'use_clamped_fape',
'is_distillation',
],
},
'predict': {
'fixed_size': True,
'subsample_templates': False,
'block_delete_msa': False,
'random_delete_msa': True,
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 128,
'max_templates': 4,
'num_ensembles': 2,
'crop': False,
'crop_size': None,
'supervised': False,
'biased_msa_by_chain': False,
'share_mask': False,
},
'eval': {
'fixed_size': True,
'subsample_templates': False,
'block_delete_msa': False,
'random_delete_msa': True,
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 128,
'max_templates': 4,
'num_ensembles': 1,
'crop': False,
'crop_size': None,
'spatial_crop_prob': 0.5,
'ca_ca_threshold': 10.0,
'supervised': True,
'biased_msa_by_chain': False,
'share_mask': False,
},
'train': {
'fixed_size': True,
'subsample_templates': True,
'block_delete_msa': True,
'random_delete_msa': True,
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 128,
'max_templates': 4,
'num_ensembles': 1,
'crop': True,
'crop_size': 256,
'spatial_crop_prob': 0.5,
'ca_ca_threshold': 10.0,
'supervised': True,
'use_clamped_fape_prob': 1.0,
'max_distillation_msa_clusters': 1000,
'biased_msa_by_chain': True,
'share_mask': True,
},
},
'globals': {
'chunk_size': chunk_size,
'block_size': None,
'd_pair': d_pair,
'd_msa': d_msa,
'd_template': d_template,
'd_extra_msa': d_extra_msa,
'd_single': d_single,
'eps': eps,
'inf': inf,
'max_recycling_iters': max_recycling_iters,
'alphafold_original_mode': False,
},
'model': {
'is_multimer': is_multimer,
'input_embedder': {
'tf_dim': 22,
'msa_dim': 49,
'd_pair': d_pair,
'd_msa': d_msa,
'relpos_k': 32,
'max_relative_chain': 2,
},
'recycling_embedder': {
'd_pair': d_pair,
'd_msa': d_msa,
'min_bin': 3.25,
'max_bin': 20.75,
'num_bins': 15,
'inf': 1e8,
},
'template': {
'distogram': {
'min_bin': 3.25,
'max_bin': 50.75,
'num_bins': 39,
},
'template_angle_embedder': {
'd_in': 57,
'd_out': d_msa,
},
'template_pair_embedder': {
'd_in': 88,
'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1],
'd_pair': d_pair,
'd_out': d_template,
'v2_feature': False,
},
'template_pair_stack': {
'd_template': d_template,
'd_hid_tri_att': 16,
'd_hid_tri_mul': 64,
'num_blocks': 2,
'num_heads': 4,
'pair_transition_n': 2,
'dropout_rate': 0.25,
'inf': 1e9,
'tri_attn_first': True,
},
'template_pointwise_attention': {
'enabled': True,
'd_template': d_template,
'd_pair': d_pair,
'd_hid': 16,
'num_heads': 4,
'inf': 1e5,
},
'inf': 1e5,
'eps': 1e-6,
'enabled': use_templates,
'embed_angles': use_templates,
},
'extra_msa': {
'extra_msa_embedder': {
'd_in': 25,
'd_out': d_extra_msa,
},
'extra_msa_stack': {
'd_msa': d_extra_msa,
'd_pair': d_pair,
'd_hid_msa_att': 8,
'd_hid_opm': 32,
'd_hid_mul': 128,
'd_hid_pair_att': 32,
'num_heads_msa': 8,
'num_heads_pair': 4,
'num_blocks': 4,
'transition_n': 4,
'msa_dropout': 0.15,
'pair_dropout': 0.25,
'inf': 1e9,
'eps': 1e-10,
'outer_product_mean_first': False,
},
'enabled': True,
},
'evoformer_stack': {
'd_msa': d_msa,
'd_pair': d_pair,
'd_hid_msa_att': 32,
'd_hid_opm': 32,
'd_hid_mul': 128,
'd_hid_pair_att': 32,
'd_single': d_single,
'num_heads_msa': 8,
'num_heads_pair': 4,
'num_blocks': 48,
'transition_n': 4,
'msa_dropout': 0.15,
'pair_dropout': 0.25,
'inf': 1e9,
'eps': 1e-10,
'outer_product_mean_first': False,
},
'structure_module': {
'd_single': d_single,
'd_pair': d_pair,
'd_ipa': 16,
'd_angle': 128,
'num_heads_ipa': 12,
'num_qk_points': 4,
'num_v_points': 8,
'dropout_rate': 0.1,
'num_blocks': 8,
'no_transition_layers': 1,
'num_resnet_blocks': 2,
'num_angles': 7,
'trans_scale_factor': 10,
'epsilon': 1e-12,
'inf': 1e5,
'separate_kv': False,
'ipa_bias': True,
},
'heads': {
'plddt': {
'num_bins': 50,
'd_in': d_single,
'd_hid': 128,
},
'distogram': {
'd_pair': d_pair,
'num_bins': aux_distogram_bins,
'disable_enhance_head': False,
},
'pae': {
'd_pair': d_pair,
'num_bins': aux_distogram_bins,
'enabled': False,
'iptm_weight': 0.8,
'disable_enhance_head': False,
},
'masked_msa': {
'd_msa': d_msa,
'd_out': 23,
'disable_enhance_head': False,
},
'experimentally_resolved': {
'd_single': d_single,
'd_out': 37,
'enabled': False,
'disable_enhance_head': False,
},
},
},
'loss': {
'distogram': {
'min_bin': 2.3125,
'max_bin': 21.6875,
'num_bins': 64,
'eps': 1e-6,
'weight': 0.3,
},
'experimentally_resolved': {
'eps': 1e-8,
'min_resolution': 0.1,
'max_resolution': 3.0,
'weight': 0.0,
},
'fape': {
'backbone': {
'clamp_distance': 10.0,
'clamp_distance_between_chains': 30.0,
'loss_unit_distance': 10.0,
'loss_unit_distance_between_chains': 20.0,
'weight': 0.5,
'eps': 1e-4,
},
'sidechain': {
'clamp_distance': 10.0,
'length_scale': 10.0,
'weight': 0.5,
'eps': 1e-4,
},
'weight': 1.0,
},
'plddt': {
'min_resolution': 0.1,
'max_resolution': 3.0,
'cutoff': 15.0,
'num_bins': 50,
'eps': 1e-10,
'weight': 0.01,
},
'masked_msa': {
'eps': 1e-8,
'weight': 2.0,
},
'supervised_chi': {
'chi_weight': 0.5,
'angle_norm_weight': 0.01,
'eps': 1e-6,
'weight': 1.0,
},
'violation': {
'violation_tolerance_factor': 12.0,
'clash_overlap_tolerance': 1.5,
'bond_angle_loss_weight': 0.3,
'eps': 1e-6,
'weight': 0.0,
},
'pae': {
'max_bin': 31,
'num_bins': 64,
'min_resolution': 0.1,
'max_resolution': 3.0,
'eps': 1e-8,
'weight': 0.0,
},
'repr_norm': {
'weight': 0.01,
'tolerance': 1.0,
},
'chain_centre_mass': {
'weight': 0.0,
'eps': 1e-8,
},
},
})


def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None):
with c.unlocked():
for k, v in c.items():
if ignore is not None and k == ignore:
continue
if isinstance(v, mlc.ConfigDict):
recursive_set(v, key, value)
elif k == key:
c[k] = value


def model_config(name, train=False):
c = copy.deepcopy(base_config())

def model_2_v2(c):
recursive_set(c, 'v2_feature', True)
recursive_set(c, 'gumbel_sample', True)
c.model.heads.masked_msa.d_out = 22
c.model.structure_module.separate_kv = True
c.model.structure_module.ipa_bias = False
c.model.template.template_angle_embedder.d_in = 34
return c

def multimer(c):
recursive_set(c, 'is_multimer', True)
recursive_set(c, 'max_extra_msa', 1152)
recursive_set(c, 'max_msa_clusters', 128)
recursive_set(c, 'v2_feature', True)
recursive_set(c, 'gumbel_sample', True)
c.model.template.template_angle_embedder.d_in = 34
c.model.template.template_pair_stack.tri_attn_first = False
c.model.template.template_pointwise_attention.enabled = False
c.model.heads.pae.enabled = True
# we forget to enable it in our training, so disable it here
c.model.heads.pae.disable_enhance_head = True
c.model.heads.masked_msa.d_out = 22
c.model.structure_module.separate_kv = True
c.model.structure_module.ipa_bias = False
c.model.structure_module.trans_scale_factor = 20
c.loss.pae.weight = 0.1
c.model.input_embedder.tf_dim = 21
c.data.train.crop_size = 384
c.loss.violation.weight = 0.02
c.loss.chain_centre_mass.weight = 1.0
return c

if name == 'model_1':
pass
elif name == 'model_1_ft':
recursive_set(c, 'max_extra_msa', 5120)
recursive_set(c, 'max_msa_clusters', 512)
c.data.train.crop_size = 384
c.loss.violation.weight = 0.02
elif name == 'model_1_af2':
recursive_set(c, 'max_extra_msa', 5120)
recursive_set(c, 'max_msa_clusters', 512)
c.data.train.crop_size = 384
c.loss.violation.weight = 0.02
c.loss.repr_norm.weight = 0
c.model.heads.experimentally_resolved.enabled = True
c.loss.experimentally_resolved.weight = 0.01
c.globals.alphafold_original_mode = True
elif name == 'model_2':
pass
elif name == 'model_init':
pass
elif name == 'model_init_af2':
c.globals.alphafold_original_mode = True
pass
elif name == 'model_2_ft':
recursive_set(c, 'max_extra_msa', 1024)
recursive_set(c, 'max_msa_clusters', 512)
c.data.train.crop_size = 384
c.loss.violation.weight = 0.02
elif name == 'model_2_af2':
recursive_set(c, 'max_extra_msa', 1024)
recursive_set(c, 'max_msa_clusters', 512)
c.data.train.crop_size = 384
c.loss.violation.weight = 0.02
c.loss.repr_norm.weight = 0
c.model.heads.experimentally_resolved.enabled = True
c.loss.experimentally_resolved.weight = 0.01
c.globals.alphafold_original_mode = True
elif name == 'model_2_v2':
c = model_2_v2(c)
elif name == 'model_2_v2_ft':
c = model_2_v2(c)
recursive_set(c, 'max_extra_msa', 1024)
recursive_set(c, 'max_msa_clusters', 512)
c.data.train.crop_size = 384
c.loss.violation.weight = 0.02
elif name == 'model_3_af2' or name == 'model_4_af2':
recursive_set(c, 'max_extra_msa', 5120)
recursive_set(c, 'max_msa_clusters', 512)
c.data.train.crop_size = 384
c.loss.violation.weight = 0.02
c.loss.repr_norm.weight = 0
c.model.heads.experimentally_resolved.enabled = True
c.loss.experimentally_resolved.weight = 0.01
c.globals.alphafold_original_mode = True
c.model.template.enabled = False
c.model.template.embed_angles = False
recursive_set(c, 'use_templates', False)
recursive_set(c, 'use_template_torsion_angles', False)
elif name == 'model_5_af2':
recursive_set(c, 'max_extra_msa', 1024)
recursive_set(c, 'max_msa_clusters', 512)
c.data.train.crop_size = 384
c.loss.violation.weight = 0.02
c.loss.repr_norm.weight = 0
c.model.heads.experimentally_resolved.enabled = True
c.loss.experimentally_resolved.weight = 0.01
c.globals.alphafold_original_mode = True
c.model.template.enabled = False
c.model.template.embed_angles = False
recursive_set(c, 'use_templates', False)
recursive_set(c, 'use_template_torsion_angles', False)
elif name == 'multimer':
c = multimer(c)
elif name == 'multimer_ft':
c = multimer(c)
recursive_set(c, 'max_extra_msa', 1152)
recursive_set(c, 'max_msa_clusters', 256)
c.data.train.crop_size = 384
c.loss.violation.weight = 0.5
elif name == 'multimer_af2':
recursive_set(c, 'max_extra_msa', 1152)
recursive_set(c, 'max_msa_clusters', 256)
recursive_set(c, 'is_multimer', True)
recursive_set(c, 'v2_feature', True)
recursive_set(c, 'gumbel_sample', True)
c.model.template.template_angle_embedder.d_in = 34
c.model.template.template_pair_stack.tri_attn_first = False
c.model.template.template_pointwise_attention.enabled = False
c.model.heads.pae.enabled = True
c.model.heads.experimentally_resolved.enabled = True
c.model.heads.masked_msa.d_out = 22
c.model.structure_module.separate_kv = True
c.model.structure_module.ipa_bias = False
c.model.structure_module.trans_scale_factor = 20
c.loss.pae.weight = 0.1
c.loss.violation.weight = 0.5
c.loss.experimentally_resolved.weight = 0.01
c.model.input_embedder.tf_dim = 21
c.globals.alphafold_original_mode = True
c.data.train.crop_size = 384
c.loss.repr_norm.weight = 0
c.loss.chain_centre_mass.weight = 1.0
recursive_set(c, 'outer_product_mean_first', True)
else:
raise ValueError(f'invalid --model-name: {name}.')
if train:
c.globals.chunk_size = None
recursive_set(c, 'inf', 3e4)
recursive_set(c, 'eps', 1e-5, 'loss')
return c

+ 14
- 0
modelscope/models/science/unifold/data/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data pipeline for model features."""

+ 1397
- 0
modelscope/models/science/unifold/data/data_ops.py
File diff suppressed because it is too large
View File


+ 526
- 0
modelscope/models/science/unifold/data/msa_pairing.py View File

@@ -0,0 +1,526 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pairing logic for multimer data """

import collections
from typing import Dict, Iterable, List, Sequence

import numpy as np
import pandas as pd
import scipy.linalg

from .data_ops import NumpyDict
from .residue_constants import restypes_with_x_and_gap

MSA_GAP_IDX = restypes_with_x_and_gap.index('-')
SEQUENCE_GAP_CUTOFF = 0.5
SEQUENCE_SIMILARITY_CUTOFF = 0.9

MSA_PAD_VALUES = {
'msa_all_seq': MSA_GAP_IDX,
'msa_mask_all_seq': 1,
'deletion_matrix_all_seq': 0,
'deletion_matrix_int_all_seq': 0,
'msa': MSA_GAP_IDX,
'msa_mask': 1,
'deletion_matrix': 0,
'deletion_matrix_int': 0,
}

MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int')
SEQ_FEATURES = (
'residue_index',
'aatype',
'all_atom_positions',
'all_atom_mask',
'seq_mask',
'between_segment_residues',
'has_alt_locations',
'has_hetatoms',
'asym_id',
'entity_id',
'sym_id',
'entity_mask',
'deletion_mean',
'prediction_atom_mask',
'literature_positions',
'atom_indices_to_group_indices',
'rigid_group_default_frame',
# zy
'num_sym',
)
TEMPLATE_FEATURES = (
'template_aatype',
'template_all_atom_positions',
'template_all_atom_mask',
)
CHAIN_FEATURES = ('num_alignments', 'seq_length')


def create_paired_features(chains: Iterable[NumpyDict], ) -> List[NumpyDict]:
"""Returns the original chains with paired NUM_SEQ features.

Args:
chains: A list of feature dictionaries for each chain.

Returns:
A list of feature dictionaries with sequence features including only
rows to be paired.
"""
chains = list(chains)
chain_keys = chains[0].keys()

if len(chains) < 2:
return chains
else:
updated_chains = []
paired_chains_to_paired_row_indices = pair_sequences(chains)
paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices)

for chain_num, chain in enumerate(chains):
new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k}
for feature_name in chain_keys:
if feature_name.endswith('_all_seq'):
feats_padded = pad_features(chain[feature_name],
feature_name)
new_chain[feature_name] = feats_padded[
paired_rows[:, chain_num]]
new_chain['num_alignments_all_seq'] = np.asarray(
len(paired_rows[:, chain_num]))
updated_chains.append(new_chain)
return updated_chains


def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray:
"""Add a 'padding' row at the end of the features list.

The padding row will be selected as a 'paired' row in the case of partial
alignment - for the chain that doesn't have paired alignment.

Args:
feature: The feature to be padded.
feature_name: The name of the feature to be padded.

Returns:
The feature with an additional padding row.
"""
assert feature.dtype != np.dtype(np.string_)
if feature_name in (
'msa_all_seq',
'msa_mask_all_seq',
'deletion_matrix_all_seq',
'deletion_matrix_int_all_seq',
):
num_res = feature.shape[1]
padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res],
feature.dtype)
elif feature_name == 'msa_species_identifiers_all_seq':
padding = [b'']
else:
return feature
feats_padded = np.concatenate([feature, padding], axis=0)
return feats_padded


def _make_msa_df(chain_features: NumpyDict) -> pd.DataFrame:
"""Makes dataframe with msa features needed for msa pairing."""
chain_msa = chain_features['msa_all_seq']
query_seq = chain_msa[0]
per_seq_similarity = np.sum(
query_seq[None] == chain_msa, axis=-1) / float(len(query_seq))
per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq))
msa_df = pd.DataFrame({
'msa_species_identifiers':
chain_features['msa_species_identifiers_all_seq'],
'msa_row':
np.arange(len(chain_features['msa_species_identifiers_all_seq'])),
'msa_similarity':
per_seq_similarity,
'gap':
per_seq_gap,
})
return msa_df


def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]:
"""Creates mapping from species to msa dataframe of that species."""
species_lookup = {}
for species, species_df in msa_df.groupby('msa_species_identifiers'):
species_lookup[species] = species_df
return species_lookup


def _match_rows_by_sequence_similarity(
this_species_msa_dfs: List[pd.DataFrame], ) -> List[List[int]]: # noqa
"""Finds MSA sequence pairings across chains based on sequence similarity.

Each chain's MSA sequences are first sorted by their sequence similarity to
their respective target sequence. The sequences are then paired, starting
from the sequences most similar to their target sequence.

Args:
this_species_msa_dfs: a list of dataframes containing MSA features for
sequences for a specific species.

Returns:
A list of lists, each containing M indices corresponding to paired MSA rows,
where M is the number of chains.
"""
all_paired_msa_rows = []

num_seqs = [
len(species_df) for species_df in this_species_msa_dfs
if species_df is not None
]
take_num_seqs = np.min(num_seqs)

# sort_by_similarity = lambda x: x.sort_values(
# 'msa_similarity', axis=0, ascending=False)

def sort_by_similarity(x):
return x.sort_values('msa_similarity', axis=0, ascending=False)

for species_df in this_species_msa_dfs:
if species_df is not None:
species_df_sorted = sort_by_similarity(species_df)
msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values
else:
msa_rows = [-1] * take_num_seqs # take the last 'padding' row
all_paired_msa_rows.append(msa_rows)
all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose())
return all_paired_msa_rows


def pair_sequences(examples: List[NumpyDict]) -> Dict[int, np.ndarray]:
"""Returns indices for paired MSA sequences across chains."""

num_examples = len(examples)

all_chain_species_dict = []
common_species = set()
for chain_features in examples:
msa_df = _make_msa_df(chain_features)
species_dict = _create_species_dict(msa_df)
all_chain_species_dict.append(species_dict)
common_species.update(set(species_dict))

common_species = sorted(common_species)
common_species.remove(b'') # Remove target sequence species.

all_paired_msa_rows = [np.zeros(len(examples), int)]
all_paired_msa_rows_dict = {k: [] for k in range(num_examples)}
all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)]

for species in common_species:
if not species:
continue
this_species_msa_dfs = []
species_dfs_present = 0
for species_dict in all_chain_species_dict:
if species in species_dict:
this_species_msa_dfs.append(species_dict[species])
species_dfs_present += 1
else:
this_species_msa_dfs.append(None)

# Skip species that are present in only one chain.
if species_dfs_present <= 1:
continue

if np.any(
np.array([
len(species_df) for species_df in this_species_msa_dfs
if isinstance(species_df, pd.DataFrame)
]) > 600):
continue

paired_msa_rows = _match_rows_by_sequence_similarity(
this_species_msa_dfs)
all_paired_msa_rows.extend(paired_msa_rows)
all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows)
all_paired_msa_rows_dict = {
num_examples: np.array(paired_msa_rows)
for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items()
}
return all_paired_msa_rows_dict


def reorder_paired_rows(
all_paired_msa_rows_dict: Dict[int, np.ndarray]) -> np.ndarray:
"""Creates a list of indices of paired MSA rows across chains.

Args:
all_paired_msa_rows_dict: a mapping from the number of paired chains to the
paired indices.

Returns:
a list of lists, each containing indices of paired MSA rows across chains.
The paired-index lists are ordered by:
1) the number of chains in the paired alignment, i.e, all-chain pairings
will come first.
2) e-values
"""
all_paired_msa_rows = []

for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True):
paired_rows = all_paired_msa_rows_dict[num_pairings]
paired_rows_product = np.abs(
np.array(
[np.prod(rows.astype(np.float64)) for rows in paired_rows]))
paired_rows_sort_index = np.argsort(paired_rows_product)
all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index])

return np.array(all_paired_msa_rows)


def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray:
"""Like scipy.linalg.block_diag but with an optional padding value."""
ones_arrs = [np.ones_like(x) for x in arrs]
off_diag_mask = 1 - scipy.linalg.block_diag(*ones_arrs)
diag = scipy.linalg.block_diag(*arrs)
diag += (off_diag_mask * pad_value).astype(diag.dtype)
return diag


def _correct_post_merged_feats(np_example: NumpyDict,
np_chains_list: Sequence[NumpyDict],
pair_msa_sequences: bool) -> NumpyDict:
"""Adds features that need to be computed/recomputed post merging."""

np_example['seq_length'] = np.asarray(
np_example['aatype'].shape[0], dtype=np.int32)
np_example['num_alignments'] = np.asarray(
np_example['msa'].shape[0], dtype=np.int32)

if not pair_msa_sequences:
# Generate a bias that is 1 for the first row of every block in the
# block diagonal MSA - i.e. make sure the cluster stack always includes
# the query sequences for each chain (since the first row is the query
# sequence).
cluster_bias_masks = []
for chain in np_chains_list:
mask = np.zeros(chain['msa'].shape[0])
mask[0] = 1
cluster_bias_masks.append(mask)
np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks)

# Initialize Bert mask with masked out off diagonals.
msa_masks = [
np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list
]

np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0)
else:
np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0])
np_example['cluster_bias_mask'][0] = 1

# Initialize Bert mask with masked out off diagonals.
msa_masks = [
np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list
]
msa_masks_all_seq = [
np.ones(x['msa_all_seq'].shape, dtype=np.int8)
for x in np_chains_list
]

msa_mask_block_diag = block_diag(*msa_masks, pad_value=0)
msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1)
np_example['bert_mask'] = np.concatenate(
[msa_mask_all_seq, msa_mask_block_diag], axis=0)
return np_example


def _pad_templates(chains: Sequence[NumpyDict],
max_templates: int) -> Sequence[NumpyDict]:
"""For each chain pad the number of templates to a fixed size.

Args:
chains: A list of protein chains.
max_templates: Each chain will be padded to have this many templates.

Returns:
The list of chains, updated to have template features padded to
max_templates.
"""
for chain in chains:
for k, v in chain.items():
if k in TEMPLATE_FEATURES:
padding = np.zeros_like(v.shape)
padding[0] = max_templates - v.shape[0]
padding = [(0, p) for p in padding]
chain[k] = np.pad(v, padding, mode='constant')
return chains


def _merge_features_from_multiple_chains(
chains: Sequence[NumpyDict], pair_msa_sequences: bool) -> NumpyDict:
"""Merge features from multiple chains.

Args:
chains: A list of feature dictionaries that we want to merge.
pair_msa_sequences: Whether to concatenate MSA features along the
num_res dimension (if True), or to block diagonalize them (if False).

Returns:
A feature dictionary for the merged example.
"""
merged_example = {}
for feature_name in chains[0]:
feats = [x[feature_name] for x in chains]
feature_name_split = feature_name.split('_all_seq')[0]
if feature_name_split in MSA_FEATURES:
if pair_msa_sequences or '_all_seq' in feature_name:
merged_example[feature_name] = np.concatenate(feats, axis=1)
if feature_name_split == 'msa':
merged_example['msa_chains_all_seq'] = np.ones(
merged_example[feature_name].shape[0]).reshape(-1, 1)
else:
merged_example[feature_name] = block_diag(
*feats, pad_value=MSA_PAD_VALUES[feature_name])
if feature_name_split == 'msa':
msa_chains = []
for i, feat in enumerate(feats):
cur_shape = feat.shape[0]
vals = np.ones(cur_shape) * (i + 2)
msa_chains.append(vals)
merged_example['msa_chains'] = np.concatenate(
msa_chains).reshape(-1, 1)
elif feature_name_split in SEQ_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=0)
elif feature_name_split in TEMPLATE_FEATURES:
merged_example[feature_name] = np.concatenate(feats, axis=1)
elif feature_name_split in CHAIN_FEATURES:
merged_example[feature_name] = np.sum(feats).astype(np.int32)
else:
merged_example[feature_name] = feats[0]
return merged_example


def _merge_homomers_dense_msa(
chains: Iterable[NumpyDict]) -> Sequence[NumpyDict]:
"""Merge all identical chains, making the resulting MSA dense.

Args:
chains: An iterable of features for each chain.

Returns:
A list of feature dictionaries. All features with the same entity_id
will be merged - MSA features will be concatenated along the num_res
dimension - making them dense.
"""
entity_chains = collections.defaultdict(list)
for chain in chains:
entity_id = chain['entity_id'][0]
entity_chains[entity_id].append(chain)

grouped_chains = []
for entity_id in sorted(entity_chains):
chains = entity_chains[entity_id]
grouped_chains.append(chains)
chains = [
_merge_features_from_multiple_chains(chains, pair_msa_sequences=True)
for chains in grouped_chains
]
return chains


def _concatenate_paired_and_unpaired_features(example: NumpyDict) -> NumpyDict:
"""Merges paired and block-diagonalised features."""
features = MSA_FEATURES + ('msa_chains', )
for feature_name in features:
if feature_name in example:
feat = example[feature_name]
feat_all_seq = example[feature_name + '_all_seq']
try:
merged_feat = np.concatenate([feat_all_seq, feat], axis=0)
except Exception as ex:
raise Exception(
'concat failed.',
feature_name,
feat_all_seq.shape,
feat.shape,
ex.__class__,
ex,
)
example[feature_name] = merged_feat
example['num_alignments'] = np.array(
example['msa'].shape[0], dtype=np.int32)
return example


def merge_chain_features(np_chains_list: List[NumpyDict],
pair_msa_sequences: bool,
max_templates: int) -> NumpyDict:
"""Merges features for multiple chains to single FeatureDict.

Args:
np_chains_list: List of FeatureDicts for each chain.
pair_msa_sequences: Whether to merge paired MSAs.
max_templates: The maximum number of templates to include.

Returns:
Single FeatureDict for entire complex.
"""
np_chains_list = _pad_templates(
np_chains_list, max_templates=max_templates)
np_chains_list = _merge_homomers_dense_msa(np_chains_list)
# Unpaired MSA features will be always block-diagonalised; paired MSA
# features will be concatenated.
np_example = _merge_features_from_multiple_chains(
np_chains_list, pair_msa_sequences=False)
if pair_msa_sequences:
np_example = _concatenate_paired_and_unpaired_features(np_example)
np_example = _correct_post_merged_feats(
np_example=np_example,
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences,
)

return np_example


def deduplicate_unpaired_sequences(
np_chains: List[NumpyDict]) -> List[NumpyDict]:
"""Removes unpaired sequences which duplicate a paired sequence."""

feature_names = np_chains[0].keys()
msa_features = MSA_FEATURES
cache_msa_features = {}
for chain in np_chains:
entity_id = int(chain['entity_id'][0])
if entity_id not in cache_msa_features:
sequence_set = set(s.tobytes() for s in chain['msa_all_seq'])
keep_rows = []
# Go through unpaired MSA seqs and remove any rows that correspond to the
# sequences that are already present in the paired MSA.
for row_num, seq in enumerate(chain['msa']):
if seq.tobytes() not in sequence_set:
keep_rows.append(row_num)
new_msa_features = {}
for feature_name in feature_names:
if feature_name in msa_features:
if keep_rows:
new_msa_features[feature_name] = chain[feature_name][
keep_rows]
else:
new_shape = list(chain[feature_name].shape)
new_shape[0] = 0
new_msa_features[feature_name] = np.zeros(
new_shape, dtype=chain[feature_name].dtype)
cache_msa_features[entity_id] = new_msa_features
for feature_name in cache_msa_features[entity_id]:
chain[feature_name] = cache_msa_features[entity_id][feature_name]
chain['num_alignments'] = np.array(
chain['msa'].shape[0], dtype=np.int32)
return np_chains

+ 264
- 0
modelscope/models/science/unifold/data/process.py View File

@@ -0,0 +1,264 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from typing import Optional

import numpy as np
import torch

from modelscope.models.science.unifold.data import data_ops


def nonensembled_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
v2_feature = common_cfg.v2_feature
operators = []
if mode_cfg.random_delete_msa:
operators.append(
data_ops.random_delete_msa(common_cfg.random_delete_msa))
operators.extend([
data_ops.cast_to_64bit_ints,
data_ops.correct_msa_restypes,
data_ops.squeeze_features,
data_ops.randomly_replace_msa_with_unknown(0.0),
data_ops.make_seq_mask,
data_ops.make_msa_mask,
])
operators.append(data_ops.make_hhblits_profile_v2
if v2_feature else data_ops.make_hhblits_profile)
if common_cfg.use_templates:
operators.extend([
data_ops.make_template_mask,
data_ops.make_pseudo_beta('template_'),
])
operators.append(
data_ops.crop_templates(
max_templates=mode_cfg.max_templates,
subsample_templates=mode_cfg.subsample_templates,
))

if common_cfg.use_template_torsion_angles:
operators.extend([
data_ops.atom37_to_torsion_angles('template_'),
])

operators.append(data_ops.make_atom14_masks)
operators.append(data_ops.make_target_feat)

return operators


def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed):
operators = []
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
pad_msa_clusters = mode_cfg.max_msa_clusters
crop_feats = dict(common_cfg.features)
if mode_cfg.fixed_size:
if mode_cfg.crop:
if common_cfg.is_multimer:
crop_fn = data_ops.crop_to_size_multimer(
crop_size=mode_cfg.crop_size,
shape_schema=crop_feats,
seed=crop_and_fix_size_seed,
spatial_crop_prob=mode_cfg.spatial_crop_prob,
ca_ca_threshold=mode_cfg.ca_ca_threshold,
)
else:
crop_fn = data_ops.crop_to_size_single(
crop_size=mode_cfg.crop_size,
shape_schema=crop_feats,
seed=crop_and_fix_size_seed,
)
operators.append(crop_fn)

operators.append(data_ops.select_feat(crop_feats))

operators.append(
data_ops.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates,
))
return operators


def ensembled_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that can be ensembled and averaged."""
operators = []
multimer_mode = common_cfg.is_multimer
v2_feature = common_cfg.v2_feature
# multimer don't use block delete msa
if mode_cfg.block_delete_msa and not multimer_mode:
operators.append(
data_ops.block_delete_msa(common_cfg.block_delete_msa))
if 'max_distillation_msa_clusters' in mode_cfg:
operators.append(
data_ops.sample_msa_distillation(
mode_cfg.max_distillation_msa_clusters))

if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
pad_msa_clusters = mode_cfg.max_msa_clusters

max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa

assert common_cfg.resample_msa_in_recycling
gumbel_sample = common_cfg.gumbel_sample
operators.append(
data_ops.sample_msa(
max_msa_clusters,
keep_extra=True,
gumbel_sample=gumbel_sample,
biased_msa_by_chain=mode_cfg.biased_msa_by_chain,
))

if 'masked_msa' in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
operators.append(
data_ops.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction,
gumbel_sample=gumbel_sample,
share_mask=mode_cfg.share_mask,
))

if common_cfg.msa_cluster_features:
if v2_feature:
operators.append(data_ops.nearest_neighbor_clusters_v2())
else:
operators.append(data_ops.nearest_neighbor_clusters())
operators.append(data_ops.summarize_clusters)

if v2_feature:
operators.append(data_ops.make_msa_feat_v2)
else:
operators.append(data_ops.make_msa_feat)
# Crop after creating the cluster profiles.
if max_extra_msa:
if v2_feature:
operators.append(data_ops.make_extra_msa_feat(max_extra_msa))
else:
operators.append(data_ops.crop_extra_msa(max_extra_msa))
else:
operators.append(data_ops.delete_extra_msa)
# operators.append(data_operators.select_feat(common_cfg.recycling_features))
return operators


def process_features(tensors, common_cfg, mode_cfg):
"""Based on the config, apply filters and transformations to the data."""
is_distillation = bool(tensors.get('is_distillation', 0))
multimer_mode = common_cfg.is_multimer
crop_and_fix_size_seed = int(tensors['crop_and_fix_size_seed'])
crop_fn = crop_and_fix_size_fns(
common_cfg,
mode_cfg,
crop_and_fix_size_seed,
)

def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_fns(
common_cfg,
mode_cfg,
)
new_d = compose(fns)(d)
if not multimer_mode or is_distillation:
new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d)
return compose(crop_fn)(new_d)
else: # select after crop for spatial cropping
d = compose(crop_fn)(d)
d = data_ops.select_feat(common_cfg.recycling_features)(d)
return d

nonensembled = nonensembled_fns(common_cfg, mode_cfg)

if mode_cfg.supervised and (not multimer_mode or is_distillation):
nonensembled.extend(label_transform_fn())

tensors = compose(nonensembled)(tensors)

num_recycling = int(tensors['num_recycling_iters']) + 1
num_ensembles = mode_cfg.num_ensembles

ensemble_tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x),
torch.arange(num_recycling * num_ensembles),
)
tensors = compose(crop_fn)(tensors)
# add a dummy dim to align with recycling features
tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors}
tensors.update(ensemble_tensors)
return tensors


@data_ops.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x


def pad_then_stack(values, ):
if len(values[0].shape) >= 1:
size = max(v.shape[0] for v in values)
new_values = []
for v in values:
if v.shape[0] < size:
res = values[0].new_zeros(size, *v.shape[1:])
res[:v.shape[0], ...] = v
else:
res = v
new_values.append(res)
else:
new_values = values
return torch.stack(new_values, dim=0)


def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = pad_then_stack(
[dict_i[feat] for dict_i in ensembles])
return ensembled_dict


def process_single_label(label: dict,
num_ensemble: Optional[int] = None) -> dict:
assert 'aatype' in label
assert 'all_atom_positions' in label
assert 'all_atom_mask' in label
label = compose(label_transform_fn())(label)
if num_ensemble is not None:
label = {
k: torch.stack([v for _ in range(num_ensemble)])
for k, v in label.items()
}
return label


def process_labels(labels_list, num_ensemble: Optional[int] = None):
return [process_single_label(ll, num_ensemble) for ll in labels_list]


def label_transform_fn():
return [
data_ops.make_atom14_masks,
data_ops.make_atom14_positions,
data_ops.atom37_to_frames,
data_ops.atom37_to_torsion_angles(''),
data_ops.make_pseudo_beta(''),
data_ops.get_backbone_frames,
data_ops.get_chi_angles,
]

+ 417
- 0
modelscope/models/science/unifold/data/process_multimer.py View File

@@ -0,0 +1,417 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature processing logic for multimer data """

import collections
from typing import Iterable, List, MutableMapping

import numpy as np

from modelscope.models.science.unifold.data import (msa_pairing,
residue_constants)
from .utils import correct_template_restypes

FeatureDict = MutableMapping[str, np.ndarray]

REQUIRED_FEATURES = frozenset({
'aatype',
'all_atom_mask',
'all_atom_positions',
'all_chains_entity_ids',
'all_crops_all_chains_mask',
'all_crops_all_chains_positions',
'all_crops_all_chains_residue_ids',
'assembly_num_chains',
'asym_id',
'bert_mask',
'cluster_bias_mask',
'deletion_matrix',
'deletion_mean',
'entity_id',
'entity_mask',
'mem_peak',
'msa',
'msa_mask',
'num_alignments',
'num_templates',
'queue_size',
'residue_index',
'resolution',
'seq_length',
'seq_mask',
'sym_id',
'template_aatype',
'template_all_atom_mask',
'template_all_atom_positions',
# zy added:
'asym_len',
'template_sum_probs',
'num_sym',
'msa_chains',
})

MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048


def _is_homomer_or_monomer(chains: Iterable[FeatureDict]) -> bool:
"""Checks if a list of chains represents a homomer/monomer example."""
# Note that an entity_id of 0 indicates padding.
num_unique_chains = len(
np.unique(
np.concatenate([
np.unique(chain['entity_id'][chain['entity_id'] > 0])
for chain in chains
])))
return num_unique_chains == 1


def pair_and_merge(
all_chain_features: MutableMapping[str, FeatureDict]) -> FeatureDict:
"""Runs processing on features to augment, pair and merge.

Args:
all_chain_features: A MutableMap of dictionaries of features for each chain.

Returns:
A dictionary of features.
"""

process_unmerged_features(all_chain_features)

np_chains_list = all_chain_features

pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list)

if pair_msa_sequences:
np_chains_list = msa_pairing.create_paired_features(
chains=np_chains_list)
np_chains_list = msa_pairing.deduplicate_unpaired_sequences(
np_chains_list)
np_chains_list = crop_chains(
np_chains_list,
msa_crop_size=MSA_CROP_SIZE,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES,
)
np_example = msa_pairing.merge_chain_features(
np_chains_list=np_chains_list,
pair_msa_sequences=pair_msa_sequences,
max_templates=MAX_TEMPLATES,
)
np_example = process_final(np_example)
return np_example


def crop_chains(
chains_list: List[FeatureDict],
msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int,
) -> List[FeatureDict]:
"""Crops the MSAs for a set of chains.

Args:
chains_list: A list of chains to be cropped.
msa_crop_size: The total number of sequences to crop from the MSA.
pair_msa_sequences: Whether we are operating in sequence-pairing mode.
max_templates: The maximum templates to use per chain.

Returns:
The chains cropped.
"""

# Apply the cropping.
cropped_chains = []
for chain in chains_list:
cropped_chain = _crop_single_chain(
chain,
msa_crop_size=msa_crop_size,
pair_msa_sequences=pair_msa_sequences,
max_templates=max_templates,
)
cropped_chains.append(cropped_chain)

return cropped_chains


def _crop_single_chain(chain: FeatureDict, msa_crop_size: int,
pair_msa_sequences: bool,
max_templates: int) -> FeatureDict:
"""Crops msa sequences to `msa_crop_size`."""
msa_size = chain['num_alignments']

if pair_msa_sequences:
msa_size_all_seq = chain['num_alignments_all_seq']
msa_crop_size_all_seq = np.minimum(msa_size_all_seq,
msa_crop_size // 2)

# We reduce the number of un-paired sequences, by the number of times a
# sequence from this chain's MSA is included in the paired MSA. This keeps
# the MSA size for each chain roughly constant.
msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :]
num_non_gapped_pairs = np.sum(
np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1))
num_non_gapped_pairs = np.minimum(num_non_gapped_pairs,
msa_crop_size_all_seq)

# Restrict the unpaired crop size so that paired+unpaired sequences do not
# exceed msa_seqs_per_chain for each chain.
max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0)
msa_crop_size = np.minimum(msa_size, max_msa_crop_size)
else:
msa_crop_size = np.minimum(msa_size, msa_crop_size)

include_templates = 'template_aatype' in chain and max_templates
if include_templates:
num_templates = chain['template_aatype'].shape[0]
templates_crop_size = np.minimum(num_templates, max_templates)

for k in chain:
k_split = k.split('_all_seq')[0]
if k_split in msa_pairing.TEMPLATE_FEATURES:
chain[k] = chain[k][:templates_crop_size, :]
elif k_split in msa_pairing.MSA_FEATURES:
if '_all_seq' in k and pair_msa_sequences:
chain[k] = chain[k][:msa_crop_size_all_seq, :]
else:
chain[k] = chain[k][:msa_crop_size, :]

chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32)
if include_templates:
chain['num_templates'] = np.asarray(
templates_crop_size, dtype=np.int32)
if pair_msa_sequences:
chain['num_alignments_all_seq'] = np.asarray(
msa_crop_size_all_seq, dtype=np.int32)
return chain


def process_final(np_example: FeatureDict) -> FeatureDict:
"""Final processing steps in data pipeline, after merging and pairing."""
np_example = _make_seq_mask(np_example)
np_example = _make_msa_mask(np_example)
np_example = _filter_features(np_example)
return np_example


def _make_seq_mask(np_example):
np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32)
return np_example


def _make_msa_mask(np_example):
"""Mask features are all ones, but will later be zero-padded."""

np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.int8)

seq_mask = (np_example['entity_id'] > 0).astype(np.int8)
np_example['msa_mask'] *= seq_mask[None]

return np_example


def _filter_features(np_example: FeatureDict) -> FeatureDict:
"""Filters features of example to only those requested."""
return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES}


def process_unmerged_features(all_chain_features: MutableMapping[str,
FeatureDict]):
"""Postprocessing stage for per-chain features before merging."""
num_chains = len(all_chain_features)
for chain_features in all_chain_features:
# Convert deletion matrices to float.
if 'deletion_matrix_int' in chain_features:
chain_features['deletion_matrix'] = np.asarray(
chain_features.pop('deletion_matrix_int'), dtype=np.float32)
if 'deletion_matrix_int_all_seq' in chain_features:
chain_features['deletion_matrix_all_seq'] = np.asarray(
chain_features.pop('deletion_matrix_int_all_seq'),
dtype=np.float32)

chain_features['deletion_mean'] = np.mean(
chain_features['deletion_matrix'], axis=0)

if 'all_atom_positions' not in chain_features:
# Add all_atom_mask and dummy all_atom_positions based on aatype.
all_atom_mask = residue_constants.STANDARD_ATOM_MASK[
chain_features['aatype']]
chain_features['all_atom_mask'] = all_atom_mask
chain_features['all_atom_positions'] = np.zeros(
list(all_atom_mask.shape) + [3])

# Add assembly_num_chains.
chain_features['assembly_num_chains'] = np.asarray(num_chains)

# Add entity_mask.
for chain_features in all_chain_features:
chain_features['entity_mask'] = (
chain_features['entity_id'] != # noqa W504
0).astype(np.int32)


def empty_template_feats(n_res):
return {
'template_aatype':
np.zeros((0, n_res)).astype(np.int64),
'template_all_atom_positions':
np.zeros((0, n_res, 37, 3)).astype(np.float32),
'template_sum_probs':
np.zeros((0, 1)).astype(np.float32),
'template_all_atom_mask':
np.zeros((0, n_res, 37)).astype(np.float32),
}


def convert_monomer_features(monomer_features: FeatureDict) -> FeatureDict:
"""Reshapes and modifies monomer features for multimer models."""
if monomer_features['template_aatype'].shape[0] == 0:
monomer_features.update(
empty_template_feats(monomer_features['aatype'].shape[0]))
converted = {}
unnecessary_leading_dim_feats = {
'sequence',
'domain_name',
'num_alignments',
'seq_length',
}
for feature_name, feature in monomer_features.items():
if feature_name in unnecessary_leading_dim_feats:
# asarray ensures it's a np.ndarray.
feature = np.asarray(feature[0], dtype=feature.dtype)
elif feature_name == 'aatype':
# The multimer model performs the one-hot operation itself.
feature = np.argmax(feature, axis=-1).astype(np.int32)
elif feature_name == 'template_aatype':
if feature.shape[0] > 0:
feature = correct_template_restypes(feature)
elif feature_name == 'template_all_atom_masks':
feature_name = 'template_all_atom_mask'
elif feature_name == 'msa':
feature = feature.astype(np.uint8)

if feature_name.endswith('_mask'):
feature = feature.astype(np.float32)

converted[feature_name] = feature

if 'deletion_matrix_int' in monomer_features:
monomer_features['deletion_matrix'] = monomer_features.pop(
'deletion_matrix_int').astype(np.float32)

converted.pop(
'template_sum_probs'
) # zy: this input is checked to be dirty in shape. TODO: figure out why and make it right.
return converted


def int_id_to_str_id(num: int) -> str:
"""Encodes a number as a string, using reverse spreadsheet style naming.

Args:
num: A positive integer.

Returns:
A string that encodes the positive integer using reverse spreadsheet style,
naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the
usual way to encode chain IDs in mmCIF files.
"""
if num <= 0:
raise ValueError(f'Only positive integers allowed, got {num}.')

num = num - 1 # 1-based indexing.
output = []
while num >= 0:
output.append(chr(num % 26 + ord('A')))
num = num // 26 - 1
return ''.join(output)


def add_assembly_features(all_chain_features, ):
"""Add features to distinguish between chains.

Args:
all_chain_features: A dictionary which maps chain_id to a dictionary of
features for each chain.

Returns:
all_chain_features: A dictionary which maps strings of the form
`<seq_id>_<sym_id>` to the corresponding chain features. E.g. two
chains from a homodimer would have keys A_1 and A_2. Two chains from a
heterodimer would have keys A_1 and B_1.
"""
# Group the chains by sequence
seq_to_entity_id = {}
grouped_chains = collections.defaultdict(list)
for chain_features in all_chain_features:
assert 'sequence' in chain_features
seq = str(chain_features['sequence'])
if seq not in seq_to_entity_id:
seq_to_entity_id[seq] = len(seq_to_entity_id) + 1
grouped_chains[seq_to_entity_id[seq]].append(chain_features)

new_all_chain_features = []
chain_id = 1
for entity_id, group_chain_features in grouped_chains.items():
num_sym = len(group_chain_features) # zy
for sym_id, chain_features in enumerate(group_chain_features, start=1):
seq_length = chain_features['seq_length']
chain_features['asym_id'] = chain_id * np.ones(seq_length)
chain_features['sym_id'] = sym_id * np.ones(seq_length)
chain_features['entity_id'] = entity_id * np.ones(seq_length)
chain_features['num_sym'] = num_sym * np.ones(seq_length)
chain_id += 1
new_all_chain_features.append(chain_features)

return new_all_chain_features


def pad_msa(np_example, min_num_seq):
np_example = dict(np_example)
num_seq = np_example['msa'].shape[0]
if num_seq < min_num_seq:
for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask',
'msa_chains'):
np_example[feat] = np.pad(np_example[feat],
((0, min_num_seq - num_seq), (0, 0)))
np_example['cluster_bias_mask'] = np.pad(
np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq), ))
return np_example


def post_process(np_example):
np_example = pad_msa(np_example, 512)
no_dim_keys = [
'num_alignments',
'assembly_num_chains',
'num_templates',
'seq_length',
'resolution',
]
for k in no_dim_keys:
if k in np_example:
np_example[k] = np_example[k].reshape(-1)
return np_example


def merge_msas(msa, del_mat, new_msa, new_del_mat):
cur_msa_set = set([tuple(m) for m in msa])
new_rows = []
for i, s in enumerate(new_msa):
if tuple(s) not in cur_msa_set:
new_rows.append(i)
ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0)
ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0)
return ret_msa, ret_del_mat

+ 322
- 0
modelscope/models/science/unifold/data/protein.py View File

@@ -0,0 +1,322 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Protein data type."""
import dataclasses
import io
from typing import Any, Mapping, Optional

import numpy as np
from Bio.PDB import PDBParser

from modelscope.models.science.unifold.data import residue_constants

FeatureDict = Mapping[str, np.ndarray]
ModelOutput = Mapping[str, Any] # Is a nested dict.

# Complete sequence of chain IDs supported by the PDB format.
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.


@dataclasses.dataclass(frozen=True)
class Protein:
"""Protein structure representation."""

# Cartesian coordinates of atoms in angstroms. The atom types correspond to
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
atom_positions: np.ndarray # [num_res, num_atom_type, 3]

# Amino-acid type for each residue represented as an integer between 0 and
# 20, where 20 is 'X'.
aatype: np.ndarray # [num_res]

# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
# is present and 0.0 if not. This should be used for loss masking.
atom_mask: np.ndarray # [num_res, num_atom_type]

# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
residue_index: np.ndarray # [num_res]

# 0-indexed number corresponding to the chain in the protein that this residue
# belongs to.
chain_index: np.ndarray # [num_res]

# B-factors, or temperature factors, of each residue (in sq. angstroms units),
# representing the displacement of the residue from its ground truth mean
# value.
b_factors: np.ndarray # [num_res, num_atom_type]

def __post_init__(self):
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
raise ValueError(
f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
'because these cannot be written to PDB format.')


def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
"""Takes a PDB string and constructs a Protein object.

WARNING: All non-standard residue types will be converted into UNK. All
non-standard atoms will be ignored.

Args:
pdb_str: The contents of the pdb file
chain_id: If chain_id is specified (e.g. A), then only that chain
is parsed. Otherwise all chains are parsed.

Returns:
A new `Protein` parsed from the pdb contents.
"""
pdb_fh = io.StringIO(pdb_str)
parser = PDBParser(QUIET=True)
structure = parser.get_structure('none', pdb_fh)
models = list(structure.get_models())
if len(models) != 1:
raise ValueError(
f'Only single model PDBs are supported. Found {len(models)} models.'
)
model = models[0]

atom_positions = []
aatype = []
atom_mask = []
residue_index = []
chain_ids = []
b_factors = []

for chain in model:
if chain_id is not None and chain.id != chain_id:
continue
for res in chain:
if res.id[2] != ' ':
raise ValueError(
f'PDB contains an insertion code at chain {chain.id} and residue '
f'index {res.id[1]}. These are not supported.')
res_shortname = residue_constants.restype_3to1.get(
res.resname, 'X')
restype_idx = residue_constants.restype_order.get(
res_shortname, residue_constants.restype_num)
pos = np.zeros((residue_constants.atom_type_num, 3))
mask = np.zeros((residue_constants.atom_type_num, ))
res_b_factors = np.zeros((residue_constants.atom_type_num, ))
for atom in res:
if atom.name not in residue_constants.atom_types:
continue
pos[residue_constants.atom_order[atom.name]] = atom.coord
mask[residue_constants.atom_order[atom.name]] = 1.0
res_b_factors[residue_constants.atom_order[
atom.name]] = atom.bfactor
if np.sum(mask) < 0.5:
# If no known atom positions are reported for the residue then skip it.
continue
aatype.append(restype_idx)
atom_positions.append(pos)
atom_mask.append(mask)
residue_index.append(res.id[1])
chain_ids.append(chain.id)
b_factors.append(res_b_factors)

# Chain IDs are usually characters so map these to ints.
unique_chain_ids = np.unique(chain_ids)
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])

return Protein(
atom_positions=np.array(atom_positions),
atom_mask=np.array(atom_mask),
aatype=np.array(aatype),
residue_index=np.array(residue_index),
chain_index=chain_index,
b_factors=np.array(b_factors),
)


def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
f'{chain_name:>1}{residue_index:>4}')


def to_pdb(prot: Protein) -> str:
"""Converts a `Protein` instance to a PDB string.

Args:
prot: The protein to convert to PDB.

Returns:
PDB string.
"""
restypes = residue_constants.restypes + ['X']

# res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
def res_1to3(r):
return residue_constants.restype_1to3.get(restypes[r], 'UNK')

atom_types = residue_constants.atom_types

pdb_lines = []

atom_mask = prot.atom_mask
aatype = prot.aatype
atom_positions = prot.atom_positions
residue_index = prot.residue_index.astype(np.int32)
chain_index = prot.chain_index.astype(np.int32)
b_factors = prot.b_factors

if np.any(aatype > residue_constants.restype_num):
raise ValueError('Invalid aatypes.')

# Construct a mapping from chain integer indices to chain ID strings.
chain_ids = {}
for i in np.unique(chain_index): # np.unique gives sorted output.
if i >= PDB_MAX_CHAINS:
raise ValueError(
f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
chain_ids[i] = PDB_CHAIN_IDS[i]

pdb_lines.append('MODEL 1')
atom_index = 1
last_chain_index = chain_index[0]
# Add all atom sites.
for i in range(aatype.shape[0]):
# Close the previous chain if in a multichain PDB.
if last_chain_index != chain_index[i]:
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[i - 1]),
chain_ids[chain_index[i - 1]],
residue_index[i - 1],
))
last_chain_index = chain_index[i]
atom_index += 1 # Atom index increases at the TER symbol.

res_name_3 = res_1to3(aatype[i])
for atom_name, pos, mask, b_factor in zip(atom_types,
atom_positions[i],
atom_mask[i], b_factors[i]):
if mask < 0.5:
continue

record_type = 'ATOM'
name = atom_name if len(atom_name) == 4 else f' {atom_name}'
alt_loc = ''
insertion_code = ''
occupancy = 1.00
element = atom_name[
0] # Protein supports only C, N, O, S, this works.
charge = ''
# PDB is a columnar format, every space matters here!
atom_line = (
f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
f'{residue_index[i]:>4}{insertion_code:>1} '
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
f'{occupancy:>6.2f}{b_factor:>6.2f} '
f'{element:>2}{charge:>2}')
pdb_lines.append(atom_line)
atom_index += 1

# Close the final chain.
pdb_lines.append(
_chain_end(
atom_index,
res_1to3(aatype[-1]),
chain_ids[chain_index[-1]],
residue_index[-1],
))
pdb_lines.append('ENDMDL')
pdb_lines.append('END')

# Pad all lines to 80 characters.
pdb_lines = [line.ljust(80) for line in pdb_lines]
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.


def ideal_atom_mask(prot: Protein) -> np.ndarray:
"""Computes an ideal atom mask.

`Protein.atom_mask` typically is defined according to the atoms that are
reported in the PDB. This function computes a mask according to heavy atoms
that should be present in the given sequence of amino acids.

Args:
prot: `Protein` whose fields are `numpy.ndarray` objects.

Returns:
An ideal atom mask.
"""
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]


def from_prediction(features: FeatureDict,
result: ModelOutput,
b_factors: Optional[np.ndarray] = None) -> Protein:
"""Assembles a protein from a prediction.

Args:
features: Dictionary holding model inputs.
fold_output: Dictionary holding model outputs.
b_factors: (Optional) B-factors to use for the protein.

Returns:
A protein instance.
"""

if 'asym_id' in features:
chain_index = features['asym_id'] - 1
else:
chain_index = np.zeros_like((features['aatype']))

if b_factors is None:
b_factors = np.zeros_like(result['final_atom_mask'])

return Protein(
aatype=features['aatype'],
atom_positions=result['final_atom_positions'],
atom_mask=result['final_atom_mask'],
residue_index=features['residue_index'] + 1,
chain_index=chain_index,
b_factors=b_factors,
)


def from_feature(features: FeatureDict,
b_factors: Optional[np.ndarray] = None) -> Protein:
"""Assembles a standard pdb from input atom positions & mask.

Args:
features: Dictionary holding model inputs.
b_factors: (Optional) B-factors to use for the protein.

Returns:
A protein instance.
"""

if 'asym_id' in features:
chain_index = features['asym_id'] - 1
else:
chain_index = np.zeros_like((features['aatype']))

if b_factors is None:
b_factors = np.zeros_like(features['all_atom_mask'])

return Protein(
aatype=features['aatype'],
atom_positions=features['all_atom_positions'],
atom_mask=features['all_atom_mask'],
residue_index=features['residue_index'] + 1,
chain_index=chain_index,
b_factors=b_factors,
)

+ 1212
- 0
modelscope/models/science/unifold/data/residue_constants.py
File diff suppressed because it is too large
View File


+ 345
- 0
modelscope/models/science/unifold/data/stereo_chemical_props.txt View File

@@ -0,0 +1,345 @@
Bond Residue Mean StdDev
CA-CB ALA 1.520 0.021
N-CA ALA 1.459 0.020
CA-C ALA 1.525 0.026
C-O ALA 1.229 0.019
CA-CB ARG 1.535 0.022
CB-CG ARG 1.521 0.027
CG-CD ARG 1.515 0.025
CD-NE ARG 1.460 0.017
NE-CZ ARG 1.326 0.013
CZ-NH1 ARG 1.326 0.013
CZ-NH2 ARG 1.326 0.013
N-CA ARG 1.459 0.020
CA-C ARG 1.525 0.026
C-O ARG 1.229 0.019
CA-CB ASN 1.527 0.026
CB-CG ASN 1.506 0.023
CG-OD1 ASN 1.235 0.022
CG-ND2 ASN 1.324 0.025
N-CA ASN 1.459 0.020
CA-C ASN 1.525 0.026
C-O ASN 1.229 0.019
CA-CB ASP 1.535 0.022
CB-CG ASP 1.513 0.021
CG-OD1 ASP 1.249 0.023
CG-OD2 ASP 1.249 0.023
N-CA ASP 1.459 0.020
CA-C ASP 1.525 0.026
C-O ASP 1.229 0.019
CA-CB CYS 1.526 0.013
CB-SG CYS 1.812 0.016
N-CA CYS 1.459 0.020
CA-C CYS 1.525 0.026
C-O CYS 1.229 0.019
CA-CB GLU 1.535 0.022
CB-CG GLU 1.517 0.019
CG-CD GLU 1.515 0.015
CD-OE1 GLU 1.252 0.011
CD-OE2 GLU 1.252 0.011
N-CA GLU 1.459 0.020
CA-C GLU 1.525 0.026
C-O GLU 1.229 0.019
CA-CB GLN 1.535 0.022
CB-CG GLN 1.521 0.027
CG-CD GLN 1.506 0.023
CD-OE1 GLN 1.235 0.022
CD-NE2 GLN 1.324 0.025
N-CA GLN 1.459 0.020
CA-C GLN 1.525 0.026
C-O GLN 1.229 0.019
N-CA GLY 1.456 0.015
CA-C GLY 1.514 0.016
C-O GLY 1.232 0.016
CA-CB HIS 1.535 0.022
CB-CG HIS 1.492 0.016
CG-ND1 HIS 1.369 0.015
CG-CD2 HIS 1.353 0.017
ND1-CE1 HIS 1.343 0.025
CD2-NE2 HIS 1.415 0.021
CE1-NE2 HIS 1.322 0.023
N-CA HIS 1.459 0.020
CA-C HIS 1.525 0.026
C-O HIS 1.229 0.019
CA-CB ILE 1.544 0.023
CB-CG1 ILE 1.536 0.028
CB-CG2 ILE 1.524 0.031
CG1-CD1 ILE 1.500 0.069
N-CA ILE 1.459 0.020
CA-C ILE 1.525 0.026
C-O ILE 1.229 0.019
CA-CB LEU 1.533 0.023
CB-CG LEU 1.521 0.029
CG-CD1 LEU 1.514 0.037
CG-CD2 LEU 1.514 0.037
N-CA LEU 1.459 0.020
CA-C LEU 1.525 0.026
C-O LEU 1.229 0.019
CA-CB LYS 1.535 0.022
CB-CG LYS 1.521 0.027
CG-CD LYS 1.520 0.034
CD-CE LYS 1.508 0.025
CE-NZ LYS 1.486 0.025
N-CA LYS 1.459 0.020
CA-C LYS 1.525 0.026
C-O LYS 1.229 0.019
CA-CB MET 1.535 0.022
CB-CG MET 1.509 0.032
CG-SD MET 1.807 0.026
SD-CE MET 1.774 0.056
N-CA MET 1.459 0.020
CA-C MET 1.525 0.026
C-O MET 1.229 0.019
CA-CB PHE 1.535 0.022
CB-CG PHE 1.509 0.017
CG-CD1 PHE 1.383 0.015
CG-CD2 PHE 1.383 0.015
CD1-CE1 PHE 1.388 0.020
CD2-CE2 PHE 1.388 0.020
CE1-CZ PHE 1.369 0.019
CE2-CZ PHE 1.369 0.019
N-CA PHE 1.459 0.020
CA-C PHE 1.525 0.026
C-O PHE 1.229 0.019
CA-CB PRO 1.531 0.020
CB-CG PRO 1.495 0.050
CG-CD PRO 1.502 0.033
CD-N PRO 1.474 0.014
N-CA PRO 1.468 0.017
CA-C PRO 1.524 0.020
C-O PRO 1.228 0.020
CA-CB SER 1.525 0.015
CB-OG SER 1.418 0.013
N-CA SER 1.459 0.020
CA-C SER 1.525 0.026
C-O SER 1.229 0.019
CA-CB THR 1.529 0.026
CB-OG1 THR 1.428 0.020
CB-CG2 THR 1.519 0.033
N-CA THR 1.459 0.020
CA-C THR 1.525 0.026
C-O THR 1.229 0.019
CA-CB TRP 1.535 0.022
CB-CG TRP 1.498 0.018
CG-CD1 TRP 1.363 0.014
CG-CD2 TRP 1.432 0.017
CD1-NE1 TRP 1.375 0.017
NE1-CE2 TRP 1.371 0.013
CD2-CE2 TRP 1.409 0.012
CD2-CE3 TRP 1.399 0.015
CE2-CZ2 TRP 1.393 0.017
CE3-CZ3 TRP 1.380 0.017
CZ2-CH2 TRP 1.369 0.019
CZ3-CH2 TRP 1.396 0.016
N-CA TRP 1.459 0.020
CA-C TRP 1.525 0.026
C-O TRP 1.229 0.019
CA-CB TYR 1.535 0.022
CB-CG TYR 1.512 0.015
CG-CD1 TYR 1.387 0.013
CG-CD2 TYR 1.387 0.013
CD1-CE1 TYR 1.389 0.015
CD2-CE2 TYR 1.389 0.015
CE1-CZ TYR 1.381 0.013
CE2-CZ TYR 1.381 0.013
CZ-OH TYR 1.374 0.017
N-CA TYR 1.459 0.020
CA-C TYR 1.525 0.026
C-O TYR 1.229 0.019
CA-CB VAL 1.543 0.021
CB-CG1 VAL 1.524 0.021
CB-CG2 VAL 1.524 0.021
N-CA VAL 1.459 0.020
CA-C VAL 1.525 0.026
C-O VAL 1.229 0.019
-

Angle Residue Mean StdDev
N-CA-CB ALA 110.1 1.4
CB-CA-C ALA 110.1 1.5
N-CA-C ALA 111.0 2.7
CA-C-O ALA 120.1 2.1
N-CA-CB ARG 110.6 1.8
CB-CA-C ARG 110.4 2.0
CA-CB-CG ARG 113.4 2.2
CB-CG-CD ARG 111.6 2.6
CG-CD-NE ARG 111.8 2.1
CD-NE-CZ ARG 123.6 1.4
NE-CZ-NH1 ARG 120.3 0.5
NE-CZ-NH2 ARG 120.3 0.5
NH1-CZ-NH2 ARG 119.4 1.1
N-CA-C ARG 111.0 2.7
CA-C-O ARG 120.1 2.1
N-CA-CB ASN 110.6 1.8
CB-CA-C ASN 110.4 2.0
CA-CB-CG ASN 113.4 2.2
CB-CG-ND2 ASN 116.7 2.4
CB-CG-OD1 ASN 121.6 2.0
ND2-CG-OD1 ASN 121.9 2.3
N-CA-C ASN 111.0 2.7
CA-C-O ASN 120.1 2.1
N-CA-CB ASP 110.6 1.8
CB-CA-C ASP 110.4 2.0
CA-CB-CG ASP 113.4 2.2
CB-CG-OD1 ASP 118.3 0.9
CB-CG-OD2 ASP 118.3 0.9
OD1-CG-OD2 ASP 123.3 1.9
N-CA-C ASP 111.0 2.7
CA-C-O ASP 120.1 2.1
N-CA-CB CYS 110.8 1.5
CB-CA-C CYS 111.5 1.2
CA-CB-SG CYS 114.2 1.1
N-CA-C CYS 111.0 2.7
CA-C-O CYS 120.1 2.1
N-CA-CB GLU 110.6 1.8
CB-CA-C GLU 110.4 2.0
CA-CB-CG GLU 113.4 2.2
CB-CG-CD GLU 114.2 2.7
CG-CD-OE1 GLU 118.3 2.0
CG-CD-OE2 GLU 118.3 2.0
OE1-CD-OE2 GLU 123.3 1.2
N-CA-C GLU 111.0 2.7
CA-C-O GLU 120.1 2.1
N-CA-CB GLN 110.6 1.8
CB-CA-C GLN 110.4 2.0
CA-CB-CG GLN 113.4 2.2
CB-CG-CD GLN 111.6 2.6
CG-CD-OE1 GLN 121.6 2.0
CG-CD-NE2 GLN 116.7 2.4
OE1-CD-NE2 GLN 121.9 2.3
N-CA-C GLN 111.0 2.7
CA-C-O GLN 120.1 2.1
N-CA-C GLY 113.1 2.5
CA-C-O GLY 120.6 1.8
N-CA-CB HIS 110.6 1.8
CB-CA-C HIS 110.4 2.0
CA-CB-CG HIS 113.6 1.7
CB-CG-ND1 HIS 123.2 2.5
CB-CG-CD2 HIS 130.8 3.1
CG-ND1-CE1 HIS 108.2 1.4
ND1-CE1-NE2 HIS 109.9 2.2
CE1-NE2-CD2 HIS 106.6 2.5
NE2-CD2-CG HIS 109.2 1.9
CD2-CG-ND1 HIS 106.0 1.4
N-CA-C HIS 111.0 2.7
CA-C-O HIS 120.1 2.1
N-CA-CB ILE 110.8 2.3
CB-CA-C ILE 111.6 2.0
CA-CB-CG1 ILE 111.0 1.9
CB-CG1-CD1 ILE 113.9 2.8
CA-CB-CG2 ILE 110.9 2.0
CG1-CB-CG2 ILE 111.4 2.2
N-CA-C ILE 111.0 2.7
CA-C-O ILE 120.1 2.1
N-CA-CB LEU 110.4 2.0
CB-CA-C LEU 110.2 1.9
CA-CB-CG LEU 115.3 2.3
CB-CG-CD1 LEU 111.0 1.7
CB-CG-CD2 LEU 111.0 1.7
CD1-CG-CD2 LEU 110.5 3.0
N-CA-C LEU 111.0 2.7
CA-C-O LEU 120.1 2.1
N-CA-CB LYS 110.6 1.8
CB-CA-C LYS 110.4 2.0
CA-CB-CG LYS 113.4 2.2
CB-CG-CD LYS 111.6 2.6
CG-CD-CE LYS 111.9 3.0
CD-CE-NZ LYS 111.7 2.3
N-CA-C LYS 111.0 2.7
CA-C-O LYS 120.1 2.1
N-CA-CB MET 110.6 1.8
CB-CA-C MET 110.4 2.0
CA-CB-CG MET 113.3 1.7
CB-CG-SD MET 112.4 3.0
CG-SD-CE MET 100.2 1.6
N-CA-C MET 111.0 2.7
CA-C-O MET 120.1 2.1
N-CA-CB PHE 110.6 1.8
CB-CA-C PHE 110.4 2.0
CA-CB-CG PHE 113.9 2.4
CB-CG-CD1 PHE 120.8 0.7
CB-CG-CD2 PHE 120.8 0.7
CD1-CG-CD2 PHE 118.3 1.3
CG-CD1-CE1 PHE 120.8 1.1
CG-CD2-CE2 PHE 120.8 1.1
CD1-CE1-CZ PHE 120.1 1.2
CD2-CE2-CZ PHE 120.1 1.2
CE1-CZ-CE2 PHE 120.0 1.8
N-CA-C PHE 111.0 2.7
CA-C-O PHE 120.1 2.1
N-CA-CB PRO 103.3 1.2
CB-CA-C PRO 111.7 2.1
CA-CB-CG PRO 104.8 1.9
CB-CG-CD PRO 106.5 3.9
CG-CD-N PRO 103.2 1.5
CA-N-CD PRO 111.7 1.4
N-CA-C PRO 112.1 2.6
CA-C-O PRO 120.2 2.4
N-CA-CB SER 110.5 1.5
CB-CA-C SER 110.1 1.9
CA-CB-OG SER 111.2 2.7
N-CA-C SER 111.0 2.7
CA-C-O SER 120.1 2.1
N-CA-CB THR 110.3 1.9
CB-CA-C THR 111.6 2.7
CA-CB-OG1 THR 109.0 2.1
CA-CB-CG2 THR 112.4 1.4
OG1-CB-CG2 THR 110.0 2.3
N-CA-C THR 111.0 2.7
CA-C-O THR 120.1 2.1
N-CA-CB TRP 110.6 1.8
CB-CA-C TRP 110.4 2.0
CA-CB-CG TRP 113.7 1.9
CB-CG-CD1 TRP 127.0 1.3
CB-CG-CD2 TRP 126.6 1.3
CD1-CG-CD2 TRP 106.3 0.8
CG-CD1-NE1 TRP 110.1 1.0
CD1-NE1-CE2 TRP 109.0 0.9
NE1-CE2-CD2 TRP 107.3 1.0
CE2-CD2-CG TRP 107.3 0.8
CG-CD2-CE3 TRP 133.9 0.9
NE1-CE2-CZ2 TRP 130.4 1.1
CE3-CD2-CE2 TRP 118.7 1.2
CD2-CE2-CZ2 TRP 122.3 1.2
CE2-CZ2-CH2 TRP 117.4 1.0
CZ2-CH2-CZ3 TRP 121.6 1.2
CH2-CZ3-CE3 TRP 121.2 1.1
CZ3-CE3-CD2 TRP 118.8 1.3
N-CA-C TRP 111.0 2.7
CA-C-O TRP 120.1 2.1
N-CA-CB TYR 110.6 1.8
CB-CA-C TYR 110.4 2.0
CA-CB-CG TYR 113.4 1.9
CB-CG-CD1 TYR 121.0 0.6
CB-CG-CD2 TYR 121.0 0.6
CD1-CG-CD2 TYR 117.9 1.1
CG-CD1-CE1 TYR 121.3 0.8
CG-CD2-CE2 TYR 121.3 0.8
CD1-CE1-CZ TYR 119.8 0.9
CD2-CE2-CZ TYR 119.8 0.9
CE1-CZ-CE2 TYR 119.8 1.6
CE1-CZ-OH TYR 120.1 2.7
CE2-CZ-OH TYR 120.1 2.7
N-CA-C TYR 111.0 2.7
CA-C-O TYR 120.1 2.1
N-CA-CB VAL 111.5 2.2
CB-CA-C VAL 111.4 1.9
CA-CB-CG1 VAL 110.9 1.5
CA-CB-CG2 VAL 110.9 1.5
CG1-CB-CG2 VAL 110.9 1.6
N-CA-C VAL 111.0 2.7
CA-C-O VAL 120.1 2.1
-

Non-bonded distance Minimum Dist Tolerance
C-C 3.4 1.5
C-N 3.25 1.5
C-S 3.5 1.5
C-O 3.22 1.5
N-N 3.1 1.5
N-S 3.35 1.5
N-O 3.07 1.5
O-S 3.32 1.5
O-O 3.04 1.5
S-S 2.03 1.0
-

+ 161
- 0
modelscope/models/science/unifold/data/utils.py View File

@@ -0,0 +1,161 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import copy as copy_lib
import functools
import gzip
import pickle
from typing import Any, Dict

import json
import numpy as np
from scipy import sparse as sp

from . import residue_constants as rc
from .data_ops import NumpyDict

# from typing import *


def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False):
if deepcopy:

def decorator(f):
cached_func = functools.lru_cache(maxsize, typed)(f)

@functools.wraps(f)
def wrapper(*args, **kwargs):
return copy_lib.deepcopy(cached_func(*args, **kwargs))

return wrapper

elif copy:

def decorator(f):
cached_func = functools.lru_cache(maxsize, typed)(f)

@functools.wraps(f)
def wrapper(*args, **kwargs):
return copy_lib.copy(cached_func(*args, **kwargs))

return wrapper

else:
decorator = functools.lru_cache(maxsize, typed)
return decorator


@lru_cache(maxsize=8, deepcopy=True)
def load_pickle_safe(path: str) -> Dict[str, Any]:

def load(path):
assert path.endswith('.pkl') or path.endswith(
'.pkl.gz'), f'bad suffix in {path} as pickle file.'
open_fn = gzip.open if path.endswith('.gz') else open
with open_fn(path, 'rb') as f:
return pickle.load(f)

ret = load(path)
ret = uncompress_features(ret)
return ret


@lru_cache(maxsize=8, copy=True)
def load_pickle(path: str) -> Dict[str, Any]:

def load(path):
assert path.endswith('.pkl') or path.endswith(
'.pkl.gz'), f'bad suffix in {path} as pickle file.'
open_fn = gzip.open if path.endswith('.gz') else open
with open_fn(path, 'rb') as f:
return pickle.load(f)

ret = load(path)
ret = uncompress_features(ret)
return ret


def correct_template_restypes(feature):
"""Correct template restype to have the same order as residue_constants."""
feature = np.argmax(feature, axis=-1).astype(np.int32)
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
feature = np.take(new_order_list, feature.astype(np.int32), axis=0)
return feature


def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict:
feature['msa'] = feature['msa'].astype(np.uint8)
if 'num_alignments' in feature:
feature.pop('num_alignments')
# make_all_seq_key = lambda k: f'{k}_all_seq' if not k.endswith('_all_seq') else k

def make_all_seq_key(k):
if not k.endswith('_all_seq'):
return f'{k}_all_seq'
return k

return {make_all_seq_key(k): v for k, v in feature.items()}


def to_dense_matrix(spmat_dict: NumpyDict):
spmat = sp.coo_matrix(
(spmat_dict['data'], (spmat_dict['row'], spmat_dict['col'])),
shape=spmat_dict['shape'],
dtype=np.float32,
)
return spmat.toarray()


FEATS_DTYPE = {'msa': np.int32}


def uncompress_features(feats: NumpyDict) -> NumpyDict:
if 'sparse_deletion_matrix_int' in feats:
v = feats.pop('sparse_deletion_matrix_int')
v = to_dense_matrix(v)
feats['deletion_matrix'] = v
return feats


def filter(feature: NumpyDict, **kwargs) -> NumpyDict:
assert len(kwargs) == 1, f'wrong usage of filter with kwargs: {kwargs}'
if 'desired_keys' in kwargs:
feature = {
k: v
for k, v in feature.items() if k in kwargs['desired_keys']
}
elif 'required_keys' in kwargs:
for k in kwargs['required_keys']:
assert k in feature, f'cannot find required key {k}.'
elif 'ignored_keys' in kwargs:
feature = {
k: v
for k, v in feature.items() if k not in kwargs['ignored_keys']
}
else:
raise AssertionError(f'wrong usage of filter with kwargs: {kwargs}')
return feature


def compress_features(features: NumpyDict):
change_dtype = {
'msa': np.uint8,
}
sparse_keys = ['deletion_matrix_int']

compressed_features = {}
for k, v in features.items():
if k in change_dtype:
v = v.astype(change_dtype[k])
if k in sparse_keys:
v = sp.coo_matrix(v, dtype=v.dtype)
sp_v = {
'shape': v.shape,
'row': v.row,
'col': v.col,
'data': v.data
}
k = f'sparse_{k}'
v = sp_v
compressed_features[k] = v
return compressed_features

+ 514
- 0
modelscope/models/science/unifold/dataset.py View File

@@ -0,0 +1,514 @@
import copy
import logging
import os
# from typing import *
from typing import Dict, Iterable, List, Optional, Tuple, Union

import json
import ml_collections as mlc
import numpy as np
import torch
from unicore.data import UnicoreDataset, data_utils
from unicore.distributed import utils as distributed_utils

from .data import utils
from .data.data_ops import NumpyDict, TorchDict
from .data.process import process_features, process_labels
from .data.process_multimer import (add_assembly_features,
convert_monomer_features, merge_msas,
pair_and_merge, post_process)

Rotation = Iterable[Iterable]
Translation = Iterable
Operation = Union[str, Tuple[Rotation, Translation]]
NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]]
TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]]

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def make_data_config(
config: mlc.ConfigDict,
mode: str,
num_res: int,
) -> Tuple[mlc.ConfigDict, List[str]]:
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
if cfg.common.is_multimer:
feature_names += cfg.common.multimer_features
if cfg[mode].supervised:
feature_names += cfg.supervised.supervised_features

return cfg, feature_names


def process_label(all_atom_positions: np.ndarray,
operation: Operation) -> np.ndarray:
if operation == 'I':
return all_atom_positions
rot, trans = operation
rot = np.array(rot).reshape(3, 3)
trans = np.array(trans).reshape(3)
return all_atom_positions @ rot.T + trans


@utils.lru_cache(maxsize=8, copy=True)
def load_single_feature(
sequence_id: str,
monomer_feature_dir: str,
uniprot_msa_dir: Optional[str] = None,
is_monomer: bool = False,
) -> NumpyDict:

monomer_feature = utils.load_pickle(
os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz'))
monomer_feature = convert_monomer_features(monomer_feature)
chain_feature = {**monomer_feature}

if uniprot_msa_dir is not None:
all_seq_feature = utils.load_pickle(
os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz'))
if is_monomer:
chain_feature['msa'], chain_feature[
'deletion_matrix'] = merge_msas(
chain_feature['msa'],
chain_feature['deletion_matrix'],
all_seq_feature['msa'],
all_seq_feature['deletion_matrix'],
) # noqa
else:
all_seq_feature = utils.convert_all_seq_feature(all_seq_feature)
for key in [
'msa_all_seq',
'msa_species_identifiers_all_seq',
'deletion_matrix_all_seq',
]:
chain_feature[key] = all_seq_feature[key]

return chain_feature


def load_single_label(
label_id: str,
label_dir: str,
symmetry_operation: Optional[Operation] = None,
) -> NumpyDict:
label = utils.load_pickle(
os.path.join(label_dir, f'{label_id}.label.pkl.gz'))
if symmetry_operation is not None:
label['all_atom_positions'] = process_label(
label['all_atom_positions'], symmetry_operation)
label = {
k: v
for k, v in label.items() if k in
['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution']
}
return label


def load(
sequence_ids: List[str],
monomer_feature_dir: str,
uniprot_msa_dir: Optional[str] = None,
label_ids: Optional[List[str]] = None,
label_dir: Optional[str] = None,
symmetry_operations: Optional[List[Operation]] = None,
is_monomer: bool = False,
) -> NumpyExample:

all_chain_features = [
load_single_feature(s, monomer_feature_dir, uniprot_msa_dir,
is_monomer) for s in sequence_ids
]

if label_ids is not None:
# load labels
assert len(label_ids) == len(sequence_ids)
assert label_dir is not None
if symmetry_operations is None:
symmetry_operations = ['I' for _ in label_ids]
all_chain_labels = [
load_single_label(ll, label_dir, o)
for ll, o in zip(label_ids, symmetry_operations)
]
# update labels into features to calculate spatial cropping etc.
[f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)]

all_chain_features = add_assembly_features(all_chain_features)

# get labels back from features, as add_assembly_features may alter the order of inputs.
if label_ids is not None:
all_chain_labels = [{
k: f[k]
for k in
['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution']
} for f in all_chain_features]
else:
all_chain_labels = None

asym_len = np.array([c['seq_length'] for c in all_chain_features],
dtype=np.int64)
if is_monomer:
all_chain_features = all_chain_features[0]
else:
all_chain_features = pair_and_merge(all_chain_features)
all_chain_features = post_process(all_chain_features)
all_chain_features['asym_len'] = asym_len

return all_chain_features, all_chain_labels


def process(
config: mlc.ConfigDict,
mode: str,
features: NumpyDict,
labels: Optional[List[NumpyDict]] = None,
seed: int = 0,
batch_idx: Optional[int] = None,
data_idx: Optional[int] = None,
is_distillation: bool = False,
) -> TorchExample:

if mode == 'train':
assert batch_idx is not None
with data_utils.numpy_seed(seed, batch_idx, key='recycling'):
num_iters = np.random.randint(
0, config.common.max_recycling_iters + 1)
use_clamped_fape = np.random.rand(
) < config[mode].use_clamped_fape_prob
else:
num_iters = config.common.max_recycling_iters
use_clamped_fape = 1

features['num_recycling_iters'] = int(num_iters)
features['use_clamped_fape'] = int(use_clamped_fape)
features['is_distillation'] = int(is_distillation)
if is_distillation and 'msa_chains' in features:
features.pop('msa_chains')

num_res = int(features['seq_length'])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)

if labels is not None:
features['resolution'] = labels[0]['resolution'].reshape(-1)

with data_utils.numpy_seed(seed, data_idx, key='protein_feature'):
features['crop_and_fix_size_seed'] = np.random.randint(0, 63355)
features = utils.filter(features, desired_keys=feature_names)
features = {k: torch.tensor(v) for k, v in features.items()}
with torch.no_grad():
features = process_features(features, cfg.common, cfg[mode])

if labels is not None:
labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels]
with torch.no_grad():
labels = process_labels(labels)

return features, labels


def load_and_process(
config: mlc.ConfigDict,
mode: str,
seed: int = 0,
batch_idx: Optional[int] = None,
data_idx: Optional[int] = None,
is_distillation: bool = False,
**load_kwargs,
):
is_monomer = (
is_distillation
if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer'))
features, labels = load(**load_kwargs, is_monomer=is_monomer)
features, labels = process(config, mode, features, labels, seed, batch_idx,
data_idx, is_distillation)
return features, labels


class UnifoldDataset(UnicoreDataset):

def __init__(
self,
args,
seed,
config,
data_path,
mode='train',
max_step=None,
disable_sd=False,
json_prefix='',
):
self.path = data_path

def load_json(filename):
return json.load(open(filename, 'r'))

sample_weight = load_json(
os.path.join(self.path,
json_prefix + mode + '_sample_weight.json'))
self.multi_label = load_json(
os.path.join(self.path, json_prefix + mode + '_multi_label.json'))
self.inverse_multi_label = self._inverse_map(self.multi_label)
self.sample_weight = {}
for chain in self.inverse_multi_label:
entity = self.inverse_multi_label[chain]
self.sample_weight[chain] = sample_weight[entity]
self.seq_sample_weight = sample_weight
logger.info('load {} chains (unique {} sequences)'.format(
len(self.sample_weight), len(self.seq_sample_weight)))
self.feature_path = os.path.join(self.path, 'pdb_features')
self.label_path = os.path.join(self.path, 'pdb_labels')
sd_sample_weight_path = os.path.join(
self.path, json_prefix + 'sd_train_sample_weight.json')
if mode == 'train' and os.path.isfile(
sd_sample_weight_path) and not disable_sd:
self.sd_sample_weight = load_json(sd_sample_weight_path)
logger.info('load {} self-distillation samples.'.format(
len(self.sd_sample_weight)))
self.sd_feature_path = os.path.join(self.path, 'sd_features')
self.sd_label_path = os.path.join(self.path, 'sd_labels')
else:
self.sd_sample_weight = None
self.batch_size = (
args.batch_size * distributed_utils.get_data_parallel_world_size()
* args.update_freq[0])
self.data_len = (
max_step * self.batch_size
if max_step is not None else len(self.sample_weight))
self.mode = mode
self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight(
self.seq_sample_weight)
self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
self.sample_weight)
if self.sd_sample_weight is not None:
(
self.sd_num_chain,
self.sd_chain_keys,
self.sd_sample_prob,
) = self.cal_sample_weight(self.sd_sample_weight)
self.config = config.data
self.seed = seed
self.sd_prob = args.sd_prob

def cal_sample_weight(self, sample_weight):
prot_keys = list(sample_weight.keys())
sum_weight = sum(sample_weight.values())
sample_prob = [sample_weight[k] / sum_weight for k in prot_keys]
num_prot = len(prot_keys)
return num_prot, prot_keys, sample_prob

def sample_chain(self, idx, sample_by_seq=False):
is_distillation = False
if self.mode == 'train':
with data_utils.numpy_seed(self.seed, idx, key='data_sample'):
is_distillation = ((np.random.rand(1)[0] < self.sd_prob)
if self.sd_sample_weight is not None else
False)
if is_distillation:
prot_idx = np.random.choice(
self.sd_num_chain, p=self.sd_sample_prob)
label_name = self.sd_chain_keys[prot_idx]
seq_name = label_name
else:
if not sample_by_seq:
prot_idx = np.random.choice(
self.num_chain, p=self.sample_prob)
label_name = self.chain_keys[prot_idx]
seq_name = self.inverse_multi_label[label_name]
else:
seq_idx = np.random.choice(
self.num_seq, p=self.seq_sample_prob)
seq_name = self.seq_keys[seq_idx]
label_name = np.random.choice(
self.multi_label[seq_name])
else:
label_name = self.chain_keys[idx]
seq_name = self.inverse_multi_label[label_name]
return seq_name, label_name, is_distillation

def __getitem__(self, idx):
sequence_id, label_id, is_distillation = self.sample_chain(
idx, sample_by_seq=True)
feature_dir, label_dir = ((self.feature_path,
self.label_path) if not is_distillation else
(self.sd_feature_path, self.sd_label_path))
features, _ = load_and_process(
self.config,
self.mode,
self.seed,
batch_idx=(idx // self.batch_size),
data_idx=idx,
is_distillation=is_distillation,
sequence_ids=[sequence_id],
monomer_feature_dir=feature_dir,
uniprot_msa_dir=None,
label_ids=[label_id],
label_dir=label_dir,
symmetry_operations=None,
is_monomer=True,
)
return features

def __len__(self):
return self.data_len

@staticmethod
def collater(samples):
# first dim is recyling. bsz is at the 2nd dim
return data_utils.collate_dict(samples, dim=1)

@staticmethod
def _inverse_map(mapping: Dict[str, List[str]]):
inverse_mapping = {}
for ent, refs in mapping.items():
for ref in refs:
if ref in inverse_mapping: # duplicated ent for this ref.
ent_2 = inverse_mapping[ref]
assert (
ent == ent_2
), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.'
inverse_mapping[ref] = ent
return inverse_mapping


class UnifoldMultimerDataset(UnifoldDataset):

def __init__(
self,
args: mlc.ConfigDict,
seed: int,
config: mlc.ConfigDict,
data_path: str,
mode: str = 'train',
max_step: Optional[int] = None,
disable_sd: bool = False,
json_prefix: str = '',
**kwargs,
):
super().__init__(args, seed, config, data_path, mode, max_step,
disable_sd, json_prefix)
self.data_path = data_path
self.pdb_assembly = json.load(
open(
os.path.join(self.data_path,
json_prefix + 'pdb_assembly.json')))
self.pdb_chains = self.get_chains(self.inverse_multi_label)
self.monomer_feature_path = os.path.join(self.data_path,
'pdb_features')
self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots')
self.label_path = os.path.join(self.data_path, 'pdb_labels')
self.max_chains = args.max_chains
if self.mode == 'train':
self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains(
self.pdb_chains, self.pdb_assembly, self.sample_weight,
self.max_chains)
self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight(
self.sample_weight)

def __getitem__(self, idx):
seq_id, label_id, is_distillation = self.sample_chain(idx)
if is_distillation:
label_ids = [label_id]
sequence_ids = [seq_id]
monomer_feature_path, uniprot_msa_path, label_path = (
self.sd_feature_path,
None,
self.sd_label_path,
)
symmetry_operations = None
else:
pdb_id = self.get_pdb_name(label_id)
if pdb_id in self.pdb_assembly and self.mode == 'train':
label_ids = [
pdb_id + '_' + id
for id in self.pdb_assembly[pdb_id]['chains']
]
symmetry_operations = [
t for t in self.pdb_assembly[pdb_id]['opers']
]
else:
label_ids = self.pdb_chains[pdb_id]
symmetry_operations = None
sequence_ids = [
self.inverse_multi_label[chain_id] for chain_id in label_ids
]
monomer_feature_path, uniprot_msa_path, label_path = (
self.monomer_feature_path,
self.uniprot_msa_path,
self.label_path,
)

return load_and_process(
self.config,
self.mode,
self.seed,
batch_idx=(idx // self.batch_size),
data_idx=idx,
is_distillation=is_distillation,
sequence_ids=sequence_ids,
monomer_feature_dir=monomer_feature_path,
uniprot_msa_dir=uniprot_msa_path,
label_ids=label_ids,
label_dir=label_path,
symmetry_operations=symmetry_operations,
is_monomer=False,
)

@staticmethod
def collater(samples):
# first dim is recyling. bsz is at the 2nd dim
if len(samples) <= 0: # tackle empty batch
return None
feats = [s[0] for s in samples]
labs = [s[1] for s in samples if s[1] is not None]
try:
feats = data_utils.collate_dict(feats, dim=1)
except BaseException:
raise ValueError('cannot collate features', feats)
if not labs:
labs = None
return feats, labs

@staticmethod
def get_pdb_name(chain):
return chain.split('_')[0]

@staticmethod
def get_chains(canon_chain_map):
pdb_chains = {}
for chain in canon_chain_map:
pdb = UnifoldMultimerDataset.get_pdb_name(chain)
if pdb not in pdb_chains:
pdb_chains[pdb] = []
pdb_chains[pdb].append(chain)
return pdb_chains

@staticmethod
def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight,
max_chains):
new_pdb_chains = {}
for chain in pdb_chains:
if chain in pdb_assembly:
size = len(pdb_assembly[chain]['chains'])
if size <= max_chains:
new_pdb_chains[chain] = pdb_chains[chain]
else:
size = len(pdb_chains[chain])
if size == 1:
new_pdb_chains[chain] = pdb_chains[chain]
new_sample_weight = {
k: sample_weight[k]
for k in sample_weight
if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains
}
logger.info(
f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs '
f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) '
f'by max_chains {max_chains}')
return new_pdb_chains, new_sample_weight

+ 75
- 0
modelscope/models/science/unifold/model.py View File

@@ -0,0 +1,75 @@
import argparse
import os
from typing import Any

import torch

from modelscope.metainfo import Models
from modelscope.models import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from .config import model_config
from .modules.alphafold import AlphaFold

__all__ = ['UnifoldForProteinStructrue']


@MODELS.register_module(Tasks.protein_structure, module_name=Models.unifold)
class UnifoldForProteinStructrue(TorchModel):

@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument(
'--model-name',
help='choose the model config',
)

def __init__(self, **kwargs):
super().__init__()
parser = argparse.ArgumentParser()
parse_comm = []
for key in kwargs:
parser.add_argument(f'--{key}')
parse_comm.append(f'--{key}')
parse_comm.append(kwargs[key])
args = parser.parse_args(parse_comm)
base_architecture(args)
self.args = args
config = model_config(
self.args.model_name,
train=True,
)
self.model = AlphaFold(config)
self.config = config

# load model state dict
param_path = os.path.join(kwargs['model_dir'],
ModelFile.TORCH_MODEL_BIN_FILE)
state_dict = torch.load(param_path)['ema']['params']
state_dict = {
'.'.join(k.split('.')[1:]): v
for k, v in state_dict.items()
}
self.model.load_state_dict(state_dict)

def half(self):
self.model = self.model.half()
return self

def bfloat16(self):
self.model = self.model.bfloat16()
return self

@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
return cls(args)

def forward(self, batch, **kwargs):
outputs = self.model.forward(batch)
return outputs, self.config.loss


def base_architecture(args):
args.model_name = getattr(args, 'model_name', 'model_2')

+ 450
- 0
modelscope/models/science/unifold/modules/alphafold.py View File

@@ -0,0 +1,450 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import torch
import torch.nn as nn
from unicore.utils import tensor_tree_map

from ..data import residue_constants
from .attentions import gen_msa_attn_mask, gen_tri_attn_mask
from .auxillary_heads import AuxiliaryHeads
from .common import residual
from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplatePairEmbedder)
from .evoformer import EvoformerStack, ExtraMSAStack
from .featurization import (atom14_to_atom37, build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
build_template_pair_feat_v2, pseudo_beta_fn)
from .structure_module import StructureModule
from .template import (TemplatePairStack, TemplatePointwiseAttention,
TemplateProjection)


class AlphaFold(nn.Module):

def __init__(self, config):
super(AlphaFold, self).__init__()

self.globals = config.globals
config = config.model
template_config = config.template
extra_msa_config = config.extra_msa

self.input_embedder = InputEmbedder(
**config['input_embedder'],
use_chain_relative=config.is_multimer,
)
self.recycling_embedder = RecyclingEmbedder(
**config['recycling_embedder'], )
if config.template.enabled:
self.template_angle_embedder = TemplateAngleEmbedder(
**template_config['template_angle_embedder'], )
self.template_pair_embedder = TemplatePairEmbedder(
**template_config['template_pair_embedder'], )
self.template_pair_stack = TemplatePairStack(
**template_config['template_pair_stack'], )
else:
self.template_pair_stack = None
self.enable_template_pointwise_attention = template_config[
'template_pointwise_attention'].enabled
if self.enable_template_pointwise_attention:
self.template_pointwise_att = TemplatePointwiseAttention(
**template_config['template_pointwise_attention'], )
else:
self.template_proj = TemplateProjection(
**template_config['template_pointwise_attention'], )
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config['extra_msa_embedder'], )
self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config['extra_msa_stack'], )
self.evoformer = EvoformerStack(**config['evoformer_stack'], )
self.structure_module = StructureModule(**config['structure_module'], )

self.aux_heads = AuxiliaryHeads(config['heads'], )

self.config = config
self.dtype = torch.float
self.inf = self.globals.inf
if self.globals.alphafold_original_mode:
self.alphafold_original_mode()

def __make_input_float__(self):
self.input_embedder = self.input_embedder.float()
self.recycling_embedder = self.recycling_embedder.float()

def half(self):
super().half()
if (not getattr(self, 'inference', False)):
self.__make_input_float__()
self.dtype = torch.half
return self

def bfloat16(self):
super().bfloat16()
if (not getattr(self, 'inference', False)):
self.__make_input_float__()
self.dtype = torch.bfloat16
return self

def alphafold_original_mode(self):

def set_alphafold_original_mode(module):
if hasattr(module, 'apply_alphafold_original_mode'):
module.apply_alphafold_original_mode()
if hasattr(module, 'act'):
module.act = nn.ReLU()

self.apply(set_alphafold_original_mode)

def inference_mode(self):

def set_inference_mode(module):
setattr(module, 'inference', True)

self.apply(set_inference_mode)

def __convert_input_dtype__(self, batch):
for key in batch:
# only convert features with mask
if batch[key].dtype != self.dtype and 'mask' in key:
batch[key] = batch[key].type(self.dtype)
return batch

def embed_templates_pair_core(self, batch, z, pair_mask,
tri_start_attn_mask, tri_end_attn_mask,
templ_dim, multichain_mask_2d):
if self.config.template.template_pair_embedder.v2_feature:
t = build_template_pair_feat_v2(
batch,
inf=self.config.template.inf,
eps=self.config.template.eps,
multichain_mask_2d=multichain_mask_2d,
**self.config.template.distogram,
)
num_template = t[0].shape[-4]
single_templates = [
self.template_pair_embedder([x[..., ti, :, :, :]
for x in t], z)
for ti in range(num_template)
]
else:
t = build_template_pair_feat(
batch,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
)
single_templates = [
self.template_pair_embedder(x, z)
for x in torch.unbind(t, dim=templ_dim)
]

t = self.template_pair_stack(
single_templates,
pair_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
templ_dim=templ_dim,
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
return_mean=not self.enable_template_pointwise_attention,
)
return t

def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask,
tri_end_attn_mask, templ_dim):
if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch:
multichain_mask_2d = (
batch['asym_id'][..., :, None] == batch['asym_id'][...,
None, :])
multichain_mask_2d = multichain_mask_2d.unsqueeze(0)
else:
multichain_mask_2d = None

if self.training or self.enable_template_pointwise_attention:
t = self.embed_templates_pair_core(batch, z, pair_mask,
tri_start_attn_mask,
tri_end_attn_mask, templ_dim,
multichain_mask_2d)
if self.enable_template_pointwise_attention:
t = self.template_pointwise_att(
t,
z,
template_mask=batch['template_mask'],
chunk_size=self.globals.chunk_size,
)
t_mask = torch.sum(
batch['template_mask'], dim=-1, keepdims=True) > 0
t_mask = t_mask[..., None, None].type(t.dtype)
t *= t_mask
else:
t = self.template_proj(t, z)
else:
template_aatype_shape = batch['template_aatype'].shape
# template_aatype is either [n_template, n_res] or [1, n_template_, n_res]
batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0
n_templ = batch['template_aatype'].shape[batch_templ_dim]

if n_templ <= 0:
t = None
else:
template_batch = {
k: v
for k, v in batch.items() if k.startswith('template_')
}

def embed_one_template(i):

def slice_template_tensor(t):
s = [slice(None) for _ in t.shape]
s[batch_templ_dim] = slice(i, i + 1)
return t[s]

template_feats = tensor_tree_map(
slice_template_tensor,
template_batch,
)
t = self.embed_templates_pair_core(
template_feats, z, pair_mask, tri_start_attn_mask,
tri_end_attn_mask, templ_dim, multichain_mask_2d)
return t

t = embed_one_template(0)
# iterate templates one by one
for i in range(1, n_templ):
t += embed_one_template(i)
t /= n_templ
t = self.template_proj(t, z)
return t

def embed_templates_angle(self, batch):
template_angle_feat, template_angle_mask = build_template_angle_feat(
batch,
v2_feature=self.config.template.template_pair_embedder.v2_feature)
t = self.template_angle_embedder(template_angle_feat)
return t, template_angle_mask

def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev):
batch_dims = feats['target_feat'].shape[:-2]
n = feats['target_feat'].shape[-2]
seq_mask = feats['seq_mask']
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats['msa_mask']

m, z = self.input_embedder(
feats['target_feat'],
feats['msa_feat'],
)

if m_1_prev is None:
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.d_msa),
requires_grad=False,
)
if z_prev is None:
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.d_pair),
requires_grad=False,
)
if x_prev is None:
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None)

z += self.recycling_embedder.recyle_pos(x_prev)

m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
)

m[..., 0, :, :] += m_1_prev_emb

z += z_prev_emb

z += self.input_embedder.relpos_emb(
feats['residue_index'].long(),
feats.get('sym_id', None),
feats.get('asym_id', None),
feats.get('entity_id', None),
feats.get('num_sym', None),
)

m = m.type(self.dtype)
z = z.type(self.dtype)
tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask(
pair_mask, self.inf)

if self.config.template.enabled:
template_mask = feats['template_mask']
if torch.any(template_mask):
z = residual(
z,
self.embed_templates_pair(
feats,
z,
pair_mask,
tri_start_attn_mask,
tri_end_attn_mask,
templ_dim=-4,
),
self.training,
)

if self.config.extra_msa.enabled:
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
extra_msa_row_mask = gen_msa_attn_mask(
feats['extra_msa_mask'],
inf=self.inf,
gen_col_mask=False,
)
z = self.extra_msa_stack(
a,
z,
msa_mask=feats['extra_msa_mask'],
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
pair_mask=pair_mask,
msa_row_attn_mask=extra_msa_row_mask,
msa_col_attn_mask=None,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
)

if self.config.template.embed_angles:
template_1d_feat, template_1d_mask = self.embed_templates_angle(
feats)
m = torch.cat([m, template_1d_feat], dim=-3)
msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2)

msa_row_mask, msa_col_mask = gen_msa_attn_mask(
msa_mask,
inf=self.inf,
)

m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_row_attn_mask=msa_row_mask,
msa_col_attn_mask=msa_col_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=self.globals.chunk_size,
block_size=self.globals.block_size,
)
return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb

def iteration_evoformer_structure_module(self,
batch,
m_1_prev,
z_prev,
x_prev,
cycle_no,
num_recycling,
num_ensembles=1):
z, s = 0, 0
n_seq = batch['msa_feat'].shape[-3]
assert num_ensembles >= 1
for ensemble_no in range(num_ensembles):
idx = cycle_no * num_ensembles + ensemble_no

# fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...]
def fetch_cur_batch(t):
return t[min(t.shape[0] - 1, idx), ...]

feats = tensor_tree_map(fetch_cur_batch, batch)
m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer(
feats, m_1_prev, z_prev, x_prev)
z += z0
s += s0
del z0, s0
if num_ensembles > 1:
z /= float(num_ensembles)
s /= float(num_ensembles)

outputs = {}

outputs['msa'] = m[..., :n_seq, :, :]
outputs['pair'] = z
outputs['single'] = s

# norm loss
if (not getattr(self, 'inference',
False)) and num_recycling == (cycle_no + 1):
delta_msa = m
delta_msa[...,
0, :, :] = delta_msa[...,
0, :, :] - m_1_prev_emb.detach()
delta_pair = z - z_prev_emb.detach()
outputs['delta_msa'] = delta_msa
outputs['delta_pair'] = delta_pair
outputs['msa_norm_mask'] = msa_mask

outputs['sm'] = self.structure_module(
s,
z,
feats['aatype'],
mask=feats['seq_mask'],
)
outputs['final_atom_positions'] = atom14_to_atom37(
outputs['sm']['positions'], feats)
outputs['final_atom_mask'] = feats['atom37_atom_exists']
outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1]

# use float32 for numerical stability
if (not getattr(self, 'inference', False)):
m_1_prev = m[..., 0, :, :].float()
z_prev = z.float()
x_prev = outputs['final_atom_positions'].float()
else:
m_1_prev = m[..., 0, :, :]
z_prev = z
x_prev = outputs['final_atom_positions']

return outputs, m_1_prev, z_prev, x_prev

def forward(self, batch):

m_1_prev = batch.get('m_1_prev', None)
z_prev = batch.get('z_prev', None)
x_prev = batch.get('x_prev', None)

is_grad_enabled = torch.is_grad_enabled()

num_iters = int(batch['num_recycling_iters']) + 1
num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters
if self.training:
# don't use ensemble during training
assert num_ensembles == 1

# convert dtypes in batch
batch = self.__convert_input_dtype__(batch)
for cycle_no in range(num_iters):
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
(
outputs,
m_1_prev,
z_prev,
x_prev,
) = self.iteration_evoformer_structure_module(
batch,
m_1_prev,
z_prev,
x_prev,
cycle_no=cycle_no,
num_recycling=num_iters,
num_ensembles=num_ensembles,
)
if not is_final_iter:
del outputs

if 'asym_id' in batch:
outputs['asym_id'] = batch['asym_id'][0, ...]
outputs.update(self.aux_heads(outputs))
return outputs

+ 430
- 0
modelscope/models/science/unifold/modules/attentions.py View File

@@ -0,0 +1,430 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from functools import partialmethod
from typing import List, Optional

import torch
import torch.nn as nn
from unicore.modules import LayerNorm, softmax_dropout
from unicore.utils import permute_final_dims

from .common import Linear, chunk_layer


def gen_attn_mask(mask, neg_inf):
assert neg_inf < -1e4
attn_mask = torch.zeros_like(mask)
attn_mask[mask == 0] = neg_inf
return attn_mask


class Attention(nn.Module):

def __init__(
self,
q_dim: int,
k_dim: int,
v_dim: int,
head_dim: int,
num_heads: int,
gating: bool = True,
):
super(Attention, self).__init__()

self.num_heads = num_heads
total_dim = head_dim * self.num_heads
self.gating = gating
self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot')
self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot')
self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot')
self.linear_o = Linear(total_dim, q_dim, init='final')
self.linear_g = None
if self.gating:
self.linear_g = Linear(q_dim, total_dim, init='gating')
# precompute the 1/sqrt(head_dim)
self.norm = head_dim**-0.5

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.Tensor = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
g = None
if self.linear_g is not None:
# gating, use raw query input
g = self.linear_g(q)

q = self.linear_q(q)
q *= self.norm
k = self.linear_k(k)
v = self.linear_v(v)

q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose(
-2, -3).contiguous()
k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose(
-2, -3).contiguous()
v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3)

attn = torch.matmul(q, k.transpose(-1, -2))
del q, k

attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias)
o = torch.matmul(attn, v)
del attn, v

o = o.transpose(-2, -3).contiguous()
o = o.view(*o.shape[:-2], -1)

if g is not None:
o = torch.sigmoid(g) * o

# merge heads
o = nn.functional.linear(o, self.linear_o.weight)
return o

def get_output_bias(self):
return self.linear_o.bias


class GlobalAttention(nn.Module):

def __init__(self, input_dim, head_dim, num_heads, inf, eps):
super(GlobalAttention, self).__init__()

self.num_heads = num_heads
self.inf = inf
self.eps = eps
self.linear_q = Linear(
input_dim, head_dim * num_heads, bias=False, init='glorot')
self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot')
self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot')
self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating')
self.linear_o = Linear(head_dim * num_heads, input_dim, init='final')
self.sigmoid = nn.Sigmoid()
# precompute the 1/sqrt(head_dim)
self.norm = head_dim**-0.5

def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:

# gating
g = self.sigmoid(self.linear_g(x))

k = self.linear_k(x)
v = self.linear_v(x)

q = torch.sum(
x * mask.unsqueeze(-1), dim=-2) / (
torch.sum(mask, dim=-1, keepdims=True) + self.eps)
q = self.linear_q(q)
q *= self.norm
q = q.view(q.shape[:-1] + (self.num_heads, -1))

attn = torch.matmul(q, k.transpose(-1, -2))
del q, k

attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :]
attn = softmax_dropout(attn, 0, self.training, mask=attn_mask)

o = torch.matmul(
attn,
v,
)
del attn, v

g = g.view(g.shape[:-1] + (self.num_heads, -1))
o = o.unsqueeze(-3) * g
del g

# merge heads
o = o.reshape(o.shape[:-2] + (-1, ))
return self.linear_o(o)


def gen_msa_attn_mask(mask, inf, gen_col_mask=True):
row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
if gen_col_mask:
col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None,
None, :]
return row_mask, col_mask
else:
return row_mask


class MSAAttention(nn.Module):

def __init__(
self,
d_in,
d_hid,
num_heads,
pair_bias=False,
d_pair=None,
):
super(MSAAttention, self).__init__()

self.pair_bias = pair_bias
self.layer_norm_m = LayerNorm(d_in)
self.layer_norm_z = None
self.linear_z = None
if self.pair_bias:
self.layer_norm_z = LayerNorm(d_pair)
self.linear_z = Linear(
d_pair, num_heads, bias=False, init='normal')

self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)

@torch.jit.ignore
def _chunk(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
chunk_size: int = None,
) -> torch.Tensor:

return chunk_layer(
self._attn_forward,
{
'm': m,
'mask': mask,
'bias': bias
},
chunk_size=chunk_size,
num_batch_dims=len(m.shape[:-2]),
)

@torch.jit.ignore
def _attn_chunk_forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = 2560,
) -> torch.Tensor:
m = self.layer_norm_m(m)
num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size
outputs = []
for i in range(num_chunk):
chunk_start = i * chunk_size
chunk_end = min(m.shape[-3], chunk_start + chunk_size)
cur_m = m[..., chunk_start:chunk_end, :, :]
cur_mask = (
mask[..., chunk_start:chunk_end, :, :, :]
if mask is not None else None)
outputs.append(
self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias))
return torch.concat(outputs, dim=-3)

def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None):
m = self.layer_norm_m(m)
return self.mha(q=m, k=m, v=m, mask=mask, bias=bias)

def forward(
self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:

bias = None
if self.pair_bias:
z = self.layer_norm_z(z)
bias = (
permute_final_dims(self.linear_z(z),
(2, 0, 1)).unsqueeze(-4).contiguous())

if chunk_size is not None:
m = self._chunk(m, attn_mask, bias, chunk_size)
else:
attn_chunk_size = 2560
if m.shape[-3] <= attn_chunk_size:
m = self._attn_forward(m, attn_mask, bias)
else:
# reduce the peak memory cost in extra_msa_stack
return self._attn_chunk_forward(
m, attn_mask, bias, chunk_size=attn_chunk_size)

return m

def get_output_bias(self):
return self.mha.get_output_bias()


class MSARowAttentionWithPairBias(MSAAttention):

def __init__(self, d_msa, d_pair, d_hid, num_heads):
super(MSARowAttentionWithPairBias, self).__init__(
d_msa,
d_hid,
num_heads,
pair_bias=True,
d_pair=d_pair,
)


class MSAColumnAttention(MSAAttention):

def __init__(self, d_msa, d_hid, num_heads):
super(MSAColumnAttention, self).__init__(
d_in=d_msa,
d_hid=d_hid,
num_heads=num_heads,
pair_bias=False,
d_pair=None,
)

def forward(
self,
m: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
m = m.transpose(-2, -3)
m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size)
m = m.transpose(-2, -3)

return m


class MSAColumnGlobalAttention(nn.Module):

def __init__(
self,
d_in,
d_hid,
num_heads,
inf=1e9,
eps=1e-10,
):
super(MSAColumnGlobalAttention, self).__init__()

self.layer_norm_m = LayerNorm(d_in)
self.global_attention = GlobalAttention(
d_in,
d_hid,
num_heads,
inf=inf,
eps=eps,
)

@torch.jit.ignore
def _chunk(
self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._attn_forward,
{
'm': m,
'mask': mask
},
chunk_size=chunk_size,
num_batch_dims=len(m.shape[:-2]),
)

def _attn_forward(self, m, mask):
m = self.layer_norm_m(m)
return self.global_attention(m, mask=mask)

def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:

m = m.transpose(-2, -3)
mask = mask.transpose(-1, -2)

if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
else:
m = self._attn_forward(m, mask=mask)

m = m.transpose(-2, -3)
return m


def gen_tri_attn_mask(mask, inf):
start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :]
end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None,
None, :]
return start_mask, end_mask


class TriangleAttention(nn.Module):

def __init__(
self,
d_in,
d_hid,
num_heads,
starting,
):
super(TriangleAttention, self).__init__()
self.starting = starting
self.layer_norm = LayerNorm(d_in)
self.linear = Linear(d_in, num_heads, bias=False, init='normal')
self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads)

@torch.jit.ignore
def _chunk(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
chunk_size: int = None,
) -> torch.Tensor:
return chunk_layer(
self.mha,
{
'q': x,
'k': x,
'v': x,
'mask': mask,
'bias': bias
},
chunk_size=chunk_size,
num_batch_dims=len(x.shape[:-2]),
)

def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
if not self.starting:
x = x.transpose(-2, -3)

x = self.layer_norm(x)
triangle_bias = (
permute_final_dims(self.linear(x),
(2, 0, 1)).unsqueeze(-4).contiguous())

if chunk_size is not None:
x = self._chunk(x, attn_mask, triangle_bias, chunk_size)
else:
x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias)

if not self.starting:
x = x.transpose(-2, -3)
return x

def get_output_bias(self):
return self.mha.get_output_bias()


class TriangleAttentionStarting(TriangleAttention):
__init__ = partialmethod(TriangleAttention.__init__, starting=True)


class TriangleAttentionEnding(TriangleAttention):
__init__ = partialmethod(TriangleAttention.__init__, starting=False)

+ 171
- 0
modelscope/models/science/unifold/modules/auxillary_heads.py View File

@@ -0,0 +1,171 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from typing import Dict

import torch.nn as nn
from unicore.modules import LayerNorm

from .common import Linear
from .confidence import (predicted_aligned_error, predicted_lddt,
predicted_tm_score)


class AuxiliaryHeads(nn.Module):

def __init__(self, config):
super(AuxiliaryHeads, self).__init__()

self.plddt = PredictedLDDTHead(**config['plddt'], )

self.distogram = DistogramHead(**config['distogram'], )

self.masked_msa = MaskedMSAHead(**config['masked_msa'], )

if config.experimentally_resolved.enabled:
self.experimentally_resolved = ExperimentallyResolvedHead(
**config['experimentally_resolved'], )

if config.pae.enabled:
self.pae = PredictedAlignedErrorHead(**config.pae, )

self.config = config

def forward(self, outputs):
aux_out = {}
plddt_logits = self.plddt(outputs['sm']['single'])
aux_out['plddt_logits'] = plddt_logits

aux_out['plddt'] = predicted_lddt(plddt_logits.detach())

distogram_logits = self.distogram(outputs['pair'])
aux_out['distogram_logits'] = distogram_logits

masked_msa_logits = self.masked_msa(outputs['msa'])
aux_out['masked_msa_logits'] = masked_msa_logits

if self.config.experimentally_resolved.enabled:
exp_res_logits = self.experimentally_resolved(outputs['single'])
aux_out['experimentally_resolved_logits'] = exp_res_logits

if self.config.pae.enabled:
pae_logits = self.pae(outputs['pair'])
aux_out['pae_logits'] = pae_logits
pae_logits = pae_logits.detach()
aux_out.update(
predicted_aligned_error(
pae_logits,
**self.config.pae,
))
aux_out['ptm'] = predicted_tm_score(
pae_logits, interface=False, **self.config.pae)

iptm_weight = self.config.pae.get('iptm_weight', 0.0)
if iptm_weight > 0.0:
aux_out['iptm'] = predicted_tm_score(
pae_logits,
interface=True,
asym_id=outputs['asym_id'],
**self.config.pae,
)
aux_out['iptm+ptm'] = (
iptm_weight * aux_out['iptm'] + # noqa W504
(1.0 - iptm_weight) * aux_out['ptm'])

return aux_out


class PredictedLDDTHead(nn.Module):

def __init__(self, num_bins, d_in, d_hid):
super(PredictedLDDTHead, self).__init__()

self.num_bins = num_bins
self.d_in = d_in
self.d_hid = d_hid

self.layer_norm = LayerNorm(self.d_in)

self.linear_1 = Linear(self.d_in, self.d_hid, init='relu')
self.linear_2 = Linear(self.d_hid, self.d_hid, init='relu')
self.act = nn.GELU()
self.linear_3 = Linear(self.d_hid, self.num_bins, init='final')

def forward(self, s):
s = self.layer_norm(s)
s = self.linear_1(s)
s = self.act(s)
s = self.linear_2(s)
s = self.act(s)
s = self.linear_3(s)
return s


class EnhancedHeadBase(nn.Module):

def __init__(self, d_in, d_out, disable_enhance_head):
super(EnhancedHeadBase, self).__init__()
if disable_enhance_head:
self.layer_norm = None
self.linear_in = None
else:
self.layer_norm = LayerNorm(d_in)
self.linear_in = Linear(d_in, d_in, init='relu')
self.act = nn.GELU()
self.linear = Linear(d_in, d_out, init='final')

def apply_alphafold_original_mode(self):
self.layer_norm = None
self.linear_in = None

def forward(self, x):
if self.layer_norm is not None:
x = self.layer_norm(x)
x = self.act(self.linear_in(x))
logits = self.linear(x)
return logits


class DistogramHead(EnhancedHeadBase):

def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs):
super(DistogramHead, self).__init__(
d_in=d_pair,
d_out=num_bins,
disable_enhance_head=disable_enhance_head,
)

def forward(self, x):
logits = super().forward(x)
logits = logits + logits.transpose(-2, -3)
return logits


class PredictedAlignedErrorHead(EnhancedHeadBase):

def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs):
super(PredictedAlignedErrorHead, self).__init__(
d_in=d_pair,
d_out=num_bins,
disable_enhance_head=disable_enhance_head,
)


class MaskedMSAHead(EnhancedHeadBase):

def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs):
super(MaskedMSAHead, self).__init__(
d_in=d_msa,
d_out=d_out,
disable_enhance_head=disable_enhance_head,
)


class ExperimentallyResolvedHead(EnhancedHeadBase):

def __init__(self, d_single, d_out, disable_enhance_head, **kwargs):
super(ExperimentallyResolvedHead, self).__init__(
d_in=d_single,
d_out=d_out,
disable_enhance_head=disable_enhance_head,
)

+ 387
- 0
modelscope/models/science/unifold/modules/common.py View File

@@ -0,0 +1,387 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from unicore.modules import LayerNorm
from unicore.utils import tensor_tree_map


class Linear(nn.Linear):

def __init__(
self,
d_in: int,
d_out: int,
bias: bool = True,
init: str = 'default',
):
super(Linear, self).__init__(d_in, d_out, bias=bias)

self.use_bias = bias

if self.use_bias:
with torch.no_grad():
self.bias.fill_(0)

if init == 'default':
self._trunc_normal_init(1.0)
elif init == 'relu':
self._trunc_normal_init(2.0)
elif init == 'glorot':
self._glorot_uniform_init()
elif init == 'gating':
self._zero_init(self.use_bias)
elif init == 'normal':
self._normal_init()
elif init == 'final':
self._zero_init(False)
else:
raise ValueError('Invalid init method.')

def _trunc_normal_init(self, scale=1.0):
# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978
_, fan_in = self.weight.shape
scale = scale / max(1, fan_in)
std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR
nn.init.trunc_normal_(self.weight, mean=0.0, std=std)

def _glorot_uniform_init(self):
nn.init.xavier_uniform_(self.weight, gain=1)

def _zero_init(self, use_bias=True):
with torch.no_grad():
self.weight.fill_(0.0)
if use_bias:
with torch.no_grad():
self.bias.fill_(1.0)

def _normal_init(self):
torch.nn.init.kaiming_normal_(self.weight, nonlinearity='linear')


class Transition(nn.Module):

def __init__(self, d_in, n):

super(Transition, self).__init__()

self.d_in = d_in
self.n = n

self.layer_norm = LayerNorm(self.d_in)
self.linear_1 = Linear(self.d_in, self.n * self.d_in, init='relu')
self.act = nn.GELU()
self.linear_2 = Linear(self.n * self.d_in, d_in, init='final')

def _transition(self, x):
x = self.layer_norm(x)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x

@torch.jit.ignore
def _chunk(
self,
x: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{'x': x},
chunk_size=chunk_size,
num_batch_dims=len(x.shape[:-2]),
)

def forward(
self,
x: torch.Tensor,
chunk_size: Optional[int] = None,
) -> torch.Tensor:

if chunk_size is not None:
x = self._chunk(x, chunk_size)
else:
x = self._transition(x=x)

return x


class OuterProductMean(nn.Module):

def __init__(self, d_msa, d_pair, d_hid, eps=1e-3):
super(OuterProductMean, self).__init__()

self.d_msa = d_msa
self.d_pair = d_pair
self.d_hid = d_hid
self.eps = eps

self.layer_norm = LayerNorm(d_msa)
self.linear_1 = Linear(d_msa, d_hid)
self.linear_2 = Linear(d_msa, d_hid)
self.linear_out = Linear(d_hid**2, d_pair, init='relu')
self.act = nn.GELU()
self.linear_z = Linear(self.d_pair, self.d_pair, init='final')
self.layer_norm_out = LayerNorm(self.d_pair)

def _opm(self, a, b):
outer = torch.einsum('...bac,...dae->...bdce', a, b)
outer = outer.reshape(outer.shape[:-2] + (-1, ))
outer = self.linear_out(outer)
return outer

@torch.jit.ignore
def _chunk(self, a: torch.Tensor, b: torch.Tensor,
chunk_size: int) -> torch.Tensor:
a = a.reshape((-1, ) + a.shape[-3:])
b = b.reshape((-1, ) + b.shape[-3:])
out = []
# TODO: optimize this
for a_prime, b_prime in zip(a, b):
outer = chunk_layer(
partial(self._opm, b=b_prime),
{'a': a_prime},
chunk_size=chunk_size,
num_batch_dims=1,
)
out.append(outer)
if len(out) == 1:
outer = out[0].unsqueeze(0)
else:
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])

return outer

def apply_alphafold_original_mode(self):
self.linear_z = None
self.layer_norm_out = None

def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:

m = self.layer_norm(m)
mask = mask.unsqueeze(-1)
if self.layer_norm_out is not None:
# for numerical stability
mask = mask * (mask.size(-2)**-0.5)
a = self.linear_1(m)
b = self.linear_2(m)
if self.training:
a = a * mask
b = b * mask
else:
a *= mask
b *= mask

a = a.transpose(-2, -3)
b = b.transpose(-2, -3)

if chunk_size is not None:
z = self._chunk(a, b, chunk_size)
else:
z = self._opm(a, b)

norm = torch.einsum('...abc,...adc->...bdc', mask, mask)
z /= self.eps + norm
if self.layer_norm_out is not None:
z = self.act(z)
z = self.layer_norm_out(z)
z = self.linear_z(z)
return z


def residual(residual, x, training):
if training:
return x + residual
else:
residual += x
return residual


@torch.jit.script
def fused_bias_dropout_add(
x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
dropmask: torch.Tensor,
prob: float,
) -> torch.Tensor:
return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual


@torch.jit.script
def fused_bias_dropout_add_inference(
x: torch.Tensor,
bias: torch.Tensor,
residual: torch.Tensor,
) -> torch.Tensor:
residual += bias + x
return residual


def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob,
training):
bias = module.get_output_bias()
if training:
shape = list(x.shape)
shape[dropout_shared_dim] = 1
with torch.no_grad():
mask = x.new_ones(shape)
return fused_bias_dropout_add(x, bias, residual, mask, prob)
else:
return fused_bias_dropout_add_inference(x, bias, residual)


@torch.jit.script
def fused_bias_gated_dropout_add(
x: torch.Tensor,
bias: torch.Tensor,
g: torch.Tensor,
g_bias: torch.Tensor,
residual: torch.Tensor,
dropout_mask: torch.Tensor,
prob: float,
) -> torch.Tensor:
return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout(
dropout_mask,
p=prob,
training=True,
) + residual


def tri_mul_residual(
module,
residual,
outputs,
dropout_shared_dim,
prob,
training,
block_size,
):
if training:
x, g = outputs
bias, g_bias = module.get_output_bias()
shape = list(x.shape)
shape[dropout_shared_dim] = 1
with torch.no_grad():
mask = x.new_ones(shape)
return fused_bias_gated_dropout_add(
x,
bias,
g,
g_bias,
residual,
mask,
prob,
)
elif block_size is None:
x, g = outputs
bias, g_bias = module.get_output_bias()
residual += (torch.sigmoid(g + g_bias) * (x + bias))
return residual
else:
# gated is not used here
residual += outputs
return residual


class SimpleModuleList(nn.ModuleList):

def __repr__(self):
return str(len(self)) + ' X ...\n' + self[0].__repr__()


def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
num_batch_dims: int,
) -> Any:
# TODO: support inplace add to output
if not (len(inputs) > 0):
raise ValueError('Must provide at least one input')

def _dict_get_shapes(input):
shapes = []
if type(input) is torch.Tensor:
shapes.append(input.shape)
elif type(input) is dict:
for v in input.values():
shapes.extend(_dict_get_shapes(v))
elif isinstance(input, Iterable):
for v in input:
shapes.extend(_dict_get_shapes(v))
else:
raise ValueError('Not supported')

return shapes

inputs = {k: v for k, v in inputs.items() if v is not None}
initial_dims = [
shape[:num_batch_dims] for shape in _dict_get_shapes(inputs)
]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])

flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size

def _flat_inputs(t):
t = t.view(-1, *t.shape[num_batch_dims:])
assert (
t.shape[0] == flat_batch_dim or t.shape[0] == 1
), 'batch dimension must be 1 or equal to the flat batch dimension'
return t

flat_inputs = tensor_tree_map(_flat_inputs, inputs)

out = None
for i in range(num_chunks):
chunk_start = i * chunk_size
chunk_end = min((i + 1) * chunk_size, flat_batch_dim)

def select_chunk(t):
if t.shape[0] == 1:
return t[0:1]
else:
return t[chunk_start:chunk_end]

chunkes = tensor_tree_map(select_chunk, flat_inputs)

output_chunk = layer(**chunkes)

if out is None:
out = tensor_tree_map(
lambda t: t.new_zeros((flat_batch_dim, ) + t.shape[1:]),
output_chunk)

out_type = type(output_chunk)
if out_type is tuple:
for x, y in zip(out, output_chunk):
x[chunk_start:chunk_end] = y
elif out_type is torch.Tensor:
out[chunk_start:chunk_end] = output_chunk
else:
raise ValueError('Not supported')

# reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
def reshape(t):
return t.view(orig_batch_dims + t.shape[1:])

out = tensor_tree_map(reshape, out)

return out

+ 159
- 0
modelscope/models/science/unifold/modules/confidence.py View File

@@ -0,0 +1,159 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from typing import Dict, Optional, Tuple

import torch


def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor:
"""Computes per-residue pLDDT from logits.
Args:
logits: [num_res, num_bins] output from the PredictedLDDTHead.
Returns:
plddt: [num_res] per-residue pLDDT.
"""
num_bins = plddt_logits.shape[-1]
bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1)
bin_width = 1.0 / num_bins
bounds = torch.arange(
start=0.5 * bin_width,
end=1.0,
step=bin_width,
device=plddt_logits.device)
plddt = torch.sum(
bin_probs
* bounds.view(*((1, ) * len(bin_probs.shape[:-1])), *bounds.shape),
dim=-1,
)
return plddt


def compute_bin_values(breaks: torch.Tensor):
"""Gets the bin centers from the bin edges.
Args:
breaks: [num_bins - 1] the error bin edges.
Returns:
bin_centers: [num_bins] the error bin centers.
"""
step = breaks[1] - breaks[0]
bin_values = breaks + step / 2
bin_values = torch.cat([bin_values, (bin_values[-1] + step).unsqueeze(-1)],
dim=0)
return bin_values


def compute_predicted_aligned_error(
bin_edges: torch.Tensor,
bin_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculates expected aligned distance errors for every pair of residues.
Args:
alignment_confidence_breaks: [num_bins - 1] the error bin edges.
aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted
probs for each error bin, for each pair of residues.
Returns:
predicted_aligned_error: [num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
bin_values = compute_bin_values(bin_edges)
return torch.sum(bin_probs * bin_values, dim=-1)


def predicted_aligned_error(
pae_logits: torch.Tensor,
max_bin: int = 31,
num_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins - 1] the error bin edges.
Returns:
aligned_confidence_probs: [num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: The maximum predicted error possible.
"""
bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1)
bin_edges = torch.linspace(
0, max_bin, steps=(num_bins - 1), device=pae_logits.device)

predicted_aligned_error = compute_predicted_aligned_error(
bin_edges=bin_edges,
bin_probs=bin_probs,
)

return {
'aligned_error_probs_per_bin': bin_probs,
'predicted_aligned_error': predicted_aligned_error,
}


def predicted_tm_score(
pae_logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
num_bins: int = 64,
eps: float = 1e-8,
asym_id: Optional[torch.Tensor] = None,
interface: bool = False,
**kwargs,
) -> torch.Tensor:
"""Computes predicted TM alignment or predicted interface TM alignment score.
Args:
logits: [num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
breaks: [num_bins] the error bins.
residue_weights: [num_res] the per residue weights to use for the
expectation.
asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for
ipTM calculation, i.e. when interface=True.
interface: If True, interface predicted TM score is computed.
Returns:
ptm_score: The predicted TM alignment or the predicted iTM score.
"""
pae_logits = pae_logits.float()
if residue_weights is None:
residue_weights = pae_logits.new_ones(pae_logits.shape[:-2])

breaks = torch.linspace(
0, max_bin, steps=(num_bins - 1), device=pae_logits.device)

def tm_kernal(nres):
clipped_n = max(nres, 19)
d0 = 1.24 * (clipped_n - 15)**(1.0 / 3.0) - 1.8
return lambda x: 1.0 / (1.0 + (x / d0)**2)

def rmsd_kernal(eps): # leave for compute pRMS
return lambda x: 1. / (x + eps)

bin_centers = compute_bin_values(breaks)
probs = torch.nn.functional.softmax(pae_logits, dim=-1)
tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers)
# tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
# rmsd_per_bin = rmsd_kernal()(bin_centers)
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)

pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape)
if interface:
assert asym_id is not None, 'must provide asym_id for iptm calculation.'
pair_mask *= asym_id[..., :, None] != asym_id[..., None, :]

predicted_tm_term *= pair_mask

pair_residue_weights = pair_mask * (
residue_weights[None, :] * residue_weights[:, None])
normed_residue_mask = pair_residue_weights / (
eps + pair_residue_weights.sum(dim=-1, keepdim=True))

per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
ret = per_alignment.gather(
dim=-1, index=weighted.max(dim=-1,
keepdim=True).indices).squeeze(dim=-1)
return ret

+ 290
- 0
modelscope/models/science/unifold/modules/embedders.py View File

@@ -0,0 +1,290 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from typing import Optional, Tuple

import torch
import torch.nn as nn
from unicore.modules import LayerNorm
from unicore.utils import one_hot

from .common import Linear, SimpleModuleList, residual


class InputEmbedder(nn.Module):

def __init__(
self,
tf_dim: int,
msa_dim: int,
d_pair: int,
d_msa: int,
relpos_k: int,
use_chain_relative: bool = False,
max_relative_chain: Optional[int] = None,
**kwargs,
):
super(InputEmbedder, self).__init__()

self.tf_dim = tf_dim
self.msa_dim = msa_dim

self.d_pair = d_pair
self.d_msa = d_msa

self.linear_tf_z_i = Linear(tf_dim, d_pair)
self.linear_tf_z_j = Linear(tf_dim, d_pair)
self.linear_tf_m = Linear(tf_dim, d_msa)
self.linear_msa_m = Linear(msa_dim, d_msa)

# RPE stuff
self.relpos_k = relpos_k
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if not self.use_chain_relative:
self.num_bins = 2 * self.relpos_k + 1
else:
self.num_bins = 2 * self.relpos_k + 2
self.num_bins += 1 # entity id
self.num_bins += 2 * max_relative_chain + 2

self.linear_relpos = Linear(self.num_bins, d_pair)

def _relpos_indices(
self,
res_id: torch.Tensor,
sym_id: Optional[torch.Tensor] = None,
asym_id: Optional[torch.Tensor] = None,
entity_id: Optional[torch.Tensor] = None,
):

max_rel_res = self.relpos_k
rp = res_id[..., None] - res_id[..., None, :]
rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res
if not self.use_chain_relative:
return rp
else:
asym_id_same = asym_id[..., :, None] == asym_id[..., None, :]
rp[~asym_id_same] = 2 * max_rel_res + 1
entity_id_same = entity_id[..., :, None] == entity_id[..., None, :]
rp_entity_id = entity_id_same.type(rp.dtype)[..., None]

rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :]

max_rel_chain = self.max_relative_chain

clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain)

clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1
return rp, rp_entity_id, clipped_rel_chain

def relpos_emb(
self,
res_id: torch.Tensor,
sym_id: Optional[torch.Tensor] = None,
asym_id: Optional[torch.Tensor] = None,
entity_id: Optional[torch.Tensor] = None,
num_sym: Optional[torch.Tensor] = None,
):

dtype = self.linear_relpos.weight.dtype
if not self.use_chain_relative:
rp = self._relpos_indices(res_id=res_id)
return self.linear_relpos(
one_hot(rp, num_classes=self.num_bins, dtype=dtype))
else:
rp, rp_entity_id, rp_rel_chain = self._relpos_indices(
res_id=res_id,
sym_id=sym_id,
asym_id=asym_id,
entity_id=entity_id)
rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype)
rp_entity_id = rp_entity_id.type(dtype)
rp_rel_chain = one_hot(
rp_rel_chain,
num_classes=(2 * self.max_relative_chain + 2),
dtype=dtype)
return self.linear_relpos(
torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1))

def forward(
self,
tf: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# [*, N_res, d_pair]
if self.tf_dim == 21:
# multimer use 21 target dim
tf = tf[..., 1:]
# convert type if necessary
tf = tf.type(self.linear_tf_z_i.weight.dtype)
msa = msa.type(self.linear_tf_z_i.weight.dtype)
n_clust = msa.shape[-3]

msa_emb = self.linear_msa_m(msa)
# target_feat (aatype) into msa representation
tf_m = (
self.linear_tf_m(tf).unsqueeze(-3).expand(
((-1, ) * len(tf.shape[:-2]) + # noqa W504
(n_clust, -1, -1))))
msa_emb += tf_m

tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]

return msa_emb, pair_emb


class RecyclingEmbedder(nn.Module):

def __init__(
self,
d_msa: int,
d_pair: int,
min_bin: float,
max_bin: float,
num_bins: int,
inf: float = 1e8,
**kwargs,
):
super(RecyclingEmbedder, self).__init__()

self.d_msa = d_msa
self.d_pair = d_pair
self.min_bin = min_bin
self.max_bin = max_bin
self.num_bins = num_bins
self.inf = inf

self.squared_bins = None

self.linear = Linear(self.num_bins, self.d_pair)
self.layer_norm_m = LayerNorm(self.d_msa)
self.layer_norm_z = LayerNorm(self.d_pair)

def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

m_update = self.layer_norm_m(m)
z_update = self.layer_norm_z(z)

return m_update, z_update

def recyle_pos(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

if self.squared_bins is None:
bins = torch.linspace(
self.min_bin,
self.max_bin,
self.num_bins,
dtype=torch.float if self.training else x.dtype,
device=x.device,
requires_grad=False,
)
self.squared_bins = bins**2
upper = torch.cat(
[self.squared_bins[1:],
self.squared_bins.new_tensor([self.inf])],
dim=-1)
if self.training:
x = x.float()
d = torch.sum(
(x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True)
d = ((d > self.squared_bins) * # noqa W504
(d < upper)).type(self.linear.weight.dtype)
d = self.linear(d)
return d


class TemplateAngleEmbedder(nn.Module):

def __init__(
self,
d_in: int,
d_out: int,
**kwargs,
):
super(TemplateAngleEmbedder, self).__init__()

self.d_out = d_out
self.d_in = d_in

self.linear_1 = Linear(self.d_in, self.d_out, init='relu')
self.act = nn.GELU()
self.linear_2 = Linear(self.d_out, self.d_out, init='relu')

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x.type(self.linear_1.weight.dtype))
x = self.act(x)
x = self.linear_2(x)
return x


class TemplatePairEmbedder(nn.Module):

def __init__(
self,
d_in: int,
v2_d_in: list,
d_out: int,
d_pair: int,
v2_feature: bool = False,
**kwargs,
):
super(TemplatePairEmbedder, self).__init__()

self.d_out = d_out
self.v2_feature = v2_feature
if self.v2_feature:
self.d_in = v2_d_in
self.linear = SimpleModuleList()
for d_in in self.d_in:
self.linear.append(Linear(d_in, self.d_out, init='relu'))
self.z_layer_norm = LayerNorm(d_pair)
self.z_linear = Linear(d_pair, self.d_out, init='relu')
else:
self.d_in = d_in
self.linear = Linear(self.d_in, self.d_out, init='relu')

def forward(
self,
x,
z,
) -> torch.Tensor:
if not self.v2_feature:
x = self.linear(x.type(self.linear.weight.dtype))
return x
else:
dtype = self.z_linear.weight.dtype
t = self.linear[0](x[0].type(dtype))
for i, s in enumerate(x[1:]):
t = residual(t, self.linear[i + 1](s.type(dtype)),
self.training)
t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training)
return t


class ExtraMSAEmbedder(nn.Module):

def __init__(
self,
d_in: int,
d_out: int,
**kwargs,
):
super(ExtraMSAEmbedder, self).__init__()

self.d_in = d_in
self.d_out = d_out
self.linear = Linear(self.d_in, self.d_out)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x.type(self.linear.weight.dtype))

+ 362
- 0
modelscope/models/science/unifold/modules/evoformer.py View File

@@ -0,0 +1,362 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from functools import partial
from typing import Optional, Tuple

import torch
import torch.nn as nn
from unicore.utils import checkpoint_sequential

from .attentions import (MSAColumnAttention, MSAColumnGlobalAttention,
MSARowAttentionWithPairBias, TriangleAttentionEnding,
TriangleAttentionStarting)
from .common import (Linear, OuterProductMean, SimpleModuleList, Transition,
bias_dropout_residual, residual, tri_mul_residual)
from .triangle_multiplication import (TriangleMultiplicationIncoming,
TriangleMultiplicationOutgoing)


class EvoformerIteration(nn.Module):

def __init__(
self,
d_msa: int,
d_pair: int,
d_hid_msa_att: int,
d_hid_opm: int,
d_hid_mul: int,
d_hid_pair_att: int,
num_heads_msa: int,
num_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
outer_product_mean_first: bool,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
):
super(EvoformerIteration, self).__init__()

self._is_extra_msa_stack = _is_extra_msa_stack
self.outer_product_mean_first = outer_product_mean_first

self.msa_att_row = MSARowAttentionWithPairBias(
d_msa=d_msa,
d_pair=d_pair,
d_hid=d_hid_msa_att,
num_heads=num_heads_msa,
)

if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention(
d_in=d_msa,
d_hid=d_hid_msa_att,
num_heads=num_heads_msa,
inf=inf,
eps=eps,
)
else:
self.msa_att_col = MSAColumnAttention(
d_msa,
d_hid_msa_att,
num_heads_msa,
)

self.msa_transition = Transition(
d_in=d_msa,
n=transition_n,
)

self.outer_product_mean = OuterProductMean(
d_msa,
d_pair,
d_hid_opm,
)

self.tri_mul_out = TriangleMultiplicationOutgoing(
d_pair,
d_hid_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
d_pair,
d_hid_mul,
)

self.tri_att_start = TriangleAttentionStarting(
d_pair,
d_hid_pair_att,
num_heads_pair,
)
self.tri_att_end = TriangleAttentionEnding(
d_pair,
d_hid_pair_att,
num_heads_pair,
)

self.pair_transition = Transition(
d_in=d_pair,
n=transition_n,
)

self.row_dropout_share_dim = -3
self.col_dropout_share_dim = -2
self.msa_dropout = msa_dropout
self.pair_dropout = pair_dropout

def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
msa_row_attn_mask: torch.Tensor,
msa_col_attn_mask: Optional[torch.Tensor],
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: Optional[int] = None,
block_size: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

if self.outer_product_mean_first:
z = residual(
z,
self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size), self.training)

m = bias_dropout_residual(
self.msa_att_row,
m,
self.msa_att_row(
m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size),
self.row_dropout_share_dim,
self.msa_dropout,
self.training,
)
if self._is_extra_msa_stack:
m = residual(
m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size),
self.training)
else:
m = bias_dropout_residual(
self.msa_att_col,
m,
self.msa_att_col(
m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size),
self.col_dropout_share_dim,
self.msa_dropout,
self.training,
)
m = residual(m, self.msa_transition(m, chunk_size=chunk_size),
self.training)
if not self.outer_product_mean_first:
z = residual(
z,
self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size), self.training)

z = tri_mul_residual(
self.tri_mul_out,
z,
self.tri_mul_out(z, mask=pair_mask, block_size=block_size),
self.row_dropout_share_dim,
self.pair_dropout,
self.training,
block_size=block_size,
)

z = tri_mul_residual(
self.tri_mul_in,
z,
self.tri_mul_in(z, mask=pair_mask, block_size=block_size),
self.row_dropout_share_dim,
self.pair_dropout,
self.training,
block_size=block_size,
)

z = bias_dropout_residual(
self.tri_att_start,
z,
self.tri_att_start(
z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
self.row_dropout_share_dim,
self.pair_dropout,
self.training,
)

z = bias_dropout_residual(
self.tri_att_end,
z,
self.tri_att_end(
z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
self.col_dropout_share_dim,
self.pair_dropout,
self.training,
)
z = residual(z, self.pair_transition(z, chunk_size=chunk_size),
self.training)
return m, z


class EvoformerStack(nn.Module):

def __init__(
self,
d_msa: int,
d_pair: int,
d_hid_msa_att: int,
d_hid_opm: int,
d_hid_mul: int,
d_hid_pair_att: int,
d_single: int,
num_heads_msa: int,
num_heads_pair: int,
num_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
outer_product_mean_first: bool,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
**kwargs,
):
super(EvoformerStack, self).__init__()

self._is_extra_msa_stack = _is_extra_msa_stack

self.blocks = SimpleModuleList()

for _ in range(num_blocks):
self.blocks.append(
EvoformerIteration(
d_msa=d_msa,
d_pair=d_pair,
d_hid_msa_att=d_hid_msa_att,
d_hid_opm=d_hid_opm,
d_hid_mul=d_hid_mul,
d_hid_pair_att=d_hid_pair_att,
num_heads_msa=num_heads_msa,
num_heads_pair=num_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
outer_product_mean_first=outer_product_mean_first,
inf=inf,
eps=eps,
_is_extra_msa_stack=_is_extra_msa_stack,
))
if not self._is_extra_msa_stack:
self.linear = Linear(d_msa, d_single)
else:
self.linear = None

def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
msa_row_attn_mask: torch.Tensor,
msa_col_attn_mask: torch.Tensor,
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: int,
block_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_row_attn_mask=msa_row_attn_mask,
msa_col_attn_mask=msa_col_attn_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size) for b in self.blocks
]

m, z = checkpoint_sequential(
blocks,
input=(m, z),
)

s = None
if not self._is_extra_msa_stack:
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
s = s.squeeze(seq_dim)

return m, z, s


class ExtraMSAStack(EvoformerStack):

def __init__(
self,
d_msa: int,
d_pair: int,
d_hid_msa_att: int,
d_hid_opm: int,
d_hid_mul: int,
d_hid_pair_att: int,
num_heads_msa: int,
num_heads_pair: int,
num_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
outer_product_mean_first: bool,
inf: float,
eps: float,
**kwargs,
):
super(ExtraMSAStack, self).__init__(
d_msa=d_msa,
d_pair=d_pair,
d_hid_msa_att=d_hid_msa_att,
d_hid_opm=d_hid_opm,
d_hid_mul=d_hid_mul,
d_hid_pair_att=d_hid_pair_att,
d_single=None,
num_heads_msa=num_heads_msa,
num_heads_pair=num_heads_pair,
num_blocks=num_blocks,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
outer_product_mean_first=outer_product_mean_first,
inf=inf,
eps=eps,
_is_extra_msa_stack=True,
)

def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
msa_row_attn_mask: torch.Tensor = None,
msa_col_attn_mask: torch.Tensor = None,
tri_start_attn_mask: torch.Tensor = None,
tri_end_attn_mask: torch.Tensor = None,
chunk_size: int = None,
block_size: int = None,
) -> torch.Tensor:
_, z, _ = super().forward(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
msa_row_attn_mask=msa_row_attn_mask,
msa_col_attn_mask=msa_col_attn_mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size)
return z

+ 195
- 0
modelscope/models/science/unifold/modules/featurization.py View File

@@ -0,0 +1,195 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from typing import Dict

import torch
import torch.nn as nn
from unicore.utils import batched_gather, one_hot

from modelscope.models.science.unifold.data import residue_constants as rc
from .frame import Frame


def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
is_gly = aatype == rc.restype_order['G']
ca_idx = rc.atom_order['CA']
cb_idx = rc.atom_order['CB']
pseudo_beta = torch.where(
is_gly[..., None].expand(*((-1, ) * len(is_gly.shape)), 3),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :],
)

if all_atom_masks is not None:
pseudo_beta_mask = torch.where(
is_gly,
all_atom_masks[..., ca_idx],
all_atom_masks[..., cb_idx],
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta


def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
batch['residx_atom37_to_atom14'],
dim=-2,
num_batch_dims=len(atom14.shape[:-2]),
)

atom37_data = atom37_data * batch['atom37_atom_exists'][..., None]

return atom37_data


def build_template_angle_feat(template_feats, v2_feature=False):
template_aatype = template_feats['template_aatype']
torsion_angles_sin_cos = template_feats['template_torsion_angles_sin_cos']
torsion_angles_mask = template_feats['template_torsion_angles_mask']
if not v2_feature:
alt_torsion_angles_sin_cos = template_feats[
'template_alt_torsion_angles_sin_cos']
template_angle_feat = torch.cat(
[
one_hot(template_aatype, 22),
torsion_angles_sin_cos.reshape(
*torsion_angles_sin_cos.shape[:-2], 14),
alt_torsion_angles_sin_cos.reshape(
*alt_torsion_angles_sin_cos.shape[:-2], 14),
torsion_angles_mask,
],
dim=-1,
)
template_angle_mask = torsion_angles_mask[..., 2]
else:
chi_mask = torsion_angles_mask[..., 3:]
chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask
chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask
template_angle_feat = torch.cat(
[
one_hot(template_aatype, 22),
chi_angles_sin,
chi_angles_cos,
chi_mask,
],
dim=-1,
)
template_angle_mask = chi_mask[..., 0]
return template_angle_feat, template_angle_mask


def build_template_pair_feat(
batch,
min_bin,
max_bin,
num_bins,
eps=1e-20,
inf=1e8,
):
template_mask = batch['template_pseudo_beta_mask']
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]

tpb = batch['template_pseudo_beta']
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True)
lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)

to_concat = [dgram, template_mask_2d[..., None]]

aatype_one_hot = nn.functional.one_hot(
batch['template_aatype'],
rc.restype_num + 2,
)

n_res = batch['template_aatype'].shape[-1]
to_concat.append(aatype_one_hot[..., None, :, :].expand(
*aatype_one_hot.shape[:-2], n_res, -1, -1))
to_concat.append(aatype_one_hot[...,
None, :].expand(*aatype_one_hot.shape[:-2],
-1, n_res, -1))

to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3))
to_concat.append(template_mask_2d[..., None])

act = torch.cat(to_concat, dim=-1)
act = act * template_mask_2d[..., None]

return act


def build_template_pair_feat_v2(
batch,
min_bin,
max_bin,
num_bins,
multichain_mask_2d=None,
eps=1e-20,
inf=1e8,
):
template_mask = batch['template_pseudo_beta_mask']
template_mask_2d = template_mask[..., None] * template_mask[..., None, :]
if multichain_mask_2d is not None:
template_mask_2d *= multichain_mask_2d

tpb = batch['template_pseudo_beta']
dgram = torch.sum(
(tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True)
lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2
upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1)
dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype)
dgram *= template_mask_2d[..., None]
to_concat = [dgram, template_mask_2d[..., None]]

aatype_one_hot = one_hot(
batch['template_aatype'],
rc.restype_num + 2,
)

n_res = batch['template_aatype'].shape[-1]
to_concat.append(aatype_one_hot[..., None, :, :].expand(
*aatype_one_hot.shape[:-2], n_res, -1, -1))
to_concat.append(aatype_one_hot[...,
None, :].expand(*aatype_one_hot.shape[:-2],
-1, n_res, -1))

n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']]
rigids = Frame.make_transform_from_reference(
n_xyz=batch['template_all_atom_positions'][..., n, :],
ca_xyz=batch['template_all_atom_positions'][..., ca, :],
c_xyz=batch['template_all_atom_positions'][..., c, :],
eps=eps,
)
points = rigids.get_trans()[..., None, :, :]
rigid_vec = rigids[..., None].invert_apply(points)

inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1))

t_aa_masks = batch['template_all_atom_mask']
backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[...,
c]
backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[...,
None, :]
if multichain_mask_2d is not None:
backbone_mask_2d *= multichain_mask_2d

inv_distance_scalar = inv_distance_scalar * backbone_mask_2d
unit_vector_data = rigid_vec * inv_distance_scalar[..., None]
to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1))
to_concat.append(backbone_mask_2d[..., None])

return to_concat


def build_extra_msa_feat(batch):
msa_1hot = one_hot(batch['extra_msa'], 23)
msa_feat = [
msa_1hot,
batch['extra_msa_has_deletion'].unsqueeze(-1),
batch['extra_msa_deletion_value'].unsqueeze(-1),
]
return torch.cat(msa_feat, dim=-1)

+ 562
- 0
modelscope/models/science/unifold/modules/frame.py View File

@@ -0,0 +1,562 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from __future__ import annotations # noqa
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple

import numpy as np
import torch


def zero_translation(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = torch.float,
device: Optional[torch.device] = torch.device('cpu'),
requires_grad: bool = False,
) -> torch.Tensor:
trans = torch.zeros((*batch_dims, 3),
dtype=dtype,
device=device,
requires_grad=requires_grad)
return trans


# pylint: disable=bad-whitespace
_QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32)

_QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr
_QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii
_QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj
_QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk

_QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij
_QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik
_QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk

_QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir
_QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr
_QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr

_QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9)
_QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT)

_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0],
[0, 0, 0, -1]]

_QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1],
[0, 0, -1, 0]]

_QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0],
[0, 1, 0, 0]]

_QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0],
[1, 0, 0, 0]]

_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC)


class Rotation:

def __init__(
self,
mat: torch.Tensor,
):
if mat.shape[-2:] != (3, 3):
raise ValueError(f'incorrect rotation shape: {mat.shape}')
self._mat = mat

@staticmethod
def identity(
shape,
dtype: Optional[torch.dtype] = torch.float,
device: Optional[torch.device] = torch.device('cpu'),
requires_grad: bool = False,
) -> Rotation:
mat = torch.eye(
3, dtype=dtype, device=device, requires_grad=requires_grad)
mat = mat.view(*((1, ) * len(shape)), 3, 3)
mat = mat.expand(*shape, -1, -1)
return Rotation(mat)

@staticmethod
def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return (a.float() @ b.float()).type(a.dtype)

@staticmethod
def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype)

def __getitem__(self, index: Any) -> Rotation:
if not isinstance(index, tuple):
index = (index, )
return Rotation(mat=self._mat[index + (slice(None), slice(None))])

def __mul__(self, right: Any) -> Rotation:
if isinstance(right, (int, float)):
return Rotation(mat=self._mat * right)
elif isinstance(right, torch.Tensor):
return Rotation(mat=self._mat * right[..., None, None])
else:
raise TypeError(
f'multiplicand must be a tensor or a number, got {type(right)}.'
)

def __rmul__(self, left: Any) -> Rotation:
return self.__mul__(left)

def __matmul__(self, other: Rotation) -> Rotation:
new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat)
return Rotation(mat=new_mat)

@property
def _inv_mat(self):
return self._mat.transpose(-1, -2)

@property
def rot_mat(self) -> torch.Tensor:
return self._mat

def invert(self) -> Rotation:
return Rotation(mat=self._inv_mat)

def apply(self, pts: torch.Tensor) -> torch.Tensor:
return Rotation.mat_mul_vec(self._mat, pts)

def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
return Rotation.mat_mul_vec(self._inv_mat, pts)

# inherit tensor behaviors
@property
def shape(self) -> torch.Size:
s = self._mat.shape[:-2]
return s

@property
def dtype(self) -> torch.dtype:
return self._mat.dtype

@property
def device(self) -> torch.device:
return self._mat.device

@property
def requires_grad(self) -> bool:
return self._mat.requires_grad

def unsqueeze(self, dim: int) -> Rotation:
if dim >= len(self.shape):
raise ValueError('Invalid dimension')

rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2)
return Rotation(mat=rot_mats)

def map_tensor_fn(self, fn: Callable[[torch.Tensor],
torch.Tensor]) -> Rotation:
mat = self._mat.view(self._mat.shape[:-2] + (9, ))
mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1)
mat = mat.view(mat.shape[:-1] + (3, 3))
return Rotation(mat=mat)

@staticmethod
def cat(rs: Sequence[Rotation], dim: int) -> Rotation:
rot_mats = [r.rot_mat for r in rs]
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)

return Rotation(mat=rot_mats)

def cuda(self) -> Rotation:
return Rotation(mat=self._mat.cuda())

def to(self, device: Optional[torch.device],
dtype: Optional[torch.dtype]) -> Rotation:
return Rotation(mat=self._mat.to(device=device, dtype=dtype))

def type(self, dtype: Optional[torch.dtype]) -> Rotation:
return Rotation(mat=self._mat.type(dtype))

def detach(self) -> Rotation:
return Rotation(mat=self._mat.detach())


class Frame:

def __init__(
self,
rotation: Optional[Rotation],
translation: Optional[torch.Tensor],
):
if rotation is None and translation is None:
rotation = Rotation.identity((0, ))
translation = zero_translation((0, ))
elif translation is None:
translation = zero_translation(rotation.shape, rotation.dtype,
rotation.device,
rotation.requires_grad)

elif rotation is None:
rotation = Rotation.identity(
translation.shape[:-1],
translation.dtype,
translation.device,
translation.requires_grad,
)

if (rotation.shape != translation.shape[:-1]) or (rotation.device
!= # noqa W504
translation.device):
raise ValueError('RotationMatrix and translation incompatible')

self._r = rotation
self._t = translation

@staticmethod
def identity(
shape: Iterable[int],
dtype: Optional[torch.dtype] = torch.float,
device: Optional[torch.device] = torch.device('cpu'),
requires_grad: bool = False,
) -> Frame:
return Frame(
Rotation.identity(shape, dtype, device, requires_grad),
zero_translation(shape, dtype, device, requires_grad),
)

def __getitem__(
self,
index: Any,
) -> Frame:
if type(index) != tuple:
index = (index, )

return Frame(
self._r[index],
self._t[index + (slice(None), )],
)

def __mul__(
self,
right: torch.Tensor,
) -> Frame:
if not (isinstance(right, torch.Tensor)):
raise TypeError('The other multiplicand must be a Tensor')

new_rots = self._r * right
new_trans = self._t * right[..., None]

return Frame(new_rots, new_trans)

def __rmul__(
self,
left: torch.Tensor,
) -> Frame:
return self.__mul__(left)

@property
def shape(self) -> torch.Size:
s = self._t.shape[:-1]
return s

@property
def device(self) -> torch.device:
return self._t.device

def get_rots(self) -> Rotation:
return self._r

def get_trans(self) -> torch.Tensor:
return self._t

def compose(
self,
other: Frame,
) -> Frame:
new_rot = self._r @ other._r
new_trans = self._r.apply(other._t) + self._t
return Frame(new_rot, new_trans)

def apply(
self,
pts: torch.Tensor,
) -> torch.Tensor:
rotated = self._r.apply(pts)
return rotated + self._t

def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
pts = pts - self._t
return self._r.invert_apply(pts)

def invert(self) -> Frame:
rot_inv = self._r.invert()
trn_inv = rot_inv.apply(self._t)

return Frame(rot_inv, -1 * trn_inv)

def map_tensor_fn(self, fn: Callable[[torch.Tensor],
torch.Tensor]) -> Frame:
new_rots = self._r.map_tensor_fn(fn)
new_trans = torch.stack(
list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1)

return Frame(new_rots, new_trans)

def to_tensor_4x4(self) -> torch.Tensor:
tensor = self._t.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self._r.rot_mat
tensor[..., :3, 3] = self._t
tensor[..., 3, 3] = 1
return tensor

@staticmethod
def from_tensor_4x4(t: torch.Tensor) -> Frame:
if t.shape[-2:] != (4, 4):
raise ValueError('Incorrectly shaped input tensor')

rots = Rotation(mat=t[..., :3, :3])
trans = t[..., :3, 3]

return Frame(rots, trans)

@staticmethod
def from_3_points(
p_neg_x_axis: torch.Tensor,
origin: torch.Tensor,
p_xy_plane: torch.Tensor,
eps: float = 1e-8,
) -> Frame:
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
origin = torch.unbind(origin, dim=-1)
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)

e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]

denom = torch.sqrt(sum((c * c for c in e0)) + eps)
e0 = [c / denom for c in e0]
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
e1 = [c / denom for c in e1]
e2 = [
e0[1] * e1[2] - e0[2] * e1[1],
e0[2] * e1[0] - e0[0] * e1[2],
e0[0] * e1[1] - e0[1] * e1[0],
]

rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
rots = rots.reshape(rots.shape[:-1] + (3, 3))

rot_obj = Rotation(mat=rots)

return Frame(rot_obj, torch.stack(origin, dim=-1))

def unsqueeze(
self,
dim: int,
) -> Frame:
if dim >= len(self.shape):
raise ValueError('Invalid dimension')
rots = self._r.unsqueeze(dim)
trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1)

return Frame(rots, trans)

@staticmethod
def cat(
Ts: Sequence[Frame],
dim: int,
) -> Frame:
rots = Rotation.cat([T._r for T in Ts], dim)
trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1)

return Frame(rots, trans)

def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame:
return Frame(fn(self._r), self._t)

def apply_trans_fn(self, fn: Callable[[torch.Tensor],
torch.Tensor]) -> Frame:
return Frame(self._r, fn(self._t))

def scale_translation(self, trans_scale_factor: float) -> Frame:
# fn = lambda t: t * trans_scale_factor
def fn(t):
return t * trans_scale_factor

return self.apply_trans_fn(fn)

def stop_rot_gradient(self) -> Frame:
# fn = lambda r: r.detach()
def fn(r):
return r.detach()

return self.apply_rot_fn(fn)

@staticmethod
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
input_dtype = ca_xyz.dtype
n_xyz = n_xyz.float()
ca_xyz = ca_xyz.float()
c_xyz = c_xyz.float()
n_xyz = n_xyz - ca_xyz
c_xyz = c_xyz - ca_xyz

c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + c_x**2 + c_y**2)
sin_c1 = -c_y / norm
cos_c1 = c_x / norm

c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
c1_rots[..., 0, 0] = cos_c1
c1_rots[..., 0, 1] = -1 * sin_c1
c1_rots[..., 1, 0] = sin_c1
c1_rots[..., 1, 1] = cos_c1
c1_rots[..., 2, 2] = 1

norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2)
sin_c2 = d_pair / norm
cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm

c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
c2_rots[..., 0, 0] = cos_c2
c2_rots[..., 0, 2] = sin_c2
c2_rots[..., 1, 1] = 1
c2_rots[..., 2, 0] = -1 * sin_c2
c2_rots[..., 2, 2] = cos_c2

c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots)
n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz)

_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
norm = torch.sqrt(eps + n_y**2 + n_z**2)
sin_n = -n_z / norm
cos_n = n_y / norm

n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
n_rots[..., 0, 0] = 1
n_rots[..., 1, 1] = cos_n
n_rots[..., 1, 2] = -1 * sin_n
n_rots[..., 2, 1] = sin_n
n_rots[..., 2, 2] = cos_n

rots = Rotation.mat_mul_mat(n_rots, c_rots)

rots = rots.transpose(-1, -2)
rot_obj = Rotation(mat=rots.type(input_dtype))

return Frame(rot_obj, ca_xyz.type(input_dtype))

def cuda(self) -> Frame:
return Frame(self._r.cuda(), self._t.cuda())

@property
def dtype(self) -> torch.dtype:
assert self._r.dtype == self._t.dtype
return self._r.dtype

def type(self, dtype) -> Frame:
return Frame(self._r.type(dtype), self._t.type(dtype))


class Quaternion:

def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor):
if quaternion.shape[-1] != 4:
raise ValueError(f'incorrect quaternion shape: {quaternion.shape}')
self._q = quaternion
self._t = translation

@staticmethod
def identity(
shape: Iterable[int],
dtype: Optional[torch.dtype] = torch.float,
device: Optional[torch.device] = torch.device('cpu'),
requires_grad: bool = False,
) -> Quaternion:
trans = zero_translation(shape, dtype, device, requires_grad)
quats = torch.zeros((*shape, 4),
dtype=dtype,
device=device,
requires_grad=requires_grad)
with torch.no_grad():
quats[..., 0] = 1
return Quaternion(quats, trans)

def get_quats(self):
return self._q

def get_trans(self):
return self._t

def get_rot_mats(self):
quats = self.get_quats()
rot_mats = Quaternion.quat_to_rot(quats)
return rot_mats

@staticmethod
def quat_to_rot(normalized_quat):
global _QUAT_TO_ROT_tensor
dtype = normalized_quat.dtype
normalized_quat = normalized_quat.float()
if _QUAT_TO_ROT_tensor.device != normalized_quat.device:
_QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to(
normalized_quat.device)
rot_tensor = torch.sum(
_QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None]
* normalized_quat[..., None, :, None],
dim=(-3, -2),
)
rot_tensor = rot_tensor.type(dtype)
rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3)
return rot_tensor

@staticmethod
def normalize_quat(quats):
dtype = quats.dtype
quats = quats.float()
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
quats = quats.type(dtype)
return quats

@staticmethod
def quat_multiply_by_vec(quat, vec):
dtype = quat.dtype
quat = quat.float()
vec = vec.float()
global _QUAT_MULTIPLY_BY_VEC_tensor
if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device:
_QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to(
quat.device)
mat = _QUAT_MULTIPLY_BY_VEC_tensor
reshaped_mat = mat.view((1, ) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None],
dim=(-3, -2),
).type(dtype)

def compose_q_update_vec(self,
q_update_vec: torch.Tensor,
normalize_quats: bool = True) -> torch.Tensor:
quats = self.get_quats()
new_quats = quats + Quaternion.quat_multiply_by_vec(
quats, q_update_vec)
if normalize_quats:
new_quats = Quaternion.normalize_quat(new_quats)
return new_quats

def compose_update_vec(
self,
update_vec: torch.Tensor,
pre_rot_mat: Rotation,
) -> Quaternion:
q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:]
new_quats = self.compose_q_update_vec(q_vec)

trans_update = pre_rot_mat.apply(t_vec)
new_trans = self._t + trans_update

return Quaternion(new_quats, new_trans)

def stop_rot_gradient(self) -> Quaternion:
return Quaternion(self._q.detach(), self._t)

+ 592
- 0
modelscope/models/science/unifold/modules/structure_module.py View File

@@ -0,0 +1,592 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import math
from typing import Tuple

import torch
import torch.nn as nn
from unicore.modules import LayerNorm, softmax_dropout
from unicore.utils import dict_multimap, one_hot, permute_final_dims

from modelscope.models.science.unifold.data.residue_constants import (
restype_atom14_mask, restype_atom14_rigid_group_positions,
restype_atom14_to_rigid_group, restype_rigid_group_default_frame)
from .attentions import gen_attn_mask
from .common import Linear, SimpleModuleList, residual
from .frame import Frame, Quaternion, Rotation


def ipa_point_weights_init_(weights):
with torch.no_grad():
softplus_inverse_1 = 0.541324854612918
weights.fill_(softplus_inverse_1)


def torsion_angles_to_frames(
frame: Frame,
alpha: torch.Tensor,
aatype: torch.Tensor,
default_frames: torch.Tensor,
):
default_frame = Frame.from_tensor_4x4(default_frames[aatype, ...])

bb_rot = alpha.new_zeros((*((1, ) * len(alpha.shape[:-1])), 2))
bb_rot[..., 1] = 1

alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha],
dim=-2)

all_rots = alpha.new_zeros(default_frame.get_rots().rot_mat.shape)
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha

all_rots = Frame(Rotation(mat=all_rots), None)

all_frames = default_frame.compose(all_rots)

chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
chi4_frame_to_frame = all_frames[..., 7]

chi1_frame_to_bb = all_frames[..., 4]
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)

all_frames_to_bb = Frame.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
chi3_frame_to_bb.unsqueeze(-1),
chi4_frame_to_bb.unsqueeze(-1),
],
dim=-1,
)

all_frames_to_global = frame[..., None].compose(all_frames_to_bb)

return all_frames_to_global


def frames_and_literature_positions_to_atom14_pos(
frame: Frame,
aatype: torch.Tensor,
default_frames,
group_idx,
atom_mask,
lit_positions,
):
group_mask = group_idx[aatype, ...]
group_mask = one_hot(
group_mask,
num_classes=default_frames.shape[-3],
)

t_atoms_to_global = frame[..., None, :] * group_mask
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
lambda x: torch.sum(x, dim=-1))

atom_mask = atom_mask[aatype, ...].unsqueeze(-1)

lit_positions = lit_positions[aatype, ...]
pred_positions = t_atoms_to_global.apply(lit_positions)
pred_positions = pred_positions * atom_mask

return pred_positions


class SideChainAngleResnetIteration(nn.Module):

def __init__(self, d_hid):
super(SideChainAngleResnetIteration, self).__init__()

self.d_hid = d_hid

self.linear_1 = Linear(self.d_hid, self.d_hid, init='relu')
self.act = nn.GELU()
self.linear_2 = Linear(self.d_hid, self.d_hid, init='final')

def forward(self, s: torch.Tensor) -> torch.Tensor:

x = self.act(s)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)

return residual(s, x, self.training)


class SidechainAngleResnet(nn.Module):

def __init__(self, d_in, d_hid, num_blocks, num_angles):
super(SidechainAngleResnet, self).__init__()

self.linear_in = Linear(d_in, d_hid)
self.act = nn.GELU()
self.linear_initial = Linear(d_in, d_hid)

self.layers = SimpleModuleList()
for _ in range(num_blocks):
self.layers.append(SideChainAngleResnetIteration(d_hid=d_hid))

self.linear_out = Linear(d_hid, num_angles * 2)

def forward(self, s: torch.Tensor,
initial_s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

initial_s = self.linear_initial(self.act(initial_s))
s = self.linear_in(self.act(s))

s = s + initial_s

for layer in self.layers:
s = layer(s)

s = self.linear_out(self.act(s))

s = s.view(s.shape[:-1] + (-1, 2))

unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s.float()**2, dim=-1, keepdim=True),
min=1e-12,
))
s = s.float() / norm_denom

return unnormalized_s, s.type(unnormalized_s.dtype)


class InvariantPointAttention(nn.Module):

def __init__(
self,
d_single: int,
d_pair: int,
d_hid: int,
num_heads: int,
num_qk_points: int,
num_v_points: int,
separate_kv: bool = False,
bias: bool = True,
eps: float = 1e-8,
):
super(InvariantPointAttention, self).__init__()

self.d_hid = d_hid
self.num_heads = num_heads
self.num_qk_points = num_qk_points
self.num_v_points = num_v_points
self.eps = eps

hc = self.d_hid * self.num_heads
self.linear_q = Linear(d_single, hc, bias=bias)
self.separate_kv = separate_kv
if self.separate_kv:
self.linear_k = Linear(d_single, hc, bias=bias)
self.linear_v = Linear(d_single, hc, bias=bias)
else:
self.linear_kv = Linear(d_single, 2 * hc, bias=bias)

hpq = self.num_heads * self.num_qk_points * 3
self.linear_q_points = Linear(d_single, hpq)
hpk = self.num_heads * self.num_qk_points * 3
hpv = self.num_heads * self.num_v_points * 3
if self.separate_kv:
self.linear_k_points = Linear(d_single, hpk)
self.linear_v_points = Linear(d_single, hpv)
else:
hpkv = hpk + hpv
self.linear_kv_points = Linear(d_single, hpkv)

self.linear_b = Linear(d_pair, self.num_heads)

self.head_weights = nn.Parameter(torch.zeros((num_heads)))
ipa_point_weights_init_(self.head_weights)

concat_out_dim = self.num_heads * (
d_pair + self.d_hid + self.num_v_points * 4)
self.linear_out = Linear(concat_out_dim, d_single, init='final')

self.softplus = nn.Softplus()

def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
f: Frame,
square_mask: torch.Tensor,
) -> torch.Tensor:
q = self.linear_q(s)

q = q.view(q.shape[:-1] + (self.num_heads, -1))

if self.separate_kv:
k = self.linear_k(s)
v = self.linear_v(s)
k = k.view(k.shape[:-1] + (self.num_heads, -1))
v = v.view(v.shape[:-1] + (self.num_heads, -1))
else:
kv = self.linear_kv(s)
kv = kv.view(kv.shape[:-1] + (self.num_heads, -1))
k, v = torch.split(kv, self.d_hid, dim=-1)

q_pts = self.linear_q_points(s)

def process_points(pts, no_points):
shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3)
if self.separate_kv:
# alphafold-multimer uses different layout
pts = pts.view(pts.shape[:-1]
+ (self.num_heads, no_points * 3))
pts = torch.split(pts, pts.shape[-1] // 3, dim=-1)
pts = torch.stack(pts, dim=-1).view(*shape)
pts = f[..., None].apply(pts)

pts = pts.view(pts.shape[:-2] + (self.num_heads, no_points, 3))
return pts

q_pts = process_points(q_pts, self.num_qk_points)

if self.separate_kv:
k_pts = self.linear_k_points(s)
v_pts = self.linear_v_points(s)
k_pts = process_points(k_pts, self.num_qk_points)
v_pts = process_points(v_pts, self.num_v_points)
else:
kv_pts = self.linear_kv_points(s)

kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = f[..., None].apply(kv_pts)

kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3))

k_pts, v_pts = torch.split(
kv_pts, [self.num_qk_points, self.num_v_points], dim=-2)

bias = self.linear_b(z)

attn = torch.matmul(
permute_final_dims(q, (1, 0, 2)),
permute_final_dims(k, (1, 2, 0)),
)

if self.training:
attn = attn * math.sqrt(1.0 / (3 * self.d_hid))
attn = attn + (
math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1)))
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att.float()**2
else:
attn *= math.sqrt(1.0 / (3 * self.d_hid))
attn += (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1)))
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att *= pt_att

pt_att = pt_att.sum(dim=-1)
head_weights = self.softplus(self.head_weights).view(
*((1, ) * len(pt_att.shape[:-2]) + (-1, 1)))
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.num_qk_points * 9.0 / 2)))
pt_att *= head_weights * (-0.5)

pt_att = torch.sum(pt_att, dim=-1)

pt_att = permute_final_dims(pt_att, (2, 0, 1))
attn += square_mask
attn = softmax_dropout(
attn, 0, self.training, bias=pt_att.type(attn.dtype))
del pt_att, q_pts, k_pts, bias
o = torch.matmul(attn, v.transpose(-2, -3)).transpose(-2, -3)
o = o.contiguous().view(*o.shape[:-2], -1)
del q, k, v
o_pts = torch.sum(
(attn[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]),
dim=-2,
)

o_pts = permute_final_dims(o_pts, (2, 0, 3, 1))
o_pts = f[..., None, None].invert_apply(o_pts)
if self.training:
o_pts_norm = torch.sqrt(
torch.sum(o_pts.float()**2, dim=-1) + self.eps).type(
o_pts.dtype)
else:
o_pts_norm = torch.sqrt(torch.sum(o_pts**2, dim=-1)
+ self.eps).type(o_pts.dtype)

o_pts_norm = o_pts_norm.view(*o_pts_norm.shape[:-2], -1)

o_pts = o_pts.view(*o_pts.shape[:-3], -1, 3)

o_pair = torch.matmul(attn.transpose(-2, -3), z)

o_pair = o_pair.view(*o_pair.shape[:-2], -1)

s = self.linear_out(
torch.cat((o, *torch.unbind(o_pts, dim=-1), o_pts_norm, o_pair),
dim=-1))

return s


class BackboneUpdate(nn.Module):

def __init__(self, d_single):
super(BackboneUpdate, self).__init__()
self.linear = Linear(d_single, 6, init='final')

def forward(self, s: torch.Tensor):
return self.linear(s)


class StructureModuleTransitionLayer(nn.Module):

def __init__(self, c):
super(StructureModuleTransitionLayer, self).__init__()

self.linear_1 = Linear(c, c, init='relu')
self.linear_2 = Linear(c, c, init='relu')
self.act = nn.GELU()
self.linear_3 = Linear(c, c, init='final')

def forward(self, s):
s_old = s
s = self.linear_1(s)
s = self.act(s)
s = self.linear_2(s)
s = self.act(s)
s = self.linear_3(s)

s = residual(s_old, s, self.training)

return s


class StructureModuleTransition(nn.Module):

def __init__(self, c, num_layers, dropout_rate):
super(StructureModuleTransition, self).__init__()

self.num_layers = num_layers
self.dropout_rate = dropout_rate

self.layers = SimpleModuleList()
for _ in range(self.num_layers):
self.layers.append(StructureModuleTransitionLayer(c))

self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = LayerNorm(c)

def forward(self, s):
for layer in self.layers:
s = layer(s)

s = self.dropout(s)
s = self.layer_norm(s)

return s


class StructureModule(nn.Module):

def __init__(
self,
d_single,
d_pair,
d_ipa,
d_angle,
num_heads_ipa,
num_qk_points,
num_v_points,
dropout_rate,
num_blocks,
no_transition_layers,
num_resnet_blocks,
num_angles,
trans_scale_factor,
separate_kv,
ipa_bias,
epsilon,
inf,
**kwargs,
):
super(StructureModule, self).__init__()

self.num_blocks = num_blocks
self.trans_scale_factor = trans_scale_factor
self.default_frames = None
self.group_idx = None
self.atom_mask = None
self.lit_positions = None
self.inf = inf

self.layer_norm_s = LayerNorm(d_single)
self.layer_norm_z = LayerNorm(d_pair)

self.linear_in = Linear(d_single, d_single)

self.ipa = InvariantPointAttention(
d_single,
d_pair,
d_ipa,
num_heads_ipa,
num_qk_points,
num_v_points,
separate_kv=separate_kv,
bias=ipa_bias,
eps=epsilon,
)

self.ipa_dropout = nn.Dropout(dropout_rate)
self.layer_norm_ipa = LayerNorm(d_single)

self.transition = StructureModuleTransition(
d_single,
no_transition_layers,
dropout_rate,
)

self.bb_update = BackboneUpdate(d_single)

self.angle_resnet = SidechainAngleResnet(
d_single,
d_angle,
num_resnet_blocks,
num_angles,
)

def forward(
self,
s,
z,
aatype,
mask=None,
):
if mask is None:
mask = s.new_ones(s.shape[:-1])

# generate square mask
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3)
s = self.layer_norm_s(s)
z = self.layer_norm_z(z)
initial_s = s
s = self.linear_in(s)

quat_encoder = Quaternion.identity(
s.shape[:-1],
s.dtype,
s.device,
requires_grad=False,
)
backb_to_global = Frame(
Rotation(mat=quat_encoder.get_rot_mats(), ),
quat_encoder.get_trans(),
)
outputs = []
for i in range(self.num_blocks):
s = residual(s, self.ipa(s, z, backb_to_global, square_mask),
self.training)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)

# update quaternion encoder
# use backb_to_global to avoid quat-to-rot conversion
quat_encoder = quat_encoder.compose_update_vec(
self.bb_update(s), pre_rot_mat=backb_to_global.get_rots())

# initial_s is always used to update the backbone
unnormalized_angles, angles = self.angle_resnet(s, initial_s)

# convert quaternion to rotation matrix
backb_to_global = Frame(
Rotation(mat=quat_encoder.get_rot_mats(), ),
quat_encoder.get_trans(),
)
if i == self.num_blocks - 1:
all_frames_to_global = self.torsion_angles_to_frames(
backb_to_global.scale_translation(self.trans_scale_factor),
angles,
aatype,
)

pred_positions = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
aatype,
)

preds = {
'frames':
backb_to_global.scale_translation(
self.trans_scale_factor).to_tensor_4x4(),
'unnormalized_angles':
unnormalized_angles,
'angles':
angles,
}

outputs.append(preds)
if i < (self.num_blocks - 1):
# stop gradient in iteration
quat_encoder = quat_encoder.stop_rot_gradient()
backb_to_global = backb_to_global.stop_rot_gradient()

outputs = dict_multimap(torch.stack, outputs)
outputs['sidechain_frames'] = all_frames_to_global.to_tensor_4x4()
outputs['positions'] = pred_positions
outputs['single'] = s

return outputs

def _init_residue_constants(self, float_dtype, device):
if self.default_frames is None:
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
)

def torsion_angles_to_frames(self, frame, alpha, aatype):
self._init_residue_constants(alpha.dtype, alpha.device)
return torsion_angles_to_frames(frame, alpha, aatype,
self.default_frames)

def frames_and_literature_positions_to_atom14_pos(self, frame, aatype):
self._init_residue_constants(frame.get_rots().dtype,
frame.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
frame,
aatype,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)

+ 330
- 0
modelscope/models/science/unifold/modules/template.py View File

@@ -0,0 +1,330 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import math
from functools import partial
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
from unicore.modules import LayerNorm
from unicore.utils import (checkpoint_sequential, permute_final_dims,
tensor_tree_map)

from .attentions import (Attention, TriangleAttentionEnding,
TriangleAttentionStarting, gen_attn_mask)
from .common import (Linear, SimpleModuleList, Transition,
bias_dropout_residual, chunk_layer, residual,
tri_mul_residual)
from .featurization import build_template_pair_feat_v2
from .triangle_multiplication import (TriangleMultiplicationIncoming,
TriangleMultiplicationOutgoing)


class TemplatePointwiseAttention(nn.Module):

def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs):
super(TemplatePointwiseAttention, self).__init__()

self.inf = inf

self.mha = Attention(
d_pair,
d_template,
d_template,
d_hid,
num_heads,
gating=False,
)

def _chunk(
self,
z: torch.Tensor,
t: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
mha_inputs = {
'q': z,
'k': t,
'v': t,
'mask': mask,
}
return chunk_layer(
self.mha,
mha_inputs,
chunk_size=chunk_size,
num_batch_dims=len(z.shape[:-2]),
)

def forward(
self,
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
if template_mask is None:
template_mask = t.new_ones(t.shape[:-3])

mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None,
None, :]
z = z.unsqueeze(-2)

t = permute_final_dims(t, (1, 2, 0, 3))

if chunk_size is not None:
z = self._chunk(z, t, mask, chunk_size)
else:
z = self.mha(z, t, t, mask=mask)

z = z.squeeze(-2)

return z


class TemplateProjection(nn.Module):

def __init__(self, d_template, d_pair, **kwargs):
super(TemplateProjection, self).__init__()

self.d_pair = d_pair
self.act = nn.ReLU()
self.output_linear = Linear(d_template, d_pair, init='relu')

def forward(self, t, z) -> torch.Tensor:
if t is None:
# handle for non-template case
shape = z.shape
shape[-1] = self.d_pair
t = torch.zeros(shape, dtype=z.dtype, device=z.device)
t = self.act(t)
z_t = self.output_linear(t)
return z_t


class TemplatePairStackBlock(nn.Module):

def __init__(
self,
d_template: int,
d_hid_tri_att: int,
d_hid_tri_mul: int,
num_heads: int,
pair_transition_n: int,
dropout_rate: float,
tri_attn_first: bool,
inf: float,
**kwargs,
):
super(TemplatePairStackBlock, self).__init__()

self.tri_att_start = TriangleAttentionStarting(
d_template,
d_hid_tri_att,
num_heads,
)
self.tri_att_end = TriangleAttentionEnding(
d_template,
d_hid_tri_att,
num_heads,
)

self.tri_mul_out = TriangleMultiplicationOutgoing(
d_template,
d_hid_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
d_template,
d_hid_tri_mul,
)

self.pair_transition = Transition(
d_template,
pair_transition_n,
)
self.tri_attn_first = tri_attn_first
self.dropout = dropout_rate
self.row_dropout_share_dim = -3
self.col_dropout_share_dim = -2

def forward(
self,
s: torch.Tensor,
mask: torch.Tensor,
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
chunk_size: Optional[int] = None,
block_size: Optional[int] = None,
):
if self.tri_attn_first:
s = bias_dropout_residual(
self.tri_att_start,
s,
self.tri_att_start(
s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
)

s = bias_dropout_residual(
self.tri_att_end,
s,
self.tri_att_end(
s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
self.col_dropout_share_dim,
self.dropout,
self.training,
)
s = tri_mul_residual(
self.tri_mul_out,
s,
self.tri_mul_out(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
block_size=block_size,
)

s = tri_mul_residual(
self.tri_mul_in,
s,
self.tri_mul_in(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
block_size=block_size,
)
else:
s = tri_mul_residual(
self.tri_mul_out,
s,
self.tri_mul_out(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
block_size=block_size,
)

s = tri_mul_residual(
self.tri_mul_in,
s,
self.tri_mul_in(s, mask=mask, block_size=block_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
block_size=block_size,
)

s = bias_dropout_residual(
self.tri_att_start,
s,
self.tri_att_start(
s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size),
self.row_dropout_share_dim,
self.dropout,
self.training,
)

s = bias_dropout_residual(
self.tri_att_end,
s,
self.tri_att_end(
s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size),
self.col_dropout_share_dim,
self.dropout,
self.training,
)
s = residual(s, self.pair_transition(
s,
chunk_size=chunk_size,
), self.training)
return s


class TemplatePairStack(nn.Module):

def __init__(
self,
d_template,
d_hid_tri_att,
d_hid_tri_mul,
num_blocks,
num_heads,
pair_transition_n,
dropout_rate,
tri_attn_first,
inf=1e9,
**kwargs,
):
super(TemplatePairStack, self).__init__()

self.blocks = SimpleModuleList()
for _ in range(num_blocks):
self.blocks.append(
TemplatePairStackBlock(
d_template=d_template,
d_hid_tri_att=d_hid_tri_att,
d_hid_tri_mul=d_hid_tri_mul,
num_heads=num_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
inf=inf,
tri_attn_first=tri_attn_first,
))

self.layer_norm = LayerNorm(d_template)

def forward(
self,
single_templates: Tuple[torch.Tensor],
mask: torch.tensor,
tri_start_attn_mask: torch.Tensor,
tri_end_attn_mask: torch.Tensor,
templ_dim: int,
chunk_size: int,
block_size: int,
return_mean: bool,
):

def one_template(i):
(s, ) = checkpoint_sequential(
functions=[
partial(
b,
mask=mask,
tri_start_attn_mask=tri_start_attn_mask,
tri_end_attn_mask=tri_end_attn_mask,
chunk_size=chunk_size,
block_size=block_size,
) for b in self.blocks
],
input=(single_templates[i], ),
)
return s

n_templ = len(single_templates)
if n_templ > 0:
new_single_templates = [one_template(0)]
if return_mean:
t = self.layer_norm(new_single_templates[0])
for i in range(1, n_templ):
s = one_template(i)
if return_mean:
t = residual(t, self.layer_norm(s), self.training)
else:
new_single_templates.append(s)

if return_mean:
if n_templ > 0:
t /= n_templ
else:
t = None
else:
t = torch.cat(
[s.unsqueeze(templ_dim) for s in new_single_templates],
dim=templ_dim)
t = self.layer_norm(t)

return t

+ 158
- 0
modelscope/models/science/unifold/modules/triangle_multiplication.py View File

@@ -0,0 +1,158 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

from functools import partialmethod
from typing import List, Optional

import torch
import torch.nn as nn
from unicore.modules import LayerNorm
from unicore.utils import permute_final_dims

from .common import Linear


class TriangleMultiplication(nn.Module):

def __init__(self, d_pair, d_hid, outgoing=True):
super(TriangleMultiplication, self).__init__()
self.outgoing = outgoing

self.linear_ab_p = Linear(d_pair, d_hid * 2)
self.linear_ab_g = Linear(d_pair, d_hid * 2, init='gating')

self.linear_g = Linear(d_pair, d_pair, init='gating')
self.linear_z = Linear(d_hid, d_pair, init='final')

self.layer_norm_in = LayerNorm(d_pair)
self.layer_norm_out = LayerNorm(d_hid)

self._alphafold_original_mode = False

def _chunk_2d(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
block_size: int = None,
) -> torch.Tensor:

# avoid too small chunk size
# block_size = max(block_size, 256)
new_z = z.new_zeros(z.shape)
dim1 = z.shape[-3]

def _slice_linear(z, linear: Linear, a=True):
d_hid = linear.bias.shape[0] // 2
index = 0 if a else d_hid
p = (
nn.functional.linear(z, linear.weight[index:index + d_hid])
+ linear.bias[index:index + d_hid])
return p

def _chunk_projection(z, mask, a=True):
p = _slice_linear(z, self.linear_ab_p, a) * mask
p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a))
return p

num_chunk = (dim1 + block_size - 1) // block_size
for i in range(num_chunk):
chunk_start = i * block_size
chunk_end = min(chunk_start + block_size, dim1)
if self.outgoing:
a_chunk = _chunk_projection(
z[..., chunk_start:chunk_end, :, :],
mask[..., chunk_start:chunk_end, :, :],
a=True,
)
a_chunk = permute_final_dims(a_chunk, (2, 0, 1))
else:
a_chunk = _chunk_projection(
z[..., :, chunk_start:chunk_end, :],
mask[..., :, chunk_start:chunk_end, :],
a=True,
)
a_chunk = a_chunk.transpose(-1, -3)

for j in range(num_chunk):
j_chunk_start = j * block_size
j_chunk_end = min(j_chunk_start + block_size, dim1)
if self.outgoing:
b_chunk = _chunk_projection(
z[..., j_chunk_start:j_chunk_end, :, :],
mask[..., j_chunk_start:j_chunk_end, :, :],
a=False,
)
b_chunk = b_chunk.transpose(-1, -3)
else:
b_chunk = _chunk_projection(
z[..., :, j_chunk_start:j_chunk_end, :],
mask[..., :, j_chunk_start:j_chunk_end, :],
a=False,
)
b_chunk = permute_final_dims(b_chunk, (2, 0, 1))
x_chunk = torch.matmul(a_chunk, b_chunk)
del b_chunk
x_chunk = permute_final_dims(x_chunk, (1, 2, 0))
x_chunk = self.layer_norm_out(x_chunk)
x_chunk = self.linear_z(x_chunk)
x_chunk *= torch.sigmoid(
self.linear_g(z[..., chunk_start:chunk_end,
j_chunk_start:j_chunk_end, :]))
new_z[..., chunk_start:chunk_end,
j_chunk_start:j_chunk_end, :] = x_chunk
del x_chunk
del a_chunk
return new_z

def forward(
self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
block_size=None,
) -> torch.Tensor:

mask = mask.unsqueeze(-1)
if not self._alphafold_original_mode:
# divided by 1/sqrt(dim) for numerical stability
mask = mask * (mask.shape[-2]**-0.5)

z = self.layer_norm_in(z)
if not self.training and block_size is not None:
return self._chunk_2d(z, mask, block_size=block_size)

g = nn.functional.linear(z, self.linear_g.weight)
if self.training:
ab = self.linear_ab_p(z) * mask * torch.sigmoid(
self.linear_ab_g(z))
else:
ab = self.linear_ab_p(z)
ab *= mask
ab *= torch.sigmoid(self.linear_ab_g(z))
a, b = torch.chunk(ab, 2, dim=-1)
del z, ab

if self.outgoing:
a = permute_final_dims(a, (2, 0, 1))
b = b.transpose(-1, -3)
else:
b = permute_final_dims(b, (2, 0, 1))
a = a.transpose(-1, -3)
x = torch.matmul(a, b)
del a, b

x = permute_final_dims(x, (1, 2, 0))

x = self.layer_norm_out(x)
x = nn.functional.linear(x, self.linear_z.weight)
return x, g

def get_output_bias(self):
return self.linear_z.bias, self.linear_g.bias


class TriangleMultiplicationOutgoing(TriangleMultiplication):
__init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True)


class TriangleMultiplicationIncoming(TriangleMultiplication):
__init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False)

+ 1
- 0
modelscope/models/science/unifold/msa/__init__.py View File

@@ -0,0 +1 @@
""" Scripts for MSA & template searching. """

+ 483
- 0
modelscope/models/science/unifold/msa/mmcif.py View File

@@ -0,0 +1,483 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Parses the mmCIF file format."""
import collections
import dataclasses
import functools
import io
from typing import Any, Mapping, Optional, Sequence, Tuple

from absl import logging
from Bio import PDB
from Bio.Data import SCOPData
from Bio.PDB.MMCIFParser import MMCIFParser

# Type aliases:
ChainId = str
PdbHeader = Mapping[str, Any]
PdbStructure = PDB.Structure.Structure
SeqRes = str
MmCIFDict = Mapping[str, Sequence[str]]


@dataclasses.dataclass(frozen=True)
class Monomer:
id: str
num: int


# Note - mmCIF format provides no guarantees on the type of author-assigned
# sequence numbers. They need not be integers.
@dataclasses.dataclass(frozen=True)
class AtomSite:
residue_name: str
author_chain_id: str
mmcif_chain_id: str
author_seq_num: str
mmcif_seq_num: int
insertion_code: str
hetatm_atom: str
model_num: int


# Used to map SEQRES index to a residue in the structure.
@dataclasses.dataclass(frozen=True)
class ResiduePosition:
chain_id: str
residue_number: int
insertion_code: str


@dataclasses.dataclass(frozen=True)
class ResidueAtPosition:
position: Optional[ResiduePosition]
name: str
is_missing: bool
hetflag: str


@dataclasses.dataclass(frozen=True)
class MmcifObject:
"""Representation of a parsed mmCIF file.

Contains:
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
files being processed.
header: Biopython header.
structure: Biopython structure.
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
{'A': 'ABCDEFG'}
seqres_to_structure: Dict; for each chain_id contains a mapping between
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, 1: ResidueAtPosition, ...}}
raw_string: The raw string used to construct the MmcifObject.
"""

file_id: str
header: PdbHeader
structure: PdbStructure
chain_to_seqres: Mapping[ChainId, SeqRes]
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
raw_string: Any
mmcif_to_author_chain_id: Mapping[ChainId, ChainId]
valid_chains: Mapping[ChainId, str]


@dataclasses.dataclass(frozen=True)
class ParsingResult:
"""Returned by the parse function.

Contains:
mmcif_object: A MmcifObject, may be None if no chain could be successfully
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""

mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]


class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""


def mmcif_loop_to_list(prefix: str,
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.

Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html

Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.

Returns:
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
"""
cols = []
data = []
for key, value in parsed_info.items():
if key.startswith(prefix):
cols.append(key)
data.append(value)

assert all([
len(xs) == len(data[0]) for xs in data
]), ('mmCIF error: Not all loops are the same length: %s' % cols)

return [dict(zip(cols, xs)) for xs in zip(*data)]


def mmcif_loop_to_dict(
prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.

Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
index: Which item of loop data should serve as the key.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.

Returns:
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
entries = mmcif_loop_to_list(prefix, parsed_info)
return {entry[index]: entry for entry in entries}


@functools.lru_cache(16, typed=False)
def fast_parse(*,
file_id: str,
mmcif_string: str,
catch_all_errors: bool = True) -> ParsingResult:
"""Entry point, parses an mmcif_string.

Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.

Returns:
A ParsingResult.
"""
errors = {}
try:
parser = MMCIFParser(QUIET=True)
# handle = io.StringIO(mmcif_string)
# full_structure = parser.get_structure('', handle)
parsed_info = parser._mmcif_dict # pylint:disable=protected-access

# Ensure all values are lists, even if singletons.
for key, value in parsed_info.items():
if not isinstance(value, list):
parsed_info[key] = [value]

header = _get_header(parsed_info)

# Determine the protein chains, and their start numbers according to the
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'})

mmcif_to_author_chain_id = {}
# seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != '1':
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id[
atom.mmcif_chain_id] = atom.author_chain_id

mmcif_object = MmcifObject(
file_id=file_id,
header=header,
structure=None,
chain_to_seqres=None,
seqres_to_structure=None,
raw_string=parsed_info,
mmcif_to_author_chain_id=mmcif_to_author_chain_id,
valid_chains=valid_chains,
)

return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, '')] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)


@functools.lru_cache(16, typed=False)
def parse(*,
file_id: str,
mmcif_string: str,
catch_all_errors: bool = True) -> ParsingResult:
"""Entry point, parses an mmcif_string.

Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.

Returns:
A ParsingResult.
"""
errors = {}
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure('', handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
parsed_info = parser._mmcif_dict # pylint:disable=protected-access

# Ensure all values are lists, even if singletons.
for key, value in parsed_info.items():
if not isinstance(value, list):
parsed_info[key] = [value]

header = _get_header(parsed_info)

# Determine the protein chains, and their start numbers according to the
# internal mmCIF numbering scheme (likely but not guaranteed to be 1).
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'})
seq_start_num = {
chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()
}

# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
# the authors / Biopython).
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != '1':
# We only process the first model at the moment.
continue

mmcif_to_author_chain_id[
atom.mmcif_chain_id] = atom.author_chain_id

if atom.mmcif_chain_id in valid_chains:
hetflag = ' '
if atom.hetatm_atom == 'HETATM':
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ('HOH', 'WAT'):
hetflag = 'W'
else:
hetflag = 'H_' + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = ' '
position = ResiduePosition(
chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num),
insertion_code=insertion_code,
)
seq_idx = int(
atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
current = seq_to_structure_mappings.get(
atom.author_chain_id, {})
current[seq_idx] = ResidueAtPosition(
position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag,
)
seq_to_structure_mappings[atom.author_chain_id] = current

# Add missing residue information to seq_to_structure_mappings.
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(
position=None,
name=monomer.id,
is_missing=True,
hetflag=' ')

author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
seq.append(code if len(code) == 1 else 'X')
seq = ''.join(seq)
author_chain_to_sequence[author_chain] = seq

mmcif_object = MmcifObject(
file_id=file_id,
header=header,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info,
mmcif_to_author_chain_id=mmcif_to_author_chain_id,
valid_chains=valid_chains,
)

return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, '')] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)


def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
return next(structure.get_models())


_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21


def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date."""
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
return min(revision_dates)


def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header = {}

experiments = mmcif_loop_to_list('_exptl.', parsed_info)
header['structure_method'] = ','.join(
[experiment['_exptl.method'].lower() for experiment in experiments])

# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if '_pdbx_audit_revision_history.revision_date' in parsed_info:
header['release_date'] = get_release_date(parsed_info)
else:
logging.warning('Could not determine release_date: %s',
parsed_info['_entry.id'])

header['resolution'] = 0.00
for res_key in (
'_refine.ls_d_res_high',
'_em_3d_reconstruction.resolution',
'_reflns.d_resolution_high',
):
if res_key in parsed_info:
try:
raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution)
except ValueError:
logging.debug('Invalid resolution format: %s',
parsed_info[res_key])

return header


def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure."""
return [
AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension
parsed_info['_atom_site.label_comp_id'],
parsed_info['_atom_site.auth_asym_id'],
parsed_info['_atom_site.label_asym_id'],
parsed_info['_atom_site.auth_seq_id'],
parsed_info['_atom_site.label_seq_id'],
parsed_info['_atom_site.pdbx_PDB_ins_code'],
parsed_info['_atom_site.group_PDB'],
parsed_info['_atom_site.pdbx_PDB_model_num'],
)
]


def _get_protein_chains(
*, parsed_info: Mapping[str,
Any]) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.

Args:
parsed_info: _mmcif_dict produced by the Biopython parser.

Returns:
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)

polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
Monomer(
id=entity_poly_seq['_entity_poly_seq.mon_id'],
num=int(entity_poly_seq['_entity_poly_seq.num']),
))

# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id',
parsed_info)

# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)

entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
chain_id = struct_asym['_struct_asym.id']
entity_id = struct_asym['_struct_asym.entity_id']
entity_to_mmcif_chains[entity_id].append(chain_id)

# Identify and return the valid protein chains.
valid_chains = {}
for entity_id, seq_info in polymers.items():
chain_ids = entity_to_mmcif_chains[entity_id]

# Reject polymers without any peptide-like components, such as DNA/RNA.
if any([
'peptide' in chem_comps[monomer.id]['_chem_comp.type']
for monomer in seq_info
]):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains


def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?')

+ 88
- 0
modelscope/models/science/unifold/msa/msa_identifiers.py View File

@@ -0,0 +1,88 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for extracting identifiers from MSA sequence descriptions."""

import dataclasses
import re
from typing import Optional

# Sequences coming from UniProtKB database come in the
# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE`
# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively).
_UNIPROT_PATTERN = re.compile(
r"""
^
# UniProtKB/TrEMBL or UniProtKB/Swiss-Prot
(?:tr|sp)
\|
# A primary accession number of the UniProtKB entry.
(?P<AccessionIdentifier>[A-Za-z0-9]{6,10})
# Occasionally there is a _0 or _1 isoform suffix, which we ignore.
(?:_\d)?
\|
# TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic
# protein ID code.
(?:[A-Za-z0-9]+)
_
# A mnemonic species identification code.
(?P<SpeciesIdentifier>([A-Za-z0-9]){1,5})
# Small BFD uses a final value after an underscore, which we ignore.
(?:_\d+)?
$
""",
re.VERBOSE,
)


@dataclasses.dataclass(frozen=True)
class Identifiers:
species_id: str = ''


def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets accession id and species from an msa sequence identifier.

The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`

Args:
msa_sequence_identifier: a sequence identifier.

Returns:
An `Identifiers` instance with a species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
return Identifiers(species_id=matches.group('SpeciesIdentifier'))
return Identifiers()


def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description. Returns None if no match."""
split_description = description.split()
if split_description:
return split_description[0].partition('/')[0]
else:
return None


def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if sequence_identifier is None:
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)

+ 627
- 0
modelscope/models/science/unifold/msa/parsers.py View File

@@ -0,0 +1,627 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for parsing various file formats."""
import collections
import dataclasses
import itertools
import re
import string
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple

DeletionMatrix = Sequence[Sequence[int]]


@dataclasses.dataclass(frozen=True)
class Msa:
"""Class representing a parsed MSA file."""

sequences: Sequence[str]
deletion_matrix: DeletionMatrix
descriptions: Sequence[str]

def __post_init__(self):
if not (len(self.sequences) == len(self.deletion_matrix) == len(
self.descriptions)):
raise ValueError(
'All fields for an MSA must have the same length. '
f'Got {len(self.sequences)} sequences, '
f'{len(self.deletion_matrix)} rows in the deletion matrix and '
f'{len(self.descriptions)} descriptions.')

def __len__(self):
return len(self.sequences)

def truncate(self, max_seqs: int):
return Msa(
sequences=self.sequences[:max_seqs],
deletion_matrix=self.deletion_matrix[:max_seqs],
descriptions=self.descriptions[:max_seqs],
)


@dataclasses.dataclass(frozen=True)
class TemplateHit:
"""Class representing a template hit."""

index: int
name: str
aligned_cols: int
sum_probs: Optional[float]
query: str
hit_sequence: str
indices_query: List[int]
indices_hit: List[int]


def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
"""Parses FASTA string and returns list of strings with amino-acid sequences.

Arguments:
fasta_string: The string contents of a FASTA file.

Returns:
A tuple of two lists:
* A list of sequences.
* A list of sequence descriptions taken from the comment lines. In the
same order as the sequences.
"""
sequences = []
descriptions = []
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith('>'):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('')
continue
elif not line:
continue # Skip blank lines.
sequences[index] += line

return sequences, descriptions


def parse_stockholm(stockholm_string: str) -> Msa:
"""Parses sequences and deletion matrix from stockholm format alignment.

Args:
stockholm_string: The string contents of a stockholm file. The first
sequence in the file should be the query sequence.

Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* The names of the targets matched, including the jackhmmer subsequence
suffix.
"""
name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines():
line = line.strip()
if not line or line.startswith(('#', '//')):
continue
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence[name] = ''
name_to_sequence[name] += sequence

msa = []
deletion_matrix = []

query = ''
keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0:
# Gather the columns with gaps from the query
query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-']

# Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns])

msa.append(aligned_sequence)

# Count the number of deletions w.r.t. query.
deletion_vec = []
deletion_count = 0
for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-':
if query_res == '-':
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)

return Msa(
sequences=msa,
deletion_matrix=deletion_matrix,
descriptions=list(name_to_sequence.keys()),
)


def parse_a3m(a3m_string: str) -> Msa:
"""Parses sequences and deletion matrix from a3m format alignment.

Args:
a3m_string: The string contents of a a3m file. The first sequence in the
file should be the query sequence.

Returns:
A tuple of:
* A list of sequences that have been aligned to the query. These
might contain duplicates.
* The deletion matrix for the alignment as a list of lists. The element
at `deletion_matrix[i][j]` is the number of residues deleted from
the aligned sequence i at residue position j.
* A list of descriptions, one per sequence, from the a3m file.
"""
sequences, descriptions = parse_fasta(a3m_string)
deletion_matrix = []
for msa_sequence in sequences:
deletion_vec = []
deletion_count = 0
for j in msa_sequence:
if j.islower():
deletion_count += 1
else:
deletion_vec.append(deletion_count)
deletion_count = 0
deletion_matrix.append(deletion_vec)

# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return Msa(
sequences=aligned_sequences,
deletion_matrix=deletion_matrix,
descriptions=descriptions,
)


def _convert_sto_seq_to_a3m(query_non_gaps: Sequence[bool],
sto_seq: str) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap:
yield sequence_res
elif sequence_res != '-':
yield sequence_res.lower()


def convert_stockholm_to_a3m(
stockholm_format: str,
max_sequences: Optional[int] = None,
remove_first_row_gaps: bool = True,
) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False

for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(
sequences) >= max_sequences
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ''
sequences[seqname] += aligned_seq

for line in stockholm_format.splitlines():
if line[:4] == '#=GS':
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ''
if feature != 'DE':
continue
if reached_max_sequences and seqname not in sequences:
continue
descriptions[seqname] = value
if len(descriptions) == len(sequences):
break

# Convert sto format to a3m line by line
a3m_sequences = {}
if remove_first_row_gaps:
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
for seqname, sto_sequence in sequences.items():
# Dots are optional in a3m format and are commonly removed.
out_sequence = sto_sequence.replace('.', '')
if remove_first_row_gaps:
out_sequence = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, out_sequence))
a3m_sequences[seqname] = out_sequence

fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.


def _keep_line(line: str, seqnames: Set[str]) -> bool:
"""Function to decide which lines to keep."""
if not line.strip():
return True
if line.strip() == '//': # End tag
return True
if line.startswith('# STOCKHOLM'): # Start tag
return True
if line.startswith('#=GC RF'): # Reference Annotation Line
return True
if line[:4] == '#=GS': # Description lines - keep if sequence in list.
_, seqname, _ = line.split(maxsplit=2)
return seqname in seqnames
elif line.startswith('#'): # Other markup - filter out
return False
else: # Alignment data - keep if sequence in list.
seqname = line.partition(' ')[0]
return seqname in seqnames


def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str:
"""Truncates a stockholm file to a maximum number of sequences."""
seqnames = set()
filtered_lines = []
for line in stockholm_msa.splitlines():
if line.strip() and not line.startswith(('#', '//')):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname = line.partition(' ')[0]
seqnames.add(seqname)
if len(seqnames) >= max_sequences:
break

for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)

return '\n'.join(filtered_lines) + '\n'


def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str:
"""Removes empty columns (dashes-only) from a Stockholm MSA."""
processed_lines = {}
unprocessed_lines = {}
for i, line in enumerate(stockholm_msa.splitlines()):
if line.startswith('#=GC RF'):
reference_annotation_i = i
reference_annotation_line = line
# Reached the end of this chunk of the alignment. Process chunk.
_, _, first_alignment = line.rpartition(' ')
mask = []
for j in range(len(first_alignment)):
for _, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
if alignment[j] != '-':
mask.append(True)
break
else: # Every row contained a hyphen - empty column.
mask.append(False)
# Add reference annotation for processing with mask.
unprocessed_lines[
reference_annotation_i] = reference_annotation_line

if not any(
mask
): # All columns were empty. Output empty lines for chunk.
for line_index in unprocessed_lines:
processed_lines[line_index] = ''
else:
for line_index, unprocessed_line in unprocessed_lines.items():
prefix, _, alignment = unprocessed_line.rpartition(' ')
masked_alignment = ''.join(
itertools.compress(alignment, mask))
processed_lines[
line_index] = f'{prefix} {masked_alignment}'

# Clear raw_alignments.
unprocessed_lines = {}
elif line.strip() and not line.startswith(('#', '//')):
unprocessed_lines[i] = line
else:
processed_lines[i] = line
return '\n'.join((processed_lines[i] for i in range(len(processed_lines))))


def deduplicate_stockholm_msa(stockholm_msa: str) -> str:
"""Remove duplicate sequences (ignoring insertions wrt query)."""
sequence_dict = collections.defaultdict(str)

# First we must extract all sequences from the MSA.
for line in stockholm_msa.splitlines():
# Only consider the alignments - ignore reference annotation, empty lines,
# descriptions or markup.
if line.strip() and not line.startswith(('#', '//')):
line = line.strip()
seqname, alignment = line.split()
sequence_dict[seqname] += alignment

seen_sequences = set()
seqnames = set()
# First alignment is the query.
query_align = next(iter(sequence_dict.values()))
mask = [c != '-' for c in query_align] # Mask is False for insertions.
for seqname, alignment in sequence_dict.items():
# Apply mask to remove all insertions from the string.
masked_alignment = ''.join(itertools.compress(alignment, mask))
if masked_alignment in seen_sequences:
continue
else:
seen_sequences.add(masked_alignment)
seqnames.add(seqname)

filtered_lines = []
for line in stockholm_msa.splitlines():
if _keep_line(line, seqnames):
filtered_lines.append(line)

return '\n'.join(filtered_lines) + '\n'


def _get_hhr_line_regex_groups(regex_pattern: str,
line: str) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
if match is None:
raise RuntimeError(f'Could not parse query line {line}')
return match.groups()


def _update_hhr_residue_indices_list(sequence: str, start_index: int,
indices_list: List[int]):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index
for symbol in sequence:
if symbol == '-':
indices_list.append(-1)
else:
indices_list.append(counter)
counter += 1


def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
"""Parses the detailed HMM HMM comparison section for a single Hit.

This works on .hhr files generated from both HHBlits and HHSearch.

Args:
detailed_lines: A list of lines from a single comparison section between 2
sequences (which each have their own HMM's)

Returns:
A dictionary with the information from that detailed comparison section

Raises:
RuntimeError: If a certain line cannot be processed
"""
# Parse first 2 lines.
number_of_hit = int(detailed_lines[0].split()[-1])
name_hit = detailed_lines[1][1:]

# Parse the summary line.
pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t'
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t '
']*Template_Neff=(.*)')
match = re.match(pattern, detailed_lines[2])
if match is None:
raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.'
% (detailed_lines, detailed_lines[2]))
(_, _, _, aligned_cols, _, _, sum_probs,
_) = [float(x) for x in match.groups()]

# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query = ''
hit_sequence = ''
indices_query = []
indices_hit = []
length_block = None

for line in detailed_lines[3:]:
# Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp')
and not line.startswith('Q ss_pred')
and not line.startswith('Q Consensus')):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
groups = _get_hhr_line_regex_groups(patt, line[17:])

# Get the length of the parsed block using the start and finish indices,
# and ensure it is the same as the actual block length.
start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1]
end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-'])
length_block = end - start + num_insertions
assert length_block == len(delta_query)

# Update the query sequence and indices list.
query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query)

elif line.startswith('T '):
# Parse the hit sequence.
if (not line.startswith('T ss_dssp')
and not line.startswith('T ss_pred')
and not line.startswith('T Consensus')):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1]
assert length_block == len(delta_hit_sequence)

# Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start,
indices_hit)

return TemplateHit(
index=number_of_hit,
name=name_hit,
aligned_cols=int(aligned_cols),
sum_probs=sum_probs,
query=query,
hit_sequence=hit_sequence,
indices_query=indices_query,
indices_hit=indices_hit,
)


def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
"""Parses the content of an entire HHR file."""
lines = hhr_string.splitlines()

# Each .hhr file starts with a results table, then has a sequence of hit
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.

block_starts = [
i for i, line in enumerate(lines) if line.startswith('No ')
]

hits = []
if block_starts:
block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1):
hits.append(
_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]]))
return hits


def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0}
lines = [line for line in tblout.splitlines() if line[0] != '#']
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
for line in lines:
fields = line.split()
e_value = fields[4]
target_name = fields[0]
e_values[target_name] = float(e_value)
return e_values


def _get_indices(sequence: str, start: int) -> List[int]:
"""Returns indices for non-gap/insert residues starting at the given index."""
indices = []
counter = start
for symbol in sequence:
# Skip gaps but add a placeholder so that the alignment is preserved.
if symbol == '-':
indices.append(-1)
# Skip deleted residues, but increase the counter.
elif symbol.islower():
counter += 1
# Normal aligned residue. Increase the counter and append to indices.
else:
indices.append(counter)
counter += 1
return indices


@dataclasses.dataclass(frozen=True)
class HitMetadata:
pdb_id: str
chain: str
start: int
end: int
length: int
text: str


def _parse_hmmsearch_description(description: str) -> HitMetadata:
"""Parses the hmmsearch A3M sequence description line."""
# Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text
# Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352
match = re.match(
r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$',
description.strip(),
)

if not match:
raise ValueError(f'Could not parse description: "{description}".')

return HitMetadata(
pdb_id=match[1],
chain=match[2],
start=int(match[3]),
end=int(match[4]),
length=int(match[5]),
text=match[6],
)


def parse_hmmsearch_a3m(query_sequence: str,
a3m_string: str,
skip_first: bool = True) -> Sequence[TemplateHit]:
"""Parses an a3m string produced by hmmsearch.

Args:
query_sequence: The query sequence.
a3m_string: The a3m string produced by hmmsearch.
skip_first: Whether to skip the first sequence in the a3m string.

Returns:
A sequence of `TemplateHit` results.
"""
# Zip the descriptions and MSAs together, skip the first query sequence.
parsed_a3m = list(zip(*parse_fasta(a3m_string)))
if skip_first:
parsed_a3m = parsed_a3m[1:]

indices_query = _get_indices(query_sequence, start=0)

hits = []
for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1):
if 'mol:protein' not in hit_description:
continue # Skip non-protein chains.
metadata = _parse_hmmsearch_description(hit_description)
# Aligned columns are only the match states.
aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence])
indices_hit = _get_indices(hit_sequence, start=metadata.start - 1)

hit = TemplateHit(
index=i,
name=f'{metadata.pdb_id}_{metadata.chain}',
aligned_cols=aligned_cols,
sum_probs=None,
query=query_sequence,
hit_sequence=hit_sequence.upper(),
indices_query=indices_query,
indices_hit=indices_hit,
)
hits.append(hit)

return hits

+ 282
- 0
modelscope/models/science/unifold/msa/pipeline.py View File

@@ -0,0 +1,282 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for building the input features for the unifold model."""

import os
from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union

import numpy as np
from absl import logging

from modelscope.models.science.unifold.data import residue_constants
from modelscope.models.science.unifold.msa import (msa_identifiers, parsers,
templates)
from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch,
hmmsearch, jackhmmer)

FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]


def make_sequence_features(sequence: str, description: str,
num_res: int) -> FeatureDict:
"""Constructs a feature dict of sequence features."""
features = {}
features['aatype'] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True,
)
features['between_segment_residues'] = np.zeros((num_res, ),
dtype=np.int32)
features['domain_name'] = np.array([description.encode('utf-8')],
dtype=np.object_)
features['residue_index'] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array([sequence.encode('utf-8')],
dtype=np.object_)
return features


def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError('At least one MSA must be provided.')

int_msa = []
deletion_matrix = []
species_ids = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(
f'MSA {msa_index} must contain at least one sequence.')
for sequence_index, sequence in enumerate(msa.sequences):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
int_msa.append(
[residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence])
deletion_matrix.append(msa.deletion_matrix[sequence_index])
identifiers = msa_identifiers.get_identifiers(
msa.descriptions[sequence_index])
species_ids.append(identifiers.species_id.encode('utf-8'))

num_res = len(msas[0].sequences[0])
num_alignments = len(int_msa)
features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array(
[num_alignments] * num_res, dtype=np.int32)
features['msa_species_identifiers'] = np.array(
species_ids, dtype=np.object_)
return features


def run_msa_tool(
msa_runner,
input_fasta_path: str,
msa_out_path: str,
msa_format: str,
use_precomputed_msas: bool,
) -> Mapping[str, Any]:
"""Runs an MSA tool, checking if output already exists first."""
if not use_precomputed_msas or not os.path.exists(msa_out_path):
result = msa_runner.query(input_fasta_path)[0]
with open(msa_out_path, 'w') as f:
f.write(result[msa_format])
else:
logging.warning('Reading MSA from file %s', msa_out_path)
with open(msa_out_path, 'r') as f:
result = {msa_format: f.read()}
return result


class DataPipeline:
"""Runs the alignment tools and assembles the input features."""

def __init__(
self,
jackhmmer_binary_path: str,
hhblits_binary_path: str,
uniref90_database_path: str,
mgnify_database_path: str,
bfd_database_path: Optional[str],
uniclust30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
uniprot_database_path: Optional[str],
template_searcher: TemplateSearcher,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000,
use_precomputed_msas: bool = False,
):
"""Initializes the data pipeline."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path)
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path],
)
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path)
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniprot_database_path)
self.template_searcher = template_searcher
self.template_featurizer = template_featurizer
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.use_precomputed_msas = use_precomputed_msas

def process(self, input_fasta_path: str,
msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)

uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
jackhmmer_uniref90_result = run_msa_tool(
self.jackhmmer_uniref90_runner,
input_fasta_path,
uniref90_out_path,
'sto',
self.use_precomputed_msas,
)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto')
jackhmmer_mgnify_result = run_msa_tool(
self.jackhmmer_mgnify_runner,
input_fasta_path,
mgnify_out_path,
'sto',
self.use_precomputed_msas,
)

msa_for_templates = jackhmmer_uniref90_result['sto']
msa_for_templates = parsers.truncate_stockholm_msa(
msa_for_templates, max_sequences=self.uniref_max_hits)
msa_for_templates = parsers.deduplicate_stockholm_msa(
msa_for_templates)
msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa(
msa_for_templates)

if self.template_searcher.input_format == 'sto':
pdb_templates_result = self.template_searcher.query(
msa_for_templates)
elif self.template_searcher.input_format == 'a3m':
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
msa_for_templates)
pdb_templates_result = self.template_searcher.query(
uniref90_msa_as_a3m)
else:
raise ValueError('Unrecognized template input format: '
f'{self.template_searcher.input_format}')

pdb_hits_out_path = os.path.join(
msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}')
with open(pdb_hits_out_path, 'w') as f:
f.write(pdb_templates_result)

uniref90_msa = parsers.parse_stockholm(
jackhmmer_uniref90_result['sto'])
uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits)
mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto'])
mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits)

pdb_template_hits = self.template_searcher.get_template_hits(
output_string=pdb_templates_result, input_sequence=input_sequence)

if self._use_small_bfd:
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto')
jackhmmer_small_bfd_result = run_msa_tool(
self.jackhmmer_small_bfd_runner,
input_fasta_path,
bfd_out_path,
'sto',
self.use_precomputed_msas,
)
bfd_msa = parsers.parse_stockholm(
jackhmmer_small_bfd_result['sto'])
else:
bfd_out_path = os.path.join(msa_output_dir,
'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = run_msa_tool(
self.hhblits_bfd_uniclust_runner,
input_fasta_path,
bfd_out_path,
'a3m',
self.use_precomputed_msas,
)
bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m'])

templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, hits=pdb_template_hits)

sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res)

msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa))

logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa))
logging.info('BFD MSA size: %d sequences.', len(bfd_msa))
logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa))
logging.info(
'Final (deduplicated) MSA size: %d sequences.',
msa_features['num_alignments'][0],
)
logging.info(
'Total number of templates (NB: this can include bad '
'templates and is later filtered to top 4): %d.',
templates_result.features['template_domain_names'].shape[0],
)

return {
**sequence_features,
**msa_features,
**templates_result.features
}

def process_uniprot(self, input_fasta_path: str,
msa_output_dir: str) -> FeatureDict:
uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto')
uniprot_result = run_msa_tool(
self.jackhmmer_uniprot_runner,
input_fasta_path,
uniprot_path,
'sto',
self.use_precomputed_msas,
)
msa = parsers.parse_stockholm(uniprot_result['sto'])
msa = msa.truncate(max_seqs=50000)
all_seq_dict = make_msa_features([msa])
return all_seq_dict

+ 1110
- 0
modelscope/models/science/unifold/msa/templates.py
File diff suppressed because it is too large
View File


+ 14
- 0
modelscope/models/science/unifold/msa/tools/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python wrappers for third party tools."""

+ 170
- 0
modelscope/models/science/unifold/msa/tools/hhblits.py View File

@@ -0,0 +1,170 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHblits from Python."""

import glob
import os
import subprocess
from typing import Any, List, Mapping, Optional, Sequence

from absl import logging

from . import utils

_HHBLITS_DEFAULT_P = 20
_HHBLITS_DEFAULT_Z = 500


class HHBlits:
"""Python wrapper of the HHblits binary."""

def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 4,
n_iter: int = 3,
e_value: float = 0.001,
maxseq: int = 1_000_000,
realign_max: int = 100_000,
maxfilt: int = 100_000,
min_prefilter_hits: int = 1000,
all_seqs: bool = False,
alt: Optional[int] = None,
p: int = _HHBLITS_DEFAULT_P,
z: int = _HHBLITS_DEFAULT_Z,
):
"""Initializes the Python HHblits wrapper.

Args:
binary_path: The path to the HHblits executable.
databases: A sequence of HHblits database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to give HHblits.
n_iter: The number of HHblits iterations.
e_value: The E-value, see HHblits docs for more details.
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
maxfilt: Max number of hits allowed to pass the 2nd prefilter.
HHblits default: 20000.
min_prefilter_hits: Min number of hits to pass prefilter.
HHblits default: 100.
all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
HHblits default: False.
alt: Show up to this many alternative alignments.
p: Minimum Prob for a hit to be included in the output hhr file.
HHblits default: 20.
z: Hard cap on number of hits reported in the hhr file.
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.

Raises:
RuntimeError: If HHblits binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases

for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHBlits database %s',
database_path)
raise ValueError(
f'Could not find HHBlits database {database_path}')

self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.maxseq = maxseq
self.realign_max = realign_max
self.maxfilt = maxfilt
self.min_prefilter_hits = min_prefilter_hits
self.all_seqs = all_seqs
self.alt = alt
self.p = p
self.z = z

def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager() as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, 'output.a3m')

db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append(db_path)
cmd = [
self.binary_path,
'-i',
input_fasta_path,
'-cpu',
str(self.n_cpu),
'-oa3m',
a3m_path,
'-o',
'/dev/null',
'-n',
str(self.n_iter),
'-e',
str(self.e_value),
'-maxseq',
str(self.maxseq),
'-realign_max',
str(self.realign_max),
'-maxfilt',
str(self.maxfilt),
'-min_prefilter_hits',
str(self.min_prefilter_hits),
]
if self.all_seqs:
cmd += ['-all']
if self.alt:
cmd += ['-alt', str(self.alt)]
if self.p != _HHBLITS_DEFAULT_P:
cmd += ['-p', str(self.p)]
if self.z != _HHBLITS_DEFAULT_Z:
cmd += ['-Z', str(self.z)]
cmd += db_cmd

logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

with utils.timing('HHblits query'):
stdout, stderr = process.communicate()
retcode = process.wait()

if retcode:
# Logs have a 15k character limit, so log HHblits error line by line.
logging.error('HHblits failed. HHblits stderr begin:')
for error_line in stderr.decode('utf-8').splitlines():
if error_line.strip():
logging.error(error_line.strip())
logging.error('HHblits stderr end')
raise RuntimeError(
'HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' %
(stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))

with open(a3m_path) as f:
a3m = f.read()

raw_output = dict(
a3m=a3m,
output=stdout,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value,
)
return [raw_output]

+ 111
- 0
modelscope/models/science/unifold/msa/tools/hhsearch.py View File

@@ -0,0 +1,111 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run HHsearch from Python."""

import glob
import os
import subprocess
from typing import Sequence

from absl import logging

from modelscope.models.science.unifold.msa import parsers
from . import utils


class HHSearch:
"""Python wrapper of the HHsearch binary."""

def __init__(self,
*,
binary_path: str,
databases: Sequence[str],
maxseq: int = 1_000_000):
"""Initializes the Python HHsearch wrapper.

Args:
binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.

Raises:
RuntimeError: If HHsearch binary not found within the path.
"""
self.binary_path = binary_path
self.databases = databases
self.maxseq = maxseq

for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHsearch database %s',
database_path)
raise ValueError(
f'Could not find HHsearch database {database_path}')

@property
def output_format(self) -> str:
return 'hhr'

@property
def input_format(self) -> str:
return 'a3m'

def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager() as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m')
hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
with open(input_path, 'w') as f:
f.write(a3m)

db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append(db_path)
cmd = [
self.binary_path,
'-i',
input_path,
'-o',
hhr_path,
'-maxseq',
str(self.maxseq),
] + db_cmd

logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing('HHsearch query'):
stdout, stderr = process.communicate()
retcode = process.wait()

if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError(
'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' %
(stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))

with open(hhr_path) as f:
hhr = f.read()
return hhr

def get_template_hits(
self, output_string: str,
input_sequence: str) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
del input_sequence # Used by hmmseach but not needed for hhsearch.
return parsers.parse_hhr(output_string)

+ 143
- 0
modelscope/models/science/unifold/msa/tools/hmmbuild.py View File

@@ -0,0 +1,143 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmbuild - construct HMM profiles from MSA."""

import os
import re
import subprocess

from absl import logging

from . import utils


class Hmmbuild(object):
"""Python wrapper of the hmmbuild binary."""

def __init__(self, *, binary_path: str, singlemx: bool = False):
"""Initializes the Python hmmbuild wrapper.

Args:
binary_path: The path to the hmmbuild executable.
singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to
just use a common substitution score matrix.

Raises:
RuntimeError: If hmmbuild binary not found within the path.
"""
self.binary_path = binary_path
self.singlemx = singlemx

def build_profile_from_sto(self,
sto: str,
model_construction='fast') -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.

Args:
sto: A string with the aligned sequences in the Stockholm format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').

Returns:
A string with the profile in the HMM format.

Raises:
RuntimeError: If hmmbuild fails.
"""
return self._build_profile(sto, model_construction=model_construction)

def build_profile_from_a3m(self, a3m: str) -> str:
"""Builds a HHM for the aligned sequences given as an A3M string.

Args:
a3m: A string with the aligned sequences in the A3M format.

Returns:
A string with the profile in the HMM format.

Raises:
RuntimeError: If hmmbuild fails.
"""
lines = []
for line in a3m.splitlines():
if not line.startswith('>'):
line = re.sub('[a-z]+', '', line) # Remove inserted residues.
lines.append(line + '\n')
msa = ''.join(lines)
return self._build_profile(msa, model_construction='fast')

def _build_profile(self,
msa: str,
model_construction: str = 'fast') -> str:
"""Builds a HMM for the aligned sequences given as an MSA string.

Args:
msa: A string with the aligned sequences, in A3M or STO format.
model_construction: Whether to use reference annotation in the msa to
determine consensus columns ('hand') or default ('fast').

Returns:
A string with the profile in the HMM format.

Raises:
RuntimeError: If hmmbuild fails.
ValueError: If unspecified arguments are provided.
"""
if model_construction not in {'hand', 'fast'}:
raise ValueError(
f'Invalid model_construction {model_construction} - only'
'hand and fast supported.')

with utils.tmpdir_manager() as query_tmp_dir:
input_query = os.path.join(query_tmp_dir, 'query.msa')
output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm')

with open(input_query, 'w') as f:
f.write(msa)

cmd = [self.binary_path]
# If adding flags, we have to do so before the output and input:

if model_construction == 'hand':
cmd.append(f'--{model_construction}')
if self.singlemx:
cmd.append('--singlemx')
cmd.extend([
'--amino',
output_hmm_path,
input_query,
])

logging.info('Launching subprocess %s', cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

with utils.timing('hmmbuild query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info(
'hmmbuild stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'),
stderr.decode('utf-8'),
)

if retcode:
raise RuntimeError(
'hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' %
(stdout.decode('utf-8'), stderr.decode('utf-8')))

with open(output_hmm_path, encoding='utf-8') as f:
hmm = f.read()

return hmm

+ 146
- 0
modelscope/models/science/unifold/msa/tools/hmmsearch.py View File

@@ -0,0 +1,146 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for hmmsearch - search profile against a sequence db."""

import os
import subprocess
from typing import Optional, Sequence

from absl import logging

from modelscope.models.science.unifold.msa import parsers
from . import hmmbuild, utils


class Hmmsearch(object):
"""Python wrapper of the hmmsearch binary."""

def __init__(
self,
*,
binary_path: str,
hmmbuild_binary_path: str,
database_path: str,
flags: Optional[Sequence[str]] = None,
):
"""Initializes the Python hmmsearch wrapper.

Args:
binary_path: The path to the hmmsearch executable.
hmmbuild_binary_path: The path to the hmmbuild executable. Used to build
an hmm from an input a3m.
database_path: The path to the hmmsearch database (FASTA format).
flags: List of flags to be used by hmmsearch.

Raises:
RuntimeError: If hmmsearch binary not found within the path.
"""
self.binary_path = binary_path
self.hmmbuild_runner = hmmbuild.Hmmbuild(
binary_path=hmmbuild_binary_path)
self.database_path = database_path
if flags is None:
# Default hmmsearch run settings.
flags = [
'--F1',
'0.1',
'--F2',
'0.1',
'--F3',
'0.1',
'--incE',
'100',
'-E',
'100',
'--domE',
'100',
'--incdomE',
'100',
]
self.flags = flags

if not os.path.exists(self.database_path):
logging.error('Could not find hmmsearch database %s',
database_path)
raise ValueError(
f'Could not find hmmsearch database {database_path}')

@property
def output_format(self) -> str:
return 'sto'

@property
def input_format(self) -> str:
return 'sto'

def query(self, msa_sto: str) -> str:
"""Queries the database using hmmsearch using a given stockholm msa."""
hmm = self.hmmbuild_runner.build_profile_from_sto(
msa_sto, model_construction='hand')
return self.query_with_hmm(hmm)

def query_with_hmm(self, hmm: str) -> str:
"""Queries the database using hmmsearch using a given hmm."""
with utils.tmpdir_manager() as query_tmp_dir:
hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm')
out_path = os.path.join(query_tmp_dir, 'output.sto')
with open(hmm_input_path, 'w') as f:
f.write(hmm)

cmd = [
self.binary_path,
'--noali', # Don't include the alignment in stdout.
'--cpu',
'8',
]
# If adding flags, we have to do so before the output and input:
if self.flags:
cmd.extend(self.flags)
cmd.extend([
'-A',
out_path,
hmm_input_path,
self.database_path,
])

logging.info('Launching sub-process %s', cmd)
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing(
f'hmmsearch ({os.path.basename(self.database_path)}) query'
):
stdout, stderr = process.communicate()
retcode = process.wait()

if retcode:
raise RuntimeError(
'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' %
(stdout.decode('utf-8'), stderr.decode('utf-8')))

with open(out_path) as f:
out_msa = f.read()

return out_msa

def get_template_hits(
self, output_string: str,
input_sequence: str) -> Sequence[parsers.TemplateHit]:
"""Gets parsed template hits from the raw string output by the tool."""
a3m_string = parsers.convert_stockholm_to_a3m(
output_string, remove_first_row_gaps=False)
template_hits = parsers.parse_hmmsearch_a3m(
query_sequence=input_sequence,
a3m_string=a3m_string,
skip_first=False)
return template_hits

+ 224
- 0
modelscope/models/science/unifold/msa/tools/jackhmmer.py View File

@@ -0,0 +1,224 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library to run Jackhmmer from Python."""

import glob
import os
import subprocess
from concurrent import futures
from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request

from absl import logging

from . import utils


class Jackhmmer:
"""Python wrapper of the Jackhmmer binary."""

def __init__(
self,
*,
binary_path: str,
database_path: str,
n_cpu: int = 8,
n_iter: int = 1,
e_value: float = 0.0001,
z_value: Optional[int] = None,
get_tblout: bool = False,
filter_f1: float = 0.0005,
filter_f2: float = 0.00005,
filter_f3: float = 0.0000005,
incdom_e: Optional[float] = None,
dom_e: Optional[float] = None,
num_streamed_chunks: Optional[int] = None,
streaming_callback: Optional[Callable[[int], None]] = None,
):
"""Initializes the Python Jackhmmer wrapper.

Args:
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format).
n_cpu: The number of CPUs to give Jackhmmer.
n_iter: The number of Jackhmmer iterations.
e_value: The E-value, see Jackhmmer docs for more details.
z_value: The Z-value, see Jackhmmer docs for more details.
get_tblout: Whether to save tblout string.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
round.
dom_e: Domain e-value criteria for inclusion in tblout.
num_streamed_chunks: Number of database chunks to stream over.
streaming_callback: Callback function run after each chunk iteration with
the iteration number as argument.
"""
self.binary_path = binary_path
self.database_path = database_path
self.num_streamed_chunks = num_streamed_chunks

if not os.path.exists(
self.database_path) and num_streamed_chunks is None:
logging.error('Could not find Jackhmmer database %s',
database_path)
raise ValueError(
f'Could not find Jackhmmer database {database_path}')

self.n_cpu = n_cpu
self.n_iter = n_iter
self.e_value = e_value
self.z_value = z_value
self.filter_f1 = filter_f1
self.filter_f2 = filter_f2
self.filter_f3 = filter_f3
self.incdom_e = incdom_e
self.dom_e = dom_e
self.get_tblout = get_tblout
self.streaming_callback = streaming_callback

def _query_chunk(self, input_fasta_path: str,
database_path: str) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager() as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto')

# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
# speeds up the pipeline at the expensive of sensitivity. They are
# currently set very low to make querying Mgnify run in a reasonable
# amount of time.
cmd_flags = [
# Don't pollute stdout with Jackhmmer output.
'-o',
'/dev/null',
'-A',
sto_path,
'--noali',
'--F1',
str(self.filter_f1),
'--F2',
str(self.filter_f2),
'--F3',
str(self.filter_f3),
'--incE',
str(self.e_value),
# Report only sequences with E-values <= x in per-sequence output.
'-E',
str(self.e_value),
'--cpu',
str(self.n_cpu),
'-N',
str(self.n_iter),
]
if self.get_tblout:
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
cmd_flags.extend(['--tblout', tblout_path])

if self.z_value:
cmd_flags.extend(['-Z', str(self.z_value)])

if self.dom_e is not None:
cmd_flags.extend(['--domE', str(self.dom_e)])

if self.incdom_e is not None:
cmd_flags.extend(['--incdomE', str(self.incdom_e)])

cmd = [self.binary_path
] + cmd_flags + [input_fasta_path, database_path]

logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing(
f'Jackhmmer ({os.path.basename(database_path)}) query'):
_, stderr = process.communicate()
retcode = process.wait()

if retcode:
raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n'
% stderr.decode('utf-8'))

# Get e-values for each target name
tbl = ''
if self.get_tblout:
with open(tblout_path) as f:
tbl = f.read()

with open(sto_path) as f:
sto = f.read()

raw_output = dict(
sto=sto,
tbl=tbl,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)

return raw_output

def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
"""Queries the database using Jackhmmer."""
if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)]

db_basename = os.path.basename(self.database_path)

def db_remote_chunk(db_idx):
return f'{self.database_path}.{db_idx}'

def db_local_chunk(db_idx):
return f'/tmp/ramdisk/{db_basename}.{db_idx}'

# db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
# db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'

# Remove existing files to prevent OOM
for f in glob.glob(db_local_chunk('[0-9]*')):
try:
os.remove(f)
except OSError:
print(f'OSError while deleting {f}')

# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
chunked_output = []
for i in range(1, self.num_streamed_chunks + 1):
# Copy the chunk locally
if i == 1:
future = executor.submit(request.urlretrieve,
db_remote_chunk(i),
db_local_chunk(i))
if i < self.num_streamed_chunks:
next_future = executor.submit(
request.urlretrieve,
db_remote_chunk(i + 1),
db_local_chunk(i + 1),
)

# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i)))

# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
# Do not set next_future for the last chunk so that this works even for
# databases with only 1 chunk.
if i < self.num_streamed_chunks:
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output

+ 110
- 0
modelscope/models/science/unifold/msa/tools/kalign.py View File

@@ -0,0 +1,110 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Python wrapper for Kalign."""
import os
import subprocess
from typing import Sequence

from absl import logging

from . import utils


def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file."""
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
a3m = []
for sequence, name in zip(sequences, names):
a3m.append('>' + name + '\n')
a3m.append(sequence + '\n')
return ''.join(a3m)


class Kalign:
"""Python wrapper of the Kalign binary."""

def __init__(self, *, binary_path: str):
"""Initializes the Python Kalign wrapper.

Args:
binary_path: The path to the Kalign binary.

Raises:
RuntimeError: If Kalign binary not found within the path.
"""
self.binary_path = binary_path

def align(self, sequences: Sequence[str]) -> str:
"""Aligns the sequences and returns the alignment in A3M string.

Args:
sequences: A list of query sequence strings. The sequences have to be at
least 6 residues long (Kalign requires this). Note that the order in
which you give the sequences might alter the output slightly as
different alignment tree might get constructed.

Returns:
A string with the alignment in a3m format.

Raises:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging.info('Aligning %d sequences', len(sequences))

for s in sequences:
if len(s) < 6:
raise ValueError(
'Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).' % (s, len(s)))

with utils.tmpdir_manager() as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')

with open(input_fasta_path, 'w') as f:
f.write(_to_a3m(sequences))

cmd = [
self.binary_path,
'-i',
input_fasta_path,
'-o',
output_a3m_path,
'-format',
'fasta',
]

logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

with utils.timing('Kalign query'):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info(
'Kalign stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'),
stderr.decode('utf-8'),
)

if retcode:
raise RuntimeError(
'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' %
(stdout.decode('utf-8'), stderr.decode('utf-8')))

with open(output_a3m_path) as f:
a3m = f.read()

return a3m

+ 40
- 0
modelscope/models/science/unifold/msa/tools/utils.py View File

@@ -0,0 +1,40 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common utilities for data pipeline tools."""
import contextlib
import shutil
import tempfile
import time
from typing import Optional

from absl import logging


@contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir)
try:
yield tmpdir
finally:
shutil.rmtree(tmpdir, ignore_errors=True)


@contextlib.contextmanager
def timing(msg: str):
logging.info('Started %s', msg)
tic = time.time()
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)

+ 89
- 0
modelscope/models/science/unifold/msa/utils.py View File

@@ -0,0 +1,89 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import os
from typing import Mapping, Sequence

import json
from absl import logging

from modelscope.models.science.unifold.data import protein


def get_chain_id_map(
sequences: Sequence[str],
descriptions: Sequence[str],
):
"""
Makes a mapping from PDB-format chain ID to sequence and description,
and parses the order of multi-chains
"""
unique_seqs = []
for seq in sequences:
if seq not in unique_seqs:
unique_seqs.append(seq)

chain_id_map = {
chain_id: {
'descriptions': [],
'sequence': seq
}
for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs)
}
chain_order = []

for seq, des in zip(sequences, descriptions):
chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)]
chain_id_map[chain_id]['descriptions'].append(des)
chain_order.append(chain_id)

return chain_id_map, chain_order


def divide_multi_chains(
fasta_name: str,
output_dir_base: str,
sequences: Sequence[str],
descriptions: Sequence[str],
):
"""
Divides the multi-chains fasta into several single fasta files and
records multi-chains mapping information.
"""
if len(sequences) != len(descriptions):
raise ValueError('sequences and descriptions must have equal length. '
f'Got {len(sequences)} != {len(descriptions)}.')
if len(sequences) > protein.PDB_MAX_CHAINS:
raise ValueError(
'Cannot process more chains than the PDB format supports. '
f'Got {len(sequences)} chains.')

chain_id_map, chain_order = get_chain_id_map(sequences, descriptions)

output_dir = os.path.join(output_dir_base, fasta_name)
if not os.path.exists(output_dir):
os.makedirs(output_dir)

chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json')
with open(chain_id_map_path, 'w') as f:
json.dump(chain_id_map, f, indent=4, sort_keys=True)

chain_order_path = os.path.join(output_dir, 'chains.txt')
with open(chain_order_path, 'w') as f:
f.write(' '.join(chain_order))

logging.info('Mapping multi-chains fasta with chain order: %s',
' '.join(chain_order))

temp_names = []
temp_paths = []
for chain_id in chain_id_map.keys():
temp_name = fasta_name + '_{}'.format(chain_id)
temp_path = os.path.join(output_dir, temp_name + '.fasta')
des = 'chain_{}'.format(chain_id)
seq = chain_id_map[chain_id]['sequence']
with open(temp_path, 'w') as f:
f.write('>' + des + '\n' + seq)
temp_names.append(temp_name)
temp_paths.append(temp_path)
return temp_names, temp_paths

+ 22
- 0
modelscope/pipelines/science/__init__.py View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .protein_structure_pipeline import ProteinStructurePipeline

else:
_import_structure = {
'protein_structure_pipeline': ['ProteinStructurePipeline']
}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 215
- 0
modelscope/pipelines/science/protein_structure_pipeline.py View File

@@ -0,0 +1,215 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import time
from typing import Any, Dict, List, Optional, Union

import json
import numpy as np
import torch
from unicore.utils import tensor_tree_map

from modelscope.metainfo import Pipelines
from modelscope.models.base import Model
from modelscope.models.science.unifold.config import model_config
from modelscope.models.science.unifold.data import protein, residue_constants
from modelscope.models.science.unifold.dataset import (UnifoldDataset,
load_and_process)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Pipeline, Tensor
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import Preprocessor, build_preprocessor
from modelscope.utils.constant import Fields, Frameworks, Tasks
from modelscope.utils.device import device_placement
from modelscope.utils.hub import read_config
from modelscope.utils.logger import get_logger

logger = get_logger()

__all__ = ['ProteinStructurePipeline']


def automatic_chunk_size(seq_len):
if seq_len < 512:
chunk_size = 256
elif seq_len < 1024:
chunk_size = 128
elif seq_len < 2048:
chunk_size = 32
elif seq_len < 3072:
chunk_size = 16
else:
chunk_size = 1
return chunk_size


def load_feature_for_one_target(
config,
data_folder,
seed=0,
is_multimer=False,
use_uniprot=False,
symmetry_group=None,
):
if not is_multimer:
uniprot_msa_dir = None
sequence_ids = ['A']
if use_uniprot:
uniprot_msa_dir = data_folder

else:
uniprot_msa_dir = data_folder
sequence_ids = open(os.path.join(data_folder,
'chains.txt')).readline().split()

if symmetry_group is None:
batch, _ = load_and_process(
config=config.data,
mode='predict',
seed=seed,
batch_idx=None,
data_idx=0,
is_distillation=False,
sequence_ids=sequence_ids,
monomer_feature_dir=data_folder,
uniprot_msa_dir=uniprot_msa_dir,
)
else:
raise NotImplementedError
batch = UnifoldDataset.collater([batch])
return batch


@PIPELINES.register_module(
Tasks.protein_structure, module_name=Pipelines.protein_structure)
class ProteinStructurePipeline(Pipeline):

def __init__(self,
model: Union[Model, str],
preprocessor: Optional[Preprocessor] = None,
**kwargs):
"""Use `model` and `preprocessor` to create a protein structure pipeline for prediction.

Args:
model (str or Model): Supply either a local model dir which supported the protein structure task,
or a model id from the model hub, or a torch model instance.
preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
the model if supplied.

Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline(task='protein-structure',
>>> model='DPTech/uni-fold-monomer')
>>> protein = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC'
>>> print(pipeline_ins(protein))

"""
import copy
model_path = copy.deepcopy(model) if isinstance(model, str) else None
cfg = read_config(model_path) # only model is str
self.cfg = cfg
self.config = model_config(
cfg['pipeline']['model_name']) # alphafold config
model = model if isinstance(
model, Model) else Model.from_pretrained(model_path)
self.postprocessor = cfg.pop('postprocessor', None)
if preprocessor is None:
preprocessor_cfg = cfg.preprocessor
preprocessor = build_preprocessor(preprocessor_cfg, Fields.science)
model.eval()
model.model.inference_mode()
model.model_dir = model_path
super().__init__(model=model, preprocessor=preprocessor, **kwargs)

def _sanitize_parameters(self, **pipeline_parameters):
return pipeline_parameters, pipeline_parameters, pipeline_parameters

def _process_single(self, input, *args, **kwargs) -> Dict[str, Any]:
preprocess_params = kwargs.get('preprocess_params', {})
forward_params = kwargs.get('forward_params', {})
postprocess_params = kwargs.get('postprocess_params', {})
out = self.preprocess(input, **preprocess_params)
with device_placement(self.framework, self.device_name):
with torch.no_grad():
out = self.forward(out, **forward_params)

out = self.postprocess(out, **postprocess_params)
return out

def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
plddts = {}
ptms = {}

output_dir = os.path.join(self.preprocessor.output_dir_base,
inputs['target_id'])

pdbs = []
for seed in range(self.cfg['pipeline']['times']):
cur_seed = hash((42, seed)) % 100000
batch = load_feature_for_one_target(
self.config,
output_dir,
cur_seed,
is_multimer=inputs['is_multimer'],
use_uniprot=inputs['is_multimer'],
symmetry_group=self.preprocessor.symmetry_group,
)
seq_len = batch['aatype'].shape[-1]
self.model.model.globals.chunk_size = automatic_chunk_size(seq_len)

with torch.no_grad():
batch = {
k: torch.as_tensor(v, device='cuda:0')
for k, v in batch.items()
}
out = self.model(batch)

def to_float(x):
if x.dtype == torch.bfloat16 or x.dtype == torch.half:
return x.float()
else:
return x

# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch)
batch = tensor_tree_map(to_float, batch)
out = tensor_tree_map(lambda t: t[0, ...], out[0])
out = tensor_tree_map(to_float, out)
batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch)
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)

plddt = out['plddt']
mean_plddt = np.mean(plddt)
plddt_b_factors = np.repeat(
plddt[..., None], residue_constants.atom_type_num, axis=-1)
# TODO: , may need to reorder chains, based on entity_ids
cur_protein = protein.from_prediction(
features=batch, result=out, b_factors=plddt_b_factors)
cur_save_name = (f'{cur_seed}')
plddts[cur_save_name] = str(mean_plddt)
if inputs[
'is_multimer'] and self.preprocessor.symmetry_group is None:
ptms[cur_save_name] = str(np.mean(out['iptm+ptm']))
with open(os.path.join(output_dir, cur_save_name + '.pdb'),
'w') as f:
f.write(protein.to_pdb(cur_protein))
pdbs.append(protein.to_pdb(cur_protein))

logger.info('plddts:' + str(plddts))
model_name = self.cfg['pipeline']['model_name']
score_name = f'{model_name}'
plddt_fname = score_name + '_plddt.json'

with open(os.path.join(output_dir, plddt_fname), 'w') as f:
json.dump(plddts, f, indent=4)
if ptms:
logger.info('ptms' + str(ptms))
ptm_fname = score_name + '_ptm.json'
with open(os.path.join(output_dir, ptm_fname), 'w') as f:
json.dump(ptms, f, indent=4)

return pdbs

def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params):
return inputs

+ 20
- 0
modelscope/preprocessors/science/__init__.py View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING

from modelscope.utils.import_utils import LazyImportModule

if TYPE_CHECKING:
from .unifold import (UniFoldPreprocessor)

else:
_import_structure = {'unifold': ['UniFoldPreprocessor']}

import sys

sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

+ 569
- 0
modelscope/preprocessors/science/uni_fold.py View File

@@ -0,0 +1,569 @@
# The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license,
# and is publicly available at https://github.com/dptech-corp/Uni-Fold.

import gzip
import hashlib
import logging
import os
import pickle
import random
import re
import tarfile
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from unittest import result

import json
import numpy as np
import requests
import torch
from tqdm import tqdm

from modelscope.metainfo import Preprocessors
from modelscope.models.science.unifold.data import protein, residue_constants
from modelscope.models.science.unifold.data.protein import PDB_CHAIN_IDS
from modelscope.models.science.unifold.data.utils import compress_features
from modelscope.models.science.unifold.msa import parsers, pipeline, templates
from modelscope.models.science.unifold.msa.tools import hhsearch
from modelscope.models.science.unifold.msa.utils import divide_multi_chains
from modelscope.preprocessors.base import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.constant import Fields

__all__ = [
'UniFoldPreprocessor',
]

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'
DEFAULT_API_SERVER = 'https://api.colabfold.com'


def run_mmseqs2(
x,
prefix,
use_env=True,
use_templates=False,
use_pairing=False,
host_url='https://api.colabfold.com') -> Tuple[List[str], List[str]]:
submission_endpoint = 'ticket/pair' if use_pairing else 'ticket/msa'

def submit(seqs, mode, N=101):
n, query = N, ''
for seq in seqs:
query += f'>{n}\n{seq}\n'
n += 1

res = requests.post(
f'{host_url}/{submission_endpoint}',
data={
'q': query,
'mode': mode
})
try:
out = res.json()
except ValueError:
out = {'status': 'ERROR'}
return out

def status(ID):
res = requests.get(f'{host_url}/ticket/{ID}')
try:
out = res.json()
except ValueError:
out = {'status': 'ERROR'}
return out

def download(ID, path):
res = requests.get(f'{host_url}/result/download/{ID}')
with open(path, 'wb') as out:
out.write(res.content)

# process input x
seqs = [x] if isinstance(x, str) else x

mode = 'env'
if use_pairing:
mode = ''
use_templates = False
use_env = False

# define path
path = f'{prefix}'
if not os.path.isdir(path):
os.mkdir(path)

# call mmseqs2 api
tar_gz_file = f'{path}/out_{mode}.tar.gz'
N, REDO = 101, True

# deduplicate and keep track of order
seqs_unique = []
# TODO this might be slow for large sets
[seqs_unique.append(x) for x in seqs if x not in seqs_unique]
Ms = [N + seqs_unique.index(seq) for seq in seqs]
# lets do it!
if not os.path.isfile(tar_gz_file):
TIME_ESTIMATE = 150 * len(seqs_unique)
with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
while REDO:
pbar.set_description('SUBMIT')

# Resubmit job until it goes through
out = submit(seqs_unique, mode, N)
while out['status'] in ['UNKNOWN', 'RATELIMIT']:
sleep_time = 5 + random.randint(0, 5)
# logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
# resubmit
time.sleep(sleep_time)
out = submit(seqs_unique, mode, N)

if out['status'] == 'ERROR':
error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.'
error = error + 'If error persists, please try again an hour later.'
raise Exception(error)

if out['status'] == 'MAINTENANCE':
raise Exception(
'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.'
)

# wait for job to finish
ID, TIME = out['id'], 0
pbar.set_description(out['status'])
while out['status'] in ['UNKNOWN', 'RUNNING', 'PENDING']:
t = 5 + random.randint(0, 5)
# logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
time.sleep(t)
out = status(ID)
pbar.set_description(out['status'])
if out['status'] == 'RUNNING':
TIME += t
pbar.update(n=t)

if out['status'] == 'COMPLETE':
if TIME < TIME_ESTIMATE:
pbar.update(n=(TIME_ESTIMATE - TIME))
REDO = False

if out['status'] == 'ERROR':
REDO = False
error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.'
error = error + 'If error persists, please try again an hour later.'
raise Exception(error)

# Download results
download(ID, tar_gz_file)

# prep list of a3m files
if use_pairing:
a3m_files = [f'{path}/pair.a3m']
else:
a3m_files = [f'{path}/uniref.a3m']
if use_env:
a3m_files.append(f'{path}/bfd.mgnify30.metaeuk30.smag30.a3m')

# extract a3m files
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
with tarfile.open(tar_gz_file) as tar_gz:
tar_gz.extractall(path)

# templates
if use_templates:
templates = {}

with open(f'{path}/pdb70.m8', 'r') as f:
lines = f.readlines()
for line in lines:
p = line.rstrip().split()
M, pdb, _, _ = p[0], p[1], p[2], p[10] # qid, e_value
M = int(M)
if M not in templates:
templates[M] = []
templates[M].append(pdb)

template_paths = {}
for k, TMPL in templates.items():
TMPL_PATH = f'{prefix}/templates_{k}'
if not os.path.isdir(TMPL_PATH):
os.mkdir(TMPL_PATH)
TMPL_LINE = ','.join(TMPL[:20])
os.system(
f'curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/'
)
os.system(
f'cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex'
)
os.system(f'touch {TMPL_PATH}/pdb70_cs219.ffdata')
template_paths[k] = TMPL_PATH

# gather a3m lines
a3m_lines = {}
for a3m_file in a3m_files:
update_M, M = True, None
with open(a3m_file, 'r') as f:
lines = f.readlines()
for line in lines:
if len(line) > 0:
if '\x00' in line:
line = line.replace('\x00', '')
update_M = True
if line.startswith('>') and update_M:
M = int(line[1:].rstrip())
update_M = False
if M not in a3m_lines:
a3m_lines[M] = []
a3m_lines[M].append(line)

# return results

a3m_lines = [''.join(a3m_lines[n]) for n in Ms]

if use_templates:
template_paths_ = []
for n in Ms:
if n not in template_paths:
template_paths_.append(None)
# print(f"{n-N}\tno_templates_found")
else:
template_paths_.append(template_paths[n])
template_paths = template_paths_

return (a3m_lines, template_paths) if use_templates else a3m_lines


def get_null_template(query_sequence: Union[List[str], str],
num_temp: int = 1) -> Dict[str, Any]:
ln = (
len(query_sequence) if isinstance(query_sequence, str) else sum(
len(s) for s in query_sequence))
output_templates_sequence = 'A' * ln
# output_confidence_scores = np.full(ln, 1.0)

templates_all_atom_positions = np.zeros(
(ln, templates.residue_constants.atom_type_num, 3))
templates_all_atom_masks = np.zeros(
(ln, templates.residue_constants.atom_type_num))
templates_aatype = templates.residue_constants.sequence_to_onehot(
output_templates_sequence,
templates.residue_constants.HHBLITS_AA_TO_ID)
template_features = {
'template_all_atom_positions':
np.tile(templates_all_atom_positions[None], [num_temp, 1, 1, 1]),
'template_all_atom_masks':
np.tile(templates_all_atom_masks[None], [num_temp, 1, 1]),
'template_sequence': ['none'.encode()] * num_temp,
'template_aatype':
np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]),
'template_domain_names': ['none'.encode()] * num_temp,
'template_sum_probs':
np.zeros([num_temp], dtype=np.float32),
}
return template_features


def get_template(a3m_lines: str, template_path: str,
query_sequence: str) -> Dict[str, Any]:
template_featurizer = templates.HhsearchHitFeaturizer(
mmcif_dir=template_path,
max_template_date='2100-01-01',
max_hits=20,
kalign_binary_path='kalign',
release_dates_path=None,
obsolete_pdbs_path=None,
)

hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path='hhsearch', databases=[f'{template_path}/pdb70'])

hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines)
hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result)
templates_result = template_featurizer.get_templates(
query_sequence=query_sequence, hits=hhsearch_hits)
return dict(templates_result.features)


@PREPROCESSORS.register_module(
Fields.science, module_name=Preprocessors.unifold_preprocessor)
class UniFoldPreprocessor(Preprocessor):

def __init__(self, **cfg):
self.symmetry_group = cfg['symmetry_group'] # "C1"
if not self.symmetry_group:
self.symmetry_group = None
self.MIN_SINGLE_SEQUENCE_LENGTH = 16 # TODO: change to cfg
self.MAX_SINGLE_SEQUENCE_LENGTH = 1000
self.MAX_MULTIMER_LENGTH = 1000
self.jobname = 'unifold'
self.output_dir_base = './unifold-predictions'
os.makedirs(self.output_dir_base, exist_ok=True)

def clean_and_validate_sequence(self, input_sequence: str, min_length: int,
max_length: int) -> str:
clean_sequence = input_sequence.translate(
str.maketrans('', '', ' \n\t')).upper()
aatypes = set(residue_constants.restypes) # 20 standard aatypes.
if not set(clean_sequence).issubset(aatypes):
raise ValueError(
f'Input sequence contains non-amino acid letters: '
f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard '
'amino acids as inputs.')
if len(clean_sequence) < min_length:
raise ValueError(
f'Input sequence is too short: {len(clean_sequence)} amino acids, '
f'while the minimum is {min_length}')
if len(clean_sequence) > max_length:
raise ValueError(
f'Input sequence is too long: {len(clean_sequence)} amino acids, while '
f'the maximum is {max_length}. You may be able to run it with the full '
f'Uni-Fold system depending on your resources (system memory, '
f'GPU memory).')
return clean_sequence

def validate_input(self, input_sequences: Sequence[str],
symmetry_group: str, min_length: int, max_length: int,
max_multimer_length: int) -> Tuple[Sequence[str], bool]:
"""Validates and cleans input sequences and determines which model to use."""
sequences = []

for input_sequence in input_sequences:
if input_sequence.strip():
input_sequence = self.clean_and_validate_sequence(
input_sequence=input_sequence,
min_length=min_length,
max_length=max_length)
sequences.append(input_sequence)

if symmetry_group is not None and symmetry_group != 'C1':
if symmetry_group.startswith(
'C') and symmetry_group[1:].isnumeric():
print(
f'Using UF-Symmetry with group {symmetry_group}. If you do not '
f'want to use UF-Symmetry, please use `C1` and copy the AU '
f'sequences to the count in the assembly.')
is_multimer = (len(sequences) > 1)
return sequences, is_multimer, symmetry_group
else:
raise ValueError(
f'UF-Symmetry does not support symmetry group '
f'{symmetry_group} currently. Cyclic groups (Cx) are '
f'supported only.')

elif len(sequences) == 1:
print('Using the single-chain model.')
return sequences, False, None

elif len(sequences) > 1:
total_multimer_length = sum([len(seq) for seq in sequences])
if total_multimer_length > max_multimer_length:
raise ValueError(
f'The total length of multimer sequences is too long: '
f'{total_multimer_length}, while the maximum is '
f'{max_multimer_length}. Please use the full AlphaFold '
f'system for long multimers.')
print(f'Using the multimer model with {len(sequences)} sequences.')
return sequences, True, None

else:
raise ValueError(
'No input amino acid sequence provided, please provide at '
'least one sequence.')

def add_hash(self, x, y):
return x + '_' + hashlib.sha1(y.encode()).hexdigest()[:5]

def get_msa_and_templates(
self,
jobname: str,
query_seqs_unique: Union[str, List[str]],
result_dir: Path,
msa_mode: str,
use_templates: bool,
homooligomers_num: int = 1,
host_url: str = DEFAULT_API_SERVER,
) -> Tuple[Optional[List[str]], Optional[List[str]], List[str], List[int],
List[Dict[str, Any]]]:

use_env = msa_mode == 'MMseqs2'

template_features = []
if use_templates:
a3m_lines_mmseqs2, template_paths = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_templates=True,
host_url=host_url,
)
if template_paths is None:
for index in range(0, len(query_seqs_unique)):
template_feature = get_null_template(
query_seqs_unique[index])
template_features.append(template_feature)
else:
for index in range(0, len(query_seqs_unique)):
if template_paths[index] is not None:
template_feature = get_template(
a3m_lines_mmseqs2[index],
template_paths[index],
query_seqs_unique[index],
)
if len(template_feature['template_domain_names']) == 0:
template_feature = get_null_template(
query_seqs_unique[index])
else:
template_feature = get_null_template(
query_seqs_unique[index])
template_features.append(template_feature)
else:
for index in range(0, len(query_seqs_unique)):
template_feature = get_null_template(query_seqs_unique[index])
template_features.append(template_feature)

if msa_mode == 'single_sequence':
a3m_lines = []
num = 101
for i, seq in enumerate(query_seqs_unique):
a3m_lines.append('>' + str(num + i) + '\n' + seq)
else:
# find normal a3ms
a3m_lines = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_pairing=False,
host_url=host_url,
)
if len(query_seqs_unique) > 1:
# find paired a3m if not a homooligomers
paired_a3m_lines = run_mmseqs2(
query_seqs_unique,
str(result_dir.joinpath(jobname)),
use_env,
use_pairing=True,
host_url=host_url,
)
else:
num = 101
paired_a3m_lines = []
for i in range(0, homooligomers_num):
paired_a3m_lines.append('>' + str(num + i) + '\n'
+ query_seqs_unique[0] + '\n')

return (
a3m_lines,
paired_a3m_lines,
template_features,
)

def __call__(self, data: Union[str, Tuple]):
if isinstance(data, str):
data = [data, '', '', '']
basejobname = ''.join(data)
basejobname = re.sub(r'\W+', '', basejobname)
target_id = self.add_hash(self.jobname, basejobname)

sequences, is_multimer, _ = self.validate_input(
input_sequences=data,
symmetry_group=self.symmetry_group,
min_length=self.MIN_SINGLE_SEQUENCE_LENGTH,
max_length=self.MAX_SINGLE_SEQUENCE_LENGTH,
max_multimer_length=self.MAX_MULTIMER_LENGTH)

descriptions = [
'> ' + target_id + ' seq' + str(ii)
for ii in range(len(sequences))
]

if is_multimer:
divide_multi_chains(target_id, self.output_dir_base, sequences,
descriptions)

s = []
for des, seq in zip(descriptions, sequences):
s += [des, seq]

unique_sequences = []
[
unique_sequences.append(x) for x in sequences
if x not in unique_sequences
]

if len(unique_sequences) == 1:
homooligomers_num = len(sequences)
else:
homooligomers_num = 1

with open(f'{self.jobname}.fasta', 'w') as f:
f.write('\n'.join(s))

result_dir = Path(self.output_dir_base)
output_dir = os.path.join(self.output_dir_base, target_id)

# msa_mode = 'single_sequence'
msa_mode = 'MMseqs2'
use_templates = True

unpaired_msa, paired_msa, template_results = self.get_msa_and_templates(
target_id,
unique_sequences,
result_dir=result_dir,
msa_mode=msa_mode,
use_templates=use_templates,
homooligomers_num=homooligomers_num)

features = []
pair_features = []

for idx, seq in enumerate(unique_sequences):
chain_id = PDB_CHAIN_IDS[idx]
sequence_features = pipeline.make_sequence_features(
sequence=seq,
description=f'> {self.jobname} seq {chain_id}',
num_res=len(seq))
monomer_msa = parsers.parse_a3m(unpaired_msa[idx])
msa_features = pipeline.make_msa_features([monomer_msa])
template_features = template_results[idx]
feature_dict = {
**sequence_features,
**msa_features,
**template_features
}
feature_dict = compress_features(feature_dict)
features_output_path = os.path.join(
output_dir, '{}.feature.pkl.gz'.format(chain_id))
pickle.dump(
feature_dict,
gzip.GzipFile(features_output_path, 'wb'),
protocol=4)
features.append(feature_dict)

if is_multimer:
multimer_msa = parsers.parse_a3m(paired_msa[idx])
pair_features = pipeline.make_msa_features([multimer_msa])
pair_feature_dict = compress_features(pair_features)
uniprot_output_path = os.path.join(
output_dir, '{}.uniprot.pkl.gz'.format(chain_id))
pickle.dump(
pair_feature_dict,
gzip.GzipFile(uniprot_output_path, 'wb'),
protocol=4,
)
pair_features.append(pair_feature_dict)

# return features, pair_features, target_id
return {
'features': features,
'pair_features': pair_features,
'target_id': target_id,
'is_multimer': is_multimer,
}


if __name__ == '__main__':
proc = UniFoldPreprocessor()
protein_example = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + \
'TVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI'
features, pair_features = proc.__call__(protein_example)
import ipdb
ipdb.set_trace()

+ 10
- 1
modelscope/utils/constant.py View File

@@ -9,6 +9,7 @@ class Fields(object):
nlp = 'nlp'
audio = 'audio'
multi_modal = 'multi-modal'
science = 'science'


class CVTasks(object):
@@ -151,6 +152,10 @@ class MultiModalTasks(object):
image_text_retrieval = 'image-text-retrieval'


class ScienceTasks(object):
protein_structure = 'protein-structure'


class TasksIODescriptions(object):
image_to_image = 'image_to_image',
images_to_image = 'images_to_image',
@@ -167,7 +172,7 @@ class TasksIODescriptions(object):
generative_multi_modal_embedding = 'generative_multi_modal_embedding'


class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks):
class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks):
""" Names for tasks supported by modelscope.

Holds the standard task name to use for identifying different tasks.
@@ -196,6 +201,10 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks):
getattr(Tasks, attr) for attr in dir(MultiModalTasks)
if not attr.startswith('__')
],
Fields.science: [
getattr(Tasks, attr) for attr in dir(ScienceTasks)
if not attr.startswith('__')
],
}

for field, tasks in field_dict.items():


+ 6
- 0
requirements/science.txt View File

@@ -0,0 +1,6 @@
iopath
lmdb
ml_collections
scipy
tensorboardX
tokenizers

+ 34
- 0
tests/pipelines/test_unifold.py View File

@@ -0,0 +1,34 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest

from modelscope.hub.snapshot_download import snapshot_download
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level


class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck):

def setUp(self) -> None:
self.task = Tasks.protein_structure
self.model_id = 'DPTech/uni-fold-monomer'
self.model_id_multimer = 'DPTech/uni-fold-multimer'

self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD'
self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \
'NIAALKNHIDKIKPIAMQIYKKYSKNIP'

@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
model_dir = snapshot_download(self.model_id)
mono_pipeline_ins = pipeline(task=self.task, model=model_dir)
_ = mono_pipeline_ins(self.protein)

model_dir1 = snapshot_download(self.model_id_multimer)
multi_pipeline_ins = pipeline(task=self.task, model=model_dir1)
_ = multi_pipeline_ins(self.protein_multimer)


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save