* mglm init * add mglm requirements Co-authored-by: Yufeng <zhuyufeng@gmail.com> Co-authored-by: wenmeng.zwm <wenmeng.zwm@alibaba-inc.com>master
| @@ -82,6 +82,7 @@ class Models(object): | |||||
| bert_for_ds = 'bert-for-document-segmentation' | bert_for_ds = 'bert-for-document-segmentation' | ||||
| ponet = 'ponet' | ponet = 'ponet' | ||||
| T5 = 'T5' | T5 = 'T5' | ||||
| mglm = 'mglm' | |||||
| bloom = 'bloom' | bloom = 'bloom' | ||||
| # audio models | # audio models | ||||
| @@ -251,6 +252,7 @@ class Pipelines(object): | |||||
| relation_extraction = 'relation-extraction' | relation_extraction = 'relation-extraction' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| mglm_text_summarization = 'mglm-text-summarization' | |||||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | translation_en_to_de = 'translation_en_to_de' # keep it underscore | ||||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | ||||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | ||||
| @@ -376,6 +378,7 @@ class Preprocessors(object): | |||||
| re_tokenizer = 're-tokenizer' | re_tokenizer = 're-tokenizer' | ||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| mglm_summarization = 'mglm-summarization' | |||||
| sentence_piece = 'sentence-piece' | sentence_piece = 'sentence-piece' | ||||
| # audio preprocessor | # audio preprocessor | ||||
| @@ -35,6 +35,7 @@ if TYPE_CHECKING: | |||||
| SbertTokenizerFast, | SbertTokenizerFast, | ||||
| ) | ) | ||||
| from .T5 import T5ForConditionalGeneration | from .T5 import T5ForConditionalGeneration | ||||
| from .mglm import MGLMForTextSummarization | |||||
| from .task_models import ( | from .task_models import ( | ||||
| FeatureExtractionModel, | FeatureExtractionModel, | ||||
| InformationExtractionModel, | InformationExtractionModel, | ||||
| @@ -106,6 +107,7 @@ else: | |||||
| ], | ], | ||||
| 'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
| 'T5': ['T5ForConditionalGeneration'], | 'T5': ['T5ForConditionalGeneration'], | ||||
| 'mglm': ['MGLMForTextSummarization'], | |||||
| 'gpt_neo': ['GPTNeoModel'], | 'gpt_neo': ['GPTNeoModel'], | ||||
| 'bloom': ['BloomModel'], | 'bloom': ['BloomModel'], | ||||
| } | } | ||||
| @@ -0,0 +1,22 @@ | |||||
| # Modified by Zhipu.AI | |||||
| # Original Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .mglm_for_text_summarization import mGlmForSummarization | |||||
| else: | |||||
| _import_structure = { | |||||
| 'mglm_for_text_summarization': ['MGLMForTextSummarization'], | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,793 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """argparser configuration""" | |||||
| import argparse | |||||
| import os | |||||
| import deepspeed | |||||
| import json | |||||
| import torch | |||||
| from .utils import get_hostname | |||||
| def add_model_config_args(parser): | |||||
| """Model arguments""" | |||||
| group = parser.add_argument_group('model', 'model configuration') | |||||
| group.add_argument( | |||||
| '--transformer-xl', | |||||
| action='store_true', | |||||
| help='use transformer-xl for training') | |||||
| group.add_argument( | |||||
| '--pretrained-bert', | |||||
| action='store_true', | |||||
| help='use a pretrained bert-large-uncased model instead' | |||||
| 'of initializing from scratch. See ' | |||||
| '--tokenizer-model-type to specify which pretrained ' | |||||
| 'BERT model to use') | |||||
| group.add_argument( | |||||
| '--encoder-decoder', | |||||
| action='store_true', | |||||
| help='use the encoder-decoder architecture for blocklm') | |||||
| group.add_argument( | |||||
| '--attention-dropout', | |||||
| type=float, | |||||
| default=0.1, | |||||
| help='dropout probability for attention weights') | |||||
| group.add_argument( | |||||
| '--num-attention-heads', | |||||
| type=int, | |||||
| default=16, | |||||
| help='num of transformer attention heads') | |||||
| group.add_argument( | |||||
| '--hidden-size', type=int, default=1024, help='tansformer hidden size') | |||||
| group.add_argument( | |||||
| '--intermediate-size', | |||||
| type=int, | |||||
| default=None, | |||||
| help='transformer embedding dimension for FFN' | |||||
| 'set to 4*`--hidden-size` if it is None') | |||||
| group.add_argument( | |||||
| '--num-layers', type=int, default=24, help='num decoder layers') | |||||
| group.add_argument( | |||||
| '--layernorm-epsilon', | |||||
| type=float, | |||||
| default=1e-5, | |||||
| help='layer norm epsilon') | |||||
| group.add_argument( | |||||
| '--hidden-dropout', | |||||
| type=float, | |||||
| default=0.1, | |||||
| help='dropout probability for hidden state transformer') | |||||
| group.add_argument( | |||||
| '--output-dropout', | |||||
| type=float, | |||||
| default=0.1, | |||||
| help='dropout probability for pooled output') | |||||
| group.add_argument( | |||||
| '--max-position-embeddings', | |||||
| type=int, | |||||
| default=512, | |||||
| help='maximum number of position embeddings to use') | |||||
| group.add_argument( | |||||
| '--vocab-size', | |||||
| type=int, | |||||
| default=250112, | |||||
| help='vocab size to use for non-character-level ' | |||||
| 'tokenization. This value will only be used when ' | |||||
| 'creating a tokenizer') | |||||
| group.add_argument( | |||||
| '--deep-init', | |||||
| action='store_true', | |||||
| help='initialize bert model similar to gpt2 model.' | |||||
| 'scales initialization of projection layers by a ' | |||||
| 'factor of 1/sqrt(2N). Necessary to train bert ' | |||||
| 'models larger than BERT-Large.') | |||||
| group.add_argument( | |||||
| '--make-vocab-size-divisible-by', | |||||
| type=int, | |||||
| default=128, | |||||
| help='Pad the vocab size to be divisible by this value.' | |||||
| 'This is added for computational efficieny reasons.') | |||||
| group.add_argument( | |||||
| '--cpu-optimizer', action='store_true', help='Run optimizer on CPU') | |||||
| group.add_argument( | |||||
| '--cpu_torch_adam', | |||||
| action='store_true', | |||||
| help='Use Torch Adam as optimizer on CPU.') | |||||
| return parser | |||||
| def add_fp16_config_args(parser): | |||||
| """Mixed precision arguments.""" | |||||
| group = parser.add_argument_group('fp16', 'fp16 configurations') | |||||
| group.add_argument( | |||||
| '--fp16', action='store_true', help='Run model in fp16 mode') | |||||
| group.add_argument( | |||||
| '--fp32-embedding', action='store_true', help='embedding in fp32') | |||||
| group.add_argument( | |||||
| '--fp32-layernorm', action='store_true', help='layer norm in fp32') | |||||
| group.add_argument( | |||||
| '--fp32-tokentypes', | |||||
| action='store_true', | |||||
| help='embedding token types in fp32') | |||||
| group.add_argument( | |||||
| '--fp32-allreduce', action='store_true', help='all-reduce in fp32') | |||||
| group.add_argument( | |||||
| '--hysteresis', | |||||
| type=int, | |||||
| default=2, | |||||
| help='hysteresis for dynamic loss scaling') | |||||
| group.add_argument( | |||||
| '--loss-scale', | |||||
| type=float, | |||||
| default=None, | |||||
| help='Static loss scaling, positive power of 2 ' | |||||
| 'values can improve fp16 convergence. If None, dynamic' | |||||
| 'loss scaling is used.') | |||||
| group.add_argument( | |||||
| '--loss-scale-window', | |||||
| type=float, | |||||
| default=1000, | |||||
| help='Window over which to raise/lower dynamic scale') | |||||
| group.add_argument( | |||||
| '--min-scale', | |||||
| type=float, | |||||
| default=1, | |||||
| help='Minimum loss scale for dynamic loss scale') | |||||
| group.add_argument('--attention-scale', type=float, default=1.0) | |||||
| return parser | |||||
| def add_training_args(parser): | |||||
| """Training arguments.""" | |||||
| group = parser.add_argument_group('train', 'training configurations') | |||||
| group.add_argument( | |||||
| '--experiment-name', | |||||
| type=str, | |||||
| default='gpt-345M', | |||||
| help='The experiment name for summary and checkpoint') | |||||
| group.add_argument( | |||||
| '--batch-size', type=int, default=4, help='Data Loader batch size') | |||||
| group.add_argument( | |||||
| '--gradient-accumulation-steps', | |||||
| type=int, | |||||
| default=1, | |||||
| help='Data Loader batch size') | |||||
| group.add_argument( | |||||
| '--weight-decay', | |||||
| type=float, | |||||
| default=0.01, | |||||
| help='weight decay coefficient for L2 regularization') | |||||
| group.add_argument( | |||||
| '--checkpoint-activations', | |||||
| action='store_true', | |||||
| help='checkpoint activation to allow for training ' | |||||
| 'with larger models and sequences') | |||||
| group.add_argument( | |||||
| '--checkpoint-num-layers', | |||||
| type=int, | |||||
| default=1, | |||||
| help='chunk size (number of layers) for checkpointing') | |||||
| group.add_argument( | |||||
| '--deepspeed-activation-checkpointing', | |||||
| action='store_true', | |||||
| help='uses activation checkpointing from deepspeed') | |||||
| group.add_argument( | |||||
| '--epochs', | |||||
| type=int, | |||||
| default=None, | |||||
| help='Number of finetunning epochs. Zero results in evaluation only.') | |||||
| group.add_argument( | |||||
| '--clip-grad', type=float, default=1.0, help='gradient clipping') | |||||
| group.add_argument( | |||||
| '--train-iters', | |||||
| type=int, | |||||
| default=0, | |||||
| help='total number of iterations to train over all training runs') | |||||
| group.add_argument('--label-smoothing', type=float, default=0.0) | |||||
| group.add_argument( | |||||
| '--log-interval', type=int, default=100, help='report interval') | |||||
| group.add_argument( | |||||
| '--summary-dir', | |||||
| type=str, | |||||
| default='', | |||||
| help='The directory to store the summary') | |||||
| group.add_argument('--seed', type=int, default=1234, help='random seed') | |||||
| # Batch producer arguments | |||||
| group.add_argument( | |||||
| '--reset-position-ids', | |||||
| action='store_true', | |||||
| help='Reset posistion ids after end-of-document token.') | |||||
| group.add_argument( | |||||
| '--reset-attention-mask', | |||||
| action='store_true', | |||||
| help='Reset self attention maske after ' | |||||
| 'end-of-document token.') | |||||
| # Learning rate. | |||||
| group.add_argument( | |||||
| '--lr-decay-iters', | |||||
| type=int, | |||||
| default=None, | |||||
| help='number of iterations to decay LR over,' | |||||
| ' If None defaults to `--train-iters`*`--epochs`') | |||||
| group.add_argument( | |||||
| '--lr-decay-style', | |||||
| type=str, | |||||
| default='linear', | |||||
| choices=['constant', 'linear', 'cosine', 'exponential'], | |||||
| help='learning rate decay function') | |||||
| group.add_argument('--lr-decay-ratio', type=float, default=0.1) | |||||
| group.add_argument( | |||||
| '--lr', type=float, default=1.0e-4, help='initial learning rate') | |||||
| group.add_argument( | |||||
| '--warmup', | |||||
| type=float, | |||||
| default=0.01, | |||||
| help='percentage of data to warmup on (.01 = 1% of all ' | |||||
| 'training iters). Default 0.01') | |||||
| group.add_argument( | |||||
| '--switch-linear', | |||||
| action='store_true', | |||||
| help='Switch to linear decay for cosine decay') | |||||
| # model checkpointing | |||||
| group.add_argument( | |||||
| '--save', | |||||
| type=str, | |||||
| default=None, | |||||
| help='Output directory to save checkpoints to.') | |||||
| group.add_argument('--new-save-directory', action='store_true') | |||||
| group.add_argument( | |||||
| '--save-epoch', | |||||
| type=int, | |||||
| default=1, | |||||
| help='number of epochs between saves') | |||||
| group.add_argument( | |||||
| '--save-interval', | |||||
| type=int, | |||||
| default=5000, | |||||
| help='number of iterations between saves') | |||||
| group.add_argument( | |||||
| '--no-save-optim', | |||||
| action='store_true', | |||||
| help='Do not save current optimizer.') | |||||
| group.add_argument( | |||||
| '--no-save-rng', | |||||
| action='store_true', | |||||
| help='Do not save current rng state.') | |||||
| group.add_argument( | |||||
| '--load', | |||||
| type=str, | |||||
| default=None, | |||||
| help='Path to a directory containing a model checkpoint.') | |||||
| group.add_argument( | |||||
| '--no-load-optim', | |||||
| action='store_true', | |||||
| help='Do not load optimizer when loading checkpoint.') | |||||
| group.add_argument( | |||||
| '--no-load-rng', | |||||
| action='store_true', | |||||
| help='Do not load rng state when loading checkpoint.') | |||||
| group.add_argument( | |||||
| '--no-load-lr-scheduler', | |||||
| action='store_true', | |||||
| help='Do not load lr scheduler when loading checkpoint.') | |||||
| group.add_argument( | |||||
| '--no-deepspeed-load', | |||||
| action='store_true', | |||||
| help='Not use deepspeed when loading checkpoint') | |||||
| group.add_argument( | |||||
| '--finetune', | |||||
| action='store_true', | |||||
| help='Load model for finetuning. Do not load optimizer ' | |||||
| 'or rng state from checkpoint and set iteration to 0. ' | |||||
| 'Assumed when loading a release checkpoint.') | |||||
| group.add_argument( | |||||
| '--resume-dataloader', | |||||
| action='store_true', | |||||
| help='Resume the dataloader when resuming training. ' | |||||
| 'Does not apply to tfrecords dataloader, try resuming' | |||||
| 'with a different seed in this case.') | |||||
| # distributed training args | |||||
| group.add_argument( | |||||
| '--distributed-backend', | |||||
| default='nccl', | |||||
| help= | |||||
| 'which backend to use for distributed training. One of [gloo, nccl]', | |||||
| choices=['nccl', 'gloo']) | |||||
| group.add_argument( | |||||
| '--DDP-impl', | |||||
| default='torch', | |||||
| choices=['local', 'torch', 'none'], | |||||
| help='which DistributedDataParallel implementation to use.') | |||||
| group.add_argument( | |||||
| '--local_rank', | |||||
| type=int, | |||||
| default=None, | |||||
| help='local rank passed from distributed launcher') | |||||
| # BlockLM training args | |||||
| group.add_argument( | |||||
| '--block-lm', | |||||
| action='store_true', | |||||
| help='whether use the BlockLM pre-training') | |||||
| group.add_argument( | |||||
| '--masked-lm', | |||||
| action='store_true', | |||||
| help='whether to use the mlm objective') | |||||
| group.add_argument('--bert-prob', type=float, default=0.5) | |||||
| group.add_argument('--gpt-infill-prob', type=float, default=0.5) | |||||
| group.add_argument('--gpt-min-ratio', type=float, default=0.5) | |||||
| group.add_argument('--gap-sentence-prob', type=float, default=0.0) | |||||
| group.add_argument('--gap-sentence-ratio', type=float, default=0.15) | |||||
| group.add_argument('--avg-block-length', type=int, default=3) | |||||
| group.add_argument('--short-seq-prob', type=float, default=0.0) | |||||
| group.add_argument('--single-span-prob', type=float, default=0.0) | |||||
| group.add_argument( | |||||
| '--task-mask', | |||||
| action='store_true', | |||||
| help='Use different mask for generation and blank filling') | |||||
| group.add_argument( | |||||
| '--no-shuffle-block', | |||||
| action='store_true', | |||||
| help='not shuffle the blocks when filling the blank') | |||||
| group.add_argument( | |||||
| '--no-block-position', | |||||
| action='store_true', | |||||
| help='Use (rough) absolute positions instead of block positions') | |||||
| group.add_argument( | |||||
| '--sentinel-token', | |||||
| action='store_true', | |||||
| help='Use sentinel (mask) tokens to replace 2d position encoding') | |||||
| group.add_argument('--block-mask-prob', type=float, default=0.0) | |||||
| group.add_argument('--context-mask-ratio', type=float, default=0.0) | |||||
| group.add_argument( | |||||
| '--random-position', | |||||
| action='store_true', | |||||
| help='Use random start position to cover all the position embeddings') | |||||
| return parser | |||||
| def add_evaluation_args(parser): | |||||
| """Evaluation arguments.""" | |||||
| group = parser.add_argument_group('validation', | |||||
| 'validation configurations') | |||||
| group.add_argument( | |||||
| '--eval-batch-size', | |||||
| type=int, | |||||
| default=None, | |||||
| help='Data Loader batch size for evaluation datasets.' | |||||
| 'Defaults to `--batch-size`') | |||||
| group.add_argument( | |||||
| '--eval-iters', | |||||
| type=int, | |||||
| default=100, | |||||
| help='number of iterations to run for evaluation' | |||||
| 'validation/test for') | |||||
| group.add_argument( | |||||
| '--eval-interval', | |||||
| type=int, | |||||
| default=1000, | |||||
| help='interval between running evaluation on validation set') | |||||
| group.add_argument( | |||||
| '--eval-epoch', | |||||
| type=int, | |||||
| default=1, | |||||
| help='epoch between running evaluation on validation set') | |||||
| group.add_argument( | |||||
| '--eval-seq-length', | |||||
| type=int, | |||||
| default=None, | |||||
| help='Maximum sequence length to process for ' | |||||
| 'evaluation. Defaults to `--seq-length`') | |||||
| group.add_argument( | |||||
| '--eval-max-preds-per-seq', | |||||
| type=int, | |||||
| default=None, | |||||
| help='Maximum number of predictions to use for ' | |||||
| 'evaluation. Defaults to ' | |||||
| 'math.ceil(`--eval-seq-length`*.15/10)*10') | |||||
| group.add_argument('--overlapping-eval', type=int, default=32) | |||||
| return parser | |||||
| def add_text_generate_args(parser): | |||||
| """Text generate arguments.""" | |||||
| group = parser.add_argument_group('Text generation', 'configurations') | |||||
| group.add_argument('--temperature', type=float, default=1.0) | |||||
| group.add_argument('--top_p', type=float, default=0.0) | |||||
| group.add_argument('--top_k', type=int, default=0) | |||||
| group.add_argument('--out-seq-length', type=int, default=256) | |||||
| group.add_argument('--num-beams', type=int, default=1) | |||||
| group.add_argument('--length-penalty', type=float, default=0.0) | |||||
| group.add_argument('--no-repeat-ngram-size', type=int, default=0) | |||||
| group.add_argument('--min-tgt-length', type=int, default=0) | |||||
| group.add_argument('--select-topk', action='store_true') | |||||
| group.add_argument('--blank-maskratio', type=float, default=0.1) | |||||
| return parser | |||||
| def add_data_args(parser): | |||||
| """Train/valid/test data arguments.""" | |||||
| group = parser.add_argument_group('data', 'data configurations') | |||||
| group.add_argument( | |||||
| '--model-parallel-size', | |||||
| type=int, | |||||
| default=1, | |||||
| help='size of the model parallel.') | |||||
| group.add_argument( | |||||
| '--shuffle', | |||||
| action='store_true', | |||||
| help='Shuffle data. Shuffling is deterministic ' | |||||
| 'based on seed and current epoch.') | |||||
| group.add_argument('--filter-english', action='store_true') | |||||
| group.add_argument( | |||||
| '--train-data', | |||||
| nargs='+', | |||||
| default=None, | |||||
| help='Whitespace separated filenames or corpora names ' | |||||
| 'for training.') | |||||
| group.add_argument( | |||||
| '--valid-data', | |||||
| nargs='*', | |||||
| default=None, | |||||
| help="""Filename for validation data.""") | |||||
| group.add_argument( | |||||
| '--test-data', | |||||
| nargs='*', | |||||
| default=None, | |||||
| help="""Filename for testing""") | |||||
| group.add_argument( | |||||
| '--data-dir', | |||||
| type=str, | |||||
| default=None, | |||||
| help='The data path to all the data files') | |||||
| group.add_argument( | |||||
| '--input-data-sizes-file', | |||||
| type=str, | |||||
| default='sizes.txt', | |||||
| help='the filename containing all the shards sizes') | |||||
| group.add_argument( | |||||
| '--delim', default=',', help='delimiter used to parse csv data files') | |||||
| group.add_argument( | |||||
| '--text-key', | |||||
| default='sentence', | |||||
| help='key to use to extract text from json/csv') | |||||
| group.add_argument( | |||||
| '--eval-text-key', | |||||
| default=None, | |||||
| help='key to use to extract text from ' | |||||
| 'json/csv evaluation datasets') | |||||
| group.add_argument( | |||||
| '--split', | |||||
| default='1000,1,1', | |||||
| help='comma-separated list of proportions for training,' | |||||
| ' validation, and test split') | |||||
| group.add_argument( | |||||
| '--no-lazy-loader', | |||||
| action='store_true', | |||||
| help='whether to lazy read the data set') | |||||
| group.add_argument('--half-lazy-loader', action='store_true') | |||||
| group.add_argument( | |||||
| '--loader-scatter', | |||||
| type=int, | |||||
| default=None, | |||||
| help='Number of scatters to use for dataloaders') | |||||
| group.add_argument( | |||||
| '--loose-json', | |||||
| action='store_true', | |||||
| help='Use loose json (one json-formatted string per ' | |||||
| 'newline), instead of tight json (data file is one ' | |||||
| 'json string)') | |||||
| group.add_argument( | |||||
| '--presplit-sentences', | |||||
| action='store_true', | |||||
| help='Dataset content consists of documents where ' | |||||
| 'each document consists of newline separated sentences') | |||||
| group.add_argument( | |||||
| '--num-workers', | |||||
| type=int, | |||||
| default=2, | |||||
| help="""Number of workers to use for dataloading""") | |||||
| group.add_argument( | |||||
| '--tokenizer-model-type', | |||||
| type=str, | |||||
| default=None, | |||||
| help="Model type to use for sentencepiece tokenization \ | |||||
| (one of ['bpe', 'char', 'unigram', 'word']) or \ | |||||
| bert vocab to use for BertWordPieceTokenizer (one of \ | |||||
| ['bert-large-uncased', 'bert-large-cased', etc.])") | |||||
| group.add_argument( | |||||
| '--tokenizer-path', | |||||
| type=str, | |||||
| default='tokenizer.model', | |||||
| help='path used to save/load sentencepiece tokenization ' | |||||
| 'models') | |||||
| group.add_argument( | |||||
| '--tokenizer-type', | |||||
| type=str, | |||||
| default='BertWordPieceTokenizer', | |||||
| choices=[ | |||||
| 'CharacterLevelTokenizer', 'SentencePieceTokenizer', | |||||
| 'BertWordPieceTokenizer', 'GPT2BPETokenizer', 'ChineseSPTokenizer' | |||||
| ], | |||||
| help='what type of tokenizer to use') | |||||
| group.add_argument('--no-pre-tokenize', action='store_true') | |||||
| group.add_argument( | |||||
| '--cache-dir', | |||||
| default=None, | |||||
| type=str, | |||||
| help='Where to store pre-trained BERT downloads') | |||||
| group.add_argument( | |||||
| '--use-tfrecords', | |||||
| action='store_true', | |||||
| help='load `--train-data`, `--valid-data`, ' | |||||
| '`--test-data` from BERT tf records instead of ' | |||||
| 'normal data pipeline') | |||||
| group.add_argument( | |||||
| '--seq-length', | |||||
| type=int, | |||||
| default=512, | |||||
| help='Maximum sequence length to process') | |||||
| group.add_argument( | |||||
| '--mem-length', | |||||
| type=int, | |||||
| default=0, | |||||
| help='The memory length to preserve') | |||||
| group.add_argument( | |||||
| '--max-preds-per-seq', | |||||
| type=int, | |||||
| default=None, | |||||
| help='Maximum number of predictions to use per sequence.' | |||||
| 'Defaults to math.ceil(`--seq-length`*.15/10)*10.' | |||||
| 'MUST BE SPECIFIED IF `--use-tfrecords` is True.') | |||||
| group.add_argument('--non-sentence-start', type=float, default=0.0) | |||||
| group.add_argument( | |||||
| '--sample-one-document', | |||||
| action='store_true', | |||||
| help='only sample one document in one sample') | |||||
| group.add_argument( | |||||
| '--load-splits', | |||||
| type=str, | |||||
| default=None, | |||||
| help='The path to load split indices from') | |||||
| group.add_argument( | |||||
| '--save-splits', | |||||
| type=str, | |||||
| default=None, | |||||
| help='The path to save split indices to') | |||||
| group.add_argument( | |||||
| '--save-test-data', | |||||
| type=str, | |||||
| default=None, | |||||
| help='The path to save the test data') | |||||
| group.add_argument( | |||||
| '--multi-task-data', | |||||
| nargs='*', | |||||
| default=None, | |||||
| help='Downsteam task names for multi-task pre-training') | |||||
| group.add_argument( | |||||
| '--multi-task-ratio', | |||||
| type=float, | |||||
| default=0.0, | |||||
| help='Ratio for multi-task pre-training') | |||||
| group.add_argument('--multi-seq-length', type=int, default=None) | |||||
| group.add_argument('--multi-batch-size', type=int, default=None) | |||||
| return parser | |||||
| def add_finetune_config_args(parser): | |||||
| group = parser.add_argument_group('finetune', 'finetune configurations') | |||||
| group.add_argument('--task', type=str, help='Task name.') | |||||
| group.add_argument( | |||||
| '--load-pretrained', | |||||
| type=str, | |||||
| help='Load pretrained model', | |||||
| default=None) | |||||
| group.add_argument( | |||||
| '--pool-token', | |||||
| type=str, | |||||
| choices=['start', 'pad', 'cls'], | |||||
| help='The token to pool the sequence representation', | |||||
| default='cls') | |||||
| group.add_argument( | |||||
| '--cloze-eval', | |||||
| action='store_true', | |||||
| help='Evaluation dataset with cloze task') | |||||
| group.add_argument( | |||||
| '--multi-token', | |||||
| action='store_true', | |||||
| help='Use multi token for cloze evaluation') | |||||
| group.add_argument( | |||||
| '--segment-length', | |||||
| type=int, | |||||
| default=0, | |||||
| help='The maximum segment length for cloze evaluation') | |||||
| group.add_argument( | |||||
| '--loss-func', | |||||
| type=str, | |||||
| choices=['cross_entropy', 'hinge', 'generative', 'mix'], | |||||
| default='cross_entropy') | |||||
| group.add_argument('--block-lm-ratio', type=float, default=0.0) | |||||
| group.add_argument( | |||||
| '--adapet', | |||||
| action='store_true', | |||||
| help='Use the decoupled cross entropy loss in AdaPET') | |||||
| group.add_argument('--pattern-id', type=int, default=0) | |||||
| group.add_argument( | |||||
| '--fast-decode', | |||||
| action='store_true', | |||||
| help= | |||||
| 'Fast decode for multi-token cloze. Can only be used without checkpoint activation.' | |||||
| ) | |||||
| group.add_argument('--few-superglue', action='store_true') | |||||
| group.add_argument( | |||||
| '--eval-valid', | |||||
| action='store_true', | |||||
| help='Whether evaluate on the valid set') | |||||
| group.add_argument('--validation-metric', type=str, default=None) | |||||
| group.add_argument( | |||||
| '--unidirectional', | |||||
| action='store_true', | |||||
| help='Use the left to right language model') | |||||
| group.add_argument('--src-seq-length', type=int, default=None) | |||||
| group.add_argument('--tgt-seq-length', type=int, default=None) | |||||
| group.add_argument('--adam-beta1', type=float, default=0.9) | |||||
| group.add_argument('--adam-beta2', type=float, default=0.999) | |||||
| group.add_argument('--adam-eps', type=float, default=1e-8) | |||||
| group.add_argument( | |||||
| '--optimizer', type=str, choices=['adam', 'adafactor'], default='adam') | |||||
| group.add_argument('--wsc-negative', action='store_true') | |||||
| group.add_argument('--overwrite', action='store_true') | |||||
| group.add_argument('--no-validation', action='store_true') | |||||
| # Continuous prompt arguments | |||||
| group.add_argument( | |||||
| '--continuous-prompt', | |||||
| action='store_true', | |||||
| help='Use continuous prompt for PET') | |||||
| group.add_argument('--num-prompt-tokens', type=int, default=0) | |||||
| group.add_argument( | |||||
| '--prompt-func', default='lstm', choices=['lstm', 'mlp', 'none']) | |||||
| group.add_argument( | |||||
| '--freeze-transformer', action='store_true', default=False) | |||||
| group.add_argument('--tune-prefix-layers', type=int, default=None) | |||||
| group.add_argument('--prefix-prompt', type=int, default=0) | |||||
| group.add_argument('--prompt-init', action='store_true', default=False) | |||||
| return parser | |||||
| def get_args(): | |||||
| """Parse all the args.""" | |||||
| parser = argparse.ArgumentParser(description='PyTorch BERT Model') | |||||
| parser = add_model_config_args(parser) | |||||
| parser = add_fp16_config_args(parser) | |||||
| parser = add_training_args(parser) | |||||
| parser = add_evaluation_args(parser) | |||||
| parser = add_text_generate_args(parser) | |||||
| parser = add_data_args(parser) | |||||
| parser = add_finetune_config_args(parser) | |||||
| # Include DeepSpeed configuration arguments | |||||
| parser = deepspeed.add_config_arguments(parser) | |||||
| args = parser.parse_args(args=[]) | |||||
| if not args.train_data and not args.data_dir: | |||||
| print('WARNING: No training data specified') | |||||
| args.cuda = torch.cuda.is_available() | |||||
| args.rank = int(os.getenv('RANK', '0')) | |||||
| args.world_size = int(os.getenv('WORLD_SIZE', '1')) | |||||
| if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: | |||||
| mpi_define_env(args) | |||||
| elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'): | |||||
| # We are using (OpenMPI) mpirun for launching distributed data parallel processes | |||||
| local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) | |||||
| local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE')) | |||||
| # Possibly running with Slurm | |||||
| num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1')) | |||||
| nodeid = int(os.getenv('SLURM_NODEID', '0')) | |||||
| args.local_rank = local_rank | |||||
| args.rank = nodeid * local_size + local_rank | |||||
| args.world_size = num_nodes * local_size | |||||
| args.model_parallel_size = min(args.model_parallel_size, args.world_size) | |||||
| if args.rank == 0: | |||||
| print('using world size: {} and model-parallel size: {} '.format( | |||||
| args.world_size, args.model_parallel_size)) | |||||
| args.dynamic_loss_scale = False | |||||
| if args.loss_scale is None: | |||||
| args.dynamic_loss_scale = True | |||||
| if args.rank == 0: | |||||
| print(' > using dynamic loss scaling') | |||||
| # The args fp32_* or fp16_* meant to be active when the | |||||
| # args fp16 is set. So the default behaviour should all | |||||
| # be false. | |||||
| if not args.fp16: | |||||
| args.fp32_embedding = False | |||||
| args.fp32_tokentypes = False | |||||
| args.fp32_layernorm = False | |||||
| if hasattr(args, 'deepspeed' | |||||
| ) and args.deepspeed and args.deepspeed_config is not None: | |||||
| with open(args.deepspeed_config) as file: | |||||
| deepspeed_config = json.load(file) | |||||
| if 'train_micro_batch_size_per_gpu' in deepspeed_config: | |||||
| args.batch_size = deepspeed_config[ | |||||
| 'train_micro_batch_size_per_gpu'] | |||||
| if 'gradient_accumulation_steps' in deepspeed_config: | |||||
| args.gradient_accumulation_steps = deepspeed_config[ | |||||
| 'gradient_accumulation_steps'] | |||||
| else: | |||||
| args.gradient_accumulation_steps = 1 | |||||
| if 'optimizer' in deepspeed_config: | |||||
| optimizer_params_config = deepspeed_config['optimizer'].get( | |||||
| 'params', {}) | |||||
| args.lr = optimizer_params_config.get('lr', args.lr) | |||||
| args.weight_decay = optimizer_params_config.get( | |||||
| 'weight_decay', args.weight_decay) | |||||
| return args | |||||
| def mpi_define_env(args): | |||||
| from mpi4py import MPI | |||||
| comm = MPI.COMM_WORLD | |||||
| rank = comm.Get_rank() | |||||
| world_size = comm.Get_size() | |||||
| master_addr = None | |||||
| if rank == 0: | |||||
| master_addr = get_hostname() | |||||
| master_addr = comm.bcast(master_addr, root=0) | |||||
| # Determine local rank by assuming hostnames are unique | |||||
| proc_name = MPI.Get_processor_name() | |||||
| all_procs = comm.allgather(proc_name) | |||||
| local_rank = sum([i == proc_name for i in all_procs[:rank]]) | |||||
| os.environ['RANK'] = str(rank) | |||||
| os.environ['WORLD_SIZE'] = str(world_size) | |||||
| args.local_rank = local_rank | |||||
| args.world_size = world_size | |||||
| args.rank = rank | |||||
| os.environ['MASTER_ADDR'] = master_addr | |||||
| os.environ[ | |||||
| 'MASTER_PORT'] = '29500' # TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 | |||||
| print( | |||||
| 'Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}' | |||||
| .format(os.environ['RANK'], args.local_rank, os.environ['WORLD_SIZE'], | |||||
| os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])) | |||||
| @@ -0,0 +1,625 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import copy | |||||
| import math | |||||
| import random | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.utils.data | |||||
| from scipy.stats import poisson | |||||
| from . import mpu | |||||
| from .utils import print_rank_0 | |||||
| def rindex(lst, val, start=None): | |||||
| if start is None: | |||||
| start = len(lst) - 1 | |||||
| for i in range(start, -1, -1): | |||||
| if lst[i] == val: | |||||
| return i | |||||
| return -1 | |||||
| def index_in_list(lst, val, start=None): | |||||
| if start is None: | |||||
| start = 0 | |||||
| for i in range(start, len(lst)): | |||||
| if lst[i] == val: | |||||
| return i | |||||
| return -1 | |||||
| class ConstructBlockStrategy: | |||||
| def __init__(self, | |||||
| args, | |||||
| tokenizer, | |||||
| max_seq_length, | |||||
| bert_prob=1.0, | |||||
| gap_sentence_prob=0.0, | |||||
| gpt_infill_prob=0.5, | |||||
| gpt_min_ratio=0.5, | |||||
| bert_ratio=0.15, | |||||
| gap_sentence_ratio=0.15, | |||||
| average_block_length=3, | |||||
| max_block_length=40, | |||||
| block_mask_prob=0.0, | |||||
| context_mask_ratio=0.0, | |||||
| context_mask_range=3, | |||||
| short_seq_prob=0.0, | |||||
| single_span_prob=0.0, | |||||
| block_position_encoding=True, | |||||
| encoder_decoder=False, | |||||
| shuffle_blocks=True, | |||||
| sentinel_token=False, | |||||
| task_mask=False, | |||||
| random_position=False, | |||||
| masked_lm=False): | |||||
| self.eod_token = args.eod_token | |||||
| self.tokenizer = tokenizer | |||||
| self.count = 0 | |||||
| self.max_seq_length = max_seq_length | |||||
| self.rank = mpu.get_data_parallel_rank() | |||||
| self.world_size = mpu.get_data_parallel_world_size() | |||||
| # self.rank = 0 | |||||
| # self.world_size = 1 | |||||
| assert 0.0 <= bert_prob <= 1.0 | |||||
| self.bert_prob = bert_prob | |||||
| self.gap_sentence_prob = gap_sentence_prob | |||||
| self.gpt_prob = 1 - bert_prob - gap_sentence_prob | |||||
| assert self.gpt_prob >= -1e-10 | |||||
| self.infill_prob = gpt_infill_prob | |||||
| self.gpt_min_ratio = gpt_min_ratio | |||||
| self.bert_ratio = bert_ratio | |||||
| self.gap_sentence_ratio = gap_sentence_ratio | |||||
| self.block_length_distribution = [ | |||||
| poisson.pmf(i, average_block_length) | |||||
| for i in range(1, max_block_length) | |||||
| ] | |||||
| self.block_mask_prob = block_mask_prob | |||||
| self.context_mask_ratio = context_mask_ratio | |||||
| self.context_mask_range = context_mask_range | |||||
| self.short_seq_prob = short_seq_prob | |||||
| self.single_span_prob = single_span_prob | |||||
| self.block_position_encoding = block_position_encoding | |||||
| self.encoder_decoder = encoder_decoder | |||||
| self.shuffle_blocks = shuffle_blocks | |||||
| self.sentinel_token = sentinel_token | |||||
| self.generation_mask = 'gMASK' if task_mask else 'MASK' | |||||
| self.generation_mask = self.tokenizer.get_command( | |||||
| self.generation_mask).Id | |||||
| self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK' | |||||
| self.gap_sentence_mask = self.tokenizer.get_command( | |||||
| self.gap_sentence_mask).Id | |||||
| self.random_position = random_position | |||||
| self.masked_lm = masked_lm | |||||
| print_rank_0( | |||||
| f'BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}' # noqa | |||||
| ) | |||||
| print_rank_0( | |||||
| f'generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}' # noqa | |||||
| ) | |||||
| print_rank_0( | |||||
| f'block length distribution {self.block_length_distribution}') | |||||
| print_rank_0( | |||||
| f'block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}' | |||||
| ) | |||||
| def contains_sentence_end(self, tok): | |||||
| tok = self.tokenizer.IdToToken(tok) | |||||
| if '.' in tok: | |||||
| return True | |||||
| if '?' in tok: | |||||
| return True | |||||
| if '!' in tok: | |||||
| return True | |||||
| if ';' in tok: | |||||
| return True | |||||
| if ':' in tok: | |||||
| return True | |||||
| if '。' in tok: | |||||
| return True | |||||
| if '?' in tok: | |||||
| return True | |||||
| if '!' in tok: | |||||
| return True | |||||
| if ';' in tok: | |||||
| return True | |||||
| if '…' in tok: | |||||
| return True | |||||
| if '\n' in tok: | |||||
| return True | |||||
| return False | |||||
| @staticmethod | |||||
| def sample_spans(span_lengths, total_length, rng, offset=0): | |||||
| blank_length = total_length - sum(span_lengths) | |||||
| m = blank_length - len(span_lengths) + 1 | |||||
| places = [rng.randrange(m + 1) for _ in range(len(span_lengths))] | |||||
| places.sort() | |||||
| spans = [] | |||||
| for place, span_length in zip(places, span_lengths): | |||||
| start = offset + place | |||||
| end = offset + place + span_length | |||||
| spans.append((start, end)) | |||||
| offset += span_length + 1 | |||||
| return spans | |||||
| def sample_span_in_document(self, tokens, masked_lengths, rng): | |||||
| rng.shuffle(masked_lengths) | |||||
| mask_spans = [] | |||||
| mask_index = 0 | |||||
| indices = [-1] + np.where(tokens == self.eod_token)[0].tolist() | |||||
| last_index = len(tokens) | |||||
| documents = [] | |||||
| for index in reversed(indices): | |||||
| start_index = index | |||||
| if start_index + 1 < len(tokens) and tokens[ | |||||
| start_index + 1] == self.tokenizer.get_command('ENC').Id: | |||||
| start_index += 1 | |||||
| length = last_index - start_index - 1 | |||||
| if last_index == len(tokens) and length > 0: | |||||
| length -= 1 | |||||
| documents.append((start_index + 1, length)) | |||||
| last_index = index | |||||
| documents.sort(key=lambda x: x[1]) | |||||
| for i, (offset, length) in enumerate(documents): | |||||
| if i == len(documents) - 1: | |||||
| current_masked_length, current_count = 0, 0 | |||||
| while mask_index + current_count < len( | |||||
| masked_lengths | |||||
| ) and masked_lengths[ | |||||
| mask_index + # noqa | |||||
| current_count] + current_masked_length + current_count <= length: | |||||
| current_masked_length += masked_lengths[mask_index | |||||
| + current_count] | |||||
| current_count += 1 | |||||
| if current_count > 0: | |||||
| spans = self.sample_spans( | |||||
| masked_lengths[mask_index:mask_index + current_count], | |||||
| length, | |||||
| rng, | |||||
| offset=offset) | |||||
| mask_spans += spans | |||||
| if mask_index + current_count < len(masked_lengths) - 1: | |||||
| print(length, masked_lengths[mask_index:], | |||||
| masked_lengths[:mask_index], indices) | |||||
| else: | |||||
| current_masked_total = int(length * self.bert_ratio) | |||||
| current_masked_length, current_count = 0, 0 | |||||
| while mask_index + current_count < len( | |||||
| masked_lengths | |||||
| ) and masked_lengths[ | |||||
| mask_index + # noqa | |||||
| current_count] + current_masked_length <= current_masked_total: | |||||
| current_masked_length += masked_lengths[mask_index | |||||
| + current_count] | |||||
| current_count += 1 | |||||
| if current_count > 0: | |||||
| spans = self.sample_spans( | |||||
| masked_lengths[mask_index:mask_index + current_count], | |||||
| length, | |||||
| rng, | |||||
| offset=offset) | |||||
| mask_spans += spans | |||||
| mask_index += current_count | |||||
| return mask_spans | |||||
| def make_masked_data(self, | |||||
| tokens, | |||||
| loss_masks, | |||||
| attention_mask, | |||||
| block_spans, | |||||
| rng, | |||||
| task='bert'): | |||||
| position_ids = np.arange(len(tokens), dtype=np.long) | |||||
| targets = copy.deepcopy(tokens) | |||||
| mask_id = self.tokenizer.get_command('MASK').Id | |||||
| mlm_masks = np.zeros(len(tokens), dtype=np.long) | |||||
| for start, end in block_spans: | |||||
| for idx in range(start, end): | |||||
| tokens[idx] = mask_id | |||||
| mlm_masks[start:end] = 1 | |||||
| loss_masks = loss_masks * mlm_masks | |||||
| return tokens, targets, loss_masks, position_ids | |||||
| def make_block_data(self, | |||||
| tokens, | |||||
| loss_masks, | |||||
| attention_mask, | |||||
| block_spans, | |||||
| rng, | |||||
| task='bert'): | |||||
| text_length = len(tokens) | |||||
| position_ids = np.ones(len(tokens), dtype=np.long) | |||||
| for start, end in block_spans: | |||||
| position_ids[start + 1:end] = 0 | |||||
| position_ids = np.cumsum(position_ids) - 1 | |||||
| if self.random_position and position_ids[-1] < self.max_seq_length - 1: | |||||
| position_bias = self.max_seq_length - position_ids[-1] | |||||
| position_bias = rng.randrange(0, position_bias) | |||||
| position_ids = position_ids + position_bias | |||||
| if self.encoder_decoder or not self.shuffle_blocks: | |||||
| block_spans.sort(key=lambda x: x[0]) | |||||
| else: | |||||
| rng.shuffle(block_spans) | |||||
| if self.sentinel_token: | |||||
| block_spans = [(start, end, idx) | |||||
| for idx, (start, end) in enumerate(block_spans)] | |||||
| else: | |||||
| block_spans = [(start, end, 0) for start, end in block_spans] | |||||
| target_tokens, target_position_ids, target_block_position_ids, targets = [], [], [], [] | |||||
| for start, end, idx in block_spans: | |||||
| sop_token = 'sop' if idx == 0 else f'sop{idx}' | |||||
| target_tokens.append([self.tokenizer.get_command(sop_token).Id]) | |||||
| span_tokens = copy.deepcopy(tokens[start:end]) | |||||
| if self.block_mask_prob > 0.0 and task == 'bert': | |||||
| for sub_idx in range(len(span_tokens)): | |||||
| if random.random() < self.block_mask_prob: | |||||
| span_tokens[sub_idx] = self.tokenizer.get_command( | |||||
| 'dBLOCK').Id | |||||
| target_tokens.append(span_tokens) | |||||
| targets.append(tokens[start:end]) | |||||
| targets.append([self.tokenizer.get_command('eop').Id]) | |||||
| if not self.sentinel_token: | |||||
| target_position_id = position_ids[start:end] | |||||
| target_position_ids.append(target_position_id) | |||||
| target_position_ids.append([target_position_id[0]]) | |||||
| else: | |||||
| target_position_ids.append([self.max_seq_length] * # noqa | |||||
| (end - start + 1)) | |||||
| if self.block_position_encoding: | |||||
| target_block_position_ids.append( | |||||
| np.arange(1, end - start + 2, dtype=np.long)) | |||||
| else: | |||||
| target_block_position_ids.append([1] * (end - start + 1)) | |||||
| block_spans.sort(key=lambda x: x[0]) | |||||
| source_tokens, source_position_ids, local_spans = [], [], [] | |||||
| last, current_length = 0, 0 | |||||
| for start, end, idx in block_spans: | |||||
| if task == 'generation': | |||||
| mask_id = self.generation_mask | |||||
| elif task == 'gap_sentence': | |||||
| mask_id = self.gap_sentence_mask | |||||
| else: | |||||
| mask_token = 'MASK' if idx == 0 else f'MASK{idx}' | |||||
| mask_id = self.tokenizer.get_command(mask_token).Id | |||||
| local_spans.append((current_length, current_length + start - last)) | |||||
| source_tokens.append(tokens[last:start]) | |||||
| source_tokens.append([mask_id]) | |||||
| source_position_ids.append(position_ids[last:start]) | |||||
| source_position_ids.append([position_ids[start]]) | |||||
| current_length += start - last + 1 | |||||
| last = end | |||||
| if last < len(tokens): | |||||
| local_spans.append( | |||||
| (current_length, current_length + len(tokens) - last)) | |||||
| source_tokens.append(tokens[last:]) | |||||
| source_position_ids.append(position_ids[last:]) | |||||
| source_length = sum(map(len, source_tokens)) | |||||
| if attention_mask is not None: | |||||
| assert source_length == attention_mask | |||||
| if target_tokens and self.eod_token in np.concatenate( | |||||
| target_tokens).tolist(): | |||||
| print('Found EOS in target', self.tokenizer.DecodeIds(tokens)) | |||||
| raise RuntimeError | |||||
| if self.encoder_decoder: | |||||
| target_tokens = target_tokens + [ | |||||
| self.tokenizer.get_command('eop').Id | |||||
| ] | |||||
| loss_masks = np.ones(len(target_tokens), dtype=np.long) | |||||
| return source_tokens, target_tokens, loss_masks | |||||
| else: | |||||
| tokens = np.concatenate(source_tokens + target_tokens) | |||||
| if task == 'bert' and self.context_mask_ratio > 0: | |||||
| mask_candidates = set() | |||||
| for start, end in local_spans: | |||||
| if start != 0: | |||||
| local_end = min(end, start + self.context_mask_range) | |||||
| mask_candidates.update(range(start, local_end)) | |||||
| if end != 0: | |||||
| local_start = max(start, end - self.context_mask_range) | |||||
| mask_candidates.update(range(local_start, end)) | |||||
| mask_pos = rng.sample( | |||||
| mask_candidates, | |||||
| int(self.context_mask_ratio * text_length)) | |||||
| for pos in mask_pos: | |||||
| tokens[pos] = self.tokenizer.get_command('dBLOCK').Id | |||||
| targets = np.concatenate(source_tokens + targets) | |||||
| loss_masks = np.ones(len(tokens), dtype=np.long) | |||||
| loss_masks[:source_length] = 0 | |||||
| position_ids = np.concatenate(source_position_ids | |||||
| + target_position_ids) | |||||
| block_position_ids = np.concatenate( | |||||
| [np.zeros(source_length, dtype=np.long)] | |||||
| + target_block_position_ids) | |||||
| position_ids = np.stack([position_ids, block_position_ids], axis=0) | |||||
| if attention_mask is not None: | |||||
| return tokens, targets, loss_masks, position_ids | |||||
| else: | |||||
| return tokens, targets, loss_masks, position_ids, source_length | |||||
| def generate_blank_data(self, | |||||
| sample, | |||||
| masked_lengths, | |||||
| attention_mask, | |||||
| rng, | |||||
| task='bert'): | |||||
| rng.shuffle(masked_lengths) | |||||
| tokens, loss_masks = sample['text'], sample['loss_mask'] | |||||
| assert tokens[0] == self.tokenizer.get_command('ENC').Id | |||||
| block_spans = self.sample_span_in_document(tokens, masked_lengths, rng) | |||||
| if len(block_spans) < len(masked_lengths): | |||||
| return None | |||||
| if self.masked_lm: | |||||
| data = self.make_masked_data(tokens, loss_masks, attention_mask, | |||||
| block_spans, rng) | |||||
| else: | |||||
| data = self.make_block_data( | |||||
| tokens, | |||||
| loss_masks, | |||||
| attention_mask, | |||||
| block_spans, | |||||
| rng, | |||||
| task=task) | |||||
| return data | |||||
| def split_samples(self, samples, rng): | |||||
| target_length = rng.randrange(32, self.max_seq_length - 1) | |||||
| num_splits = (self.max_seq_length - 1) // target_length | |||||
| new_samples = [] | |||||
| cls_id = self.tokenizer.get_command('ENC').Id | |||||
| eos_id = self.tokenizer.get_command('eos').Id | |||||
| for sample in samples: | |||||
| tokens, loss_masks = sample['text'][1:], sample['loss_mask'][1:] | |||||
| for _ in range(num_splits): | |||||
| if target_length >= len(tokens): | |||||
| new_tokens, new_loss_masks = tokens, loss_masks | |||||
| else: | |||||
| random_start = rng.randrange(0, | |||||
| len(tokens) - target_length) | |||||
| while random_start > 0 and ( | |||||
| tokens[random_start] == eos_id or # noqa | |||||
| not (self.contains_sentence_end( # noqa | |||||
| tokens[random_start - 1]) or # noqa | |||||
| tokens[random_start - 1] == eos_id)): # noqa | |||||
| random_start -= 1 | |||||
| random_end = random_start + target_length | |||||
| while random_end > random_start and not ( | |||||
| self.contains_sentence_end(tokens[random_end - 1]) | |||||
| or tokens[random_end - 1] == eos_id): | |||||
| random_end -= 1 | |||||
| if random_end - random_start < target_length // 2: | |||||
| random_end = random_start + target_length | |||||
| new_tokens, new_loss_masks = tokens[ | |||||
| random_start:random_end], loss_masks[ | |||||
| random_start:random_end] | |||||
| new_tokens = np.concatenate(([cls_id], new_tokens)) | |||||
| new_loss_masks = np.concatenate(([0], new_loss_masks)) | |||||
| new_samples.append({ | |||||
| 'text': new_tokens, | |||||
| 'loss_mask': new_loss_masks | |||||
| }) | |||||
| return new_samples | |||||
| def construct_blocks(self, samples): | |||||
| worker_info = torch.utils.data.get_worker_info() | |||||
| if worker_info is not None: | |||||
| worker_id, num_workers = worker_info.id, worker_info.num_workers | |||||
| else: | |||||
| worker_id, num_workers = 0, 1 | |||||
| rng = random.Random((self.count * num_workers + worker_id) | |||||
| * self.world_size + self.rank) | |||||
| self.count += 1 | |||||
| token_batch, target_batch, loss_mask_batch, position_id_batch = [], [], [], [] | |||||
| source_batch, target_batch = [], [] | |||||
| if rng.random() < self.short_seq_prob: | |||||
| samples = self.split_samples(samples, rng) | |||||
| rand = rng.random() | |||||
| single_span = rand < self.single_span_prob | |||||
| rand = 0.0 if single_span else rng.random() | |||||
| attention_mask = [] | |||||
| if rand < self.bert_prob: | |||||
| mode = 'bert' | |||||
| for sample in samples: | |||||
| if single_span: | |||||
| masked_lengths = [ | |||||
| rng.choices( | |||||
| range(1, | |||||
| len(self.block_length_distribution) + 1), | |||||
| weights=self.block_length_distribution)[0] | |||||
| ] | |||||
| masked_count = masked_lengths[0] | |||||
| else: | |||||
| masked_lengths, masked_count = [], 0 | |||||
| while masked_count < int( | |||||
| self.bert_ratio * len(sample['text'])): | |||||
| block_length = rng.choices( | |||||
| range(1, | |||||
| len(self.block_length_distribution) + 1), | |||||
| weights=self.block_length_distribution)[0] | |||||
| masked_lengths.append(block_length) | |||||
| masked_count += block_length | |||||
| if self.masked_lm: | |||||
| sep = len(sample['text']) | |||||
| else: | |||||
| sep = len( | |||||
| sample['text']) - masked_count + len(masked_lengths) | |||||
| data = self.generate_blank_data( | |||||
| sample, masked_lengths, sep, rng, task='bert') | |||||
| if data is not None: | |||||
| if self.encoder_decoder: | |||||
| source_tokens, target_tokens, loss_masks = data | |||||
| source_batch.append(source_tokens) | |||||
| target_batch.append(target_tokens) | |||||
| loss_mask_batch.append(loss_masks) | |||||
| else: | |||||
| tokens, targets, loss_masks, position_ids = data | |||||
| token_batch.append(tokens) | |||||
| target_batch.append(targets) | |||||
| loss_mask_batch.append(loss_masks) | |||||
| position_id_batch.append(position_ids) | |||||
| attention_mask.append(sep) | |||||
| elif rand < self.bert_prob + self.gap_sentence_prob: | |||||
| mode = 'sentence' | |||||
| for sample in samples: | |||||
| tokens, loss_masks = sample['text'], sample['loss_mask'] | |||||
| sentence_spans = [] | |||||
| last_index = 1 if tokens[0] == self.tokenizer.get_command( | |||||
| 'ENC').Id else 0 | |||||
| for i in range(len(tokens)): | |||||
| if self.contains_sentence_end(tokens[i]): | |||||
| if last_index < i + 1: | |||||
| sentence_spans.append((last_index, i + 1)) | |||||
| last_index = i + 1 | |||||
| elif tokens[i] == self.tokenizer.get_command('eos').Id: | |||||
| last_index = i + 1 | |||||
| if last_index < len(tokens): | |||||
| sentence_spans.append((last_index, len(tokens))) | |||||
| if not sentence_spans and torch.distributed.get_rank() == 0: | |||||
| try: | |||||
| print(self.tokenizer.DecodeIds(tokens[1:])) | |||||
| except IndexError: | |||||
| print(tokens[1:]) | |||||
| rng.shuffle(sentence_spans) | |||||
| block_spans, block_length = [], 0 | |||||
| for start, end in sentence_spans: | |||||
| block_spans.append((start, end)) | |||||
| block_length += end - start | |||||
| if block_length >= int( | |||||
| self.gap_sentence_ratio * len(tokens)): | |||||
| break | |||||
| data = self.make_block_data( | |||||
| tokens, | |||||
| loss_masks, | |||||
| None, | |||||
| block_spans, | |||||
| rng, | |||||
| task='gap_sentence') | |||||
| tokens, targets, loss_masks, position_ids, sep = data | |||||
| token_batch.append(tokens) | |||||
| target_batch.append(targets) | |||||
| loss_mask_batch.append(loss_masks) | |||||
| position_id_batch.append(position_ids) | |||||
| attention_mask.append(sep) | |||||
| else: | |||||
| # start_indices = [index_in_list(sample['loss_mask'], 1) for sample in samples] | |||||
| # end_indices = [rindex(sample['loss_mask'], 1) for sample in samples] | |||||
| # start_index, end_index = max(start_indices), min(end_indices) - self.min_generation_length | |||||
| # if end_index < start_index + 1: | |||||
| # end_index = start_index + 1 | |||||
| # division = rng.randrange(start_index, end_index) | |||||
| mode = 'gpt' | |||||
| max_generation_length = rng.randint( | |||||
| int(self.gpt_min_ratio | |||||
| * min(map(lambda x: len(x['text']), samples))), | |||||
| max(map(lambda x: len(x['text']), samples)) - 2) | |||||
| for sample in samples: | |||||
| generation_length = min(max_generation_length, | |||||
| len(sample['text']) - 2) | |||||
| attention_mask.append( | |||||
| len(sample['text']) - generation_length + 1) | |||||
| multiple_doc = index_in_list( | |||||
| sample['text'], | |||||
| self.tokenizer.get_command('eos').Id) not in [ | |||||
| -1, len(sample['text']) - 1 | |||||
| ] # noqa | |||||
| if multiple_doc or rng.random() < self.infill_prob: | |||||
| division = len(sample['text']) - generation_length | |||||
| tokens, loss_masks = sample['text'], sample['loss_mask'] | |||||
| source_tokens, target_tokens = tokens[:division], tokens[ | |||||
| division:] | |||||
| target_masks = loss_masks[division:] | |||||
| tokens = np.concatenate((source_tokens, [ | |||||
| self.generation_mask, | |||||
| self.tokenizer.get_command('sop').Id | |||||
| ], target_tokens[:-1])) | |||||
| targets = np.concatenate( | |||||
| (source_tokens, [self.generation_mask], target_tokens)) | |||||
| loss_masks = np.concatenate( | |||||
| (np.zeros(len(source_tokens) + 1, | |||||
| dtype=np.long), target_masks)) | |||||
| token_batch.append(tokens) | |||||
| target_batch.append(targets) | |||||
| loss_mask_batch.append(loss_masks) | |||||
| position_ids = np.arange( | |||||
| len(source_tokens) + len(target_tokens) + 1, | |||||
| dtype=np.long) | |||||
| position_ids[len(source_tokens) + 1:] = len(source_tokens) | |||||
| if self.block_position_encoding: | |||||
| block_position_ids = np.concatenate( | |||||
| (np.zeros(len(source_tokens), dtype=np.long), | |||||
| np.arange(len(target_tokens) + 1, dtype=np.long))) | |||||
| else: | |||||
| block_position_ids = np.concatenate( | |||||
| (np.zeros(len(source_tokens) + 1, dtype=np.long), | |||||
| np.ones(len(target_tokens) + 1, dtype=np.long))) | |||||
| position_id_batch.append( | |||||
| np.stack([position_ids, block_position_ids], axis=0)) | |||||
| else: | |||||
| tokens, targets, loss_masks, position_ids = self.generate_blank_data( | |||||
| sample, [generation_length], | |||||
| attention_mask[-1], | |||||
| rng, | |||||
| task='generation') | |||||
| token_batch.append(tokens) | |||||
| target_batch.append(targets) | |||||
| loss_mask_batch.append(loss_masks) | |||||
| position_id_batch.append(position_ids) | |||||
| if tokens is None: | |||||
| print(sample, generation_length, multiple_doc) | |||||
| if self.encoder_decoder: | |||||
| return { | |||||
| 'text': torch.tensor(source_batch, dtype=torch.long), | |||||
| 'target': torch.tensor(target_batch, dtype=torch.long), | |||||
| 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long) | |||||
| } | |||||
| else: | |||||
| token_batch, target_batch, loss_mask_batch, position_id_batch = self.pad_batch( | |||||
| token_batch, target_batch, loss_mask_batch, position_id_batch) | |||||
| return { | |||||
| 'text': torch.tensor(token_batch, dtype=torch.long), | |||||
| 'target': torch.tensor(target_batch, dtype=torch.long), | |||||
| 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long), | |||||
| 'position_id': | |||||
| torch.tensor(position_id_batch, dtype=torch.long), | |||||
| 'attention_mask': | |||||
| torch.tensor(attention_mask, dtype=torch.long), | |||||
| 'mode': mode | |||||
| } | |||||
| @staticmethod | |||||
| def pad_batch(token_batch, target_batch, loss_mask_batch, | |||||
| position_id_batch): | |||||
| seq_lengths = list(map(len, token_batch)) | |||||
| if seq_lengths.count(seq_lengths[0]) != len(seq_lengths): | |||||
| max_length = max(seq_lengths) | |||||
| token_batch = [ | |||||
| np.concatenate( | |||||
| (tokens, np.zeros(max_length - len(tokens), | |||||
| dtype=np.long))) | |||||
| for tokens in token_batch | |||||
| ] | |||||
| target_batch = [ | |||||
| np.concatenate( | |||||
| (targets, | |||||
| np.zeros(max_length - len(targets), dtype=np.long))) | |||||
| for targets in target_batch | |||||
| ] | |||||
| loss_mask_batch = [ | |||||
| np.concatenate( | |||||
| (loss_masks, | |||||
| np.zeros(max_length - len(loss_masks), dtype=np.long))) | |||||
| for loss_masks in loss_mask_batch | |||||
| ] | |||||
| position_id_batch = [ | |||||
| np.concatenate((position_ids, | |||||
| np.zeros( | |||||
| (2, max_length - position_ids.shape[1]), | |||||
| dtype=np.long)), | |||||
| axis=1) for position_ids in position_id_batch | |||||
| ] | |||||
| return token_batch, target_batch, loss_mask_batch, position_id_batch | |||||
| @@ -0,0 +1,513 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """parses arguments and preps data loader""" | |||||
| import copy | |||||
| import os | |||||
| import random | |||||
| from bisect import bisect_right | |||||
| from itertools import accumulate | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.utils.data | |||||
| from . import data_utils, mpu | |||||
| from .blocklm_utils import ConstructBlockStrategy | |||||
| from .data_utils.tokenization import make_tokenizer | |||||
| from .utils import print_rank_0 | |||||
| class MultiTaskDataset(torch.utils.data.Dataset): | |||||
| def __init__(self, | |||||
| tasks, | |||||
| datasets, | |||||
| reweight=True, | |||||
| temperature=0.8, | |||||
| max_limit=200000): | |||||
| super(MultiTaskDataset, self).__init__() | |||||
| self.tasks = tasks | |||||
| self.datasets = datasets | |||||
| self.reweight = reweight | |||||
| self.temperature = temperature | |||||
| self.lens = [len(dataset) for dataset in datasets] | |||||
| self.weights = np.array( | |||||
| [min(length, max_limit)**temperature for length in self.lens]) | |||||
| self.total_len = sum(self.lens) | |||||
| self.cumulative_lens = list(accumulate(self.lens)) | |||||
| if self.reweight: | |||||
| print_rank_0(list(zip(self.tasks, self.lens, self.weights))) | |||||
| else: | |||||
| print_rank_0(list(zip(self.tasks, self.lens))) | |||||
| self.weights /= self.weights.sum() | |||||
| def __len__(self): | |||||
| return self.total_len * 1000 | |||||
| @staticmethod | |||||
| def pet_wrapper(data): | |||||
| text = data['text'] | |||||
| loss_mask = data['logit_mask'] | |||||
| target = data['target'] | |||||
| attention_mask = data['mask'] | |||||
| position_id = data['position'] | |||||
| label = data['label'] | |||||
| if len(text.shape) == 2: | |||||
| text = text[label] | |||||
| loss_mask = loss_mask[label] | |||||
| target = target[label] | |||||
| attention_mask = attention_mask[label] | |||||
| position_id = position_id[label] | |||||
| else: | |||||
| target = target[label] | |||||
| if not target.shape: | |||||
| target = target.repeat(len(text)) | |||||
| return { | |||||
| 'text': text, | |||||
| 'target': target, | |||||
| 'loss_mask': loss_mask, | |||||
| 'position_id': position_id, | |||||
| 'attention_mask': attention_mask | |||||
| } | |||||
| def __getitem__(self, idx): | |||||
| if self.reweight: | |||||
| rng = random.Random(idx) | |||||
| rng = np.random.RandomState( | |||||
| seed=[rng.randint(0, 2**32 - 1) for _ in range(16)]) | |||||
| dataset_idx = rng.choice( | |||||
| np.arange(len(self.datasets)), p=self.weights) | |||||
| dataset = self.datasets[dataset_idx] | |||||
| sample_idx = rng.choice(np.arange(len(dataset))) | |||||
| item = self.datasets[dataset_idx][sample_idx] | |||||
| else: | |||||
| dataset_idx = bisect_right(self.cumulative_lens, idx) | |||||
| if dataset_idx == 0: | |||||
| sample_idx = idx | |||||
| else: | |||||
| sample_idx = idx - self.cumulative_lens[dataset_idx - 1] | |||||
| item = self.datasets[dataset_idx][sample_idx] | |||||
| item = self.pet_wrapper(item) | |||||
| return item | |||||
| class DataConfig: | |||||
| def __init__(self, defaults=None): | |||||
| super(DataConfig, self).__init__() | |||||
| if defaults is None: | |||||
| defaults = {} | |||||
| self.defaults = defaults | |||||
| def apply(self, args, tokenizer): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('configuring data') | |||||
| self.apply_defaults(args) | |||||
| return make_loaders(args, tokenizer) | |||||
| def set_defaults(self, **kwargs): | |||||
| for k, v in kwargs.items(): | |||||
| self.defaults[k] = v | |||||
| def apply_defaults(self, args): | |||||
| for k, v in self.defaults.items(): | |||||
| k = k.replace('-', '_') | |||||
| if not hasattr(args, k): | |||||
| setattr(args, k, v) | |||||
| def prepare_tokenizer(args): | |||||
| add_sentinel_token = 0 | |||||
| if args.sentinel_token: | |||||
| add_sentinel_token = args.max_position_embeddings | |||||
| tokenizer = make_tokenizer( | |||||
| args.tokenizer_type, | |||||
| None, | |||||
| args.tokenizer_path, | |||||
| args.vocab_size, | |||||
| args.tokenizer_model_type, | |||||
| add_block_symbols=args.block_lm, | |||||
| cache_dir=args.cache_dir, | |||||
| add_sentinel_token=add_sentinel_token, | |||||
| add_task_mask=args.task_mask, | |||||
| add_decoder_mask=args.block_mask_prob > 0.0 | |||||
| or args.context_mask_ratio > 0.0) | |||||
| if mpu.get_model_parallel_rank() == 0: | |||||
| num_tokens = tokenizer.num_tokens | |||||
| eod_token = tokenizer.get_command('eos').Id | |||||
| assert eod_token == tokenizer.get_command('pad').Id | |||||
| before = num_tokens | |||||
| after = before | |||||
| multiple = args.make_vocab_size_divisible_by | |||||
| while (after % multiple) != 0: | |||||
| after += 1 | |||||
| print_rank_0('> padded vocab (size: {}) with {} dummy ' | |||||
| 'tokens (new size: {})'.format(before, after - before, | |||||
| after)) | |||||
| print_rank_0('> found end-of-document token: {}'.format(eod_token)) | |||||
| token_counts = torch.cuda.LongTensor([after, eod_token]) | |||||
| else: | |||||
| token_counts = torch.cuda.LongTensor([0, 0]) | |||||
| # Broadcast num tokens. | |||||
| torch.distributed.broadcast( | |||||
| token_counts, | |||||
| mpu.get_model_parallel_src_rank(), | |||||
| group=mpu.get_model_parallel_group()) | |||||
| num_tokens = token_counts[0].item() | |||||
| eod_token = token_counts[1].item() | |||||
| args.vocab_size, args.eod_token = num_tokens, eod_token | |||||
| return tokenizer | |||||
| def make_data_loader(dataset, | |||||
| tokenizer, | |||||
| batch_size, | |||||
| num_iters, | |||||
| args, | |||||
| shuffle=False, | |||||
| block_collate=False): | |||||
| world_size = torch.distributed.get_world_size( | |||||
| group=mpu.get_data_parallel_group()) | |||||
| rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) | |||||
| if args.loader_scatter is not None: | |||||
| rank = rank // args.loader_scatter | |||||
| world_size = world_size // args.loader_scatter | |||||
| batch_size = batch_size // args.loader_scatter | |||||
| distributed = world_size > 1 | |||||
| if args.transformer_xl: | |||||
| batch_sampler = data_utils.samplers.DistributedSequentialSampler( | |||||
| len(dataset), num_iters, batch_size, rank, world_size) | |||||
| else: | |||||
| if shuffle: | |||||
| sampler = data_utils.samplers.RandomSampler( | |||||
| dataset, | |||||
| replacement=True, | |||||
| num_samples=batch_size * args.train_iters | |||||
| * args.gradient_accumulation_steps) | |||||
| else: | |||||
| sampler = torch.utils.data.SequentialSampler(dataset) | |||||
| drop_last = distributed | |||||
| # the GPUs in the same model parallel group receive the same data | |||||
| if distributed: | |||||
| batch_sampler = data_utils.samplers.DistributedBatchSampler( | |||||
| sampler, | |||||
| batch_size, | |||||
| drop_last, | |||||
| rank, | |||||
| world_size, | |||||
| gradient_accumulation_steps=args.gradient_accumulation_steps) | |||||
| else: | |||||
| batch_sampler = torch.utils.data.BatchSampler( | |||||
| sampler, batch_size, drop_last) | |||||
| collate_fn = None | |||||
| if block_collate: | |||||
| collate_fn = ConstructBlockStrategy( | |||||
| args, | |||||
| tokenizer, | |||||
| args.seq_length, | |||||
| bert_prob=args.bert_prob, | |||||
| gap_sentence_prob=args.gap_sentence_prob, | |||||
| gap_sentence_ratio=args.gap_sentence_ratio, | |||||
| gpt_infill_prob=args.gpt_infill_prob, | |||||
| average_block_length=args.avg_block_length, | |||||
| gpt_min_ratio=args.gpt_min_ratio, | |||||
| block_mask_prob=args.block_mask_prob, | |||||
| context_mask_ratio=args.context_mask_ratio, | |||||
| short_seq_prob=args.short_seq_prob, | |||||
| single_span_prob=args.single_span_prob, | |||||
| shuffle_blocks=not args.no_shuffle_block, | |||||
| block_position_encoding=not args.no_block_position, | |||||
| sentinel_token=args.sentinel_token, | |||||
| encoder_decoder=args.encoder_decoder, | |||||
| task_mask=args.task_mask, | |||||
| random_position=args.random_position, | |||||
| masked_lm=args.masked_lm).construct_blocks | |||||
| data_loader = torch.utils.data.DataLoader( | |||||
| dataset, | |||||
| batch_sampler=batch_sampler, | |||||
| num_workers=args.num_workers, | |||||
| pin_memory=True, | |||||
| collate_fn=collate_fn) | |||||
| return data_loader | |||||
| def make_tfrecord_loaders(args): | |||||
| """Load train/val/test dataset from shuffled TFRecords""" | |||||
| import data_utils.tf_dl | |||||
| data_set_args = { | |||||
| 'batch_size': args.batch_size, | |||||
| 'max_seq_len': args.seq_length, | |||||
| 'max_preds_per_seq': args.max_preds_per_seq, | |||||
| 'train': True, | |||||
| 'num_workers': max(args.num_workers, 1), | |||||
| 'seed': args.seed + args.rank + 1, | |||||
| 'threaded_dl': args.num_workers > 0 | |||||
| } | |||||
| train = data_utils.tf_dl.TFRecordDataLoader(args.train_data, | |||||
| **data_set_args) | |||||
| data_set_args['train'] = False | |||||
| if args.eval_seq_length is not None: | |||||
| data_set_args['max_seq_len'] = args.eval_seq_length | |||||
| if args.eval_max_preds_per_seq is not None: | |||||
| data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq | |||||
| valid = None | |||||
| if args.valid_data is not None: | |||||
| valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data, | |||||
| **data_set_args) | |||||
| test = None | |||||
| if args.test_data is not None: | |||||
| test = data_utils.tf_dl.TFRecordDataLoader(args.test_data, | |||||
| **data_set_args) | |||||
| tokenizer = data_utils.make_tokenizer( | |||||
| args.tokenizer_type, | |||||
| train, | |||||
| args.tokenizer_path, | |||||
| args.vocab_size, | |||||
| args.tokenizer_model_type, | |||||
| cache_dir=args.cache_dir) | |||||
| return (train, valid, test), tokenizer | |||||
| def make_loaders(args, tokenizer): | |||||
| """makes training/val/test""" | |||||
| if args.use_tfrecords: | |||||
| return make_tfrecord_loaders(args) | |||||
| world_size = torch.distributed.get_world_size( | |||||
| group=mpu.get_data_parallel_group()) | |||||
| if args.loader_scatter is not None: | |||||
| assert world_size % args.loader_scatter == 0 | |||||
| batch_size = args.batch_size * world_size | |||||
| eval_batch_size = batch_size | |||||
| if args.eval_batch_size is not None: | |||||
| eval_batch_size = args.eval_batch_size * world_size | |||||
| seq_length = args.seq_length | |||||
| if seq_length < 0: | |||||
| seq_length = seq_length * world_size | |||||
| eval_seq_length = args.eval_seq_length | |||||
| if eval_seq_length is not None and eval_seq_length < 0: | |||||
| eval_seq_length = eval_seq_length * world_size | |||||
| split = get_split(args) | |||||
| data_set_args = { | |||||
| 'path': args.train_data, | |||||
| 'seq_length': seq_length, | |||||
| 'mem_length': args.mem_length, | |||||
| 'delim': args.delim, | |||||
| 'text_key': args.text_key, | |||||
| 'label_key': 'label', | |||||
| 'ds_type': args.data_set_type, | |||||
| 'split': split, | |||||
| 'loose': args.loose_json, | |||||
| 'max_preds_per_seq': args.max_preds_per_seq, | |||||
| 'presplit_sentences': args.presplit_sentences, | |||||
| 'sample_one_document': args.sample_one_document, | |||||
| 'filter_english': args.filter_english, | |||||
| 'pre_tokenize': not args.no_pre_tokenize, | |||||
| 'tokenizer': tokenizer, | |||||
| 'save_splits': args.save_splits, | |||||
| 'load_splits': args.load_splits, | |||||
| 'save_test_data': args.save_test_data, | |||||
| 'no_lazy_loader': args.no_lazy_loader, | |||||
| 'loader_scatter': args.loader_scatter, | |||||
| 'data_parallel_rank': mpu.get_data_parallel_rank(), | |||||
| 'non_sentence_start': args.non_sentence_start, | |||||
| 'half_lazy_loader': args.half_lazy_loader | |||||
| } | |||||
| eval_set_args = copy.copy(data_set_args) | |||||
| eval_set_args['split'] = [1.] | |||||
| # if optional eval args were set then replace their | |||||
| # equivalent values in the arg dict | |||||
| if eval_seq_length: | |||||
| eval_set_args['seq_length'] = eval_seq_length | |||||
| if args.eval_max_preds_per_seq: | |||||
| eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq | |||||
| if args.eval_text_key is not None: | |||||
| eval_set_args['text_key'] = args.eval_text_key | |||||
| # make datasets splits and tokenizer | |||||
| train, valid, test = None, None, None | |||||
| if args.train_data is not None: | |||||
| train = data_utils.make_dataset(**data_set_args) | |||||
| if data_utils.should_split(split): | |||||
| train, valid, test = train | |||||
| eval_set_args['tokenizer'] = tokenizer | |||||
| # make training and val dataset if necessary | |||||
| if valid is None and args.valid_data is not None: | |||||
| eval_set_args['path'] = args.valid_data | |||||
| valid = data_utils.make_dataset(**eval_set_args) | |||||
| eval_set_args['tokenizer'] = tokenizer | |||||
| if test is None and args.test_data is not None: | |||||
| eval_set_args['path'] = args.test_data | |||||
| test = data_utils.make_dataset(**eval_set_args) | |||||
| # wrap datasets with data loader | |||||
| use_block = args.block_lm or args.encoder_decoder | |||||
| if train is not None and args.batch_size > 0: | |||||
| train = make_data_loader( | |||||
| train, | |||||
| tokenizer, | |||||
| batch_size, | |||||
| args.train_iters, | |||||
| args, | |||||
| shuffle=args.shuffle, | |||||
| block_collate=use_block) | |||||
| args.do_train = True | |||||
| else: | |||||
| args.do_train = False | |||||
| eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size | |||||
| if valid is not None: | |||||
| valid = make_data_loader( | |||||
| valid, | |||||
| tokenizer, | |||||
| eval_batch_size, | |||||
| args.train_iters, | |||||
| args, | |||||
| shuffle=args.shuffle, | |||||
| block_collate=use_block) | |||||
| args.do_valid = True | |||||
| else: | |||||
| args.do_valid = False | |||||
| if test is not None: | |||||
| test = make_data_loader( | |||||
| test, | |||||
| tokenizer, | |||||
| eval_batch_size, | |||||
| len(test) // eval_batch_size + 1, | |||||
| args, | |||||
| shuffle=args.shuffle, | |||||
| block_collate=use_block) | |||||
| args.do_test = True | |||||
| else: | |||||
| args.do_test = False | |||||
| return train, valid, test | |||||
| def build_multi_task_dataset(args, tokenizer): | |||||
| task_dirs = { | |||||
| 'mnli': 'MNLI', | |||||
| 'cola': 'CoLA', | |||||
| 'mrpc': 'MRPC', | |||||
| 'qnli': 'QNLI', | |||||
| 'qqp': 'QQP', | |||||
| 'sst2': 'SST-2', | |||||
| 'agnews': 'Agnews', | |||||
| 'yelp-polarity': 'yelp_review_polarity_csv', | |||||
| 'yelp-full': 'yelp_review_full_csv', | |||||
| 'yahoo': 'Yahoo', | |||||
| 'squad': 'SQuAD', | |||||
| 'race': 'RACE' | |||||
| } | |||||
| train, valid = None, None | |||||
| if mpu.get_model_parallel_rank() == 0: | |||||
| multi_seq_length = args.seq_length | |||||
| if args.multi_seq_length is not None: | |||||
| multi_seq_length = args.multi_seq_length | |||||
| train_datasets, valid_datasets = [], [] | |||||
| for task in args.multi_task_data: | |||||
| task = task.lower() | |||||
| data_dir = os.path.join(args.data_dir, task_dirs[task]) | |||||
| train_datasets.append( | |||||
| SuperGlueDataset( | |||||
| args, | |||||
| task, | |||||
| data_dir, | |||||
| multi_seq_length, | |||||
| 'train', | |||||
| tokenizer, | |||||
| pattern_ensemble=True)) | |||||
| valid_datasets.append( | |||||
| SuperGlueDataset( | |||||
| args, | |||||
| task, | |||||
| data_dir, | |||||
| multi_seq_length, | |||||
| 'dev', | |||||
| tokenizer, | |||||
| pattern_ensemble=True)) | |||||
| train = MultiTaskDataset(args.multi_task_data, train_datasets) | |||||
| valid = MultiTaskDataset(args.multi_task_data, valid_datasets) | |||||
| world_size = torch.distributed.get_world_size( | |||||
| group=mpu.get_data_parallel_group()) | |||||
| multi_batch_size = args.batch_size * world_size | |||||
| if args.multi_batch_size is not None: | |||||
| multi_batch_size = args.multi_batch_size * world_size | |||||
| train = make_data_loader( | |||||
| train, | |||||
| tokenizer, | |||||
| multi_batch_size, | |||||
| args.train_iters, | |||||
| args, | |||||
| shuffle=True) | |||||
| valid = make_data_loader( | |||||
| valid, | |||||
| tokenizer, | |||||
| multi_batch_size, | |||||
| args.train_iters, | |||||
| args, | |||||
| shuffle=True) | |||||
| return train, valid | |||||
| def get_split(args): | |||||
| """ | |||||
| Get dataset splits from comma separated string list | |||||
| """ | |||||
| splits = [] | |||||
| if args.split.find(',') != -1: | |||||
| splits = [float(s) for s in args.split.split(',')] | |||||
| elif args.split.find('/') != -1: | |||||
| splits = [float(s) for s in args.split.split('/')] | |||||
| else: | |||||
| splits = [float(args.split)] | |||||
| split_total = sum(splits) | |||||
| if split_total < 1.: | |||||
| splits.append(1 - split_total) | |||||
| while len(splits) < 3: | |||||
| splits.append(0.) | |||||
| splits = splits[:3] | |||||
| if args.valid_data is not None: | |||||
| splits[1] = 0. | |||||
| if args.test_data is not None: | |||||
| splits[2] = 0. | |||||
| final_sum = sum(splits) | |||||
| return [s / final_sum for s in splits] | |||||
| def configure_data(): | |||||
| """add cmdline flags for configuring datasets""" | |||||
| # These are options that are used by data_utils, but are either | |||||
| # deprecated or not meant to be exposed to the command line user. | |||||
| # These options are intneded to be set in code by specific scripts. | |||||
| defaults = { | |||||
| 'world_size': 1, | |||||
| 'rank': -1, | |||||
| 'persist_state': 0, | |||||
| 'lazy': False, | |||||
| 'transpose': False, | |||||
| 'data_set_type': 'supervised', | |||||
| 'seq_length': 256, | |||||
| 'eval_seq_length': 256, | |||||
| 'samples_per_shard': 100 | |||||
| } | |||||
| return DataConfig(defaults=defaults) | |||||
| @@ -0,0 +1,341 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """utils for creating datasets""" | |||||
| import math | |||||
| import os | |||||
| import random | |||||
| import time | |||||
| import torch | |||||
| from . import corpora | |||||
| from .datasets import (BertSentencepairDataset, BlockDataset, ConcatDataset, | |||||
| GPT2Dataset, ShuffleDataset, SplitDataset, XLDataset, | |||||
| split_ds) | |||||
| from .lazy_loader import (LazyLoader, LazyWriter, exists_lazy, exists_scatter, | |||||
| get_scatter_path) | |||||
| from .samplers import DistributedBatchSampler | |||||
| from .tokenization import (BertWordPieceTokenizer, CharacterLevelTokenizer, | |||||
| CommandToken, GPT2BPETokenizer, Tokenization, | |||||
| Tokenizer, make_tokenizer) | |||||
| TRAIN_DATA = 0 | |||||
| VAL_DATA = 1 | |||||
| TEST_DATA = 2 | |||||
| def should_split(split): | |||||
| """ | |||||
| given split proportions checks if should split | |||||
| Examples: | |||||
| >>> should_split([10,0,0]) | |||||
| False | |||||
| >>> should_split([1,.1,.2]) | |||||
| True | |||||
| """ | |||||
| return max(split) / sum(split) != 1. | |||||
| def get_ext(path): | |||||
| """gets path extension""" | |||||
| return os.path.splitext(path)[1] | |||||
| def get_dataset(name, | |||||
| tokenizer, | |||||
| pre_tokenize, | |||||
| data_parallel_rank, | |||||
| loader_scatter=None, | |||||
| no_lazy_loader=False, | |||||
| half_lazy_loader=False): | |||||
| """gets dataset object based on keyword args and file at `path`""" | |||||
| global_rank = torch.distributed.get_rank() | |||||
| if not supported_corpus(name): | |||||
| raise NotImplementedError('dataset %s is not supported' % name) | |||||
| dataset = corpora.NAMED_CORPORA[name] | |||||
| path = dataset.PATH | |||||
| if issubclass(dataset, corpora.PromptReader): | |||||
| if not (exists_lazy(path, data_type='prompt') | |||||
| and exists_lazy(path, data_type='text')) and not ( | |||||
| loader_scatter is not None and exists_scatter( | |||||
| path, data_type='prompt', scatter_num=loader_scatter) | |||||
| and exists_scatter( | |||||
| path, data_type='text', scatter_num=loader_scatter)): | |||||
| # create cached version of dataset for lazy loading if it doesn't exist | |||||
| if global_rank == 0: | |||||
| print(f'Creating lazy loader for dataset {name}') | |||||
| prompt_writer = LazyWriter( | |||||
| path, data_type='prompt', is_array=pre_tokenize) | |||||
| text_writer = LazyWriter( | |||||
| path, data_type='text', is_array=pre_tokenize) | |||||
| writers = {'prompt': prompt_writer, 'text': text_writer} | |||||
| reader = dataset( | |||||
| writers=writers, | |||||
| tokenizer=tokenizer, | |||||
| tokenize=pre_tokenize) | |||||
| reader.process() | |||||
| prompt_writer.close() | |||||
| text_writer.close() | |||||
| else: | |||||
| while not os.path.exists( | |||||
| LazyWriter.get_len_path(path, data_type='prompt')): | |||||
| time.sleep(1) | |||||
| map_fn = (lambda x: x.tolist()) if pre_tokenize else None | |||||
| if loader_scatter is not None: | |||||
| if not (exists_scatter( | |||||
| path, data_type='prompt', scatter_num=loader_scatter) | |||||
| and exists_scatter( | |||||
| path, data_type='text', scatter_num=loader_scatter)): | |||||
| if global_rank == 0: | |||||
| print(f'Creating scatter loader for dataset {name}') | |||||
| prompts = LazyLoader( | |||||
| path, | |||||
| data_type='prompt', | |||||
| map_fn=map_fn, | |||||
| mem_map=True, | |||||
| is_array=pre_tokenize) | |||||
| texts = LazyLoader( | |||||
| path, | |||||
| data_type='text', | |||||
| map_fn=map_fn, | |||||
| mem_map=True, | |||||
| is_array=pre_tokenize) | |||||
| indices = list(range(len(texts))) | |||||
| random.shuffle(indices) | |||||
| segment_length = (len(indices) - 1) // loader_scatter + 1 | |||||
| for i in range(loader_scatter): | |||||
| scatter_path = get_scatter_path(path, scatter_rank=i) | |||||
| prompt_writer = LazyWriter( | |||||
| scatter_path, | |||||
| data_type='prompt', | |||||
| is_array=pre_tokenize) | |||||
| text_writer = LazyWriter( | |||||
| scatter_path, | |||||
| data_type='text', | |||||
| is_array=pre_tokenize) | |||||
| for idx in indices[i * segment_length:(i + 1) | |||||
| * segment_length]: | |||||
| prompt_writer.write(prompts[idx]) | |||||
| text_writer.write(texts[idx]) | |||||
| prompt_writer.close() | |||||
| text_writer.close() | |||||
| else: | |||||
| while not (exists_scatter( | |||||
| path, data_type='prompt', | |||||
| scatter_num=loader_scatter) and exists_scatter( | |||||
| path, | |||||
| data_type='text', | |||||
| scatter_num=loader_scatter)): | |||||
| time.sleep(1) | |||||
| scatter_path = get_scatter_path( | |||||
| path, scatter_rank=data_parallel_rank % loader_scatter) | |||||
| print(f'Rank {global_rank} is using scatter from {scatter_path}') | |||||
| prompts = LazyLoader( | |||||
| scatter_path, | |||||
| data_type='prompt', | |||||
| map_fn=map_fn, | |||||
| mem_map=True, | |||||
| is_array=pre_tokenize, | |||||
| load_memory=no_lazy_loader, | |||||
| half_load=half_lazy_loader) | |||||
| texts = LazyLoader( | |||||
| scatter_path, | |||||
| data_type='text', | |||||
| map_fn=map_fn, | |||||
| mem_map=True, | |||||
| is_array=pre_tokenize, | |||||
| load_memory=no_lazy_loader, | |||||
| half_load=half_lazy_loader) | |||||
| else: | |||||
| prompts = LazyLoader( | |||||
| path, | |||||
| data_type='prompt', | |||||
| map_fn=map_fn, | |||||
| mem_map=True, | |||||
| is_array=pre_tokenize, | |||||
| load_memory=no_lazy_loader, | |||||
| half_load=half_lazy_loader) | |||||
| texts = LazyLoader( | |||||
| path, | |||||
| data_type='text', | |||||
| map_fn=map_fn, | |||||
| mem_map=True, | |||||
| is_array=pre_tokenize, | |||||
| load_memory=no_lazy_loader, | |||||
| half_load=half_lazy_loader) | |||||
| text = corpora.PromptDataset( | |||||
| prompt_loader=prompts, | |||||
| text_loader=texts, | |||||
| tokenizer=tokenizer, | |||||
| to_tokenize=not pre_tokenize) | |||||
| if loader_scatter is None: | |||||
| if global_rank == 0: | |||||
| print(f'Create dataset {name} with {len(text)} documents') | |||||
| for i in range(10): | |||||
| rand_id = i if i < 5 else random.randrange(len(text)) | |||||
| sample_tokens = text[rand_id]['tokens'][:1024] | |||||
| print(sample_tokens) | |||||
| print(tokenizer.DecodeIds(sample_tokens).encode('utf-8')) | |||||
| else: | |||||
| for scatter_id in range(loader_scatter): | |||||
| if data_parallel_rank % loader_scatter == scatter_id and data_parallel_rank // loader_scatter == 0: | |||||
| print( | |||||
| f'Create dataset {name} at scatter {scatter_id} with {len(text)} documents' | |||||
| ) | |||||
| for i in range(10): | |||||
| sample_tokens = text[i]['tokens'][:1024] | |||||
| print(sample_tokens) | |||||
| print(tokenizer.DecodeIds(sample_tokens)) | |||||
| torch.distributed.barrier() | |||||
| return text | |||||
| elif issubclass(dataset, corpora.KeyReader): | |||||
| if not (exists_lazy(path, data_type='text') | |||||
| and exists_lazy(path, data_type='mask')): | |||||
| # create cached version of dataset for lazy loading if it doesn't exist | |||||
| if global_rank == 0: | |||||
| text_writer = LazyWriter( | |||||
| path, data_type='text', is_array=pre_tokenize) | |||||
| mask_writer = LazyWriter(path, data_type='mask', is_array=True) | |||||
| writers = {'mask': mask_writer, 'text': text_writer} | |||||
| dataset( | |||||
| writers=writers, | |||||
| tokenizer=tokenizer, | |||||
| tokenize=pre_tokenize) | |||||
| mask_writer.close() | |||||
| text_writer.close() | |||||
| else: | |||||
| while not os.path.exists( | |||||
| LazyWriter.get_len_path(path, data_type='mask')): | |||||
| time.sleep(1) | |||||
| map_fn = (lambda x: x.tolist()) if pre_tokenize else None | |||||
| masks = LazyLoader( | |||||
| path, data_type='mask', map_fn=map_fn, mem_map=True, is_array=True) | |||||
| texts = LazyLoader( | |||||
| path, | |||||
| data_type='text', | |||||
| map_fn=map_fn, | |||||
| mem_map=True, | |||||
| is_array=pre_tokenize) | |||||
| text = corpora.KeyDataset( | |||||
| mask_loader=masks, | |||||
| text_loader=texts, | |||||
| tokenizer=tokenizer, | |||||
| to_tokenize=not pre_tokenize) | |||||
| return text | |||||
| def supported_corpus(corpus_name): | |||||
| """checks if corpus name is defined in `corpora.py`""" | |||||
| return corpus_name in corpora.NAMED_CORPORA | |||||
| def make_dataset(path, | |||||
| seq_length, | |||||
| mem_length, | |||||
| shuffle=True, | |||||
| split=None, | |||||
| tokenizer=None, | |||||
| sample_one_document=False, | |||||
| pre_tokenize=False, | |||||
| ds_type='', | |||||
| save_splits=None, | |||||
| load_splits=None, | |||||
| save_test_data=None, | |||||
| no_lazy_loader=False, | |||||
| loader_scatter=None, | |||||
| data_parallel_rank=None, | |||||
| filter_english=False, | |||||
| non_sentence_start=0.0, | |||||
| half_lazy_loader=False, | |||||
| **kwargs): | |||||
| """function to create datasets+tokenizers for common options""" | |||||
| if split is None: | |||||
| split = [1.] | |||||
| # get one or multiple datasets and concatenate | |||||
| if isinstance(path, str): | |||||
| ds = get_dataset( | |||||
| path, | |||||
| tokenizer=tokenizer, | |||||
| pre_tokenize=pre_tokenize, | |||||
| no_lazy_loader=no_lazy_loader, | |||||
| loader_scatter=loader_scatter, | |||||
| data_parallel_rank=data_parallel_rank, | |||||
| half_lazy_loader=half_lazy_loader) | |||||
| else: | |||||
| ds = [ | |||||
| get_dataset( | |||||
| p, | |||||
| tokenizer=tokenizer, | |||||
| pre_tokenize=pre_tokenize, | |||||
| no_lazy_loader=no_lazy_loader, | |||||
| loader_scatter=loader_scatter, | |||||
| data_parallel_rank=data_parallel_rank, | |||||
| half_lazy_loader=half_lazy_loader) for p in path | |||||
| ] | |||||
| ds = ConcatDataset(ds) | |||||
| # Split dataset into train/val/test (and wrap bert dataset) | |||||
| def wrap_dataset(dataset): | |||||
| if ds_type.lower() == 'bert': | |||||
| presplit_sentences = kwargs[ | |||||
| 'presplit_sentences'] if 'presplit_sentences' in kwargs else False | |||||
| dataset = BertSentencepairDataset( | |||||
| dataset, | |||||
| max_seq_len=seq_length, | |||||
| presplit_sentences=presplit_sentences) | |||||
| elif ds_type.lower() == 'gpt-xl': | |||||
| assert pre_tokenize | |||||
| dataset = XLDataset( | |||||
| dataset, | |||||
| tokenizer, | |||||
| max_seq_len=seq_length, | |||||
| mem_len=mem_length, | |||||
| sample_across_doc=not sample_one_document) | |||||
| elif ds_type.lower() == 'gpt2': | |||||
| dataset = GPT2Dataset( | |||||
| dataset, | |||||
| tokenizer, | |||||
| max_seq_len=seq_length, | |||||
| sample_across_doc=not sample_one_document) | |||||
| elif ds_type.lower() == 'block': | |||||
| dataset = BlockDataset( | |||||
| dataset, | |||||
| tokenizer, | |||||
| max_seq_len=seq_length, | |||||
| sample_across_doc=not sample_one_document, | |||||
| filter_english=filter_english, | |||||
| non_sentence_start=non_sentence_start) | |||||
| return dataset | |||||
| if should_split(split): | |||||
| ds = split_ds( | |||||
| ds, | |||||
| split, | |||||
| shuffle=shuffle, | |||||
| save_splits=save_splits, | |||||
| load_splits=load_splits) | |||||
| if save_test_data is not None and torch.distributed.get_rank() == 0: | |||||
| test_ds = ds[-1] | |||||
| with open(save_test_data, 'w', encoding='utf-8') as output: | |||||
| for data in test_ds: | |||||
| text = data['tokens'] | |||||
| text = tokenizer.DecodeIds(text) | |||||
| output.write(text) | |||||
| output.write('\n') | |||||
| print(f'Write test data to {save_test_data}') | |||||
| ds = [wrap_dataset(d) if d is not None else None for d in ds] | |||||
| else: | |||||
| ds = wrap_dataset(ds) | |||||
| return ds | |||||
| @@ -0,0 +1,583 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """several datasets with preset arguments""" | |||||
| import os | |||||
| import random | |||||
| from collections import defaultdict | |||||
| from multiprocessing import Process, Queue | |||||
| from queue import Empty | |||||
| import json | |||||
| import tqdm | |||||
| from torch.utils import data | |||||
| from modelscope.models.nlp.mglm.utils import print_rank_0 | |||||
| from .datasets import csv_dataset, json_dataset | |||||
| from .lazy_loader import LazyLoader | |||||
| NUM_PROCESSES = 100 | |||||
| def punctuation_standardization(string: str): | |||||
| punctuation_dict = { | |||||
| '\u201c': "\"", | |||||
| '\u201d': "\"", | |||||
| '\u2019': "'", | |||||
| '\u2018': "'", | |||||
| '\u2013': '-' | |||||
| } | |||||
| for key, value in punctuation_dict.items(): | |||||
| string = string.replace(key, value) | |||||
| return string | |||||
| class KeyDataset(data.Dataset): | |||||
| def __init__(self, text_loader, mask_loader, **kwargs): | |||||
| self.texts = text_loader | |||||
| self.masks = mask_loader | |||||
| self.is_lazy = False | |||||
| if isinstance(self.texts, LazyLoader) and isinstance( | |||||
| self.masks, LazyLoader): | |||||
| self.text_lens = self.texts.lens | |||||
| self.is_lazy = True | |||||
| def get_text_len(self, idx): | |||||
| return self.text_lens[idx] | |||||
| def __getitem__(self, index): | |||||
| text = self.texts[index] | |||||
| mask_length = self.masks[index] | |||||
| mask = [] | |||||
| for i, length in enumerate(mask_length): | |||||
| if i % 2 == 0: | |||||
| mask += [0] * length | |||||
| else: | |||||
| mask += [1] * length | |||||
| assert len(text) == len(mask) | |||||
| return {'tokens': text, 'loss_masks': mask} | |||||
| def __len__(self): | |||||
| return len(self.texts) | |||||
| class PromptDataset(data.Dataset): | |||||
| def __init__(self, | |||||
| prompt_loader, | |||||
| text_loader, | |||||
| tokenizer=None, | |||||
| to_tokenize=False, | |||||
| **kwargs): | |||||
| self.prompts = prompt_loader | |||||
| self.texts = text_loader | |||||
| self.tokenizer = tokenizer | |||||
| self.to_tokenize = to_tokenize | |||||
| if isinstance(self.prompts, LazyLoader) and isinstance( | |||||
| self.texts, LazyLoader): | |||||
| self.prompt_lens = self.prompts.lens | |||||
| self.text_lens = self.texts.lens | |||||
| self.is_lazy = True | |||||
| def get_text_len(self, idx): | |||||
| return self.prompt_lens[idx] + self.text_lens[idx] | |||||
| def __getitem__(self, index): | |||||
| prompt = self.prompts[index] | |||||
| text = self.texts[index] | |||||
| if self.to_tokenize: | |||||
| prompt = self.tokenizer.EncodeAsIds(prompt).tokenization | |||||
| text = self.tokenizer.EncodeAsIds(text).tokenization | |||||
| return { | |||||
| 'tokens': prompt + text, | |||||
| 'loss_masks': [0] * len(prompt) + [1] * len(text) | |||||
| } | |||||
| def __len__(self): | |||||
| return len(self.prompts) | |||||
| class DataReader: | |||||
| PATH = None | |||||
| assert_str = None | |||||
| reserve_punct = False | |||||
| split_row = True | |||||
| TASK_QUEUE_LIMIT = 10000000 | |||||
| DONE_QUEUE_LIMIT = 10000000 | |||||
| def tokenize_worker(self, input, output, info, tokenizer, tokenize): | |||||
| raise NotImplementedError | |||||
| def print_info(self, info): | |||||
| pass | |||||
| def __init__(self, writers, tokenizer=None, tokenize=False, **kwargs): | |||||
| print(self.PATH) | |||||
| print(self.assert_str) | |||||
| assert os.path.exists(self.PATH), self.assert_str | |||||
| print_rank_0(f'Creating dataset from {self.PATH}') | |||||
| self.tokenizer = tokenizer | |||||
| self.tokenize = tokenize | |||||
| self.writers = writers | |||||
| def process(self): | |||||
| if os.path.isdir(self.PATH): | |||||
| paths = [ | |||||
| os.path.join(top, name) for top, _, names in os.walk(self.PATH) | |||||
| for name in names | |||||
| ] | |||||
| # paths = [entry.path for entry in os.scandir(self.PATH) if | |||||
| # not entry.is_dir() and not entry.name.endswith("bz2")] | |||||
| else: | |||||
| paths = [self.PATH] | |||||
| task_queue, done_queue, info_queue = Queue( | |||||
| maxsize=self.TASK_QUEUE_LIMIT), Queue( | |||||
| maxsize=self.DONE_QUEUE_LIMIT), Queue() | |||||
| processes = [] | |||||
| for i in range(NUM_PROCESSES): | |||||
| process = Process( | |||||
| target=self.tokenize_worker, | |||||
| args=(task_queue, done_queue, info_queue, self.tokenizer, | |||||
| self.tokenize)) | |||||
| process.start() | |||||
| processes.append(process) | |||||
| def read_input_to_queue(): | |||||
| for path in paths: | |||||
| print_rank_0(f'Start reading {path}') | |||||
| with open(path) as file: | |||||
| items = json.load(file) | |||||
| for item in items: | |||||
| task_queue.put(item) | |||||
| # if self.split_row: | |||||
| # for row in file: | |||||
| # task_queue.put(row) | |||||
| # else: | |||||
| # items = json.load(file) | |||||
| # for item in items["RECORDS"]: | |||||
| # task_queue.put(item) | |||||
| print_rank_0('Read input complete') | |||||
| for i in range(len(processes)): | |||||
| task_queue.put('STOP') | |||||
| process = Process(target=read_input_to_queue) | |||||
| process.start() | |||||
| count = len(processes) | |||||
| progress_bar = tqdm.tqdm() | |||||
| while True: | |||||
| data = done_queue.get() | |||||
| if data == 'COMPLETE': | |||||
| count -= 1 | |||||
| if count == 0: | |||||
| break | |||||
| else: | |||||
| self.write_result(data, self.writers) | |||||
| progress_bar.update() | |||||
| progress_bar.close() | |||||
| self.print_info(info_queue) | |||||
| @staticmethod | |||||
| def write_result(data, writers): | |||||
| raise NotImplementedError | |||||
| @staticmethod | |||||
| def get_token_count(contents): | |||||
| return sum(map(len, contents)) | |||||
| @classmethod | |||||
| def process_sample(cls, text, tokenizer, tokenize): | |||||
| if isinstance(text, str) and tokenize: | |||||
| if not cls.reserve_punct: | |||||
| text = punctuation_standardization(text) | |||||
| text = tokenizer.EncodeAsIds(text).tokenization if text else [] | |||||
| return text | |||||
| @staticmethod | |||||
| def trim_field(content, max_length): | |||||
| if len(content) > max_length: | |||||
| content = content[:max_length] | |||||
| content += '......' | |||||
| return content | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| raise NotImplementedError | |||||
| class PromptReader(DataReader): | |||||
| is_json = True | |||||
| def tokenize_worker(self, input, output, info, tokenizer, tokenize): | |||||
| for row in iter(input.get, 'STOP'): | |||||
| if row: | |||||
| if self.is_json: | |||||
| row = row.rstrip() | |||||
| row = json.loads(row) | |||||
| prompts, texts = self.process_line(row, tokenizer, tokenize) | |||||
| for prompt, text in zip(prompts, texts): | |||||
| output.put((prompt, text)) | |||||
| output.put('COMPLETE') | |||||
| @staticmethod | |||||
| def write_result(data, writers): | |||||
| prompt, text = data | |||||
| writers['prompt'].write(prompt) | |||||
| writers['text'].write(text) | |||||
| class KeyReader(DataReader): | |||||
| PATH = '/root/data/wikipedia/wiki-key.txt' | |||||
| assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| keys, contents = data['key'], data['content'] | |||||
| assert len(keys) == len(contents) | |||||
| for i in range(1, len(keys)): | |||||
| keys[i] = ' ' + keys[i] | |||||
| contents = [' ' + content for content in contents] | |||||
| keys = [tokenizer.EncodeAsIds(key).tokenization for key in keys] | |||||
| contents = [ | |||||
| tokenizer.EncodeAsIds(content).tokenization for content in contents | |||||
| ] | |||||
| summary = sum(keys, []) | |||||
| summary_prefix = self.process_sample('Summary: ', tokenizer, tokenize) | |||||
| summary_mask = [len(summary_prefix), len(summary)] | |||||
| summary = summary_prefix + summary | |||||
| text, text_mask = [], [] | |||||
| for key, content in zip(keys, contents): | |||||
| content = content + [tokenizer.get_command('eop').Id] | |||||
| text += key | |||||
| text += content | |||||
| text_mask.append(len(key)) | |||||
| text_mask.append(len(content)) | |||||
| return (summary, summary_mask), (text, text_mask) | |||||
| def tokenize_worker(self, input, output, info, tokenizer, tokenize): | |||||
| for row in iter(input.get, 'STOP'): | |||||
| data = json.loads(row) | |||||
| summary, content = self.process_line(data, tokenizer, tokenize) | |||||
| output.put((summary, content)) | |||||
| output.put('COMPLETE') | |||||
| @staticmethod | |||||
| def write_result(data, writers): | |||||
| summary, content = data | |||||
| writers['text'].write(summary[0]) | |||||
| writers['mask'].write(summary[1]) | |||||
| writers['text'].write(content[0]) | |||||
| writers['mask'].write(content[1]) | |||||
| class zhihu(PromptReader): | |||||
| PATH = '/dataset/fd5061f6/data/tokenize_data/zhihu.lazy' | |||||
| reserve_punct = True | |||||
| assert_str = 'make sure to set PATH for zhihu data_utils/corpora.py' | |||||
| qtitle_prefix = '问题:' | |||||
| qcontent_prefix = '问题描述:' | |||||
| user_prefix = '回答用户:' | |||||
| answer_prefix = ' 回答:' | |||||
| # qtitle_prefix = [] | |||||
| # qcontent_prefix = [] | |||||
| # user_prefix = [] | |||||
| # answer_prefix = [] | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| prompts, texts = [], [] | |||||
| ans_length = len(data.get('ans-content', '')) | |||||
| ans_up = data.get('ans-up-num', '') | |||||
| ans_up = int(ans_up) if ans_up else 0 | |||||
| if ans_length > 100 or ans_up > 1000: | |||||
| qtitle = data['q_title'] | |||||
| qcontent = data['q-content'] | |||||
| if qcontent is None: | |||||
| qcontent = '' | |||||
| qcontent = self.trim_field(qcontent, max_length=100) | |||||
| user = data.get('user-signature', '') | |||||
| prompt = self.qtitle_prefix + qtitle + self.qcontent_prefix + qcontent + self.user_prefix + user + self.answer_prefix # noqa | |||||
| text = data['ans-content'] | |||||
| prompt, text = self.process_sample(prompt, tokenizer, | |||||
| tokenize), self.process_sample( | |||||
| text, tokenizer, tokenize) | |||||
| prompts.append(prompt) | |||||
| texts.append(text) | |||||
| # prompt = data["q_title"] + data["q-content"] + data["user-signature"] | |||||
| # text = data["ans-content"] | |||||
| # prompts.append(prompt) | |||||
| # texts.append(text) | |||||
| return prompts, texts | |||||
| class zhidao(PromptReader): | |||||
| PATH = '/root/data/zhidao/zhidao' | |||||
| reserve_punct = True | |||||
| assert_str = 'make sure to set PATH for zhidao data_utils/corpora.py' | |||||
| qtitle_prefix = '问题:' | |||||
| qcontent_prefix = '问题描述:' | |||||
| answer_prefix = '回答:' | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| if 'title' not in data: | |||||
| return [], [] | |||||
| prompts, texts = [], [] | |||||
| qtitle = data['title'] | |||||
| qcontent = data.get('content', '') | |||||
| qcontent = self.trim_field(qcontent, max_length=100) | |||||
| prompt = self.qtitle_prefix + qtitle + self.qcontent_prefix + qcontent + self.answer_prefix | |||||
| prompt = self.process_sample(prompt, tokenizer, tokenize) | |||||
| if 'best_answer' in data: | |||||
| text = data['best_answer']['content'] | |||||
| if len(text) > 10: | |||||
| text = self.process_sample(text, tokenizer, tokenize) | |||||
| prompts.append(prompt) | |||||
| texts.append(text) | |||||
| for answer in data.get('other_answers', []): | |||||
| text = answer['content'] | |||||
| if len(text) > 100: | |||||
| text = self.process_sample(text, tokenizer, tokenize) | |||||
| prompts.append(prompt) | |||||
| texts.append(text) | |||||
| return prompts, texts | |||||
| class baike(PromptReader): | |||||
| PATH = '/dataset/fd5061f6/data/tokenize_data/baike.lazy' | |||||
| reserve_punct = True | |||||
| assert_str = 'make sure to set PATH for baike data_utils/corpora.py' | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| prompts, texts = [], [] | |||||
| text = data.get('title', '') + data.get('abstract', '') + data.get( | |||||
| 'content', '') | |||||
| if text: | |||||
| p, t = self.process_sample('', tokenizer, | |||||
| tokenize), self.process_sample( | |||||
| text, tokenizer, tokenize) | |||||
| prompts.append(p) | |||||
| texts.append(t) | |||||
| return prompts, texts | |||||
| class wikipedia(PromptReader): | |||||
| """ | |||||
| dataset for wikipedia with arguments configured for convenience | |||||
| command line usage: `--train-data wikipedia` | |||||
| """ | |||||
| # PATH = '/dataset/data/wiki.txt' | |||||
| PATH = '/root/data/bert_data/wiki.txt' | |||||
| assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| text = data['text'] | |||||
| prompt, text = self.process_sample('', tokenizer, | |||||
| tokenize), self.process_sample( | |||||
| text, tokenizer, tokenize) | |||||
| return [prompt], [text] | |||||
| class TestDataset(PromptReader): | |||||
| PATH = '/root/data/test.json' | |||||
| assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| prompt, text = data['prompt'], data['text'] | |||||
| prompt, text = self.process_sample(prompt, tokenizer, | |||||
| tokenize), self.process_sample( | |||||
| text, tokenizer, tokenize) | |||||
| return [prompt], [text] | |||||
| class OpenWebText(PromptReader): | |||||
| PATH = '/dataset/fd5061f6/english_data/openwebtext2' | |||||
| assert_str = 'make sure to set PATH for openwebtext data_utils/corpora.py' | |||||
| def __init__(self, *args, **kwargs): | |||||
| import fasttext | |||||
| super().__init__(*args, **kwargs) | |||||
| self.model = fasttext.load_model( | |||||
| '/dataset/fd5061f6/english_data/lid.176.bin') | |||||
| print_rank_0('Load language detection model') | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| text = data['text'] | |||||
| if len(text) > 100: | |||||
| lang = self.model.predict(text.replace('\n', ''))[0][0] | |||||
| if lang == '__label__en': | |||||
| prompt, text = self.process_sample( | |||||
| '', tokenizer, | |||||
| tokenize), self.process_sample(text, tokenizer, tokenize) | |||||
| return [prompt], [text] | |||||
| return [], [] | |||||
| class CCNews(PromptReader): | |||||
| PATH = '/mnt/cc_news.json' | |||||
| assert_str = 'make sure to set PATH for cc-news data_utils/corpora.py' | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| text = '' | |||||
| title = data.get('title', None) | |||||
| description = data.get('description', None) | |||||
| maintext = data.get('maintext', None) | |||||
| if title: | |||||
| text += title.strip() + ' ' | |||||
| if description and (not maintext | |||||
| or not maintext.startswith(description)): | |||||
| text += description.strip() + ' ' | |||||
| if maintext: | |||||
| text += maintext | |||||
| if len(text) > 100: | |||||
| prompt, text = self.process_sample('', tokenizer, | |||||
| tokenize), self.process_sample( | |||||
| text, tokenizer, tokenize) | |||||
| return [prompt], [text] | |||||
| else: | |||||
| return [], [] | |||||
| class BertData(PromptReader): | |||||
| is_json = False | |||||
| PATH = '/dataset/fd5061f6/english_data/wikibook' | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| if data: | |||||
| prompt, text = '', data | |||||
| prompt, text = self.process_sample(prompt, tokenizer, | |||||
| tokenize), self.process_sample( | |||||
| text, tokenizer, tokenize) | |||||
| return [prompt], [text] | |||||
| else: | |||||
| return [], [] | |||||
| class Pile(PromptReader): | |||||
| is_json = True | |||||
| PATH = '/mnt/train' | |||||
| filtered_sources = [ | |||||
| 'Github', 'StackExchange', 'DM Mathematics', 'Ubuntu IRC', 'EuroParl', | |||||
| 'YoutubeSubtitles', 'Enron Emails' | |||||
| ] | |||||
| downsample_sources = {'PubMed Central': 0.3, 'ArXiv': 0.3, 'FreeLaw': 0.3} | |||||
| def print_info(self, info): | |||||
| total_dict = defaultdict(int) | |||||
| while True: | |||||
| try: | |||||
| source_dict = info.get(block=False) | |||||
| for source, length in source_dict.items(): | |||||
| total_dict[source] += length | |||||
| except Empty: | |||||
| break | |||||
| print_rank_0(total_dict) | |||||
| def tokenize_worker(self, input, output, info, tokenizer, tokenize): | |||||
| source_dict = defaultdict(int) | |||||
| for row in iter(input.get, 'STOP'): | |||||
| row = row.rstrip() | |||||
| if row: | |||||
| if self.is_json: | |||||
| row = json.loads(row) | |||||
| prompts, texts, source = self.process_line( | |||||
| row, tokenizer, tokenize) | |||||
| length = 0 | |||||
| for prompt, text in zip(prompts, texts): | |||||
| length += len(text) | |||||
| output.put((prompt, text)) | |||||
| if source: | |||||
| source_dict[source] += length | |||||
| output.put('COMPLETE') | |||||
| info.put(source_dict) | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| source = data['meta'].get('pile_set_name', None) | |||||
| text = data.get('text', None) | |||||
| if source and text: | |||||
| if source in self.filtered_sources: | |||||
| return [], [], None | |||||
| elif source in self.downsample_sources and random.random( | |||||
| ) > self.downsample_sources[source]: | |||||
| return [], [], None | |||||
| else: | |||||
| prompt, text = self.process_sample( | |||||
| '', tokenizer, | |||||
| tokenize), self.process_sample(text, tokenizer, tokenize) | |||||
| return [prompt], [text], source | |||||
| else: | |||||
| return [], [], None | |||||
| class Stories(PromptReader): | |||||
| is_json = True | |||||
| PATH = '/dataset/fd5061f6/english_data/stories_31G.jsonl' | |||||
| def process_line(self, data, tokenizer, tokenize): | |||||
| text = data.get('text', None) | |||||
| if text: | |||||
| prompt, text = self.process_sample('', tokenizer, | |||||
| tokenize), self.process_sample( | |||||
| text, tokenizer, tokenize) | |||||
| return [prompt], [text] | |||||
| else: | |||||
| return [], [] | |||||
| class BertBaseData(BertData): | |||||
| PATH = '/root/data/formatted_one_article_per_line' | |||||
| class BertLargeData(BertData): | |||||
| PATH = '/dataset/c07bd62b/cognitive/zhengxiao/formatted_one_article_per_line_large' | |||||
| class WuDaoCorpus(PromptReader): | |||||
| # PATH = "/dataset/fd5061f6/chinese_data/WuDao" | |||||
| PATH = '/wudao' | |||||
| is_json = False | |||||
| reserve_punct = True | |||||
| split_row = False | |||||
| def process_line(self, item, tokenizer, tokenize): | |||||
| prompts, texts = [], [] | |||||
| text = '' | |||||
| title = item.get('title', None) | |||||
| content = item.get('content', None) | |||||
| if title: | |||||
| text += title.strip() + ' ' | |||||
| if content: | |||||
| text += content | |||||
| if len(text) > 100: | |||||
| prompt, text = self.process_sample('', tokenizer, | |||||
| tokenize), self.process_sample( | |||||
| text, tokenizer, tokenize) | |||||
| prompts.append(prompt) | |||||
| texts.append(text) | |||||
| return prompts, texts | |||||
| NAMED_CORPORA = { | |||||
| 'wikipedia': wikipedia, | |||||
| 'wikipedia-key': KeyReader, | |||||
| 'openwebtext': OpenWebText, | |||||
| 'zhihu': zhihu, | |||||
| 'zhidao': zhidao, | |||||
| 'baike': baike, | |||||
| 'test': TestDataset, | |||||
| 'wikibook': BertData, | |||||
| 'bert-base': BertBaseData, | |||||
| 'bert-large': BertLargeData, | |||||
| 'cc-news': CCNews, | |||||
| 'pile': Pile, | |||||
| 'stories': Stories, | |||||
| 'wudao': WuDaoCorpus | |||||
| } | |||||
| @@ -0,0 +1,71 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import glob | |||||
| import os | |||||
| import json | |||||
| import nltk | |||||
| nltk.download('punkt') | |||||
| class NLTKSegmenter: | |||||
| def __init(self): | |||||
| pass | |||||
| @staticmethod | |||||
| def segment_string(article): | |||||
| return nltk.tokenize.sent_tokenize(article) | |||||
| wiki_path = 'data/extracted' | |||||
| output_path = 'formatted/wiki-key.txt' | |||||
| segmenter = NLTKSegmenter() | |||||
| with open(output_path, 'w') as output: | |||||
| for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False): | |||||
| for filename in glob.glob( | |||||
| os.path.join(dirname, 'wiki_*'), recursive=True): | |||||
| print(filename) | |||||
| article_lines = [] | |||||
| article_open = False | |||||
| with open(filename, mode='r', newline='\n') as file: | |||||
| for line in file: | |||||
| line = line.rstrip() | |||||
| if '<doc id=' in line: | |||||
| article_open = True | |||||
| elif '</doc>' in line: | |||||
| key_sentences, contents = [], [] | |||||
| key, content = None, [] | |||||
| for sentences in article_lines[1:]: | |||||
| if len(sentences) > 1: | |||||
| if key: | |||||
| if len(content) > 0 or len(contents) == 0: | |||||
| key_sentences.append(key) | |||||
| contents.append(content) | |||||
| else: | |||||
| contents[-1].append(key) | |||||
| key, content = None, [] | |||||
| key_sentences.append(sentences[0]) | |||||
| contents.append(sentences[1:]) | |||||
| elif len(sentences) > 0: | |||||
| if key: | |||||
| content.append(sentences[0]) | |||||
| else: | |||||
| key = sentences[0] | |||||
| if key: | |||||
| if len(content) > 0 or len(contents) == 0: | |||||
| key_sentences.append(key) | |||||
| contents.append(content) | |||||
| else: | |||||
| contents[-1].append(key) | |||||
| contents = [' '.join(content) for content in contents] | |||||
| article = {'key': key_sentences, 'content': contents} | |||||
| output.write(json.dumps(article)) | |||||
| output.write('\n') | |||||
| article_open = False | |||||
| article_lines = [] | |||||
| else: | |||||
| if article_open and line: | |||||
| sentences = segmenter.segment_string(line) | |||||
| article_lines.append(sentences) | |||||
| @@ -0,0 +1,256 @@ | |||||
| # Modified by Zhipu.AI | |||||
| # This file is provided as is from: | |||||
| # https://github.com/huggingface/pytorch-pretrained-BERT | |||||
| # Please refer to their repository for copyright. | |||||
| """ | |||||
| Utilities for working with the local dataset cache. | |||||
| This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp | |||||
| Copyright by the AllenNLP authors. | |||||
| """ | |||||
| from __future__ import (absolute_import, division, print_function, | |||||
| unicode_literals) | |||||
| import logging | |||||
| import os | |||||
| import shutil | |||||
| import sys | |||||
| import tempfile | |||||
| from functools import wraps | |||||
| from hashlib import sha256 | |||||
| from io import open | |||||
| from urllib.parse import urlparse | |||||
| import boto3 | |||||
| import json | |||||
| import requests | |||||
| from botocore.exceptions import ClientError | |||||
| from tqdm import tqdm | |||||
| try: | |||||
| from pathlib import Path | |||||
| PYTORCH_PRETRAINED_BERT_CACHE = Path( | |||||
| os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', | |||||
| Path.home() / '.pytorch_pretrained_bert')) | |||||
| except (AttributeError, ImportError): | |||||
| PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( | |||||
| 'PYTORCH_PRETRAINED_BERT_CACHE', | |||||
| os.path.join(os.path.expanduser('~'), '.pytorch_pretrained_bert')) | |||||
| logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |||||
| def url_to_filename(url, etag=None): | |||||
| """ | |||||
| Convert `url` into a hashed filename in a repeatable way. | |||||
| If `etag` is specified, append its hash to the url's, delimited | |||||
| by a period. | |||||
| """ | |||||
| url_bytes = url.encode('utf-8') | |||||
| url_hash = sha256(url_bytes) | |||||
| filename = url_hash.hexdigest() | |||||
| if etag: | |||||
| etag_bytes = etag.encode('utf-8') | |||||
| etag_hash = sha256(etag_bytes) | |||||
| filename += '.' + etag_hash.hexdigest() | |||||
| return filename | |||||
| def filename_to_url(filename, cache_dir=None): | |||||
| """ | |||||
| Return the url and etag (which may be ``None``) stored for `filename`. | |||||
| Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. | |||||
| """ | |||||
| if cache_dir is None: | |||||
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | |||||
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |||||
| cache_dir = str(cache_dir) | |||||
| cache_path = os.path.join(cache_dir, filename) | |||||
| if not os.path.exists(cache_path): | |||||
| raise EnvironmentError('file {} not found'.format(cache_path)) | |||||
| meta_path = cache_path + '.json' | |||||
| if not os.path.exists(meta_path): | |||||
| raise EnvironmentError('file {} not found'.format(meta_path)) | |||||
| with open(meta_path, encoding='utf-8') as meta_file: | |||||
| metadata = json.load(meta_file) | |||||
| url = metadata['url'] | |||||
| etag = metadata['etag'] | |||||
| return url, etag | |||||
| def cached_path(url_or_filename, cache_dir=None): | |||||
| """ | |||||
| Given something that might be a URL (or might be a local path), | |||||
| determine which. If it's a URL, download the file and cache it, and | |||||
| return the path to the cached file. If it's already a local path, | |||||
| make sure the file exists and then return the path. | |||||
| """ | |||||
| if cache_dir is None: | |||||
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | |||||
| if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): | |||||
| url_or_filename = str(url_or_filename) | |||||
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |||||
| cache_dir = str(cache_dir) | |||||
| parsed = urlparse(url_or_filename) | |||||
| if parsed.scheme in ('http', 'https', 's3'): | |||||
| # URL, so get it from the cache (downloading if necessary) | |||||
| return get_from_cache(url_or_filename, cache_dir) | |||||
| elif os.path.exists(url_or_filename): | |||||
| # File, and it exists. | |||||
| return url_or_filename | |||||
| elif parsed.scheme == '': | |||||
| # File, but it doesn't exist. | |||||
| raise EnvironmentError('file {} not found'.format(url_or_filename)) | |||||
| else: | |||||
| # Something unknown | |||||
| raise ValueError( | |||||
| 'unable to parse {} as a URL or as a local path'.format( | |||||
| url_or_filename)) | |||||
| def split_s3_path(url): | |||||
| """Split a full s3 path into the bucket name and path.""" | |||||
| parsed = urlparse(url) | |||||
| if not parsed.netloc or not parsed.path: | |||||
| raise ValueError('bad s3 path {}'.format(url)) | |||||
| bucket_name = parsed.netloc | |||||
| s3_path = parsed.path | |||||
| # Remove '/' at beginning of path. | |||||
| if s3_path.startswith('/'): | |||||
| s3_path = s3_path[1:] | |||||
| return bucket_name, s3_path | |||||
| def s3_request(func): | |||||
| """ | |||||
| Wrapper function for s3 requests in order to create more helpful error | |||||
| messages. | |||||
| """ | |||||
| @wraps(func) | |||||
| def wrapper(url, *args, **kwargs): | |||||
| try: | |||||
| return func(url, *args, **kwargs) | |||||
| except ClientError as exc: | |||||
| if int(exc.response['Error']['Code']) == 404: | |||||
| raise EnvironmentError('file {} not found'.format(url)) | |||||
| else: | |||||
| raise | |||||
| return wrapper | |||||
| @s3_request | |||||
| def s3_etag(url): | |||||
| """Check ETag on S3 object.""" | |||||
| s3_resource = boto3.resource('s3') | |||||
| bucket_name, s3_path = split_s3_path(url) | |||||
| s3_object = s3_resource.Object(bucket_name, s3_path) | |||||
| return s3_object.e_tag | |||||
| @s3_request | |||||
| def s3_get(url, temp_file): | |||||
| """Pull a file directly from S3.""" | |||||
| s3_resource = boto3.resource('s3') | |||||
| bucket_name, s3_path = split_s3_path(url) | |||||
| s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) | |||||
| def http_get(url, temp_file): | |||||
| req = requests.get(url, stream=True) | |||||
| content_length = req.headers.get('Content-Length') | |||||
| total = int(content_length) if content_length is not None else None | |||||
| progress = tqdm(unit='B', total=total) | |||||
| for chunk in req.iter_content(chunk_size=1024): | |||||
| if chunk: # filter out keep-alive new chunks | |||||
| progress.update(len(chunk)) | |||||
| temp_file.write(chunk) | |||||
| progress.close() | |||||
| def get_from_cache(url, cache_dir=None): | |||||
| """ | |||||
| Given a URL, look for the corresponding dataset in the local cache. | |||||
| If it's not there, download it. Then return the path to the cached file. | |||||
| """ | |||||
| if cache_dir is None: | |||||
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | |||||
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |||||
| cache_dir = str(cache_dir) | |||||
| if not os.path.exists(cache_dir): | |||||
| os.makedirs(cache_dir) | |||||
| # Get eTag to add to filename, if it exists. | |||||
| if url.startswith('s3://'): | |||||
| etag = s3_etag(url) | |||||
| else: | |||||
| response = requests.head(url, allow_redirects=True) | |||||
| if response.status_code != 200: | |||||
| raise IOError( | |||||
| 'HEAD request failed for url {} with status code {}'.format( | |||||
| url, response.status_code)) | |||||
| etag = response.headers.get('ETag') | |||||
| filename = url_to_filename(url, etag) | |||||
| # get cache path to put the file | |||||
| cache_path = os.path.join(cache_dir, filename) | |||||
| if not os.path.exists(cache_path): | |||||
| # Download to temporary file, then copy to cache dir once finished. | |||||
| # Otherwise you get corrupt cache entries if the download gets interrupted. | |||||
| with tempfile.NamedTemporaryFile() as temp_file: | |||||
| logger.info('%s not found in cache, downloading to %s', url, | |||||
| temp_file.name) | |||||
| # GET file object | |||||
| if url.startswith('s3://'): | |||||
| s3_get(url, temp_file) | |||||
| else: | |||||
| http_get(url, temp_file) | |||||
| # we are copying the file before closing it, so flush to avoid truncation | |||||
| temp_file.flush() | |||||
| # shutil.copyfileobj() starts at the current position, so go to the start | |||||
| temp_file.seek(0) | |||||
| logger.info('copying %s to cache at %s', temp_file.name, | |||||
| cache_path) | |||||
| with open(cache_path, 'wb') as cache_file: | |||||
| shutil.copyfileobj(temp_file, cache_file) | |||||
| logger.info('creating metadata file for %s', cache_path) | |||||
| meta = {'url': url, 'etag': etag} | |||||
| meta_path = cache_path + '.json' | |||||
| with open(meta_path, 'w', encoding='utf-8') as meta_file: | |||||
| json.dump(meta, meta_file) | |||||
| logger.info('removing temp file %s', temp_file.name) | |||||
| return cache_path | |||||
| def read_set_from_file(filename): | |||||
| ''' | |||||
| Extract a de-duped collection (set) of text from a file. | |||||
| Expected file format is one item per line. | |||||
| ''' | |||||
| collection = set() | |||||
| with open(filename, 'r', encoding='utf-8') as file_: | |||||
| for line in file_: | |||||
| collection.add(line.rstrip()) | |||||
| return collection | |||||
| def get_file_extension(path, dot=True, lower=True): | |||||
| ext = os.path.splitext(path)[1] | |||||
| ext = ext if dot else ext[1:] | |||||
| return ext.lower() if lower else ext | |||||
| @@ -0,0 +1,286 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """utils for loading text from disk""" | |||||
| import mmap | |||||
| import os | |||||
| import pickle as pkl | |||||
| import time | |||||
| from itertools import accumulate | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch.multiprocessing import Lock | |||||
| def get_lazy_path(path): | |||||
| """ | |||||
| Gets directory path where lazy files are stored. | |||||
| """ | |||||
| return os.path.splitext(path)[0] + '.lazy' | |||||
| def exists_lazy(path, data_type='data'): | |||||
| """ | |||||
| Check if we've already made a lazy version of this file for the `data_type` field. | |||||
| """ | |||||
| if not os.path.exists(get_lazy_path(path)): | |||||
| return False | |||||
| contents = os.listdir(get_lazy_path(path)) | |||||
| if data_type not in contents: | |||||
| return False | |||||
| if data_type + '.len.pkl' not in contents: | |||||
| return False | |||||
| return True | |||||
| def get_scatter_path(path, scatter_rank): | |||||
| path = os.path.splitext(path)[0] + '.scatter' | |||||
| scatter_path = os.path.join(path, str(scatter_rank)) | |||||
| return scatter_path | |||||
| def exists_scatter(path, scatter_num=64, data_type='data'): | |||||
| for i in range(scatter_num): | |||||
| scatter_path = get_scatter_path(path, scatter_rank=i) | |||||
| if not exists_lazy(scatter_path, data_type=data_type): | |||||
| return False | |||||
| return True | |||||
| class LazyWriter: | |||||
| def __init__(self, | |||||
| path, | |||||
| data_type, | |||||
| is_array=False, | |||||
| array_data_type=np.int32): | |||||
| lazypath = get_lazy_path(path) | |||||
| if not os.path.exists(lazypath): | |||||
| os.makedirs(lazypath) | |||||
| self.datapath = os.path.join(lazypath, data_type) | |||||
| self.lenpath = os.path.join(lazypath, data_type + '.len.pkl') | |||||
| self.array_data_type = array_data_type | |||||
| self.output = open(self.datapath, 'wb') | |||||
| self.lengths = [] | |||||
| self.is_array = is_array | |||||
| @staticmethod | |||||
| def get_len_path(path, data_type): | |||||
| lazypath = get_lazy_path(path) | |||||
| return os.path.join(lazypath, data_type + '.len.pkl') | |||||
| def write(self, s): | |||||
| if isinstance(s, dict): | |||||
| s = s['text'] | |||||
| if self.is_array: | |||||
| encoded = np.array( | |||||
| s, dtype=self.array_data_type).tobytes(order='C') | |||||
| self.output.write(encoded) | |||||
| self.lengths.append(len(s)) | |||||
| else: | |||||
| encoded = s.encode('utf-8') | |||||
| self.output.write(encoded) | |||||
| self.lengths.append(len(encoded)) | |||||
| def close(self): | |||||
| self.output.close() | |||||
| with open(self.lenpath, 'wb') as f: | |||||
| pkl.dump(self.lengths, f) | |||||
| def split_strings(strings, start, chr_lens): | |||||
| """ | |||||
| Split strings based on string lengths and given start. | |||||
| """ | |||||
| return [ | |||||
| strings[i - start:j - start] | |||||
| for i, j in zip([start] + chr_lens[:-1], chr_lens) | |||||
| ] | |||||
| class ProcessorTokenizer: | |||||
| """ | |||||
| callable class that runs a preprocessing, as well as tokenization step, | |||||
| on input text. | |||||
| """ | |||||
| def __init__(self, tokenizer, process_fn=None): | |||||
| self.tokenizer = tokenizer | |||||
| self.process_fn = process_fn | |||||
| def __call__(self, string): | |||||
| if self.tokenizer is not None: | |||||
| string = self.tokenizer(string, process_fn=self.process_fn) | |||||
| elif self.process_fn is not None: | |||||
| string = self.process_fn(string) | |||||
| return string | |||||
| class LazyLoader(object): | |||||
| """ | |||||
| Arguments: | |||||
| path: path to directory where array entries are concatenated into one big string file | |||||
| and the .len file are located | |||||
| data_type (str): Some datsets have multiple fields that are stored in different paths. | |||||
| `data_type` specifies which of these fields to load in this class | |||||
| mem_map (boolean): Specifies whether to memory map file `path` | |||||
| map_fn (callable): Fetched strings are passed through map_fn before being returned. | |||||
| Example of lazy loader directory structure: | |||||
| file.json | |||||
| file.lazy/ | |||||
| data_type1 | |||||
| data_type1.len.pkl | |||||
| data_type2 | |||||
| data_type2.len.pkl | |||||
| """ | |||||
| def __init__(self, | |||||
| path, | |||||
| data_type='data', | |||||
| mem_map=False, | |||||
| map_fn=None, | |||||
| is_array=False, | |||||
| array_data_type=np.int32, | |||||
| load_memory=False, | |||||
| half_load=False): | |||||
| lazypath = get_lazy_path(path) | |||||
| datapath = os.path.join(lazypath, data_type) | |||||
| # get file where array entries are concatenated into one big string | |||||
| self._file = open(datapath, 'rb') | |||||
| self.file = self._file | |||||
| self.is_array = is_array | |||||
| self.array_data_type = array_data_type | |||||
| # memory map file if necessary | |||||
| lenpath = os.path.join(lazypath, data_type + '.len.pkl') | |||||
| self.lens = pkl.load(open(lenpath, 'rb')) | |||||
| if half_load: | |||||
| self.lens = self.lens[:2 * len(self.lens) // 3] | |||||
| self.ends = list(accumulate(self.lens)) | |||||
| self.dumb_ends = list(self.ends) | |||||
| self.mem_map = mem_map | |||||
| self.load_memory = load_memory | |||||
| if self.load_memory: | |||||
| data_type_size = np.dtype(self.array_data_type).itemsize | |||||
| if half_load: | |||||
| self.file = self.file.read(sum(self.lens) * data_type_size) | |||||
| else: | |||||
| self.file = self.file.read() | |||||
| self.file = np.ndarray( | |||||
| shape=(len(self.file) // data_type_size, ), | |||||
| dtype=array_data_type, | |||||
| buffer=self.file, | |||||
| order='C') | |||||
| elif self.mem_map: | |||||
| if is_array: | |||||
| if self.ends[-1] == 0: | |||||
| self.file = np.array([], dtype=array_data_type) | |||||
| else: | |||||
| self.file = np.memmap( | |||||
| self.file, dtype=array_data_type, mode='r', order='C') | |||||
| else: | |||||
| if self.ends[-1] == 0: | |||||
| self.file = bytearray() | |||||
| else: | |||||
| self.file = mmap.mmap( | |||||
| self.file.fileno(), 0, prot=mmap.PROT_READ) | |||||
| self.read_lock = Lock() | |||||
| self.process_fn = map_fn | |||||
| self.map_fn = map_fn | |||||
| self._tokenizer = None | |||||
| self.is_lazy = True | |||||
| def SetTokenizer(self, tokenizer): | |||||
| """ | |||||
| logic to set and remove (set to None) tokenizer. | |||||
| combines preprocessing/tokenization into one callable. | |||||
| """ | |||||
| if tokenizer is None: | |||||
| if not hasattr(self, '_tokenizer'): | |||||
| self._tokenizer = tokenizer | |||||
| else: | |||||
| self._tokenizer = tokenizer | |||||
| self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn) | |||||
| def GetTokenizer(self): | |||||
| return self._tokenizer | |||||
| def __getitem__(self, index): | |||||
| """ | |||||
| read file and splice strings based on string ending array `self.ends` | |||||
| """ | |||||
| if not isinstance(index, slice): | |||||
| if index == 0: | |||||
| start = 0 | |||||
| else: | |||||
| start = self.ends[index - 1] | |||||
| end = self.ends[index] | |||||
| rtn = self.file_read(start, end) | |||||
| if self.map_fn is not None: | |||||
| rtn = self.map_fn(rtn) | |||||
| else: | |||||
| # if slice, fetch strings with 1 diskread and then splice in memory | |||||
| chr_lens = self.ends[index] | |||||
| if index.start == 0 or index.start is None: | |||||
| start = 0 | |||||
| else: | |||||
| start = self.ends[index.start - 1] | |||||
| stop = chr_lens[-1] | |||||
| strings = self.file_read(start, stop) | |||||
| rtn = split_strings(strings, start, chr_lens) | |||||
| if self.map_fn is not None: | |||||
| rtn = [self.map_fn(s) for s in rtn] | |||||
| return rtn | |||||
| def __len__(self): | |||||
| return len(self.ends) | |||||
| def file_read(self, start=0, end=None): | |||||
| """read specified portion of file""" | |||||
| data_type_size = np.dtype(self.array_data_type).itemsize | |||||
| # atomic reads to avoid race conditions with multiprocess dataloader | |||||
| self.read_lock.acquire() | |||||
| if not self.mem_map and not self.load_memory: | |||||
| # seek to start of file read | |||||
| if self.is_array: | |||||
| start = start * data_type_size | |||||
| end = end * data_type_size if end is not None else None | |||||
| self.file.seek(start) | |||||
| # read to end of file if no end point provided | |||||
| if end is None: | |||||
| rtn = self.file.read() | |||||
| # else read amount needed to reach end point | |||||
| else: | |||||
| rtn = self.file.read(end - start) | |||||
| if self.is_array: | |||||
| rtn = np.ndarray( | |||||
| shape=(len(rtn) // data_type_size, ), | |||||
| dtype=self.array_data_type, | |||||
| buffer=rtn, | |||||
| order='C') | |||||
| else: | |||||
| rtn = rtn.decode('utf-8', 'ignore') | |||||
| else: | |||||
| rtn = self.file[start:end] | |||||
| if self.is_array: | |||||
| rtn = rtn.copy() | |||||
| else: | |||||
| rtn = rtn.decode('utf-8', 'strict') | |||||
| self.read_lock.release() | |||||
| # TODO: @raulp figure out mem map byte string bug | |||||
| # if mem map'd need to decode byte string to string | |||||
| # # rtn = str(rtn) | |||||
| # if self.mem_map: | |||||
| # rtn = rtn.decode('unicode_escape') | |||||
| return rtn | |||||
| @@ -0,0 +1,190 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """batch samplers that work with either random or sequential data samplers""" | |||||
| import math | |||||
| import os | |||||
| import sys | |||||
| import numpy as np | |||||
| import torch | |||||
| from torch.utils import data | |||||
| class RandomSampler(data.sampler.Sampler): | |||||
| r""" | |||||
| Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, | |||||
| but this class lets the user set an epoch like DistributedSampler | |||||
| Samples elements randomly. If without replacement, then sample from a shuffled dataset. | |||||
| If with replacement, then user can specify ``num_samples`` to draw. | |||||
| Arguments: | |||||
| data_source (Dataset): dataset to sample from | |||||
| num_samples (int): number of samples to draw, default=len(dataset) | |||||
| replacement (bool): samples are drawn with replacement if ``True``, default=False | |||||
| """ | |||||
| def __init__(self, data_source, replacement=False, num_samples=None): | |||||
| super(RandomSampler, self).__init__(data_source) | |||||
| self.data_source = data_source | |||||
| self.replacement = replacement | |||||
| self._num_samples = num_samples | |||||
| self.epoch = -1 | |||||
| if self._num_samples is not None and replacement is False: | |||||
| raise ValueError( | |||||
| 'With replacement=False, num_samples should not be specified, ' | |||||
| 'since a random permute will be performed.') | |||||
| if not isinstance(self.num_samples, int) or self.num_samples <= 0: | |||||
| raise ValueError('num_samples should be a positive integer ' | |||||
| 'value, but got num_samples={}'.format( | |||||
| self.num_samples)) | |||||
| if not isinstance(self.replacement, bool): | |||||
| raise ValueError('replacement should be a boolean value, but got ' | |||||
| 'replacement={}'.format(self.replacement)) | |||||
| @property | |||||
| def num_samples(self): | |||||
| # dataset size might change at runtime | |||||
| if self._num_samples is None: | |||||
| return len(self.data_source) | |||||
| return self._num_samples | |||||
| def __iter__(self): | |||||
| n = len(self.data_source) | |||||
| g = torch.Generator() | |||||
| if self.epoch >= 0: | |||||
| g.manual_seed(self.epoch) | |||||
| if self.replacement: | |||||
| for _ in range(self.num_samples // 32): | |||||
| yield from torch.randint( | |||||
| high=n, size=(32, ), dtype=torch.int64, | |||||
| generator=g).tolist() | |||||
| yield from torch.randint( | |||||
| high=n, | |||||
| size=(self.num_samples % 32, ), | |||||
| dtype=torch.int64, | |||||
| generator=g).tolist() | |||||
| else: | |||||
| yield from torch.randperm(n, generator=self.generator).tolist() | |||||
| def __len__(self): | |||||
| return self.num_samples | |||||
| def set_epoch(self, epoch): | |||||
| self.epoch = epoch | |||||
| class DistributedSequentialSampler(data.sampler.Sampler): | |||||
| def __init__(self, | |||||
| num_samples, | |||||
| train_iters, | |||||
| batch_size, | |||||
| rank=-1, | |||||
| world_size=2): | |||||
| super().__init__(num_samples) | |||||
| if rank == -1: | |||||
| rank = 0 | |||||
| world_size = 1 | |||||
| self.num_samples = num_samples | |||||
| self.rank = rank | |||||
| self.world_size = world_size | |||||
| self.start_iter = 0 | |||||
| self.train_iters = train_iters | |||||
| self.batch_size = batch_size | |||||
| self.batch_bias = [ | |||||
| i * (num_samples // batch_size) for i in range(batch_size) | |||||
| ] | |||||
| def __iter__(self): | |||||
| for idx in range(self.start_iter, self.train_iters * 10): | |||||
| batch = [(idx + bias) % self.num_samples | |||||
| for bias in self.batch_bias] | |||||
| tbatch = self._batch(batch) | |||||
| yield tbatch | |||||
| def __len__(self): | |||||
| return self.train_iters | |||||
| def _batch(self, batch): | |||||
| """extracts samples only pertaining to this worker's batch""" | |||||
| start = self.rank * self.batch_size // self.world_size | |||||
| end = (self.rank + 1) * self.batch_size // self.world_size | |||||
| return batch[start:end] | |||||
| class DistributedBatchSampler(data.sampler.BatchSampler): | |||||
| """ | |||||
| similar to normal implementation of distributed sampler, except implementation is at the | |||||
| batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary | |||||
| data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. | |||||
| """ | |||||
| def __init__(self, | |||||
| sampler, | |||||
| batch_size, | |||||
| drop_last, | |||||
| rank=-1, | |||||
| world_size=2, | |||||
| wrap_last=False, | |||||
| gradient_accumulation_steps=None): | |||||
| super(DistributedBatchSampler, self).__init__(sampler, batch_size, | |||||
| drop_last) | |||||
| if rank == -1: | |||||
| assert False, 'should not be here' | |||||
| self.rank = rank | |||||
| self.world_size = world_size | |||||
| self.sampler.wrap_around = 0 | |||||
| self.wrap_around = 0 | |||||
| self.wrap_last = wrap_last | |||||
| self.start_iter = 0 | |||||
| self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps # noqa | |||||
| def __iter__(self): | |||||
| batch = [] | |||||
| i = 0 | |||||
| for idx in self.data_iterator(self.sampler, wrap_around=False): | |||||
| batch.append(idx) | |||||
| if len(batch) == self.batch_size: | |||||
| tbatch = self._batch(batch) | |||||
| if i >= self.start_iter * self.effective_batch_size: | |||||
| yield tbatch | |||||
| self.start_iter = 0 | |||||
| i += len(batch) | |||||
| batch = [] | |||||
| batch_len = len(batch) | |||||
| if batch_len > 0 and not self.drop_last: | |||||
| if self.wrap_last: | |||||
| self.sampler.wrap_around -= (self.batch_size) | |||||
| self.wrap_around += (len(batch)) | |||||
| self.wrap_around %= self.batch_size | |||||
| yield self._batch(batch) | |||||
| if self.wrap_last: | |||||
| self.sampler.wrap_around += self.batch_size | |||||
| def data_iterator(self, _iter, wrap_around=False): | |||||
| """iterates through data and handles wrap around""" | |||||
| for i, idx in enumerate(_iter): | |||||
| if i < self.wrap_around % self.batch_size: | |||||
| continue | |||||
| if wrap_around: | |||||
| self.wrap_around += 1 | |||||
| self.wrap_around %= self.batch_size | |||||
| yield idx | |||||
| def _batch(self, batch): | |||||
| """extracts samples only pertaining to this worker's batch""" | |||||
| start = self.rank * self.batch_size // self.world_size | |||||
| end = (self.rank + 1) * self.batch_size // self.world_size | |||||
| return batch[start:end] | |||||
| @@ -0,0 +1,158 @@ | |||||
| # Modified by Zhipu.AI | |||||
| """ | |||||
| from https://github.com/openai/gpt-2/, changed for chinese | |||||
| """ | |||||
| import os # yapf: disable | |||||
| """ | |||||
| SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation | |||||
| systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements | |||||
| subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the | |||||
| extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end | |||||
| system that does not depend on language-specific pre/postprocessing. | |||||
| https://github.com/google/sentencepiece | |||||
| pip install sentencepiece | |||||
| or git clone https://github.com/google/sentencepiece.git | |||||
| python setup.py install | |||||
| """ | |||||
| def get_pairs(word): | |||||
| pairs = set() | |||||
| prev_char = word[0] | |||||
| for char in word[1:]: | |||||
| pairs.add((prev_char, char)) | |||||
| prev_char = char | |||||
| return pairs | |||||
| class Encoder: | |||||
| def __init__(self, encoder, bpe_merges): | |||||
| self.encoder = encoder | |||||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||||
| self.cache = {} | |||||
| self.max_len = 0 | |||||
| def bpe(self, token): | |||||
| if token in self.cache: | |||||
| return self.cache[token] | |||||
| word = tuple(token) | |||||
| pairs = get_pairs(word) | |||||
| if not pairs: | |||||
| return token | |||||
| while True: | |||||
| bigram = min( | |||||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||||
| if bigram not in self.bpe_ranks: | |||||
| break | |||||
| first, second = bigram | |||||
| new_word = [] | |||||
| i = 0 | |||||
| while i < len(word): | |||||
| try: | |||||
| j = word.index(first, i) | |||||
| new_word.extend(word[i:j]) | |||||
| i = j | |||||
| except: # noqa | |||||
| new_word.extend(word[i:]) | |||||
| break | |||||
| if word[i] == first and i < len(word) - 1 and word[ | |||||
| i + 1] == second: | |||||
| new_word.append(first + second) | |||||
| i += 2 | |||||
| else: | |||||
| new_word.append(word[i]) | |||||
| i += 1 | |||||
| new_word = tuple(new_word) | |||||
| word = new_word | |||||
| if len(word) == 1: | |||||
| break | |||||
| else: | |||||
| pairs = get_pairs(word) | |||||
| word = ' '.join(word) | |||||
| self.cache[token] = word | |||||
| return word | |||||
| def encode(self, text): | |||||
| return [self.encoder.get(token, 1) for token in self.tokenize(text)] | |||||
| def decode(self, tokens): | |||||
| text = ''.join([self.decoder[token] for token in tokens]) | |||||
| return text | |||||
| def tokenize(self, text): | |||||
| bpe_tokens = [] | |||||
| bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' ')) | |||||
| return bpe_tokens | |||||
| def convert_tokens_to_ids(self, tokens): | |||||
| return [self.encoder.get(token, 1) for token in tokens] | |||||
| class Encoder_SP: | |||||
| def __init__(self, model_path): | |||||
| import sentencepiece as spm | |||||
| self.sp = spm.SentencePieceProcessor() | |||||
| self.sp.Load(model_path) | |||||
| def encode(self, text): | |||||
| """ | |||||
| text="...." | |||||
| """ | |||||
| return self.sp.EncodeAsIds(text) | |||||
| def decode(self, tokens): | |||||
| """ | |||||
| tokens=[x1,x2,...] | |||||
| """ | |||||
| text = [int(token) for token in tokens] | |||||
| # print(text) | |||||
| return self.sp.DecodeIds(text) | |||||
| def tokenize(self, text): | |||||
| return self.sp.EncodeAsPieces(text) | |||||
| def convert_tokens_to_ids(self, tokens): | |||||
| return [self.sp.PieceToId(token) for token in tokens] | |||||
| def convert_token_to_id(self, token): | |||||
| return self.sp.PieceToId(token) | |||||
| def convert_id_to_token(self, idx): | |||||
| return self.sp.IdToPiece(idx) | |||||
| def get_encoder(encoder_file, bpe_file): | |||||
| import json | |||||
| filepath, filename = os.path.split(encoder_file) | |||||
| shotname, extension = os.path.splitext(filename) | |||||
| if ('.model' == extension) and (bpe_file == ''): | |||||
| return Encoder_SP(encoder_file) | |||||
| else: | |||||
| with open(encoder_file, 'r', encoding='utf-8') as f: | |||||
| encoder = json.load(f) | |||||
| with open(bpe_file, 'r', encoding='utf-8') as f: | |||||
| bpe_data = f.read() | |||||
| bpe_merges = [ | |||||
| tuple(merge_str.split()) | |||||
| for merge_str in bpe_data.split('\n')[1:-1] | |||||
| ] | |||||
| return Encoder( | |||||
| encoder=encoder, | |||||
| bpe_merges=bpe_merges, | |||||
| ) | |||||
| def from_pretrained(model_path): | |||||
| return get_encoder(model_path + '/tokenizer/mglm250k/mglm250k-uni.model', | |||||
| '') | |||||
| @@ -0,0 +1,359 @@ | |||||
| # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Tokenization classes for OpenAI GPT.""" | |||||
| from __future__ import (absolute_import, division, print_function, | |||||
| unicode_literals) | |||||
| import logging | |||||
| import os | |||||
| import sys | |||||
| from io import open | |||||
| import json | |||||
| import regex as re | |||||
| from .file_utils import cached_path | |||||
| try: | |||||
| from functools import lru_cache | |||||
| except ImportError: | |||||
| # Just a dummy decorator to get the checks to run on python2 | |||||
| # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. | |||||
| def lru_cache(): | |||||
| return lambda func: func | |||||
| logger = logging.getLogger(__name__) | |||||
| PRETRAINED_VOCAB_ARCHIVE_MAP = { | |||||
| 'gpt2': '.pytorch_pretrained_bert/gpt2-vocab.json', | |||||
| 'roberta': '.pytorch_pretrained_bert/roberta-vocab.json' | |||||
| } | |||||
| PRETRAINED_MERGES_ARCHIVE_MAP = { | |||||
| 'gpt2': '.pytorch_pretrained_bert/gpt2-merges.txt', | |||||
| 'roberta': '.pytorch_pretrained_bert/roberta-merges.txt' | |||||
| } | |||||
| PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { | |||||
| 'gpt2': 1024, | |||||
| } | |||||
| VOCAB_NAME = 'vocab.json' | |||||
| MERGES_NAME = 'merges.txt' | |||||
| SPECIAL_TOKENS_NAME = 'special_tokens.txt' | |||||
| @lru_cache() | |||||
| def bytes_to_unicode(): | |||||
| """ | |||||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||||
| The reversible bpe codes work on unicode strings. | |||||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||||
| """ | |||||
| _chr = unichr if sys.version_info[0] == 2 else chr | |||||
| bs = list(range(ord('!'), | |||||
| ord('~') + 1)) + list(range( | |||||
| ord('¡'), | |||||
| ord('¬') + 1)) + list(range(ord('®'), | |||||
| ord('ÿ') + 1)) | |||||
| cs = bs[:] | |||||
| n = 0 | |||||
| for b in range(2**8): | |||||
| if b not in bs: | |||||
| bs.append(b) | |||||
| cs.append(2**8 + n) | |||||
| n += 1 | |||||
| cs = [_chr(n) for n in cs] | |||||
| return dict(zip(bs, cs)) | |||||
| def get_pairs(word): | |||||
| """Return set of symbol pairs in a word. | |||||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||||
| """ | |||||
| pairs = set() | |||||
| prev_char = word[0] | |||||
| for char in word[1:]: | |||||
| pairs.add((prev_char, char)) | |||||
| prev_char = char | |||||
| return pairs | |||||
| class GPT2Tokenizer(object): | |||||
| """ | |||||
| GPT-2 BPE tokenizer. Peculiarities: | |||||
| - Byte-level BPE | |||||
| """ | |||||
| @classmethod | |||||
| def from_pretrained(cls, | |||||
| pretrained_model_name_or_path, | |||||
| cache_dir=None, | |||||
| *inputs, | |||||
| **kwargs): | |||||
| """ | |||||
| Instantiate a PreTrainedBertModel from a pre-trained model file. | |||||
| Download and cache the pre-trained model file if needed. | |||||
| """ | |||||
| if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: | |||||
| vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ | |||||
| pretrained_model_name_or_path] | |||||
| merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[ | |||||
| pretrained_model_name_or_path] | |||||
| special_tokens_file = None | |||||
| else: | |||||
| vocab_file = os.path.join(pretrained_model_name_or_path, | |||||
| VOCAB_NAME) | |||||
| merges_file = os.path.join(pretrained_model_name_or_path, | |||||
| MERGES_NAME) | |||||
| special_tokens_file = os.path.join(pretrained_model_name_or_path, | |||||
| SPECIAL_TOKENS_NAME) | |||||
| if not os.path.exists(special_tokens_file): | |||||
| special_tokens_file = None | |||||
| else: | |||||
| logger.info('loading special tokens file {}'.format( | |||||
| special_tokens_file)) | |||||
| # redirect to the cache, if necessary | |||||
| # try: | |||||
| # resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) | |||||
| # resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) | |||||
| # except EnvironmentError: | |||||
| # logger.error( | |||||
| # "Model name '{}' was not found in model name list ({}). " | |||||
| # "We assumed '{}' was a path or url but couldn't find files {} and {} " | |||||
| # "at this path or url.".format( | |||||
| # pretrained_model_name_or_path, | |||||
| # ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), | |||||
| # pretrained_model_name_or_path, | |||||
| # vocab_file, merges_file)) | |||||
| # return None | |||||
| # if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: | |||||
| # logger.info("loading vocabulary file {}".format(vocab_file)) | |||||
| # logger.info("loading merges file {}".format(merges_file)) | |||||
| # else: | |||||
| # logger.info("loading vocabulary file {} from cache at {}".format( | |||||
| # vocab_file, resolved_vocab_file)) | |||||
| # logger.info("loading merges file {} from cache at {}".format( | |||||
| # merges_file, resolved_merges_file)) | |||||
| resolved_vocab_file = vocab_file | |||||
| resolved_merges_file = merges_file | |||||
| logger.info('loading vocabulary file {}'.format(vocab_file)) | |||||
| logger.info('loading merges file {}'.format(merges_file)) | |||||
| if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: | |||||
| # if we're using a pretrained model, ensure the tokenizer wont index sequences longer | |||||
| # than the number of positional embeddings | |||||
| max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ | |||||
| pretrained_model_name_or_path] | |||||
| kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) | |||||
| # Instantiate tokenizer. | |||||
| if special_tokens_file and 'special_tokens' not in kwargs: | |||||
| special_tokens = open( | |||||
| special_tokens_file, encoding='utf-8').read().split('\n')[:-1] | |||||
| else: | |||||
| special_tokens = kwargs.pop('special_tokens', []) | |||||
| tokenizer = cls( | |||||
| resolved_vocab_file, | |||||
| resolved_merges_file, | |||||
| special_tokens=special_tokens, | |||||
| *inputs, | |||||
| **kwargs) | |||||
| return tokenizer | |||||
| def __init__(self, | |||||
| vocab_file, | |||||
| merges_file, | |||||
| errors='replace', | |||||
| special_tokens=None, | |||||
| max_len=None): | |||||
| self.max_len = max_len if max_len is not None else int(1e12) | |||||
| self.encoder = json.load(open(vocab_file)) | |||||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||||
| self.errors = errors # how to handle errors in decoding | |||||
| self.byte_encoder = bytes_to_unicode() | |||||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||||
| bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] | |||||
| bpe_merges = [tuple(merge.split()) for merge in bpe_data] | |||||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||||
| self.cache = {} | |||||
| # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions | |||||
| self.pat = re.compile( | |||||
| r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |||||
| ) | |||||
| self.special_tokens = {} | |||||
| self.special_tokens_decoder = {} | |||||
| self.set_special_tokens(special_tokens) | |||||
| def __len__(self): | |||||
| return len(self.encoder) + len(self.special_tokens) | |||||
| def set_special_tokens(self, special_tokens): | |||||
| """ Add a list of additional tokens to the encoder. | |||||
| The additional tokens are indexed starting from the last index of the | |||||
| current vocabulary in the order of the `special_tokens` list. | |||||
| """ | |||||
| if not special_tokens: | |||||
| self.special_tokens = {} | |||||
| self.special_tokens_decoder = {} | |||||
| return | |||||
| self.special_tokens = dict((tok, len(self.encoder) + i) | |||||
| for i, tok in enumerate(special_tokens)) | |||||
| self.special_tokens_decoder = { | |||||
| v: k | |||||
| for k, v in self.special_tokens.items() | |||||
| } | |||||
| logger.info('Special tokens {}'.format(self.special_tokens)) | |||||
| def bpe(self, token): | |||||
| if token in self.cache: | |||||
| return self.cache[token] | |||||
| word = tuple(token) | |||||
| pairs = get_pairs(word) | |||||
| if not pairs: | |||||
| return token | |||||
| while True: | |||||
| bigram = min( | |||||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||||
| if bigram not in self.bpe_ranks: | |||||
| break | |||||
| first, second = bigram | |||||
| new_word = [] | |||||
| i = 0 | |||||
| while i < len(word): | |||||
| try: | |||||
| j = word.index(first, i) | |||||
| new_word.extend(word[i:j]) | |||||
| i = j | |||||
| except: # noqa | |||||
| new_word.extend(word[i:]) | |||||
| break | |||||
| if word[i] == first and i < len(word) - 1 and word[ | |||||
| i + 1] == second: | |||||
| new_word.append(first + second) | |||||
| i += 2 | |||||
| else: | |||||
| new_word.append(word[i]) | |||||
| i += 1 | |||||
| new_word = tuple(new_word) | |||||
| word = new_word | |||||
| if len(word) == 1: | |||||
| break | |||||
| else: | |||||
| pairs = get_pairs(word) | |||||
| word = ' '.join(word) | |||||
| self.cache[token] = word | |||||
| return word | |||||
| def tokenize(self, text): | |||||
| """ Tokenize a string. """ | |||||
| bpe_tokens = [] | |||||
| for token in re.findall(self.pat, text): | |||||
| if sys.version_info[0] == 2: | |||||
| token = ''.join(self.byte_encoder[ord(b)] for b in token) | |||||
| else: | |||||
| token = ''.join(self.byte_encoder[b] | |||||
| for b in token.encode('utf-8')) | |||||
| bpe_tokens.extend( | |||||
| bpe_token for bpe_token in self.bpe(token).split(' ')) | |||||
| return bpe_tokens | |||||
| def convert_tokens_to_ids(self, tokens): | |||||
| """ Converts a sequence of tokens into ids using the vocab. """ | |||||
| ids = [] | |||||
| if isinstance(tokens, str) or (sys.version_info[0] == 2 | |||||
| and isinstance(tokens, unicode)): | |||||
| if tokens in self.special_tokens: | |||||
| return self.special_tokens[tokens] | |||||
| else: | |||||
| return self.encoder.get(tokens, 0) | |||||
| for token in tokens: | |||||
| if token in self.special_tokens: | |||||
| ids.append(self.special_tokens[token]) | |||||
| else: | |||||
| ids.append(self.encoder.get(token, 0)) | |||||
| if len(ids) > self.max_len: | |||||
| logger.warning( | |||||
| 'Token indices sequence length is longer than the specified maximum ' | |||||
| ' sequence length for this OpenAI GPT model ({} > {}). Running this' | |||||
| ' sequence through the model will result in indexing errors'. | |||||
| format(len(ids), self.max_len)) | |||||
| return ids | |||||
| def convert_ids_to_tokens(self, ids, skip_special_tokens=False): | |||||
| """Converts a sequence of ids in BPE tokens using the vocab.""" | |||||
| tokens = [] | |||||
| for i in ids: | |||||
| if i in self.special_tokens_decoder: | |||||
| if not skip_special_tokens: | |||||
| tokens.append(self.special_tokens_decoder[i]) | |||||
| else: | |||||
| tokens.append(self.decoder[i]) | |||||
| return tokens | |||||
| def encode(self, text): | |||||
| return self.convert_tokens_to_ids(self.tokenize(text)) | |||||
| def decode(self, tokens): | |||||
| text = ''.join([self.decoder[token] for token in tokens]) | |||||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||||
| 'utf-8', errors=self.errors) | |||||
| return text | |||||
| def save_vocabulary(self, vocab_path): | |||||
| """Save the tokenizer vocabulary and merge files to a directory.""" | |||||
| if not os.path.isdir(vocab_path): | |||||
| logger.error('Vocabulary path ({}) should be a directory'.format( | |||||
| vocab_path)) | |||||
| return | |||||
| vocab_file = os.path.join(vocab_path, VOCAB_NAME) | |||||
| merge_file = os.path.join(vocab_path, MERGES_NAME) | |||||
| special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) | |||||
| with open(vocab_file, 'w', encoding='utf-8') as f: | |||||
| f.write(json.dumps(self.encoder, ensure_ascii=False)) | |||||
| index = 0 | |||||
| with open(merge_file, 'w', encoding='utf-8') as writer: | |||||
| writer.write(u'#version: 0.2\n') | |||||
| for bpe_tokens, token_index in sorted( | |||||
| self.bpe_ranks.items(), key=lambda kv: kv[1]): | |||||
| if index != token_index: | |||||
| logger.warning( | |||||
| 'Saving vocabulary to {}: BPE merge indices are not consecutive.' | |||||
| ' Please check that the tokenizer is not corrupted!'. | |||||
| format(merge_file)) | |||||
| index = token_index | |||||
| writer.write(' '.join(bpe_tokens) + u'\n') | |||||
| index += 1 | |||||
| index = len(self.encoder) | |||||
| with open(special_tokens_file, 'w', encoding='utf-8') as writer: | |||||
| for token, token_index in sorted( | |||||
| self.special_tokens.items(), key=lambda kv: kv[1]): | |||||
| if index != token_index: | |||||
| logger.warning( | |||||
| 'Saving special tokens vocabulary to {}: BPE indices are not consecutive.' | |||||
| ' Please check that the tokenizer is not corrupted!'. | |||||
| format(special_tokens_file)) | |||||
| index = token_index | |||||
| writer.write(token + u'\n') | |||||
| index += 1 | |||||
| return vocab_file, merge_file, special_tokens_file | |||||
| @@ -0,0 +1,408 @@ | |||||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py""" # noqa | |||||
| from __future__ import (absolute_import, division, print_function, | |||||
| unicode_literals) | |||||
| import collections | |||||
| import logging | |||||
| import os | |||||
| import unicodedata | |||||
| from io import open | |||||
| from .file_utils import cached_path | |||||
| logger = logging.getLogger(__name__) | |||||
| PRETRAINED_VOCAB_ARCHIVE_MAP = { | |||||
| 'bert-base-uncased': | |||||
| '.pytorch_pretrained_bert/bert-base-uncased-vocab.txt', | |||||
| 'bert-large-uncased': | |||||
| '.pytorch_pretrained_bert/bert-large-uncased-vocab.txt', | |||||
| 'bert-base-cased': | |||||
| '.pytorch_pretrained_bert/bert-base-cased-vocab.txt', | |||||
| 'bert-large-cased': | |||||
| '.pytorch_pretrained_bert/bert-large-cased-vocab.txt', | |||||
| 'bert-base-multilingual-uncased': | |||||
| 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt', | |||||
| 'bert-base-multilingual-cased': | |||||
| 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt', | |||||
| 'bert-base-chinese': | |||||
| 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt', | |||||
| } | |||||
| PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { | |||||
| 'bert-base-uncased': 512, | |||||
| 'bert-large-uncased': 512, | |||||
| 'bert-base-cased': 512, | |||||
| 'bert-large-cased': 512, | |||||
| 'bert-base-multilingual-uncased': 512, | |||||
| 'bert-base-multilingual-cased': 512, | |||||
| 'bert-base-chinese': 512, | |||||
| } | |||||
| VOCAB_NAME = 'vocab.txt' | |||||
| def load_vocab(vocab_file): | |||||
| """Loads a vocabulary file into a dictionary.""" | |||||
| vocab = collections.OrderedDict() | |||||
| index = 0 | |||||
| with open(vocab_file, 'r', encoding='utf-8') as reader: | |||||
| while True: | |||||
| token = reader.readline() | |||||
| if not token: | |||||
| break | |||||
| token = token.strip() | |||||
| vocab[token] = index | |||||
| index += 1 | |||||
| return vocab | |||||
| def whitespace_tokenize(text): | |||||
| """Runs basic whitespace cleaning and splitting on a piece of text.""" | |||||
| text = text.strip() | |||||
| if not text: | |||||
| return [] | |||||
| tokens = text.split() | |||||
| return tokens | |||||
| class BertTokenizer(object): | |||||
| """Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||||
| def __init__(self, | |||||
| vocab_file, | |||||
| do_lower_case=True, | |||||
| max_len=None, | |||||
| do_basic_tokenize=True, | |||||
| never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): | |||||
| """Constructs a BertTokenizer. | |||||
| Args: | |||||
| vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||||
| do_lower_case: Whether to lower case the input | |||||
| Only has an effect when do_wordpiece_only=False | |||||
| do_basic_tokenize: Whether to do basic tokenization before wordpiece. | |||||
| max_len: An artificial maximum length to truncate tokenized sequences to; | |||||
| Effective maximum length is always the minimum of this | |||||
| value (if specified) and the underlying BERT model's | |||||
| sequence length. | |||||
| never_split: List of tokens which will never be split during tokenization. | |||||
| Only has an effect when do_wordpiece_only=False | |||||
| """ | |||||
| if not os.path.isfile(vocab_file): | |||||
| raise ValueError( | |||||
| "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |||||
| 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' | |||||
| .format(vocab_file)) | |||||
| self.vocab = load_vocab(vocab_file) | |||||
| self.ids_to_tokens = collections.OrderedDict([ | |||||
| (ids, tok) for tok, ids in self.vocab.items() | |||||
| ]) | |||||
| self.do_basic_tokenize = do_basic_tokenize | |||||
| if do_basic_tokenize: | |||||
| self.basic_tokenizer = BasicTokenizer( | |||||
| do_lower_case=do_lower_case, never_split=never_split) | |||||
| self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||||
| self.max_len = max_len if max_len is not None else int(1e12) | |||||
| def tokenize(self, text): | |||||
| if self.do_basic_tokenize: | |||||
| split_tokens = [] | |||||
| for token in self.basic_tokenizer.tokenize(text): | |||||
| for sub_token in self.wordpiece_tokenizer.tokenize(token): | |||||
| split_tokens.append(sub_token) | |||||
| else: | |||||
| split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||||
| return split_tokens | |||||
| def convert_tokens_to_ids(self, tokens): | |||||
| """Converts a sequence of tokens into ids using the vocab.""" | |||||
| ids = [] | |||||
| for token in tokens: | |||||
| ids.append(self.vocab[token]) | |||||
| if len(ids) > self.max_len: | |||||
| logger.warning( | |||||
| 'Token indices sequence length is longer than the specified maximum ' | |||||
| ' sequence length for this BERT model ({} > {}). Running this' | |||||
| ' sequence through BERT will result in indexing errors'.format( | |||||
| len(ids), self.max_len)) | |||||
| return ids | |||||
| def convert_ids_to_tokens(self, ids): | |||||
| """Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||||
| tokens = [] | |||||
| for i in ids: | |||||
| tokens.append(self.ids_to_tokens[i]) | |||||
| return tokens | |||||
| @classmethod | |||||
| def from_pretrained(cls, | |||||
| pretrained_model_name_or_path, | |||||
| cache_dir=None, | |||||
| *inputs, | |||||
| **kwargs): | |||||
| """ | |||||
| Instantiate a PreTrainedBertModel from a pre-trained model file. | |||||
| Download and cache the pre-trained model file if needed. | |||||
| """ | |||||
| if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: | |||||
| vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ | |||||
| pretrained_model_name_or_path] | |||||
| else: | |||||
| vocab_file = pretrained_model_name_or_path | |||||
| if os.path.isdir(vocab_file): | |||||
| vocab_file = os.path.join(vocab_file, VOCAB_NAME) | |||||
| # redirect to the cache, if necessary | |||||
| try: | |||||
| resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) | |||||
| except EnvironmentError: | |||||
| logger.error( | |||||
| "Model name '{}' was not found in model name list ({}). " | |||||
| "We assumed '{}' was a path or url but couldn't find any file " | |||||
| 'associated to this path or url.'.format( | |||||
| pretrained_model_name_or_path, | |||||
| ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), | |||||
| vocab_file)) | |||||
| return None | |||||
| if resolved_vocab_file == vocab_file: | |||||
| logger.info('loading vocabulary file {}'.format(vocab_file)) | |||||
| else: | |||||
| logger.info('loading vocabulary file {} from cache at {}'.format( | |||||
| vocab_file, resolved_vocab_file)) | |||||
| if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: | |||||
| # if we're using a pretrained model, ensure the tokenizer wont index sequences longer | |||||
| # than the number of positional embeddings | |||||
| max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ | |||||
| pretrained_model_name_or_path] | |||||
| kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) | |||||
| # Instantiate tokenizer. | |||||
| tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) | |||||
| return tokenizer | |||||
| class BasicTokenizer(object): | |||||
| """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||||
| def __init__(self, | |||||
| do_lower_case=True, | |||||
| never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): | |||||
| """Constructs a BasicTokenizer. | |||||
| Args: | |||||
| do_lower_case: Whether to lower case the input. | |||||
| """ | |||||
| self.do_lower_case = do_lower_case | |||||
| self.never_split = never_split | |||||
| def tokenize(self, text): | |||||
| """Tokenizes a piece of text.""" | |||||
| text = self._clean_text(text) | |||||
| # This was added on November 1st, 2018 for the multilingual and Chinese | |||||
| # models. This is also applied to the English models now, but it doesn't | |||||
| # matter since the English models were not trained on any Chinese data | |||||
| # and generally don't have any Chinese data in them (there are Chinese | |||||
| # characters in the vocabulary because Wikipedia does have some Chinese | |||||
| # words in the English Wikipedia.). | |||||
| text = self._tokenize_chinese_chars(text) | |||||
| orig_tokens = whitespace_tokenize(text) | |||||
| split_tokens = [] | |||||
| for token in orig_tokens: | |||||
| if self.do_lower_case and token not in self.never_split: | |||||
| token = token.lower() | |||||
| token = self._run_strip_accents(token) | |||||
| split_tokens.extend(self._run_split_on_punc(token)) | |||||
| output_tokens = whitespace_tokenize(' '.join(split_tokens)) | |||||
| return output_tokens | |||||
| def _run_strip_accents(self, text): | |||||
| """Strips accents from a piece of text.""" | |||||
| text = unicodedata.normalize('NFD', text) | |||||
| output = [] | |||||
| for char in text: | |||||
| cat = unicodedata.category(char) | |||||
| if cat == 'Mn': | |||||
| continue | |||||
| output.append(char) | |||||
| return ''.join(output) | |||||
| def _run_split_on_punc(self, text): | |||||
| """Splits punctuation on a piece of text.""" | |||||
| if text in self.never_split: | |||||
| return [text] | |||||
| chars = list(text) | |||||
| i = 0 | |||||
| start_new_word = True | |||||
| output = [] | |||||
| while i < len(chars): | |||||
| char = chars[i] | |||||
| if _is_punctuation(char): | |||||
| output.append([char]) | |||||
| start_new_word = True | |||||
| else: | |||||
| if start_new_word: | |||||
| output.append([]) | |||||
| start_new_word = False | |||||
| output[-1].append(char) | |||||
| i += 1 | |||||
| return [''.join(x) for x in output] | |||||
| def _tokenize_chinese_chars(self, text): | |||||
| """Adds whitespace around any CJK character.""" | |||||
| output = [] | |||||
| for char in text: | |||||
| cp = ord(char) | |||||
| if self._is_chinese_char(cp): | |||||
| output.append(' ') | |||||
| output.append(char) | |||||
| output.append(' ') | |||||
| else: | |||||
| output.append(char) | |||||
| return ''.join(output) | |||||
| def _is_chinese_char(self, cp): | |||||
| """Checks whether CP is the codepoint of a CJK character.""" | |||||
| # This defines a "chinese character" as anything in the CJK Unicode block: | |||||
| # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||||
| # | |||||
| # Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |||||
| # despite its name. The modern Korean Hangul alphabet is a different block, | |||||
| # as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||||
| # space-separated words, so they are not treated specially and handled | |||||
| # like the all of the other languages. | |||||
| if ((cp >= 0x4E00 and cp <= 0x9FFF) or # noqa | |||||
| (cp >= 0x3400 and cp <= 0x4DBF) or # noqa | |||||
| (cp >= 0x20000 and cp <= 0x2A6DF) or # noqa | |||||
| (cp >= 0x2A700 and cp <= 0x2B73F) or # noqa | |||||
| (cp >= 0x2B740 and cp <= 0x2B81F) or # noqa | |||||
| (cp >= 0x2B820 and cp <= 0x2CEAF) or # noqa | |||||
| (cp >= 0xF900 and cp <= 0xFAFF) or # noqa | |||||
| (cp >= 0x2F800 and cp <= 0x2FA1F)): # noqa | |||||
| return True | |||||
| return False | |||||
| def _clean_text(self, text): | |||||
| """Performs invalid character removal and whitespace cleanup on text.""" | |||||
| output = [] | |||||
| for char in text: | |||||
| cp = ord(char) | |||||
| if cp == 0 or cp == 0xfffd or _is_control(char): | |||||
| continue | |||||
| if _is_whitespace(char): | |||||
| output.append(' ') | |||||
| else: | |||||
| output.append(char) | |||||
| return ''.join(output) | |||||
| class WordpieceTokenizer(object): | |||||
| """Runs WordPiece tokenization.""" | |||||
| def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): | |||||
| self.vocab = vocab | |||||
| self.unk_token = unk_token | |||||
| self.max_input_chars_per_word = max_input_chars_per_word | |||||
| def tokenize(self, text): | |||||
| """Tokenizes a piece of text into its word pieces. | |||||
| This uses a greedy longest-match-first algorithm to perform tokenization | |||||
| using the given vocabulary. | |||||
| For example: | |||||
| input = "unaffable" | |||||
| output = ["un", "##aff", "##able"] | |||||
| Args: | |||||
| text: A single token or whitespace separated tokens. This should have | |||||
| already been passed through `BasicTokenizer`. | |||||
| Returns: | |||||
| A list of wordpiece tokens. | |||||
| """ | |||||
| output_tokens = [] | |||||
| for token in whitespace_tokenize(text): | |||||
| chars = list(token) | |||||
| if len(chars) > self.max_input_chars_per_word: | |||||
| output_tokens.append(self.unk_token) | |||||
| continue | |||||
| is_bad = False | |||||
| start = 0 | |||||
| sub_tokens = [] | |||||
| while start < len(chars): | |||||
| end = len(chars) | |||||
| cur_substr = None | |||||
| while start < end: | |||||
| substr = ''.join(chars[start:end]) | |||||
| if start > 0: | |||||
| substr = '##' + substr | |||||
| if substr in self.vocab: | |||||
| cur_substr = substr | |||||
| break | |||||
| end -= 1 | |||||
| if cur_substr is None: | |||||
| is_bad = True | |||||
| break | |||||
| sub_tokens.append(cur_substr) | |||||
| start = end | |||||
| if is_bad: | |||||
| output_tokens.append(self.unk_token) | |||||
| else: | |||||
| output_tokens.extend(sub_tokens) | |||||
| return output_tokens | |||||
| def _is_whitespace(char): | |||||
| """Checks whether `chars` is a whitespace character.""" | |||||
| # \t, \n, and \r are technically contorl characters but we treat them | |||||
| # as whitespace since they are generally considered as such. | |||||
| if char == ' ' or char == '\t' or char == '\n' or char == '\r': | |||||
| return True | |||||
| cat = unicodedata.category(char) | |||||
| if cat == 'Zs': | |||||
| return True | |||||
| return False | |||||
| def _is_control(char): | |||||
| """Checks whether `chars` is a control character.""" | |||||
| # These are technically control characters but we count them as whitespace | |||||
| # characters. | |||||
| if char == '\t' or char == '\n' or char == '\r': | |||||
| return False | |||||
| cat = unicodedata.category(char) | |||||
| if cat.startswith('C'): | |||||
| return True | |||||
| return False | |||||
| def _is_punctuation(char): | |||||
| """Checks whether `chars` is a punctuation character.""" | |||||
| cp = ord(char) | |||||
| # We treat all non-letter/number ASCII as punctuation. | |||||
| # Characters such as "^", "$", and "`" are not in the Unicode | |||||
| # Punctuation class but we treat them as punctuation anyways, for | |||||
| # consistency. | |||||
| if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) | |||||
| or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): | |||||
| return True | |||||
| cat = unicodedata.category(char) | |||||
| if cat.startswith('P'): | |||||
| return True | |||||
| return False | |||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from .fp16 import * # noqa | |||||
| from .fp16util import (BN_convert_float, FP16Model, clip_grad_norm, | |||||
| convert_module, convert_network, | |||||
| master_params_to_model_params, | |||||
| model_grads_to_master_grads, network_to_half, | |||||
| prep_param_lists, to_python_float, tofp16) | |||||
| from .loss_scaler import * # noqa | |||||
| @@ -0,0 +1,660 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Stable version of apex FP16 Optimizer""" | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||||
| from torch.autograd import Variable | |||||
| from torch.nn.parameter import Parameter | |||||
| from .fp16util import (clip_grad_norm, master_params_to_model_params, | |||||
| model_grads_to_master_grads) | |||||
| from .loss_scaler import DynamicLossScaler, LossScaler | |||||
| FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) | |||||
| HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) | |||||
| def conversion_helper(val, conversion): | |||||
| """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" | |||||
| if not isinstance(val, (tuple, list)): | |||||
| return conversion(val) | |||||
| rtn = [conversion_helper(v, conversion) for v in val] | |||||
| if isinstance(val, tuple): | |||||
| rtn = tuple(rtn) | |||||
| return rtn | |||||
| def fp32_to_fp16(val): | |||||
| """Convert fp32 `val` to fp16""" | |||||
| def half_conversion(val): | |||||
| val_typecheck = val | |||||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||||
| val_typecheck = val.data | |||||
| if isinstance(val_typecheck, FLOAT_TYPES): | |||||
| val = val.half() | |||||
| return val | |||||
| return conversion_helper(val, half_conversion) | |||||
| def fp16_to_fp32(val): | |||||
| """Convert fp16 `val` to fp32""" | |||||
| def float_conversion(val): | |||||
| val_typecheck = val | |||||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||||
| val_typecheck = val.data | |||||
| if isinstance(val_typecheck, HALF_TYPES): | |||||
| val = val.float() | |||||
| return val | |||||
| return conversion_helper(val, float_conversion) | |||||
| class FP16_Module(nn.Module): | |||||
| def __init__(self, module): | |||||
| super(FP16_Module, self).__init__() | |||||
| self.add_module('module', module.half()) | |||||
| def forward(self, *inputs, **kwargs): | |||||
| return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) | |||||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||||
| return self.module.named_parameters(prefix=prefix, recurse=recurse) | |||||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||||
| return self.module.state_dict(destination, prefix, keep_vars) | |||||
| def load_state_dict(self, state_dict, strict=True): | |||||
| return self.module.load_state_dict(state_dict, strict=strict) | |||||
| # TODO: Update overflow check + downscale to use Carl's fused kernel. | |||||
| class FP16_Optimizer(object): | |||||
| """ | |||||
| :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, | |||||
| and manage static or dynamic loss scaling and master weights in a manner transparent to the user. | |||||
| For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, | |||||
| and changing the call to ``backward``. | |||||
| Example:: | |||||
| model = torch.nn.Linear(D_in, D_out).cuda().half() | |||||
| optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||||
| # Name the FP16_Optimizer instance to replace the existing optimizer | |||||
| # (recommended but not required): | |||||
| optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||||
| ... | |||||
| # loss.backward() becomes: | |||||
| optimizer.backward(loss) | |||||
| ... | |||||
| Example with dynamic loss scaling:: | |||||
| ... | |||||
| optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) | |||||
| # optional arg to control dynamic loss scaling behavior | |||||
| # dynamic_loss_args={'scale_window' : 500}) | |||||
| # Usually, dynamic_loss_args is not necessary. | |||||
| Args: | |||||
| init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. | |||||
| static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. | |||||
| dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. | |||||
| dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. | |||||
| verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. | |||||
| ``init_optimizer`` is expected to have been constructed in the ordinary way. | |||||
| It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be | |||||
| named to replace ``init_optimizer``, for two reasons: | |||||
| First, it means that references to the same name | |||||
| later in the file will not have to change. | |||||
| Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to | |||||
| modify ``init_optimizer``. If you do choose a unique name for the new | |||||
| :class:`FP16_Optimizer` instance, you should only work with this new instance, | |||||
| because the preexisting optimizer might no longer behave as expected. | |||||
| ``init_optimizer`` may be any Pytorch optimizer. | |||||
| It may contain a mixture of fp16 and fp32 parameters organized into any number of | |||||
| ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will | |||||
| ingest these ``param_groups`` and remember them. | |||||
| Calls to :: | |||||
| loss.backward() | |||||
| must be replaced with :: | |||||
| optimizer.backward(loss) | |||||
| because :class:`FP16_Optimizer` requires ownership of the backward pass to implement | |||||
| loss scaling and copies to master gradients. | |||||
| .. note:: | |||||
| Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients | |||||
| are downscaled before being applied. This means that adjusting the loss scale, or using | |||||
| dynamic loss scaling, should not require retuning the learning rate or any other | |||||
| hyperparameters. | |||||
| **Advanced options** | |||||
| **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. | |||||
| See docstring for :attr:`step`. | |||||
| **Gradient clipping**: Use :attr:`clip_master_grads`. | |||||
| **Multiple losses**: If your model accumulates gradients from multiple losses, | |||||
| this can be made more efficient by supplying ``update_master_grads=False`` | |||||
| to :attr:`backward`. See docstring for :attr:`backward`. | |||||
| **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: | |||||
| print(optimizer.loss_scale) | |||||
| optimizer.loss_scale = new_loss_scale | |||||
| For static loss scaling, manually adjusting the loss scale over time is a reasonable | |||||
| thing to do. During later epochs, gradients may become smaller, and a | |||||
| higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss | |||||
| scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting | |||||
| the loss scale is not recommended. | |||||
| **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in | |||||
| Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` | |||||
| should still work as intended. | |||||
| """ # noqa | |||||
| def __init__(self, | |||||
| init_optimizer, | |||||
| static_loss_scale=1.0, | |||||
| dynamic_loss_scale=False, | |||||
| dynamic_loss_args=None, | |||||
| verbose=False): | |||||
| if not torch.cuda.is_available: | |||||
| raise SystemError('Cannot use fp16 without CUDA.') | |||||
| self.verbose = verbose | |||||
| self.optimizer = init_optimizer | |||||
| # init_state_dict sets up an alternative way to cast per-param state tensors. | |||||
| # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. | |||||
| # init_state_dict = init_optimizer.state_dict() | |||||
| self.fp16_groups = [] | |||||
| self.fp32_from_fp16_groups = [] | |||||
| self.fp32_from_fp32_groups = [] | |||||
| for i, param_group in enumerate(self.optimizer.param_groups): | |||||
| self.maybe_print( | |||||
| 'FP16_Optimizer processing param group {}:'.format(i)) | |||||
| fp16_params_this_group = [] | |||||
| fp32_params_this_group = [] | |||||
| fp32_from_fp16_params_this_group = [] | |||||
| for i, param in enumerate(param_group['params']): | |||||
| if param.requires_grad: | |||||
| if param.type() == 'torch.cuda.HalfTensor': | |||||
| self.maybe_print( | |||||
| 'FP16_Optimizer received torch.cuda.HalfTensor with {}' | |||||
| .format(param.size())) | |||||
| fp16_params_this_group.append(param) | |||||
| master_param = param.detach().clone().float() | |||||
| master_param.requires_grad = True | |||||
| # Copythe model parallel flag. | |||||
| master_param.model_parallel = param.model_parallel | |||||
| param_group['params'][i] = master_param | |||||
| fp32_from_fp16_params_this_group.append(master_param) | |||||
| # Reset existing state dict key to the new master param. | |||||
| # We still need to recast per-param state tensors, if any, to FP32. | |||||
| if param in self.optimizer.state: | |||||
| self.optimizer.state[ | |||||
| master_param] = self.optimizer.state.pop(param) | |||||
| elif param.type() == 'torch.cuda.FloatTensor': | |||||
| self.maybe_print( | |||||
| 'FP16_Optimizer received torch.cuda.FloatTensor with {}' | |||||
| .format(param.size())) | |||||
| fp32_params_this_group.append(param) | |||||
| param_group['params'][i] = param | |||||
| else: | |||||
| raise TypeError( | |||||
| 'Wrapped parameters must be either ' | |||||
| 'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' | |||||
| 'Received {}'.format(param.type())) | |||||
| self.fp16_groups.append(fp16_params_this_group) | |||||
| self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) | |||||
| self.fp32_from_fp32_groups.append(fp32_params_this_group) | |||||
| # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors | |||||
| self.optimizer.load_state_dict(self.optimizer.state_dict()) | |||||
| # alternative way to cast per-param state tensors: | |||||
| # self.optimizer.load_state_dict(init_state_dict) | |||||
| if dynamic_loss_scale: | |||||
| self.dynamic_loss_scale = True | |||||
| if dynamic_loss_args is not None: | |||||
| self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) | |||||
| else: | |||||
| self.loss_scaler = DynamicLossScaler() | |||||
| else: | |||||
| self.dynamic_loss_scale = False | |||||
| self.loss_scaler = LossScaler(static_loss_scale) | |||||
| self.overflow = False | |||||
| self.first_closure_call_this_step = True | |||||
| self.clip_grad_norm = clip_grad_norm | |||||
| def maybe_print(self, msg): | |||||
| if self.verbose: | |||||
| print(msg) | |||||
| def __getstate__(self): | |||||
| raise RuntimeError( | |||||
| 'FP16_Optimizer should be serialized using state_dict().') | |||||
| def __setstate__(self, state): | |||||
| raise RuntimeError( | |||||
| 'FP16_Optimizer should be deserialized using load_state_dict().') | |||||
| def zero_grad(self, set_grads_to_None=False): | |||||
| """ | |||||
| Zero fp32 and fp16 parameter grads. | |||||
| """ | |||||
| # In principle, only the .grad attributes of the model params need to be zeroed, | |||||
| # because gradients are copied into the FP32 master params. However, we zero | |||||
| # all gradients owned by the optimizer, just to be safe: | |||||
| for group in self.optimizer.param_groups: | |||||
| for p in group['params']: | |||||
| if set_grads_to_None: | |||||
| p.grad = None | |||||
| else: | |||||
| if p.grad is not None: | |||||
| p.grad.detach_() | |||||
| p.grad.zero_() | |||||
| # Zero fp16 gradients owned by the model: | |||||
| for fp16_group in self.fp16_groups: | |||||
| for param in fp16_group: | |||||
| if set_grads_to_None: | |||||
| param.grad = None | |||||
| else: | |||||
| if param.grad is not None: | |||||
| param.grad.detach_( | |||||
| ) # as in torch.optim.optimizer.zero_grad() | |||||
| param.grad.zero_() | |||||
| def _check_overflow(self): | |||||
| params = [] | |||||
| for group in self.fp16_groups: | |||||
| for param in group: | |||||
| params.append(param) | |||||
| for group in self.fp32_from_fp32_groups: | |||||
| for param in group: | |||||
| params.append(param) | |||||
| self.overflow = self.loss_scaler.has_overflow(params) | |||||
| def _update_scale(self, has_overflow=False): | |||||
| self.loss_scaler.update_scale(has_overflow) | |||||
| def _master_params_to_model_params(self): | |||||
| for fp16_group, fp32_from_fp16_group in zip( | |||||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||||
| master_params_to_model_params(fp16_group, fp32_from_fp16_group) | |||||
| def _model_params_to_master_params(self): | |||||
| for fp16_group, fp32_from_fp16_group in zip( | |||||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||||
| master_params_to_model_params(fp32_from_fp16_group, fp16_group) | |||||
| # To consider: Integrate distributed with this wrapper by registering a hook on each variable | |||||
| # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. | |||||
| def _model_grads_to_master_grads(self): | |||||
| for fp16_group, fp32_from_fp16_group in zip( | |||||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||||
| model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) | |||||
| def _downscale_master(self): | |||||
| if self.loss_scale != 1.0: | |||||
| for group in self.optimizer.param_groups: | |||||
| for param in group['params']: | |||||
| if param.grad is not None: | |||||
| param.grad.data.mul_(1. / self.loss_scale) | |||||
| def clip_master_grads(self, max_norm, norm_type=2): | |||||
| """ | |||||
| Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. | |||||
| Args: | |||||
| max_norm (float or int): max norm of the gradients | |||||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |||||
| infinity norm. | |||||
| Returns: | |||||
| Total norm of the current fp32 gradients (viewed as a single vector). | |||||
| .. warning:: | |||||
| Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). | |||||
| """ # noqa | |||||
| if not self.overflow: | |||||
| fp32_params = [] | |||||
| for param_group in self.optimizer.param_groups: | |||||
| for param in param_group['params']: | |||||
| fp32_params.append(param) | |||||
| return self.clip_grad_norm(fp32_params, max_norm, norm_type) | |||||
| else: | |||||
| return -1 | |||||
| def state_dict(self): | |||||
| """ | |||||
| Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. | |||||
| This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict | |||||
| of the contained Pytorch optimizer. | |||||
| Example:: | |||||
| checkpoint = {} | |||||
| checkpoint['model'] = model.state_dict() | |||||
| checkpoint['optimizer'] = optimizer.state_dict() | |||||
| torch.save(checkpoint, "saved.pth") | |||||
| """ | |||||
| state_dict = {} | |||||
| state_dict['loss_scaler'] = self.loss_scaler | |||||
| state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale | |||||
| state_dict['overflow'] = self.overflow | |||||
| state_dict[ | |||||
| 'first_closure_call_this_step'] = self.first_closure_call_this_step | |||||
| state_dict['optimizer_state_dict'] = self.optimizer.state_dict() | |||||
| state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups | |||||
| return state_dict | |||||
| def load_state_dict(self, state_dict): | |||||
| """ | |||||
| Loads a state_dict created by an earlier call to state_dict(). | |||||
| If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, | |||||
| whose parameters in turn came from ``model``, it is expected that the user | |||||
| will call ``model.load_state_dict()`` before | |||||
| ``fp16_optimizer_instance.load_state_dict()`` is called. | |||||
| Example:: | |||||
| model = torch.nn.Linear(D_in, D_out).cuda().half() | |||||
| optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||||
| optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||||
| ... | |||||
| checkpoint = torch.load("saved.pth") | |||||
| model.load_state_dict(checkpoint['model']) | |||||
| optimizer.load_state_dict(checkpoint['optimizer']) | |||||
| """ | |||||
| # I think it should actually be ok to reload the optimizer before the model. | |||||
| self.loss_scaler = state_dict['loss_scaler'] | |||||
| self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] | |||||
| self.overflow = state_dict['overflow'] | |||||
| self.first_closure_call_this_step = state_dict[ | |||||
| 'first_closure_call_this_step'] | |||||
| self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) | |||||
| # At this point, the optimizer's references to the model's fp32 parameters are up to date. | |||||
| # The optimizer's hyperparameters and internal buffers are also up to date. | |||||
| # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still | |||||
| # out of date. There are two options. | |||||
| # 1: Refresh the master params from the model's fp16 params. | |||||
| # This requires less storage but incurs precision loss. | |||||
| # 2: Save and restore the fp32 master copies separately. | |||||
| # We choose option 2. | |||||
| # | |||||
| # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device | |||||
| # of their associated parameters, because it's possible those buffers might not exist yet in | |||||
| # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been | |||||
| # constructed in the same way as the one whose state_dict we are loading, the same master params | |||||
| # are guaranteed to exist, so we can just copy_() from the saved master params. | |||||
| for current_group, saved_group in zip(self.fp32_from_fp16_groups, | |||||
| state_dict['fp32_from_fp16']): | |||||
| for current, saved in zip(current_group, saved_group): | |||||
| current.data.copy_(saved.data) | |||||
| def step(self, closure=None): # could add clip option. | |||||
| """ | |||||
| If no closure is supplied, :attr:`step` should be called after | |||||
| ``fp16_optimizer_obj.backward(loss)``. | |||||
| :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to | |||||
| :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params | |||||
| originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run | |||||
| another forward pass using their model. | |||||
| If a closure is supplied, :attr:`step` may be called without a prior call to | |||||
| :attr:`backward(loss)`. | |||||
| This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. | |||||
| However, the user should take care that any ``loss.backward()`` call within the closure | |||||
| has been replaced by ``fp16_optimizer_obj.backward(loss)``. | |||||
| Args: | |||||
| closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. | |||||
| Example with closure:: | |||||
| # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an | |||||
| # existing pytorch optimizer. | |||||
| for input, target in dataset: | |||||
| def closure(): | |||||
| optimizer.zero_grad() | |||||
| output = model(input) | |||||
| loss = loss_fn(output, target) | |||||
| # loss.backward() becomes: | |||||
| optimizer.backward(loss) | |||||
| return loss | |||||
| optimizer.step(closure) | |||||
| .. warning:: | |||||
| Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. | |||||
| .. _`ordinary Pytorch optimizer use`: | |||||
| http://pytorch.org/docs/master/optim.html#optimizer-step-closure | |||||
| """ # noqa | |||||
| scale = self.loss_scaler.loss_scale | |||||
| self._update_scale(self.overflow) | |||||
| if self.overflow: | |||||
| self.maybe_print( | |||||
| 'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' | |||||
| .format(scale, self.loss_scale)) | |||||
| return | |||||
| if closure is not None: | |||||
| retval = self._step_with_closure(closure) | |||||
| else: | |||||
| retval = self.optimizer.step() | |||||
| self._master_params_to_model_params() | |||||
| return retval | |||||
| def _step_with_closure(self, closure): | |||||
| def wrapped_closure(): | |||||
| # helpful for debugging | |||||
| # print("Calling wrapped_closure, first_closure_call_this_step = {}" | |||||
| # .format(self.first_closure_call_this_step)) | |||||
| if self.first_closure_call_this_step: | |||||
| # We expect that the fp16 params are initially fresh on entering self.step(), | |||||
| # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() | |||||
| # is called within self.optimizer.step(). | |||||
| self.first_closure_call_this_step = False | |||||
| else: | |||||
| # If self.optimizer.step() internally calls wrapped_closure more than once, | |||||
| # it may update the fp32 params after each call. However, self.optimizer | |||||
| # doesn't know about the fp16 params at all. If the fp32 params get updated, | |||||
| # we can't rely on self.optimizer to refresh the fp16 params. We need | |||||
| # to handle that manually: | |||||
| self._master_params_to_model_params() | |||||
| # Our API expects the user to give us ownership of the backward() call by | |||||
| # replacing all calls to loss.backward() with optimizer.backward(loss). | |||||
| # This requirement holds whether or not the call to backward() is made within a closure. | |||||
| # If the user is properly calling optimizer.backward(loss) within "closure," | |||||
| # calling closure() here will give the fp32 master params fresh gradients | |||||
| # for the optimizer to play with, so all wrapped_closure needs to do is call | |||||
| # closure() and return the loss. | |||||
| temp_loss = closure() | |||||
| while (self.overflow): | |||||
| scale = self.loss_scaler.loss_scale | |||||
| self._update_scale(self.overflow) | |||||
| self.maybe_print( | |||||
| 'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' | |||||
| 'reducing to {}'.format(scale, self.loss_scale)) | |||||
| temp_loss = closure() | |||||
| return temp_loss | |||||
| retval = self.optimizer.step(wrapped_closure) | |||||
| self.first_closure_call_this_step = True | |||||
| return retval | |||||
| def backward(self, loss, update_master_grads=True, retain_graph=False): | |||||
| """ | |||||
| :attr:`backward` performs the following conceptual steps: | |||||
| 1. fp32_loss = loss.float() (see first Note below) | |||||
| 2. scaled_loss = fp32_loss*loss_scale | |||||
| 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). | |||||
| 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. | |||||
| 5. Finally, master grads are divided by loss_scale. | |||||
| In this way, after :attr:`backward`, the master params have fresh gradients, | |||||
| and :attr:`step` may be called. | |||||
| .. note:: | |||||
| :attr:`backward` internally converts the loss to fp32 before applying the loss scale. | |||||
| This provides some additional safety against overflow if the user has supplied an | |||||
| fp16 loss value. | |||||
| However, for maximum overflow safety, the user should | |||||
| compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to | |||||
| :attr:`backward`. | |||||
| .. warning:: | |||||
| The gradients found in a model's leaves after the call to | |||||
| :attr:`backward` should not be regarded as valid in general, | |||||
| because it's possible | |||||
| they have been scaled (and in the case of dynamic loss scaling, | |||||
| the scale factor may change over time). | |||||
| If the user wants to inspect gradients after a call to :attr:`backward`, | |||||
| only the master gradients should be regarded as valid. These can be retrieved via | |||||
| :attr:`inspect_master_grad_data()`. | |||||
| Args: | |||||
| loss: The loss output by the user's model. loss may be either float or half (but see first Note above). | |||||
| update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. | |||||
| retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). | |||||
| Example:: | |||||
| # Ordinary operation: | |||||
| optimizer.backward(loss) | |||||
| # Naive operation with multiple losses (technically valid, but less efficient): | |||||
| # fp32 grads will be correct after the second call, but | |||||
| # the first call incurs an unnecessary fp16->fp32 grad copy. | |||||
| optimizer.backward(loss1) | |||||
| optimizer.backward(loss2) | |||||
| # More efficient way to handle multiple losses: | |||||
| # The fp16->fp32 grad copy is delayed until fp16 grads from all | |||||
| # losses have been accumulated. | |||||
| optimizer.backward(loss1, update_master_grads=False) | |||||
| optimizer.backward(loss2, update_master_grads=False) | |||||
| optimizer.update_master_grads() | |||||
| """ # noqa | |||||
| # To consider: try multiple backward passes using retain_grad=True to find | |||||
| # a loss scale that works. After you find a loss scale that works, do a final dummy | |||||
| # backward pass with retain_graph=False to tear down the graph. Doing this would avoid | |||||
| # discarding the iteration, but probably wouldn't improve overall efficiency. | |||||
| self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) | |||||
| if update_master_grads: | |||||
| self.update_master_grads() | |||||
| def update_master_grads(self): | |||||
| """ | |||||
| Copy the ``.grad`` attribute from stored references to fp16 parameters to | |||||
| the ``.grad`` attribute of the fp32 master parameters that are directly | |||||
| updated by the optimizer. :attr:`update_master_grads` only needs to be called if | |||||
| ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. | |||||
| """ # noqa | |||||
| if self.dynamic_loss_scale: | |||||
| self._check_overflow() | |||||
| if self.overflow: return # noqa | |||||
| self._model_grads_to_master_grads() | |||||
| self._downscale_master() | |||||
| def inspect_master_grad_data(self): | |||||
| """ | |||||
| When running with :class:`FP16_Optimizer`, | |||||
| ``.grad`` attributes of a model's fp16 leaves should not be | |||||
| regarded as truthful, because they might be scaled. | |||||
| After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, | |||||
| the fp32 master params' ``.grad`` | |||||
| attributes will contain valid gradients properly divided by the loss scale. However, | |||||
| because :class:`FP16_Optimizer` flattens some parameters, accessing them may be | |||||
| nonintuitive. :attr:`inspect_master_grad_data` | |||||
| allows those gradients to be viewed with shapes corresponding to their associated model leaves. | |||||
| Returns: | |||||
| List of lists (one list for each parameter group). The list for each parameter group | |||||
| is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. | |||||
| """ | |||||
| if self.overflow: | |||||
| print( | |||||
| 'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' | |||||
| 'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' | |||||
| ) | |||||
| return None | |||||
| else: | |||||
| # The optimizer owns only references to master params. | |||||
| master_grads_data = [] | |||||
| for param_group in self.optimizer.param_groups: | |||||
| master_grads_this_group = [] | |||||
| for param in param_group['params']: | |||||
| if param.grad is not None: | |||||
| master_grads_this_group.append(param.grad.data) | |||||
| else: | |||||
| master_grads_this_group.append(None) | |||||
| master_grads_data.append(master_grads_this_group) | |||||
| return master_grads_data | |||||
| # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" | |||||
| def _get_loss_scale(self): | |||||
| return self.loss_scaler.loss_scale | |||||
| def _set_loss_scale(self, value): | |||||
| self.loss_scaler.cur_scale = value | |||||
| loss_scale = property(_get_loss_scale, _set_loss_scale) | |||||
| # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" | |||||
| def _get_state(self): | |||||
| return self.optimizer.state | |||||
| def _set_state(self, value): | |||||
| self.optimizer.state = value | |||||
| state = property(_get_state, _set_state) | |||||
| # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" | |||||
| # (for example, to adjust the learning rate) | |||||
| def _get_param_groups(self): | |||||
| return self.optimizer.param_groups | |||||
| def _set_param_groups(self, value): | |||||
| self.optimizer.param_groups = value | |||||
| param_groups = property(_get_param_groups, _set_param_groups) | |||||
| @@ -0,0 +1,220 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||||
| from torch.autograd import Variable | |||||
| from modelscope.models.nlp.mglm import mpu | |||||
| class tofp16(nn.Module): | |||||
| """ | |||||
| Utility module that implements:: | |||||
| def forward(self, input): | |||||
| return input.half() | |||||
| """ | |||||
| def __init__(self): | |||||
| super(tofp16, self).__init__() | |||||
| def forward(self, input): | |||||
| return input.half() | |||||
| def BN_convert_float(module): | |||||
| """ | |||||
| Utility function for network_to_half(). | |||||
| Retained for legacy purposes. | |||||
| """ | |||||
| if isinstance( | |||||
| module, | |||||
| torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: | |||||
| module.float() | |||||
| for child in module.children(): | |||||
| BN_convert_float(child) | |||||
| return module | |||||
| def network_to_half(network): | |||||
| """ | |||||
| Convert model to half precision in a batchnorm-safe way. | |||||
| Retained for legacy purposes. It is recommended to use FP16Model. | |||||
| """ | |||||
| return nn.Sequential(tofp16(), BN_convert_float(network.half())) | |||||
| def convert_module(module, dtype): | |||||
| """ | |||||
| Converts a module's immediate parameters and buffers to dtype. | |||||
| """ | |||||
| for param in module.parameters(recurse=False): | |||||
| if param is not None: | |||||
| if param.data.dtype.is_floating_point: | |||||
| param.data = param.data.to(dtype=dtype) | |||||
| if param._grad is not None and param._grad.data.dtype.is_floating_point: | |||||
| param._grad.data = param._grad.data.to(dtype=dtype) | |||||
| for buf in module.buffers(recurse=False): | |||||
| if buf is not None and buf.data.dtype.is_floating_point: | |||||
| buf.data = buf.data.to(dtype=dtype) | |||||
| def convert_network(network, dtype): | |||||
| """ | |||||
| Converts a network's parameters and buffers to dtype. | |||||
| """ | |||||
| for module in network.modules(): | |||||
| if isinstance(module, torch.nn.modules.batchnorm._BatchNorm | |||||
| ) and module.affine is True: | |||||
| continue | |||||
| convert_module(module, dtype) | |||||
| return network | |||||
| class FP16Model(nn.Module): | |||||
| """ | |||||
| Convert model to half precision in a batchnorm-safe way. | |||||
| """ | |||||
| def __init__(self, network): | |||||
| super(FP16Model, self).__init__() | |||||
| self.network = convert_network(network, dtype=torch.half) | |||||
| def forward(self, *inputs): | |||||
| inputs = tuple(t.half() for t in inputs) | |||||
| return self.network(*inputs) | |||||
| def backwards_debug_hook(grad): | |||||
| raise RuntimeError( | |||||
| 'master_params recieved a gradient in the backward pass!') | |||||
| def prep_param_lists(model, flat_master=False): | |||||
| """ | |||||
| Creates a list of FP32 master parameters for a given model, as in | |||||
| `Training Neural Networks with Mixed Precision: Real Examples`_. | |||||
| Args: | |||||
| model (torch.nn.Module): Existing Pytorch model | |||||
| flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. | |||||
| Returns: | |||||
| A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. | |||||
| Example:: | |||||
| model_params, master_params = prep_param_lists(model) | |||||
| .. warning:: | |||||
| Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. | |||||
| .. _`Training Neural Networks with Mixed Precision: Real Examples`: | |||||
| http://on-demand.gputechconf.com/gtc/2018/video/S81012/ | |||||
| """ # noqa | |||||
| model_params = [ | |||||
| param for param in model.parameters() if param.requires_grad | |||||
| ] | |||||
| if flat_master: | |||||
| # Give the user some more useful error messages | |||||
| try: | |||||
| # flatten_dense_tensors returns a contiguous flat array. | |||||
| # http://pytorch.org/docs/master/_modules/torch/_utils.html | |||||
| master_params = _flatten_dense_tensors( | |||||
| [param.data for param in model_params]).float() | |||||
| except: # noqa | |||||
| print( | |||||
| 'Error in prep_param_lists: model may contain a mixture of parameters ' | |||||
| 'of different types. Use flat_master=False, or use F16_Optimizer.' | |||||
| ) | |||||
| raise | |||||
| master_params = torch.nn.Parameter(master_params) | |||||
| master_params.requires_grad = True | |||||
| # master_params.register_hook(backwards_debug_hook) | |||||
| if master_params.grad is None: | |||||
| master_params.grad = master_params.new(*master_params.size()) | |||||
| return model_params, [master_params] | |||||
| else: | |||||
| master_params = [ | |||||
| param.clone().float().detach() for param in model_params | |||||
| ] | |||||
| for param in master_params: | |||||
| param.requires_grad = True | |||||
| return model_params, master_params | |||||
| def model_grads_to_master_grads(model_params, | |||||
| master_params, | |||||
| flat_master=False): | |||||
| """ | |||||
| Copy model gradients to master gradients. | |||||
| Args: | |||||
| model_params: List of model parameters created by :func:`prep_param_lists`. | |||||
| master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. | |||||
| """ # noqa | |||||
| if flat_master: | |||||
| # The flattening may incur one more deep copy than is necessary. | |||||
| master_params[0].grad.data.copy_( | |||||
| _flatten_dense_tensors([p.grad.data for p in model_params])) | |||||
| else: | |||||
| for model, master in zip(model_params, master_params): | |||||
| if model.grad is not None: | |||||
| if master.grad is None: | |||||
| master.grad = Variable( | |||||
| master.data.new(*master.data.size())) | |||||
| master.grad.data.copy_(model.grad.data) | |||||
| else: | |||||
| master.grad = None | |||||
| def master_params_to_model_params(model_params, | |||||
| master_params, | |||||
| flat_master=False): | |||||
| """ | |||||
| Copy master parameters to model parameters. | |||||
| Args: | |||||
| model_params: List of model parameters created by :func:`prep_param_lists`. | |||||
| master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. | |||||
| """ # noqa | |||||
| if flat_master: | |||||
| for model, master in zip( | |||||
| model_params, | |||||
| _unflatten_dense_tensors(master_params[0].data, model_params)): | |||||
| model.data.copy_(master) | |||||
| else: | |||||
| for model, master in zip(model_params, master_params): | |||||
| model.data.copy_(master.data) | |||||
| # Backward compatibility fixes | |||||
| def to_python_float(t): | |||||
| if hasattr(t, 'item'): | |||||
| return t.item() | |||||
| else: | |||||
| return t[0] | |||||
| TORCH_MAJOR = int(torch.__version__.split('.')[0]) | |||||
| TORCH_MINOR = int(torch.__version__.split('.')[1]) | |||||
| clip_grad_norm = mpu.clip_grad_norm | |||||
| @@ -0,0 +1,245 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import torch | |||||
| from modelscope.models.nlp.mglm import mpu | |||||
| # item() is a recent addition, so this helps with backward compatibility. | |||||
| def to_python_float(t): | |||||
| if hasattr(t, 'item'): | |||||
| return t.item() | |||||
| else: | |||||
| return t[0] | |||||
| class LossScaler: | |||||
| """ | |||||
| Class that manages a static loss scale. This class is intended to interact with | |||||
| :class:`FP16_Optimizer`, and should not be directly manipulated by the user. | |||||
| Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to | |||||
| :class:`FP16_Optimizer`'s constructor. | |||||
| Args: | |||||
| scale (float, optional, default=1.0): The loss scale. | |||||
| """ | |||||
| def __init__(self, scale=1): | |||||
| self.cur_scale = scale | |||||
| # `params` is a list / generator of torch.Variable | |||||
| def has_overflow(self, params): | |||||
| return False | |||||
| # `x` is a torch.Tensor | |||||
| def _has_inf_or_nan(x): | |||||
| return False | |||||
| def update_scale(self, overflow): | |||||
| pass | |||||
| @property | |||||
| def loss_scale(self): | |||||
| return self.cur_scale | |||||
| def scale_gradient(self, module, grad_in, grad_out): | |||||
| return tuple(self.loss_scale * g for g in grad_in) | |||||
| def backward(self, loss, retain_graph=False): | |||||
| scaled_loss = loss * self.loss_scale | |||||
| scaled_loss.backward(retain_graph=retain_graph) | |||||
| class DynamicLossScaler: | |||||
| """ | |||||
| Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` | |||||
| indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of | |||||
| :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` | |||||
| operates, because the default options can be changed using the | |||||
| the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. | |||||
| Loss scaling is designed to combat the problem of underflowing gradients encountered at long | |||||
| times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss | |||||
| scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are | |||||
| encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has | |||||
| occurred. | |||||
| :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, | |||||
| and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. | |||||
| If a certain number of iterations occur without overflowing gradients detected, | |||||
| :class:`DynamicLossScaler` increases the loss scale once more. | |||||
| In this way :class:`DynamicLossScaler` attempts to "ride the edge" of | |||||
| always using the highest loss scale possible without incurring overflow. | |||||
| Args: | |||||
| init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` | |||||
| scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. | |||||
| scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. | |||||
| """ # noqa | |||||
| def __init__(self, | |||||
| init_scale=2**32, | |||||
| scale_factor=2., | |||||
| scale_window=1000, | |||||
| min_scale=1, | |||||
| delayed_shift=1, | |||||
| consecutive_hysteresis=False): | |||||
| self.cur_scale = init_scale | |||||
| self.cur_iter = 0 | |||||
| self.last_overflow_iter = -1 | |||||
| self.scale_factor = scale_factor | |||||
| self.scale_window = scale_window | |||||
| self.min_scale = min_scale | |||||
| self.delayed_shift = delayed_shift | |||||
| self.cur_hysteresis = delayed_shift | |||||
| self.consecutive_hysteresis = consecutive_hysteresis | |||||
| # `params` is a list / generator of torch.Variable | |||||
| def has_overflow_serial(self, params): | |||||
| for p in params: | |||||
| if p.grad is not None and DynamicLossScaler._has_inf_or_nan( | |||||
| p.grad.data): | |||||
| return True | |||||
| return False | |||||
| def has_overflow(self, params): | |||||
| overflow = self.has_overflow_serial(params) | |||||
| # Since each model parallel GPU carries only part of the model, | |||||
| # make sure overflow flag is synced across all the model parallel GPUs | |||||
| overflow_gpu = torch.cuda.ByteTensor([overflow]) | |||||
| torch.distributed.all_reduce( | |||||
| overflow_gpu, | |||||
| op=torch.distributed.ReduceOp.MAX, | |||||
| group=mpu.get_model_parallel_group()) | |||||
| overflow = overflow_gpu[0].item() | |||||
| return bool(overflow) | |||||
| # `x` is a torch.Tensor | |||||
| def _has_inf_or_nan(x): | |||||
| try: | |||||
| # if x is half, the .float() incurs an additional deep copy, but it's necessary if | |||||
| # Pytorch's .sum() creates a one-element tensor of the same type as x | |||||
| # (which is true for some recent version of pytorch). | |||||
| cpu_sum = float(x.float().sum()) | |||||
| # More efficient version that can be used if .sum() returns a Python scalar | |||||
| # cpu_sum = float(x.sum()) | |||||
| except RuntimeError as instance: | |||||
| # We want to check if inst is actually an overflow exception. | |||||
| # RuntimeError could come from a different error. | |||||
| # If so, we still want the exception to propagate. | |||||
| if 'value cannot be converted' not in instance.args[0]: | |||||
| raise | |||||
| return True | |||||
| else: | |||||
| if cpu_sum == float( | |||||
| 'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: | |||||
| return True | |||||
| return False | |||||
| # `overflow` is boolean indicating whether the gradient overflowed | |||||
| def update_scale(self, overflow): | |||||
| if not hasattr(self, 'min_scale'): | |||||
| self.min_scale = 1 | |||||
| if not hasattr(self, 'delayed_shift'): | |||||
| self.delayed_shift = 1 | |||||
| if not hasattr(self, 'cur_hysteresis'): | |||||
| self.cur_hysteresis = 1 | |||||
| if not hasattr(self, 'consecutive_hysteresis'): | |||||
| self.consecutive_hysteresis = True | |||||
| if overflow: | |||||
| # self.cur_scale /= self.scale_factor | |||||
| if self.delayed_shift == 1 or self.cur_hysteresis == 1: | |||||
| self.cur_scale = max(self.cur_scale / self.scale_factor, | |||||
| self.min_scale) | |||||
| else: | |||||
| self.cur_hysteresis -= 1 | |||||
| self.last_overflow_iter = self.cur_iter | |||||
| else: | |||||
| if self.consecutive_hysteresis: | |||||
| self.cur_hysteresis = self.delayed_shift | |||||
| if (self.cur_iter | |||||
| - self.last_overflow_iter) % self.scale_window == 0: | |||||
| if not self.consecutive_hysteresis: | |||||
| self.cur_hysteresis = self.delayed_shift | |||||
| self.cur_scale *= self.scale_factor | |||||
| self.cur_iter += 1 | |||||
| @property | |||||
| def loss_scale(self): | |||||
| return self.cur_scale | |||||
| def scale_gradient(self, module, grad_in, grad_out): | |||||
| return tuple(self.loss_scale * g for g in grad_in) | |||||
| def backward(self, loss, retain_graph=False): | |||||
| scaled_loss = loss * self.loss_scale | |||||
| scaled_loss.backward(retain_graph=retain_graph) | |||||
| ############################################################## | |||||
| # Example usage below here -- assuming it's in a separate file | |||||
| ############################################################## | |||||
| """ | |||||
| TO-DO separate out into an example. | |||||
| if __name__ == "__main__": | |||||
| import torch | |||||
| from torch.autograd import Variable | |||||
| from dynamic_loss_scaler import DynamicLossScaler | |||||
| # N is batch size; D_in is input dimension; | |||||
| # H is hidden dimension; D_out is output dimension. | |||||
| N, D_in, H, D_out = 64, 1000, 100, 10 | |||||
| # Create random Tensors to hold inputs and outputs, and wrap them in Variables. | |||||
| x = Variable(torch.randn(N, D_in), requires_grad=False) | |||||
| y = Variable(torch.randn(N, D_out), requires_grad=False) | |||||
| w1 = Variable(torch.randn(D_in, H), requires_grad=True) | |||||
| w2 = Variable(torch.randn(H, D_out), requires_grad=True) | |||||
| parameters = [w1, w2] | |||||
| learning_rate = 1e-6 | |||||
| optimizer = torch.optim.SGD(parameters, lr=learning_rate) | |||||
| loss_scaler = DynamicLossScaler() | |||||
| for t in range(500): | |||||
| y_pred = x.mm(w1).clamp(min=0).mm(w2) | |||||
| loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale | |||||
| print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) | |||||
| print('Iter {} scaled loss: {}'.format(t, loss.data[0])) | |||||
| print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) | |||||
| # Run backprop | |||||
| optimizer.zero_grad() | |||||
| loss.backward() | |||||
| # Check for overflow | |||||
| has_overflow = DynamicLossScaler.has_overflow(parameters) | |||||
| # If no overflow, unscale grad and update as usual | |||||
| if not has_overflow: | |||||
| for param in parameters: | |||||
| param.grad.data.mul_(1. / loss_scaler.loss_scale) | |||||
| optimizer.step() | |||||
| # Otherwise, don't do anything -- ie, skip iteration | |||||
| else: | |||||
| print('OVERFLOW!') | |||||
| # Update loss scale for next iteration | |||||
| loss_scaler.update_scale(has_overflow) | |||||
| """ | |||||
| @@ -0,0 +1,483 @@ | |||||
| # Copyright 2020 The HuggingFace Inc. team | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from abc import ABC, abstractmethod | |||||
| from collections import UserDict | |||||
| from typing import Iterable, List, Optional, Tuple | |||||
| import torch | |||||
| PROCESS_INPUTS_DOCSTRING = r""" | |||||
| Args: | |||||
| input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): | |||||
| Indices of input sequence tokens in the vocabulary. | |||||
| Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See | |||||
| :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||||
| details. | |||||
| `What are input IDs? <../glossary.html#input-ids>`__ | |||||
| next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |||||
| Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses. | |||||
| next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |||||
| :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses. | |||||
| next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |||||
| Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond. | |||||
| pad_token_id (:obj:`int`, `optional`): | |||||
| The id of the `padding` token. | |||||
| eos_token_id (:obj:`int`, `optional`): | |||||
| The id of the `end-of-sequence` token. | |||||
| Return: | |||||
| :obj:`UserDict`: A dictionary composed of the fields as defined above: | |||||
| - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated | |||||
| scores of all non-finished beams. | |||||
| - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens | |||||
| to be added to the non-finished beam_hypotheses. | |||||
| - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices | |||||
| indicating to which beam the next tokens shall be added. | |||||
| """ | |||||
| FINALIZE_INPUTS_DOCSTRING = r""" | |||||
| Args: | |||||
| input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): | |||||
| Indices of input sequence tokens in the vocabulary. | |||||
| Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See | |||||
| :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||||
| details. | |||||
| `What are input IDs? <../glossary.html#input-ids>`__ | |||||
| final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |||||
| The final scores of all non-finished beams. | |||||
| final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |||||
| The last tokens to be added to the non-finished beam_hypotheses. | |||||
| final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |||||
| The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added. | |||||
| pad_token_id (:obj:`int`, `optional`): | |||||
| The id of the `padding` token. | |||||
| eos_token_id (:obj:`int`, `optional`): | |||||
| The id of the `end-of-sequence` token. | |||||
| Return: | |||||
| :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated | |||||
| sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all | |||||
| batches finished early due to the :obj:`eos_token_id`. | |||||
| """ | |||||
| class BeamScorer(ABC): | |||||
| """ | |||||
| Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and | |||||
| :meth:`~transformers.PretrainedModel.beam_sample`. | |||||
| """ | |||||
| @abstractmethod | |||||
| def process(self, input_ids: torch.LongTensor, | |||||
| next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, | |||||
| next_indices: torch.LongTensor, | |||||
| **kwargs) -> Tuple[torch.Tensor]: | |||||
| raise NotImplementedError('This is an abstract method.') | |||||
| @abstractmethod | |||||
| def finalize(self, input_ids: torch.LongTensor, | |||||
| next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, | |||||
| next_indices: torch.LongTensor, **kwargs) -> torch.LongTensor: | |||||
| raise NotImplementedError('This is an abstract method.') | |||||
| class BeamSearchScorer(BeamScorer): | |||||
| r""" | |||||
| :class:`transformers.BeamScorer` implementing standard beam search decoding. | |||||
| Adapted in part from `Facebook's XLM beam search code | |||||
| <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. | |||||
| Args: | |||||
| batch_size (:obj:`int`): | |||||
| Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. | |||||
| max_length (:obj:`int`): | |||||
| The maximum length of the sequence to be generated. | |||||
| num_beams (:obj:`int`): | |||||
| Number of beams for beam search. | |||||
| device (:obj:`torch.device`): | |||||
| Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of | |||||
| :obj:`BeamSearchScorer` will be allocated. | |||||
| length_penalty (:obj:`float`, `optional`, defaults to 1.0): | |||||
| Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the | |||||
| model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer | |||||
| sequences. | |||||
| do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
| Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. | |||||
| num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): | |||||
| The number of beam hypotheses that shall be returned upon calling | |||||
| :meth:`~transformer.BeamSearchScorer.finalize`. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| batch_size: int, | |||||
| max_length: int, | |||||
| num_beams: int, | |||||
| device: torch.device, | |||||
| length_penalty: Optional[float] = 1.0, | |||||
| do_early_stopping: Optional[bool] = False, | |||||
| num_beam_hyps_to_keep: Optional[int] = 1, | |||||
| ): | |||||
| self.max_length = max_length | |||||
| self.num_beams = num_beams | |||||
| self.device = device | |||||
| self.length_penalty = length_penalty | |||||
| self.do_early_stopping = do_early_stopping | |||||
| self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | |||||
| self._is_init = False | |||||
| self._beam_hyps = [ | |||||
| BeamHypotheses( | |||||
| num_beams=self.num_beams, | |||||
| max_length=self.max_length, | |||||
| length_penalty=self.length_penalty, | |||||
| early_stopping=self.do_early_stopping, | |||||
| ) for _ in range(batch_size) | |||||
| ] | |||||
| self._done = torch.tensor([False for _ in range(batch_size)], | |||||
| dtype=torch.bool, | |||||
| device=self.device) | |||||
| # if not isinstance(num_beams, int) or num_beams <= 1: | |||||
| # raise ValueError( | |||||
| # ) | |||||
| @property | |||||
| def is_done(self) -> bool: | |||||
| return self._done.all() | |||||
| def process(self, | |||||
| input_ids: torch.LongTensor, | |||||
| next_scores: torch.FloatTensor, | |||||
| next_tokens: torch.LongTensor, | |||||
| next_indices: torch.LongTensor, | |||||
| pad_token_id: Optional[int] = None, | |||||
| eos_token_id: Optional[int] = None, | |||||
| mems=None) -> Tuple[torch.Tensor]: | |||||
| cur_len = input_ids.shape[-1] | |||||
| batch_size = len(self._beam_hyps) | |||||
| assert batch_size == (input_ids.shape[0] // self.num_beams) | |||||
| if isinstance(eos_token_id, int): | |||||
| eos_token_id = [eos_token_id] | |||||
| device = next_scores.device | |||||
| next_beam_scores = torch.zeros((batch_size, self.num_beams), | |||||
| dtype=next_scores.dtype, | |||||
| device=device) | |||||
| next_beam_tokens = torch.zeros((batch_size, self.num_beams), | |||||
| dtype=next_tokens.dtype, | |||||
| device=device) | |||||
| next_beam_indices = torch.zeros((batch_size, self.num_beams), | |||||
| dtype=next_indices.dtype, | |||||
| device=device) | |||||
| for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |||||
| if self._done[batch_idx]: | |||||
| assert ( | |||||
| len(beam_hyp) >= self.num_beams | |||||
| ), 'Batch can only be done if at least {} beams have been generated'.format( | |||||
| self.num_beams) | |||||
| assert ( | |||||
| eos_token_id is not None and pad_token_id is not None | |||||
| ), 'generated beams >= num_beams -> eos_token_id and pad_token have to be defined' | |||||
| # pad the batch | |||||
| next_beam_scores[batch_idx, :] = 0 | |||||
| next_beam_tokens[batch_idx, :] = pad_token_id | |||||
| next_beam_indices[batch_idx, :] = 0 | |||||
| continue | |||||
| # next tokens for this sentence | |||||
| beam_idx = 0 | |||||
| for beam_token_rank, (next_token, next_score, | |||||
| next_index) in enumerate( | |||||
| zip(next_tokens[batch_idx], | |||||
| next_scores[batch_idx], | |||||
| next_indices[batch_idx])): | |||||
| batch_beam_idx = batch_idx * self.num_beams + next_index | |||||
| # add to generated hypotheses if end of sentence | |||||
| if (eos_token_id is not None) and (next_token.item() | |||||
| in eos_token_id): | |||||
| # if beam_token does not belong to top num_beams tokens, it should not be added | |||||
| is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams | |||||
| if is_beam_token_worse_than_top_num_beams: | |||||
| continue | |||||
| beam_hyp.add( | |||||
| input_ids[batch_beam_idx].clone(), | |||||
| next_score.item(), | |||||
| mems=[mem[[next_index.item()]] | |||||
| for mem in mems] if mems else None) | |||||
| else: | |||||
| # add next predicted token since it is not eos_token | |||||
| next_beam_scores[batch_idx, beam_idx] = next_score | |||||
| next_beam_tokens[batch_idx, beam_idx] = next_token | |||||
| next_beam_indices[batch_idx, beam_idx] = batch_beam_idx | |||||
| beam_idx += 1 | |||||
| # once the beam for next step is full, don't add more tokens to it. | |||||
| if beam_idx == self.num_beams: | |||||
| break | |||||
| if beam_idx < self.num_beams: | |||||
| raise ValueError( | |||||
| f'At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected.' # noqa | |||||
| ) # noqa | |||||
| # Check if we are done so that we can save a pad step if all(done) | |||||
| self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( | |||||
| next_scores[batch_idx].max().item(), cur_len) | |||||
| return UserDict({ | |||||
| 'next_beam_scores': next_beam_scores.view(-1), | |||||
| 'next_beam_tokens': next_beam_tokens.view(-1), | |||||
| 'next_beam_indices': next_beam_indices.view(-1), | |||||
| }) | |||||
| def finalize(self, | |||||
| input_ids: torch.LongTensor, | |||||
| final_beam_scores: torch.FloatTensor, | |||||
| final_beam_tokens: torch.LongTensor, | |||||
| final_beam_indices: torch.LongTensor, | |||||
| pad_token_id: Optional[int] = None, | |||||
| eos_token_id: Optional[int] = None, | |||||
| mems=None) -> Tuple[torch.LongTensor, List[torch.Tensor]]: | |||||
| batch_size = len(self._beam_hyps) | |||||
| # finalize all open beam hypotheses and add to generated hypotheses | |||||
| for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |||||
| if self._done[batch_idx]: | |||||
| continue | |||||
| # need to add best num_beams hypotheses to generated hyps | |||||
| for beam_id in range(self.num_beams): | |||||
| batch_beam_idx = batch_idx * self.num_beams + beam_id | |||||
| final_score = final_beam_scores[batch_beam_idx].item() | |||||
| final_tokens = input_ids[batch_beam_idx] | |||||
| beam_hyp.add( | |||||
| final_tokens, | |||||
| final_score, | |||||
| mems=[mem[[batch_beam_idx]] | |||||
| for mem in mems] if mems else None) | |||||
| # select the best hypotheses | |||||
| sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) | |||||
| best = [] | |||||
| # retrieve best hypotheses | |||||
| for i, beam_hyp in enumerate(self._beam_hyps): | |||||
| sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) | |||||
| for j in range(self.num_beam_hyps_to_keep): | |||||
| best_hyp, mems = sorted_hyps.pop()[1:] | |||||
| sent_lengths[self.num_beam_hyps_to_keep * i | |||||
| + j] = len(best_hyp) | |||||
| best.append((best_hyp, mems)) | |||||
| # prepare for adding eos | |||||
| sent_max_len = min(sent_lengths.max().item(), self.max_length) | |||||
| decoded: torch.LongTensor = input_ids.new( | |||||
| batch_size * self.num_beam_hyps_to_keep, sent_max_len) | |||||
| # shorter batches are padded if needed | |||||
| if sent_lengths.min().item() != sent_lengths.max().item(): | |||||
| assert pad_token_id is not None, '`pad_token_id` has to be defined' | |||||
| decoded.fill_(pad_token_id) | |||||
| # fill with hypotheses and eos_token_id if the latter fits in | |||||
| mems = [] | |||||
| for i, (hypo, mem) in enumerate(best): | |||||
| decoded[i, :sent_lengths[i]] = hypo | |||||
| if sent_lengths[i] < sent_max_len: | |||||
| decoded[i, sent_lengths[i]] = eos_token_id | |||||
| mems.append(mem) | |||||
| mems = [ | |||||
| torch.cat([mem[i] for mem in mems], dim=0) | |||||
| for i in range(len(mems[0])) | |||||
| ] if mems and mems[0] else None | |||||
| return decoded, mems | |||||
| class BeamHypotheses: | |||||
| def __init__(self, num_beams: int, max_length: int, length_penalty: float, | |||||
| early_stopping: bool): | |||||
| """ | |||||
| Initialize n-best list of hypotheses. | |||||
| """ | |||||
| self.max_length = max_length - 1 # ignoring bos_token | |||||
| self.length_penalty = length_penalty | |||||
| self.early_stopping = early_stopping | |||||
| self.num_beams = num_beams | |||||
| self.beams = [] | |||||
| self.worst_score = 1e9 | |||||
| def __len__(self): | |||||
| """ | |||||
| Number of hypotheses in the list. | |||||
| """ | |||||
| return len(self.beams) | |||||
| def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None): | |||||
| """ | |||||
| Add a new hypothesis to the list. | |||||
| """ | |||||
| score = sum_logprobs / (max(hyp.shape[-1], 1)**self.length_penalty) | |||||
| if len(self) < self.num_beams or score > self.worst_score: | |||||
| self.beams.append((score, hyp, mems)) | |||||
| if len(self) > self.num_beams: | |||||
| sorted_next_scores = sorted([ | |||||
| (s, idx) for idx, (s, _, _) in enumerate(self.beams) | |||||
| ]) | |||||
| del self.beams[sorted_next_scores[0][1]] | |||||
| self.worst_score = sorted_next_scores[1][0] | |||||
| else: | |||||
| self.worst_score = min(score, self.worst_score) | |||||
| def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: | |||||
| """ | |||||
| If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst | |||||
| one in the heap, then we are done with this sentence. | |||||
| """ | |||||
| if len(self) < self.num_beams: | |||||
| return False | |||||
| elif self.early_stopping: | |||||
| return True | |||||
| else: | |||||
| cur_score = best_sum_logprobs / cur_len**self.length_penalty | |||||
| ret = self.worst_score >= cur_score | |||||
| return ret | |||||
| class LogitsProcessor(ABC): | |||||
| """Abstract base class for all logit processors that can be applied during generation.""" | |||||
| def __call__(self, input_ids: torch.LongTensor, | |||||
| scores: torch.FloatTensor) -> torch.FloatTensor: | |||||
| """Torch method for processing logits.""" | |||||
| raise NotImplementedError( | |||||
| f'{self.__class__} is an abstract class. Only classes inheriting this class can be called.' | |||||
| ) | |||||
| class LogitsProcessorList(list): | |||||
| """ | |||||
| This class can be used to create a list of :class:`~transformers.LogitsProcessor` or | |||||
| :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from | |||||
| list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or | |||||
| :class:`~transformers.LogitsProcessor` to the inputs. | |||||
| """ | |||||
| def __call__(self, input_ids: torch.LongTensor, | |||||
| scores: torch.FloatTensor) -> torch.FloatTensor: | |||||
| for processor in self: | |||||
| scores = processor(input_ids, scores) | |||||
| return scores | |||||
| class MinLengthLogitsProcessor(LogitsProcessor): | |||||
| r""" | |||||
| :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. | |||||
| Args: | |||||
| min_length (:obj:`int`): | |||||
| The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. | |||||
| eos_token_id (:obj:`int`): | |||||
| The id of the `end-of-sequence` token. | |||||
| """ | |||||
| def __init__(self, min_length: int, eos_token_id: int): | |||||
| if not isinstance(min_length, int) or min_length < 0: | |||||
| raise ValueError( | |||||
| f'`min_length` has to be a positive integer, but is {min_length}' | |||||
| ) | |||||
| if not isinstance(eos_token_id, int) or eos_token_id < 0: | |||||
| raise ValueError( | |||||
| f'`eos_token_id` has to be a positive integer, but is {eos_token_id}' | |||||
| ) | |||||
| self.min_length = min_length | |||||
| self.eos_token_id = eos_token_id | |||||
| def __call__(self, input_ids: torch.LongTensor, | |||||
| scores: torch.FloatTensor) -> torch.FloatTensor: | |||||
| cur_len = input_ids.shape[-1] | |||||
| if cur_len < self.min_length: | |||||
| scores[:, self.eos_token_id] = -float('inf') | |||||
| return scores | |||||
| class NoRepeatNGramLogitsProcessor(LogitsProcessor): | |||||
| r""" | |||||
| :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq | |||||
| <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__. | |||||
| Args: | |||||
| ngram_size (:obj:`int`): | |||||
| All ngrams of size :obj:`ngram_size` can only occur once. | |||||
| """ | |||||
| def __init__(self, ngram_size: int): | |||||
| if not isinstance(ngram_size, int) or ngram_size <= 0: | |||||
| raise ValueError( | |||||
| f'`ngram_size` has to be a strictly positive integer, but is {ngram_size}' | |||||
| ) | |||||
| self.ngram_size = ngram_size | |||||
| def __call__(self, input_ids: torch.LongTensor, | |||||
| scores: torch.FloatTensor) -> torch.FloatTensor: | |||||
| num_batch_hypotheses = scores.shape[0] | |||||
| cur_len = input_ids.shape[-1] | |||||
| banned_batch_tokens = self._calc_banned_ngram_tokens( | |||||
| input_ids, num_batch_hypotheses, cur_len) | |||||
| for i, banned_tokens in enumerate(banned_batch_tokens): | |||||
| scores[i, banned_tokens] = -float('inf') | |||||
| return scores | |||||
| def _calc_banned_ngram_tokens(self, prev_input_ids: torch.Tensor, | |||||
| num_hypos: int, | |||||
| cur_len: int) -> List[Iterable[int]]: | |||||
| """Copied from fairseq for no_repeat_ngram in beam_search""" | |||||
| if cur_len + 1 < self.ngram_size: | |||||
| # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |||||
| return [[] for _ in range(num_hypos)] | |||||
| generated_ngrams = [{} for _ in range(num_hypos)] | |||||
| for idx in range(num_hypos): | |||||
| gen_tokens = prev_input_ids[idx].tolist() | |||||
| generated_ngram = generated_ngrams[idx] | |||||
| for ngram in zip(*[gen_tokens[i:] | |||||
| for i in range(self.ngram_size)]): | |||||
| prev_ngram_tuple = tuple(ngram[:-1]) | |||||
| generated_ngram[prev_ngram_tuple] = generated_ngram.get( | |||||
| prev_ngram_tuple, []) + [ngram[-1]] | |||||
| def _get_generated_ngrams(hypo_idx): | |||||
| # Before decoding the next token, prevent decoding of ngrams that have already appeared | |||||
| start_idx = cur_len + 1 - self.ngram_size | |||||
| ngram_idx = tuple(prev_input_ids[hypo_idx, | |||||
| start_idx:cur_len].tolist()) | |||||
| return generated_ngrams[hypo_idx].get(ngram_idx, []) | |||||
| banned_tokens = [ | |||||
| _get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos) | |||||
| ] | |||||
| return banned_tokens | |||||
| @@ -0,0 +1,469 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import os | |||||
| import random | |||||
| from os import path as osp | |||||
| from typing import Dict | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import Tensor, TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.config import Config | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from . import mpu | |||||
| from .arguments import get_args | |||||
| from .generation_utils import BeamSearchScorer | |||||
| from .train_utils import get_model | |||||
| from .utils import load_checkpoint | |||||
| __all__ = ['MGLMForTextSummarization'] | |||||
| def setup_args(args): | |||||
| args.block_lm = True | |||||
| args.task_mask = True | |||||
| args.cloze_eval = True | |||||
| args.num_layers = 24 | |||||
| args.hidden_size = 1536 | |||||
| args.num_attention_heads = 16 | |||||
| args.max_position_embeddings = 1024 | |||||
| args.tokenizer_type = 'ChineseSPTokenizer' | |||||
| args.load_pretrained = '' | |||||
| args.DDP_impl = 'none' | |||||
| args.model_parallel_size = 1 | |||||
| args.fp16 = True | |||||
| args.cache_dir = 'cache' | |||||
| args.out_seq_length = 200 | |||||
| args.seq_length = 512 | |||||
| args.temperature = 0.9 | |||||
| args.top_k = 2 | |||||
| args.top_p = 0.8 | |||||
| args.frequency_penalty = 0.1 | |||||
| args.presence_penalty = 0.1 | |||||
| args.mem_length = args.seq_length + args.mem_length - 1 | |||||
| return args | |||||
| def setup_model(args): | |||||
| """Setup model and optimizer.""" | |||||
| model = get_model(args, model_type='generation') | |||||
| if args.load_pretrained is not None: | |||||
| args.no_load_optim = True | |||||
| args.load = args.load_pretrained | |||||
| _ = load_checkpoint(model, None, None, args) | |||||
| return model | |||||
| def set_random_seed(seed): | |||||
| """Set random seed for reproducability.""" | |||||
| if seed is not None and seed > 0: | |||||
| random.seed(seed) | |||||
| np.random.seed(seed) | |||||
| torch.manual_seed(seed) | |||||
| mpu.model_parallel_cuda_manual_seed(seed) | |||||
| def get_masks_and_position_ids(data, | |||||
| eod_token, | |||||
| reset_position_ids, | |||||
| reset_attention_mask, | |||||
| loss_mask=None, | |||||
| attention_mask=None, | |||||
| set_loss_mask=False, | |||||
| mem_length=None): | |||||
| # Extract batch size and sequence length. | |||||
| batch_size, seq_length = data.size() | |||||
| # Attention mask (lower triangular). | |||||
| if mem_length: | |||||
| if attention_mask is None: | |||||
| attention_mask = torch.ones( | |||||
| (1, seq_length, seq_length + mem_length), device=data.device) | |||||
| attention_mask = torch.tril( | |||||
| torch.triu(attention_mask, 1 - seq_length + mem_length), | |||||
| mem_length) | |||||
| else: | |||||
| if reset_attention_mask: | |||||
| att_mask_batch = batch_size | |||||
| else: | |||||
| att_mask_batch = 1 | |||||
| if attention_mask is None: | |||||
| attention_mask = torch.ones( | |||||
| (att_mask_batch, seq_length, seq_length), device=data.device) | |||||
| attention_mask = torch.tril(attention_mask) | |||||
| attention_mask = attention_mask.unsqueeze(1) | |||||
| # Loss mask. | |||||
| if loss_mask is None: | |||||
| loss_mask = torch.ones( | |||||
| data.size(), dtype=torch.float, device=data.device) | |||||
| # Position ids. | |||||
| position_ids = torch.arange( | |||||
| seq_length, dtype=torch.long, device=data.device) | |||||
| position_ids = position_ids.unsqueeze(0).expand_as(data) | |||||
| if set_loss_mask: | |||||
| loss_mask[data == eod_token] = 0.0 | |||||
| # We need to clone as the ids will be modifed based on batch index. | |||||
| if reset_position_ids: | |||||
| position_ids = position_ids.clone() | |||||
| if reset_position_ids or reset_attention_mask: | |||||
| # Loop through the batches: | |||||
| for b in range(batch_size): | |||||
| # Find indecies where EOD token is. | |||||
| eod_index = position_ids[b, data[b] == eod_token] | |||||
| # Detach indecies from positions if going to modify positions. | |||||
| if reset_position_ids: | |||||
| eod_index = eod_index.clone() | |||||
| # Loop through EOD indecies: | |||||
| prev_index = 0 | |||||
| for j in range(eod_index.size()[0]): | |||||
| i = eod_index[j] | |||||
| # Mask attention loss. | |||||
| if reset_attention_mask: | |||||
| attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 | |||||
| # Reset positions. | |||||
| if reset_position_ids: | |||||
| position_ids[b, (i + 1):] -= (i + 1 - prev_index) | |||||
| prev_index = i + 1 | |||||
| return attention_mask, loss_mask, position_ids | |||||
| def initialize_distributed(args): | |||||
| """Initialize torch.distributed.""" | |||||
| # Manually set the device ids. | |||||
| device = args.rank % torch.cuda.device_count() | |||||
| if args.local_rank is not None: | |||||
| device = args.local_rank | |||||
| torch.cuda.set_device(device) | |||||
| # Call the init process | |||||
| init_method = 'tcp://' | |||||
| args.master_ip = os.getenv('MASTER_ADDR', 'localhost') | |||||
| args.master_port = os.getenv('MASTER_PORT', '6000') | |||||
| init_method += args.master_ip + ':' + args.master_port | |||||
| torch.distributed.init_process_group( | |||||
| backend=args.distributed_backend, | |||||
| world_size=args.world_size, | |||||
| rank=args.rank, | |||||
| init_method=init_method) | |||||
| # Set the model-parallel / data-parallel communicators. | |||||
| mpu.initialize_model_parallel(args.model_parallel_size) | |||||
| # Optional DeepSpeed Activation Checkpointing Features | |||||
| # | |||||
| if hasattr( | |||||
| args, 'deepspeed' | |||||
| ) and args.deepspeed and args.deepspeed_activation_checkpointing: | |||||
| set_deepspeed_activation_checkpointing(args) | |||||
| def get_batch(context_tokens, device, args): | |||||
| tokens = context_tokens | |||||
| tokens = tokens.view(args.batch_size, -1).contiguous() | |||||
| tokens = tokens.to(device) | |||||
| # Get the masks and postition ids. | |||||
| if args.block_lm: | |||||
| attention_mask = torch.tensor([tokens.size(1)], | |||||
| device=device, | |||||
| dtype=torch.long) | |||||
| position_ids = torch.arange( | |||||
| tokens.size(1), device=device, dtype=torch.long) | |||||
| if not args.no_block_position: | |||||
| block_position_ids = torch.zeros( | |||||
| tokens.size(1), device=device, dtype=torch.long) | |||||
| position_ids = torch.stack((position_ids, block_position_ids), | |||||
| dim=0) | |||||
| position_ids = position_ids.unsqueeze(0) | |||||
| else: | |||||
| attention_mask, loss_mask, position_ids = get_masks_and_position_ids( | |||||
| tokens, | |||||
| args.eod_token, | |||||
| reset_position_ids=False, | |||||
| reset_attention_mask=False, | |||||
| set_loss_mask=False, | |||||
| mem_length=args.mem_length) | |||||
| return tokens, attention_mask, position_ids | |||||
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||||
| # This function has been mostly taken from huggingface conversational ai code at | |||||
| # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 | |||||
| if top_k > 0: | |||||
| # Remove all tokens with a probability less than the last token of the top-k | |||||
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||||
| None] | |||||
| logits[indices_to_remove] = filter_value | |||||
| if top_p > 0.0: | |||||
| # convert to 1D | |||||
| logits = logits.view(logits.size()[1]).contiguous() | |||||
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||||
| cumulative_probs = torch.cumsum( | |||||
| F.softmax(sorted_logits, dim=-1), dim=-1) | |||||
| # Remove tokens with cumulative probability above the threshold | |||||
| sorted_indices_to_remove = cumulative_probs > top_p | |||||
| # Shift the indices to the right to keep also the first token above the threshold | |||||
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||||
| ..., :-1].clone() | |||||
| sorted_indices_to_remove[..., 0] = 0 | |||||
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |||||
| logits[indices_to_remove] = filter_value | |||||
| # going back to 2D | |||||
| logits = logits.view(1, -1).contiguous() | |||||
| return logits | |||||
| def sample_sequence(model, | |||||
| tokenizer, | |||||
| context_tokens, | |||||
| context_length, | |||||
| args, | |||||
| device, | |||||
| mems=None, | |||||
| end_tokens=None): | |||||
| if not args.block_lm: | |||||
| context_tokens, attention_mask, position_ids = get_batch( | |||||
| context_tokens, device, args) | |||||
| tokens = torch.empty((args.num_beams, 0), | |||||
| device=context_tokens.device, | |||||
| dtype=torch.long) | |||||
| else: | |||||
| tokens = context_tokens.new_full((1, 1), | |||||
| tokenizer.get_command('sop').Id) | |||||
| counter = 0 | |||||
| if mems is None: | |||||
| mems = [] | |||||
| if end_tokens is None: | |||||
| end_tokens = [args.eod_token] | |||||
| last_beam_num = 1 | |||||
| output_tokens_list = [] | |||||
| generated_tokens_list = [] | |||||
| while counter < args.out_seq_length: | |||||
| if counter == 0 and not args.block_lm: | |||||
| next_token_logits, *mems = model(context_tokens, position_ids, | |||||
| attention_mask, *mems) | |||||
| else: | |||||
| if args.block_lm: | |||||
| if args.no_block_position: | |||||
| position_ids = context_tokens.new_full( | |||||
| (last_beam_num, 1), context_length + counter) | |||||
| else: | |||||
| position_ids = context_tokens.new_ones(last_beam_num, 2, 1) | |||||
| position_ids[:, 0] = context_length | |||||
| position_ids[:, 1] = counter + 1 | |||||
| attention_mask = context_tokens.new_zeros( | |||||
| [1], device=context_tokens.device, dtype=torch.long) | |||||
| else: | |||||
| position_ids = context_tokens.new_ones((last_beam_num, 1)) * ( | |||||
| context_length + counter - 1) | |||||
| attention_mask = context_tokens.new_ones( | |||||
| last_beam_num, | |||||
| 1, | |||||
| 1, | |||||
| args.mem_length + 1, | |||||
| device=context_tokens.device, | |||||
| dtype=torch.float) | |||||
| last_token = tokens[:, -1:] | |||||
| next_token_logits, *mems = model(last_token, position_ids, | |||||
| attention_mask, *mems) | |||||
| next_token_logits = next_token_logits[:, -1] | |||||
| next_token_logits /= args.temperature | |||||
| frequency_count = torch.zeros(next_token_logits.shape) | |||||
| for tk in output_tokens_list: | |||||
| frequency_count[0][tk] += 1 | |||||
| next_token_logits -= (args.frequency_penalty | |||||
| * frequency_count).to(device) | |||||
| next_token_logits -= ( | |||||
| args.presence_penalty * # noqa | |||||
| (frequency_count > 0)).to(device) | |||||
| next_token_logits = top_k_logits( | |||||
| next_token_logits, top_k=args.top_k, top_p=args.top_p) | |||||
| log_probs = F.softmax(next_token_logits, dim=-1) | |||||
| prev = torch.multinomial(log_probs, num_samples=1)[0] | |||||
| is_end = prev.item() in end_tokens | |||||
| if is_end: | |||||
| break | |||||
| decode_tokens = tokenizer.DecodeIds([prev.item()]) # noqa | |||||
| generated_tokens_list.append(prev.item()) | |||||
| prev = prev.view(1, 1) | |||||
| tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1) | |||||
| counter += 1 | |||||
| output_tokens_list = tokens.view(-1).contiguous() | |||||
| return torch.cat((context_tokens, tokens), dim=1), mems | |||||
| def read_context(tokenizer, args, context): | |||||
| terminate_runs, skip_run = 0, 0 # noqa | |||||
| if mpu.get_model_parallel_rank() == 0: | |||||
| while True: | |||||
| # raw_text = input("\nContext prompt (stop to exit) >>> ") | |||||
| raw_text = context | |||||
| if not raw_text: | |||||
| print('Prompt should not be empty!') | |||||
| break | |||||
| # if raw_text == "stop": | |||||
| # terminate_runs = 1 | |||||
| # break | |||||
| generation_mask = '[gMASK]' if args.task_mask else '[MASK]' | |||||
| if args.block_lm and 'MASK]' not in raw_text: | |||||
| raw_text += ' ' + generation_mask | |||||
| # output.write(raw_text) | |||||
| context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization | |||||
| if args.block_lm: | |||||
| context_tokens = [tokenizer.get_command('ENC').Id | |||||
| ] + context_tokens | |||||
| if not raw_text.endswith('[gMASK]'): | |||||
| context_tokens = context_tokens + [ | |||||
| tokenizer.get_command('eos').Id | |||||
| ] | |||||
| context_length = len(context_tokens) | |||||
| if context_length >= args.seq_length: | |||||
| print('\nContext length', context_length, | |||||
| '\nPlease give smaller context than the window length!') | |||||
| break | |||||
| break | |||||
| else: | |||||
| context_length = 0 | |||||
| terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) | |||||
| torch.distributed.broadcast( | |||||
| terminate_runs_tensor, | |||||
| mpu.get_model_parallel_src_rank(), | |||||
| group=mpu.get_model_parallel_group()) | |||||
| terminate_runs = terminate_runs_tensor[0].item() | |||||
| if terminate_runs == 1: | |||||
| return terminate_runs, None, None, None | |||||
| context_length_tensor = torch.cuda.LongTensor([context_length]) | |||||
| torch.distributed.broadcast( | |||||
| context_length_tensor, | |||||
| mpu.get_model_parallel_src_rank(), | |||||
| group=mpu.get_model_parallel_group()) | |||||
| context_length = context_length_tensor[0].item() | |||||
| if mpu.get_model_parallel_rank() == 0: | |||||
| context_tokens_tensor = torch.cuda.LongTensor(context_tokens) | |||||
| else: | |||||
| context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) | |||||
| torch.distributed.broadcast( | |||||
| context_tokens_tensor, | |||||
| mpu.get_model_parallel_src_rank(), | |||||
| group=mpu.get_model_parallel_group()) | |||||
| if mpu.get_model_parallel_rank() != 0: | |||||
| raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) | |||||
| return terminate_runs, raw_text, context_tokens_tensor, context_length | |||||
| @MODELS.register_module(Tasks.text_summarization, module_name=Models.mglm) | |||||
| class MGLMForTextSummarization(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the text summarization model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| from .configure_data import prepare_tokenizer | |||||
| # Disable CuDNN. | |||||
| torch.backends.cudnn.enabled = False | |||||
| # Arguments. | |||||
| self.args = setup_args(get_args()) | |||||
| self.args.load_pretrained = model_dir | |||||
| # Pytorch distributed. | |||||
| try: | |||||
| initialize_distributed(self.args) | |||||
| except (RuntimeError): | |||||
| print('group process initialized twice') | |||||
| # Random seeds for reproducability. | |||||
| set_random_seed(self.args.seed) | |||||
| # setting default batch size to 1 | |||||
| self.args.batch_size = 1 | |||||
| self.args.tokenizer_path = model_dir | |||||
| self.tokenizer = prepare_tokenizer(self.args) | |||||
| self.model = setup_model(self.args) | |||||
| self.cfg = Config.from_file( | |||||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||||
| def forward(self, input: Dict[str, str]) -> Dict[str, str]: | |||||
| pass | |||||
| def generate(self, input: Dict[str, str]) -> Dict[str, str]: | |||||
| model = self.model | |||||
| tokenizer = self.tokenizer | |||||
| args = self.args | |||||
| device = torch.cuda.current_device() | |||||
| model.eval() | |||||
| context = input['text'] + self.cfg.model.prompt | |||||
| with torch.no_grad(): | |||||
| terminate_runs, raw_text, context_tokens_tensor, context_length = read_context( | |||||
| tokenizer, args, context) | |||||
| mems = [] | |||||
| tokens, attention_mask, position_ids = get_batch( | |||||
| context_tokens_tensor, device, args) | |||||
| mask_tokens = ['MASK', 'sMASK', 'gMASK' | |||||
| ] if args.task_mask else ['MASK'] | |||||
| mask_tokens = [ | |||||
| tokenizer.get_command(token).Id for token in mask_tokens | |||||
| ] | |||||
| end_tokens = [tokenizer.get_command('eop').Id, args.eod_token] | |||||
| mask_positions = [] | |||||
| for token in mask_tokens: | |||||
| mask_positions += (context_tokens_tensor == token).nonzero( | |||||
| as_tuple=True)[0].tolist() | |||||
| mask_positions.sort() | |||||
| if args.no_block_position: | |||||
| for mask_position in mask_positions: | |||||
| position_ids[0, mask_position + 1:] += args.out_seq_length | |||||
| _, *mems = model(tokens, position_ids, attention_mask, *mems) | |||||
| for mask_position in mask_positions: | |||||
| if args.no_block_position: | |||||
| position = position_ids[0, mask_position].item() | |||||
| else: | |||||
| position = mask_position | |||||
| tokens, mems, = sample_sequence( | |||||
| model, | |||||
| tokenizer, | |||||
| tokens, | |||||
| position, | |||||
| args, | |||||
| device, | |||||
| mems=mems, | |||||
| end_tokens=end_tokens) | |||||
| output_tokens_list = tokens.view(-1).contiguous() | |||||
| trim_decode_tokens = tokenizer.DecodeIds( | |||||
| output_tokens_list.tolist()) | |||||
| res = trim_decode_tokens.split('<|startofpiece|>')[-1] | |||||
| print(res) | |||||
| return {OutputKeys.TEXT: res} | |||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from .distributed import (DistributedDataParallel, | |||||
| PyTorchDistributedDataParallel) | |||||
| from .downstream import (GLMForMultiTokenCloze, GLMForMultiTokenClozeFast, | |||||
| GLMForSequenceClassification, GLMForSingleTokenCloze) | |||||
| from .modeling_glm import (GLMModel, | |||||
| glm_get_params_for_weight_decay_optimization) | |||||
| @@ -0,0 +1,127 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||||
| from torch.autograd import Variable | |||||
| from torch.nn.modules import Module | |||||
| from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |||||
| from modelscope.models.nlp.mglm import mpu | |||||
| class PyTorchDistributedDataParallel(DDP): | |||||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||||
| return self.module.named_parameters(prefix=prefix, recurse=recurse) | |||||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||||
| sd = self.module.state_dict(destination, prefix, keep_vars) | |||||
| return sd | |||||
| def load_state_dict(self, state_dict, strict=True): | |||||
| return self.module.load_state_dict(state_dict, strict=strict) | |||||
| class DistributedDataParallel(Module): | |||||
| def __init__(self, module): | |||||
| super(DistributedDataParallel, self).__init__() | |||||
| self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False | |||||
| self.module = module | |||||
| self.data_parallel_group = mpu.get_data_parallel_group() | |||||
| src_rank = mpu.get_model_parallel_rank() | |||||
| for p in self.module.parameters(): | |||||
| if torch.is_tensor(p): | |||||
| dist.broadcast(p, src_rank, group=self.data_parallel_group) | |||||
| def allreduce_params(reduce_after=True, | |||||
| no_scale=False, | |||||
| fp32_allreduce=False): | |||||
| if (self.needs_reduction): | |||||
| self.needs_reduction = False | |||||
| buckets = {} | |||||
| for name, param in self.module.named_parameters(): | |||||
| if param.requires_grad and param.grad is not None: | |||||
| tp = (param.data.type()) | |||||
| if tp not in buckets: | |||||
| buckets[tp] = [] | |||||
| buckets[tp].append(param) | |||||
| if self.warn_on_half: | |||||
| if torch.cuda.HalfTensor in buckets: | |||||
| print( | |||||
| 'WARNING: gloo dist backend for half parameters may be extremely slow. It is recommended to use the NCCL backend in this case.' # noqa | |||||
| ) | |||||
| self.warn_on_half = False | |||||
| for tp in buckets: | |||||
| bucket = buckets[tp] | |||||
| grads = [param.grad.data for param in bucket] | |||||
| coalesced = _flatten_dense_tensors(grads) | |||||
| if fp32_allreduce: | |||||
| coalesced = coalesced.float() | |||||
| if not no_scale and not reduce_after: | |||||
| coalesced /= dist.get_world_size( | |||||
| group=self.data_parallel_group) | |||||
| dist.all_reduce(coalesced, group=self.data_parallel_group) | |||||
| torch.cuda.synchronize() | |||||
| if not no_scale and reduce_after: | |||||
| coalesced /= dist.get_world_size( | |||||
| group=self.data_parallel_group) | |||||
| for buf, synced in zip( | |||||
| grads, _unflatten_dense_tensors(coalesced, grads)): | |||||
| buf.copy_(synced) | |||||
| self.hook_handles = [] | |||||
| self.hooks = [] | |||||
| for param in list(self.module.parameters()): | |||||
| def allreduce_hook(*unused): | |||||
| Variable._execution_engine.queue_callback(allreduce_params) | |||||
| self.allreduce_params = allreduce_params | |||||
| def forward(self, *inputs, **kwargs): | |||||
| self.needs_reduction = True | |||||
| return self.module(*inputs, **kwargs) | |||||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||||
| sd = self.module.state_dict(destination, prefix, keep_vars) | |||||
| return sd | |||||
| def load_state_dict(self, state_dict, strict=True): | |||||
| return self.module.load_state_dict(state_dict, strict=strict) | |||||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||||
| return self.module.named_parameters(prefix=prefix, recurse=recurse) | |||||
| ''' | |||||
| def _sync_buffers(self): | |||||
| buffers = list(self.module._all_buffers()) | |||||
| if len(buffers) > 0: | |||||
| # cross-node buffer sync | |||||
| flat_buffers = _flatten_dense_tensors(buffers) | |||||
| dist.broadcast(flat_buffers, 0) | |||||
| for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): | |||||
| buf.copy_(synced) | |||||
| def train(self, mode=True): | |||||
| # Clear NCCL communicator and CUDA event cache of the default group ID, | |||||
| # These cache will be recreated at the later call. This is currently a | |||||
| # work-around for a potential NCCL deadlock. | |||||
| if dist._backend == dist.dist_backend.NCCL: | |||||
| dist._clear_group_cache() | |||||
| super(DistributedDataParallel, self).train(mode) | |||||
| self.module.train(mode) | |||||
| ''' | |||||
| @@ -0,0 +1,242 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| """Multiple choice model.""" | |||||
| import torch | |||||
| import torch.nn | |||||
| from .modeling_glm import GLMModel | |||||
| class GLMForMultiTokenCloze(torch.nn.Module): | |||||
| def __init__(self, | |||||
| language_model: GLMModel, | |||||
| take_softmax=True, | |||||
| length_penalty=0.0): | |||||
| super(GLMForMultiTokenCloze, self).__init__() | |||||
| self.model = language_model | |||||
| self.take_softmax = take_softmax | |||||
| self.length_penalty = length_penalty | |||||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||||
| # [h.remove() for h in self.hook_handles] | |||||
| sd = self.model.state_dict(destination, prefix, keep_vars) | |||||
| return sd | |||||
| def load_state_dict(self, state_dict, strict=True): | |||||
| return self.model.load_state_dict(state_dict, strict=strict) | |||||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||||
| return self.model.named_parameters(prefix=prefix, recurse=recurse) | |||||
| def forward(self, | |||||
| input_ids, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| target_ids=None, | |||||
| logit_mask=None, | |||||
| prompt_pos=None): | |||||
| if target_ids is None: | |||||
| return self.model(input_ids, position_ids, attention_mask) | |||||
| num_choices = None | |||||
| if len(input_ids.shape) == 3: | |||||
| batch_size, num_choices = input_ids.shape[:2] | |||||
| input_ids = input_ids.reshape(-1, input_ids.size(-1)) | |||||
| attention_mask = attention_mask.reshape(-1, | |||||
| *attention_mask.size()[2:]) | |||||
| position_ids = position_ids.reshape(-1, *position_ids.size()[2:]) | |||||
| target_ids = target_ids.reshape(-1, target_ids.size(-1)) | |||||
| logit_mask = logit_mask.reshape(-1, logit_mask.size(-1)) | |||||
| if prompt_pos is not None: | |||||
| prompt_pos = prompt_pos.reshape(-1, prompt_pos.size(-1)) | |||||
| outputs, *mems = self.model( | |||||
| input_ids, position_ids, attention_mask, prompt_pos=prompt_pos) | |||||
| if self.take_softmax: | |||||
| outputs = torch.nn.functional.log_softmax(outputs, dim=-1) | |||||
| # select the target logits | |||||
| batch_ids = torch.arange( | |||||
| target_ids.size(0), dtype=torch.long, device=target_ids.device) | |||||
| batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids) | |||||
| seq_ids = torch.arange( | |||||
| target_ids.size(-1), dtype=torch.long, device=target_ids.device) | |||||
| seq_ids = seq_ids.unsqueeze(0).expand_as(target_ids) | |||||
| logits = outputs[batch_ids, seq_ids, target_ids] | |||||
| logits = (logits * logit_mask).sum(dim=1) | |||||
| if self.length_penalty > 0.0: | |||||
| logits = logits / logit_mask.sum(dim=1)**self.length_penalty | |||||
| if num_choices is not None: | |||||
| logits = logits.view(-1, num_choices) | |||||
| return (logits, *mems) | |||||
| class GLMForMultiTokenClozeFast(torch.nn.Module): | |||||
| def __init__(self, language_model, take_softmax=True, length_penalty=0.0): | |||||
| super(GLMForMultiTokenClozeFast, self).__init__() | |||||
| self.model = language_model | |||||
| self.take_softmax = take_softmax | |||||
| self.length_penalty = length_penalty | |||||
| def forward(self, input_ids, position_ids, attention_mask, dec_input_ids, | |||||
| dec_position_ids, dec_attention_mask, dec_target_ids, | |||||
| dec_logit_mask): | |||||
| # encoder | |||||
| outputs, *mems = self.model( | |||||
| input_ids, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| return_memory=True, | |||||
| detach_memory=False) | |||||
| batch_size, num_choices, max_dec_len = dec_input_ids.size() | |||||
| max_enc_len = input_ids.size(-1) | |||||
| enc_mems = [] | |||||
| for hidden in mems: | |||||
| hidden = hidden.unsqueeze(1).expand(-1, num_choices, -1, | |||||
| -1).reshape( | |||||
| batch_size * num_choices, | |||||
| *hidden.size()[1:]) | |||||
| enc_mems.append(hidden) | |||||
| def build_dec_mask_matrix(seq_length, sep, memory_length=0): | |||||
| m = enc_mems[0].new_ones((1, seq_length, seq_length)) | |||||
| m = torch.tril(m) | |||||
| # sep = dec_attention_mask | |||||
| ids = torch.arange( | |||||
| memory_length, device=sep.device, dtype=sep.dtype).view(1, -1) | |||||
| mask = ids < sep.view(-1, 1) # batch * mem | |||||
| mask = mask.unsqueeze(1).float().expand(-1, seq_length, -1) | |||||
| m = m.expand(batch_size * num_choices, -1, -1) | |||||
| m = torch.cat((mask, m), dim=2) | |||||
| m = m.unsqueeze(1) | |||||
| return m | |||||
| dec_input_ids = dec_input_ids.reshape(-1, max_dec_len) | |||||
| dec_position_ids = dec_position_ids.reshape( | |||||
| -1, | |||||
| *dec_position_ids.size()[2:]) | |||||
| # dec_attention_mask = dec_attention_mask.reshape(-1, *dec_attention_mask.size()[2:]).unsqueeze(1) | |||||
| dec_attention_mask = build_dec_mask_matrix( | |||||
| max_dec_len, dec_attention_mask.reshape(-1), max_enc_len) | |||||
| dec_target_ids = dec_target_ids.reshape(-1, dec_target_ids.size(-1)) | |||||
| dec_logit_mask = dec_logit_mask.reshape(-1, dec_logit_mask.size(-1)) | |||||
| outputs, *mems = self.model(dec_input_ids, dec_position_ids, | |||||
| dec_attention_mask, *enc_mems) | |||||
| if self.take_softmax: | |||||
| outputs = torch.nn.functional.log_softmax(outputs, dim=-1) | |||||
| batch_ids = torch.arange( | |||||
| dec_target_ids.size(0), | |||||
| dtype=torch.long, | |||||
| device=dec_target_ids.device) | |||||
| batch_ids = batch_ids.unsqueeze(1).expand_as(dec_target_ids) | |||||
| seq_ids = torch.arange( | |||||
| dec_target_ids.size(-1), | |||||
| dtype=torch.long, | |||||
| device=dec_target_ids.device) | |||||
| seq_ids = seq_ids.unsqueeze(0).expand_as(dec_target_ids) | |||||
| logits = outputs[batch_ids, seq_ids, dec_target_ids] | |||||
| logits = (logits * dec_logit_mask).sum(dim=1) | |||||
| if self.length_penalty > 0.0: | |||||
| logits = logits / dec_logit_mask.sum(dim=1)**self.length_penalty | |||||
| if num_choices is not None: | |||||
| logits = logits.view(-1, num_choices) | |||||
| return (logits, *mems) | |||||
| class GLMForSingleTokenCloze(torch.nn.Module): | |||||
| def __init__(self, language_model, take_softmax=False): | |||||
| super().__init__() | |||||
| self.model = language_model | |||||
| self.take_softmax = take_softmax | |||||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||||
| # [h.remove() for h in self.hook_handles] | |||||
| sd = self.model.state_dict(destination, prefix, keep_vars) | |||||
| return sd | |||||
| def load_state_dict(self, state_dict, strict=True): | |||||
| return self.model.load_state_dict(state_dict, strict=strict) | |||||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||||
| return self.model.named_parameters(prefix=prefix, recurse=recurse) | |||||
| def forward(self, | |||||
| input_ids, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| target_ids=None, | |||||
| logit_mask=None, | |||||
| prompt_pos=None): | |||||
| if target_ids is None: | |||||
| return self.model(input_ids, position_ids, attention_mask) | |||||
| assert len(input_ids.shape) == 2 | |||||
| outputs, *mems = self.model( | |||||
| input_ids, position_ids, attention_mask, prompt_pos=prompt_pos) | |||||
| batch_ids = torch.arange( | |||||
| outputs.size(0), | |||||
| dtype=attention_mask.dtype, | |||||
| device=attention_mask.device) | |||||
| target_logits = outputs[batch_ids, attention_mask] | |||||
| if self.take_softmax: | |||||
| target_prob = torch.nn.functional.log_softmax( | |||||
| target_logits, dim=-1) | |||||
| else: | |||||
| target_prob = target_logits | |||||
| batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids) | |||||
| output = target_prob[batch_ids, target_ids] | |||||
| return (output, target_logits, *mems) | |||||
| class GLMForSequenceClassification(torch.nn.Module): | |||||
| def __init__(self, | |||||
| language_model, | |||||
| hidden_size, | |||||
| hidden_dropout, | |||||
| pool_token, | |||||
| num_class=1): | |||||
| super().__init__() | |||||
| self.pool_token = pool_token | |||||
| self.model = language_model | |||||
| self.num_class = num_class | |||||
| # Multi-choice head. | |||||
| self.pool_layer = torch.nn.Linear(hidden_size, hidden_size) | |||||
| self.multichoice_dropout = torch.nn.Dropout(hidden_dropout) | |||||
| self.multichoice_head = torch.nn.Linear(hidden_size, num_class) | |||||
| def forward(self, input_ids, position_ids, attention_mask): | |||||
| num_choices = None | |||||
| if len(input_ids.shape) == 3: | |||||
| assert self.num_class == 1 | |||||
| batch_size, num_choices = input_ids.shape[:2] | |||||
| input_ids = input_ids.reshape(-1, input_ids.size(-1)) | |||||
| attention_mask = attention_mask.reshape(-1, | |||||
| *attention_mask.size()[2:]) | |||||
| position_ids = position_ids.reshape(-1, *position_ids.size()[2:]) | |||||
| outputs, *mems = self.model(input_ids, position_ids, attention_mask) | |||||
| if self.pool_token == 'start': | |||||
| output = outputs[torch.arange( | |||||
| outputs.size(0), | |||||
| dtype=attention_mask.dtype, | |||||
| device=attention_mask.device), attention_mask] | |||||
| elif self.pool_token == 'pad': | |||||
| output = outputs[torch.arange( | |||||
| outputs.size(0), | |||||
| dtype=attention_mask.dtype, | |||||
| device=attention_mask.device), attention_mask - 1] | |||||
| elif self.pool_token == 'cls': | |||||
| output = outputs[:, 0] | |||||
| else: | |||||
| raise NotImplementedError | |||||
| output = torch.tanh(self.pool_layer(output)) | |||||
| multichoice_output = self.multichoice_dropout(output) | |||||
| logits = self.multichoice_head(multichoice_output) | |||||
| if num_choices is not None: | |||||
| logits = logits.view(-1, num_choices) | |||||
| return (logits, *mems) | |||||
| @@ -0,0 +1,245 @@ | |||||
| # Modified by Zhipu.AI | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """GPT-2 model.""" | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| from modelscope.models.nlp.mglm import mpu | |||||
| from modelscope.models.nlp.mglm.model.prompt import PromptSpell | |||||
| from modelscope.models.nlp.mglm.utils import print_rank_0 | |||||
| def init_method_normal(std=0.02): | |||||
| """Init method based on normal distribution. | |||||
| This is only used for embeddings. The transformer has its | |||||
| own initializer. | |||||
| """ | |||||
| def init_(tensor): | |||||
| return torch.nn.init.normal_(tensor, mean=0.0, std=std) | |||||
| return init_ | |||||
| class GLMModel(torch.nn.Module): | |||||
| """GLM Language model. | |||||
| The output of the forward method are the logits (parallel or | |||||
| serial depending on the `parallel_output` flag. | |||||
| """ | |||||
| def __init__( | |||||
| self, | |||||
| num_layers, | |||||
| vocab_size, | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| embedding_dropout_prob, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| max_sequence_length, | |||||
| max_memory_length, | |||||
| checkpoint_activations, | |||||
| checkpoint_num_layers=1, | |||||
| parallel_output=True, | |||||
| relative_encoding=False, | |||||
| block_position_encoding=False, | |||||
| output_predict=True, | |||||
| spell_length=None, | |||||
| spell_func='lstm', | |||||
| attention_scale=1.0, | |||||
| ): | |||||
| super(GLMModel, self).__init__() | |||||
| self.parallel_output = parallel_output | |||||
| self.output_predict = output_predict | |||||
| self.hidden_size = hidden_size | |||||
| init_method = init_method_normal(std=0.02) | |||||
| # Word embeddings (parallel). | |||||
| self.word_embeddings = mpu.VocabParallelEmbedding( | |||||
| vocab_size, hidden_size, init_method=init_method) | |||||
| # Transformer | |||||
| self.transformer = mpu.GPT2ParallelTransformer( | |||||
| num_layers, | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| max_sequence_length, | |||||
| max_memory_length, | |||||
| embedding_dropout_prob, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| checkpoint_activations, | |||||
| checkpoint_num_layers, | |||||
| attention_scale=attention_scale, | |||||
| relative_encoding=relative_encoding, | |||||
| block_position_encoding=block_position_encoding) | |||||
| if spell_length is not None: | |||||
| self.prompt_spell = PromptSpell(spell_length, self.hidden_size, | |||||
| spell_func) | |||||
| def freeze_transformer(self, tune_prefix_layers=None): | |||||
| log_str = 'Freeze transformer' | |||||
| self.word_embeddings.requires_grad_(False) | |||||
| self.transformer.requires_grad_(False) | |||||
| if tune_prefix_layers is not None: | |||||
| log_str += f' tune {tune_prefix_layers} prefix layers' | |||||
| for i in range(tune_prefix_layers): | |||||
| self.transformer.layers[i].requires_grad_(True) | |||||
| print_rank_0(log_str) | |||||
| def forward(self, | |||||
| input_ids, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| *mems, | |||||
| return_memory=False, | |||||
| detach_memory=True, | |||||
| prompt_pos=None): | |||||
| # Embeddings. | |||||
| batch_size = input_ids.size(0) | |||||
| words_embeddings = self.word_embeddings(input_ids) | |||||
| embeddings = words_embeddings | |||||
| if prompt_pos is not None: | |||||
| embeddings = embeddings.clone() | |||||
| prompt_embeds = self.prompt_spell() | |||||
| batch_index = torch.arange( | |||||
| batch_size, device=input_ids.device).unsqueeze(1) | |||||
| embeddings[batch_index, prompt_pos] = prompt_embeds | |||||
| # Transformer. | |||||
| transformer_output = self.transformer( | |||||
| embeddings, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| mems, | |||||
| return_memory=return_memory, | |||||
| detach_memory=detach_memory) | |||||
| logits, hidden_layers = transformer_output | |||||
| outputs = hidden_layers | |||||
| if self.output_predict: | |||||
| # Parallel logits. | |||||
| logits_parallel = mpu.copy_to_model_parallel_region(logits) | |||||
| logits_parallel = F.linear(logits_parallel, | |||||
| self.word_embeddings.weight) | |||||
| if self.parallel_output: | |||||
| return (logits_parallel, *outputs) | |||||
| return (mpu.gather_from_model_parallel_region(logits_parallel), | |||||
| *outputs) | |||||
| else: | |||||
| return (logits, *outputs) | |||||
| class EncoderDecoder(torch.nn.Module): | |||||
| """Seq2Seq Transformer Model | |||||
| The output of the forward method are the logits (parallel or serial depending on the `parallel_output` flag). | |||||
| """ | |||||
| def __init__(self, | |||||
| num_layers, | |||||
| vocab_size, | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| embedding_dropout_prob, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| max_sequence_length, | |||||
| max_memory_length, | |||||
| checkpoint_activations, | |||||
| checkpoint_num_layers=1, | |||||
| parallel_output=True, | |||||
| output_predict=True): | |||||
| super(EncoderDecoder, self).__init__() | |||||
| self.parallel_output = parallel_output | |||||
| self.output_predict = output_predict | |||||
| init_method = init_method_normal(std=0.02) | |||||
| # Word embeddings (parallel). | |||||
| self.word_embeddings = mpu.VocabParallelEmbedding( | |||||
| vocab_size, hidden_size, init_method=init_method) | |||||
| # Transformer | |||||
| self.encoder = mpu.GPT2ParallelTransformer( | |||||
| num_layers, hidden_size, num_attention_heads, max_sequence_length, | |||||
| max_memory_length, embedding_dropout_prob, attention_dropout_prob, | |||||
| output_dropout_prob, checkpoint_activations, checkpoint_num_layers) | |||||
| self.decoder = mpu.GPT2ParallelTransformer( | |||||
| num_layers, | |||||
| hidden_size, | |||||
| num_attention_heads, | |||||
| max_sequence_length, | |||||
| max_memory_length, | |||||
| embedding_dropout_prob, | |||||
| attention_dropout_prob, | |||||
| output_dropout_prob, | |||||
| checkpoint_activations, | |||||
| checkpoint_num_layers, | |||||
| use_decoder_layer=True) | |||||
| def forward(self, source_ids, target_ids, source_position_ids, | |||||
| target_position_ids, source_mask, target_mask): | |||||
| # Embeddings. | |||||
| source_embeddings = self.word_embeddings(source_ids) | |||||
| target_embeddings = self.word_embeddings(target_ids) | |||||
| # Transformer. | |||||
| encoder_output, _ = self.encoder(source_embeddings, | |||||
| source_position_ids, source_mask) | |||||
| decoder_output, _ = self.decoder(target_embeddings, | |||||
| target_position_ids, target_mask) | |||||
| if self.output_predict: | |||||
| # Parallel logits. | |||||
| output_parallel = mpu.copy_to_model_parallel_region(decoder_output) | |||||
| logits_parallel = F.linear(output_parallel, | |||||
| self.word_embeddings.weight) | |||||
| if self.parallel_output: | |||||
| return (logits_parallel, ) | |||||
| return (mpu.gather_from_model_parallel_region(logits_parallel), ) | |||||
| else: | |||||
| return (decoder_output, ) | |||||
| def glm_get_params_for_weight_decay_optimization(module): | |||||
| weight_decay_params = {'params': []} | |||||
| no_weight_decay_params = {'params': [], 'weight_decay': 0.0} | |||||
| for module_ in module.modules(): | |||||
| if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): | |||||
| no_weight_decay_params['params'].extend([ | |||||
| p for p in list(module_._parameters.values()) | |||||
| if p is not None and p.requires_grad | |||||
| ]) | |||||
| else: | |||||
| weight_decay_params['params'].extend([ | |||||
| p for n, p in list(module_._parameters.items()) | |||||
| if p is not None and p.requires_grad and n != 'bias' | |||||
| ]) | |||||
| no_weight_decay_params['params'].extend([ | |||||
| p for n, p in list(module_._parameters.items()) | |||||
| if p is not None and p.requires_grad and n == 'bias' | |||||
| ]) | |||||
| return weight_decay_params, no_weight_decay_params | |||||
| @@ -0,0 +1,59 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import random | |||||
| import torch | |||||
| class PromptSpell(torch.nn.Module): | |||||
| def __init__(self, spell_length, hidden_size, spell_func): | |||||
| super(PromptSpell, self).__init__() | |||||
| self.spell_length = spell_length | |||||
| self.hidden_size = hidden_size | |||||
| self.spell_embeddings = torch.nn.Embedding(self.spell_length, | |||||
| self.hidden_size) | |||||
| self.spell_func = spell_func | |||||
| if self.spell_func == 'lstm': | |||||
| self.lstm_head = torch.nn.LSTM( | |||||
| input_size=self.hidden_size, | |||||
| hidden_size=self.hidden_size, | |||||
| num_layers=2, | |||||
| # dropout=self.lstm_dropout, | |||||
| bidirectional=True, | |||||
| batch_first=True) # .to(torch.device("cuda")) | |||||
| self.mlp_head = torch.nn.Sequential( | |||||
| torch.nn.Linear(2 * self.hidden_size, self.hidden_size), | |||||
| torch.nn.ReLU(), | |||||
| torch.nn.Linear(self.hidden_size, self.hidden_size)) | |||||
| elif self.spell_func == 'mlp': | |||||
| self.mlp_head = torch.nn.Sequential( | |||||
| torch.nn.Linear(self.hidden_size, self.hidden_size), | |||||
| torch.nn.ReLU(), | |||||
| torch.nn.Linear(self.hidden_size, self.hidden_size)) | |||||
| elif self.spell_func != 'none': | |||||
| raise NotImplementedError('Prompt function ' + self.spell_func) | |||||
| def init_embedding(self, word_embeddings=None, task_tokens=None): | |||||
| num_words = 5000 | |||||
| with torch.no_grad(): | |||||
| for i in range(self.spell_length): | |||||
| rand_token = random.randrange(num_words) | |||||
| if task_tokens is None: | |||||
| target_embedding = word_embeddings[rand_token] | |||||
| else: | |||||
| word_embedding = word_embeddings[rand_token] | |||||
| task_token = random.choice(task_tokens) | |||||
| task_embedding = word_embeddings[task_token] | |||||
| ratio = random.random() | |||||
| target_embedding = word_embedding * ratio + task_embedding * ( | |||||
| 1 - ratio) | |||||
| self.spell_embeddings.weight.data[i] = target_embedding | |||||
| def forward(self): | |||||
| prompt_embeds = self.spell_embeddings.weight.unsqueeze(0) | |||||
| if self.spell_func == 'lstm': | |||||
| prompt_embeds = self.lstm_head(prompt_embeds)[0] | |||||
| if self.spell_func == 'lstm' or self.spell_func == 'mlp': | |||||
| prompt_embeds = self.mlp_head(prompt_embeds) | |||||
| return prompt_embeds | |||||
| @@ -0,0 +1,37 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Model parallel utility interface.""" | |||||
| from .cross_entropy import vocab_parallel_cross_entropy | |||||
| from .data import broadcast_data | |||||
| from .grads import clip_grad_norm | |||||
| from .initialize import (destroy_model_parallel, get_data_parallel_group, | |||||
| get_data_parallel_rank, get_data_parallel_world_size, | |||||
| get_model_parallel_group, get_model_parallel_rank, | |||||
| get_model_parallel_src_rank, | |||||
| get_model_parallel_world_size, | |||||
| initialize_model_parallel, | |||||
| model_parallel_is_initialized) | |||||
| from .layers import (ColumnParallelLinear, ParallelEmbedding, | |||||
| RowParallelLinear, VocabParallelEmbedding) | |||||
| from .mappings import (copy_to_model_parallel_region, | |||||
| gather_from_model_parallel_region, | |||||
| reduce_from_model_parallel_region, | |||||
| scatter_to_model_parallel_region) | |||||
| from .random import (checkpoint, get_cuda_rng_tracker, | |||||
| model_parallel_cuda_manual_seed, | |||||
| partition_activations_in_checkpoint) | |||||
| from .transformer import (BertParallelSelfAttention, | |||||
| BertParallelTransformerLayer, | |||||
| GPT2ParallelTransformer, LayerNorm) | |||||
| @@ -0,0 +1,110 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import torch | |||||
| from .initialize import (get_model_parallel_group, get_model_parallel_rank, | |||||
| get_model_parallel_world_size) | |||||
| from .utils import VocabUtility | |||||
| class _VocabParallelCrossEntropy(torch.autograd.Function): | |||||
| @staticmethod | |||||
| def forward(ctx, vocab_parallel_logits, target): | |||||
| # Copy so the input remains unchanged. | |||||
| logits = vocab_parallel_logits.clone() | |||||
| # Maximum value along vocab dimension across all GPUs. | |||||
| logits_max = torch.max(logits, dim=-1)[0] | |||||
| torch.distributed.all_reduce( | |||||
| logits_max, | |||||
| op=torch.distributed.ReduceOp.MAX, | |||||
| group=get_model_parallel_group()) | |||||
| # Subtract the maximum value. | |||||
| logits.sub_(logits_max.unsqueeze(dim=-1)) | |||||
| # Sum of exponential of logits along vocab dimension across all GPUs. | |||||
| exp_logits = logits.exp() | |||||
| sum_exp_logits = exp_logits.sum(dim=-1) | |||||
| torch.distributed.all_reduce( | |||||
| sum_exp_logits, | |||||
| op=torch.distributed.ReduceOp.SUM, | |||||
| group=get_model_parallel_group()) | |||||
| # Get the partition's vocab indecies | |||||
| get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size | |||||
| partition_vocab_size = vocab_parallel_logits.size()[-1] | |||||
| rank = get_model_parallel_rank() | |||||
| world_size = get_model_parallel_world_size() | |||||
| vocab_start_index, vocab_end_index = get_vocab_range( | |||||
| partition_vocab_size, rank, world_size) | |||||
| # Create a mask of valid vocab ids (1 means it needs to be masked). | |||||
| target_mask = (target < vocab_start_index) | ( | |||||
| target >= vocab_end_index) | |||||
| masked_target = target.clone() - vocab_start_index | |||||
| masked_target[target_mask] = 0 | |||||
| # Get predicted-logits = logits[target]. | |||||
| # For Simplicity, we convert logits to a 2-D tensor with size | |||||
| # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. | |||||
| logits_2d = logits.view(-1, partition_vocab_size) | |||||
| masked_target_1d = masked_target.view(-1) | |||||
| arange_1d = torch.arange( | |||||
| start=0, end=logits_2d.size()[0], device=logits_2d.device) | |||||
| predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] | |||||
| predicted_logits = predicted_logits_1d.view_as(target) | |||||
| predicted_logits[target_mask] = 0.0 | |||||
| # All reduce is needed to get the chunks from other GPUs. | |||||
| torch.distributed.all_reduce( | |||||
| predicted_logits, | |||||
| op=torch.distributed.ReduceOp.SUM, | |||||
| group=get_model_parallel_group()) | |||||
| # Loss = log(sum(exp(logits))) - predicted-logit. | |||||
| loss = torch.log(sum_exp_logits) - predicted_logits | |||||
| # Store softmax, target-mask and masked-target for backward pass. | |||||
| exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) | |||||
| ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) | |||||
| return loss | |||||
| @staticmethod | |||||
| def backward(ctx, grad_output): | |||||
| # Retreive tensors from the forward path. | |||||
| softmax, target_mask, masked_target_1d = ctx.saved_tensors | |||||
| # All the inputs have softmax as thier gradient. | |||||
| grad_input = softmax | |||||
| # For simplicity, work with the 2D gradient. | |||||
| partition_vocab_size = softmax.size()[-1] | |||||
| grad_2d = grad_input.view(-1, partition_vocab_size) | |||||
| # Add the gradient from matching classes. | |||||
| arange_1d = torch.arange( | |||||
| start=0, end=grad_2d.size()[0], device=grad_2d.device) | |||||
| grad_2d[arange_1d, | |||||
| masked_target_1d] -= (1.0 - target_mask.view(-1).float()) | |||||
| # Finally elementwise multiplication with the output gradients. | |||||
| grad_input.mul_(grad_output.unsqueeze(dim=-1)) | |||||
| return grad_input, None | |||||
| def vocab_parallel_cross_entropy(vocab_parallel_logits, target): | |||||
| """Helper function for the cross entropy.""" | |||||
| return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) | |||||
| @@ -0,0 +1,117 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import torch | |||||
| from .initialize import (get_model_parallel_group, get_model_parallel_rank, | |||||
| get_model_parallel_src_rank) | |||||
| _MAX_DATA_DIM = 5 | |||||
| def _check_data_types(keys, data, target_dtype): | |||||
| """Check that all the keys have the same target data type.""" | |||||
| for key in keys: | |||||
| assert data[key].dtype == target_dtype, '{} has data type {} which '\ | |||||
| 'is different than {}'.format(key, data[key].dtype, target_dtype) | |||||
| def _build_key_size_numel_dictionaries(keys, data): | |||||
| """Build the size on rank 0 and broadcast.""" | |||||
| max_dim = _MAX_DATA_DIM | |||||
| sizes = [0 for _ in range(max_dim) for _ in keys] | |||||
| # Pack the sizes on rank zero. | |||||
| if get_model_parallel_rank() == 0: | |||||
| offset = 0 | |||||
| for key in keys: | |||||
| assert data[key].dim( | |||||
| ) < max_dim, 'you should increase MAX_DATA_DIM' | |||||
| size = data[key].size() | |||||
| for i, s in enumerate(size): | |||||
| sizes[i + offset] = s | |||||
| offset += max_dim | |||||
| # Move to GPU and broadcast. | |||||
| sizes_cuda = torch.cuda.LongTensor(sizes) | |||||
| torch.distributed.broadcast( | |||||
| sizes_cuda, | |||||
| get_model_parallel_src_rank(), | |||||
| group=get_model_parallel_group()) | |||||
| # Move back to cpu and unpack. | |||||
| sizes_cpu = sizes_cuda.cpu() | |||||
| key_size = {} | |||||
| key_numel = {} | |||||
| total_numel = 0 | |||||
| offset = 0 | |||||
| for key in keys: | |||||
| i = 0 | |||||
| size = [] | |||||
| numel = 1 | |||||
| while sizes_cpu[offset + i] > 0: | |||||
| this_size = sizes_cpu[offset + i] | |||||
| size.append(this_size) | |||||
| numel *= this_size | |||||
| i += 1 | |||||
| key_size[key] = size | |||||
| key_numel[key] = numel | |||||
| total_numel += numel | |||||
| offset += max_dim | |||||
| return key_size, key_numel, total_numel | |||||
| def broadcast_data(keys, data, datatype): | |||||
| """Broadcast data from rank zero of each model parallel group to the | |||||
| members of the same model parallel group. | |||||
| Arguments: | |||||
| keys: list of keys in the data disctionary to be broadcasted | |||||
| data: data dictionary of string keys and cpu tensor values. | |||||
| datatype: torch data type of all tensors in data associated | |||||
| with keys. | |||||
| """ | |||||
| # Build (key, size) and (key, number of elements) dictionaries along | |||||
| # with the total number of elements on all ranks. | |||||
| key_size, key_numel, total_numel = _build_key_size_numel_dictionaries( | |||||
| keys, data) | |||||
| # Pack on rank zero. | |||||
| if get_model_parallel_rank() == 0: | |||||
| # Check that all keys have the same data type. | |||||
| _check_data_types(keys, data, datatype) | |||||
| # Flatten the data associated with the keys | |||||
| flatten_data = torch.cat( | |||||
| [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() | |||||
| else: | |||||
| flatten_data = torch.empty( | |||||
| total_numel, device=torch.cuda.current_device(), dtype=datatype) | |||||
| # Boradcast | |||||
| torch.distributed.broadcast( | |||||
| flatten_data, | |||||
| get_model_parallel_src_rank(), | |||||
| group=get_model_parallel_group()) | |||||
| # Unpack | |||||
| output = {} | |||||
| offset = 0 | |||||
| for key in keys: | |||||
| size = key_size[key] | |||||
| numel = key_numel[key] | |||||
| output[key] = flatten_data.narrow(0, offset, numel).view(size) | |||||
| offset += numel | |||||
| return output | |||||
| @@ -0,0 +1,72 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # Parts of the code here are adapted from PyTorch | |||||
| # repo: https://github.com/pytorch/pytorch | |||||
| import torch | |||||
| from torch._six import inf | |||||
| from .initialize import get_model_parallel_group, get_model_parallel_rank | |||||
| def clip_grad_norm(parameters, max_norm, norm_type=2): | |||||
| """Clips gradient norm of an iterable of parameters. | |||||
| This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and | |||||
| added functionality to handle model parallel parameters. Note that | |||||
| the gradients are modified in place. | |||||
| Arguments: | |||||
| parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a | |||||
| single Tensor that will have gradients normalized | |||||
| max_norm (float or int): max norm of the gradients | |||||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |||||
| infinity norm. | |||||
| Returns: | |||||
| Total norm of the parameters (viewed as a single vector). | |||||
| """ | |||||
| if isinstance(parameters, torch.Tensor): | |||||
| parameters = [parameters] | |||||
| parameters = list(filter(lambda p: p.grad is not None, parameters)) | |||||
| max_norm = float(max_norm) | |||||
| norm_type = float(norm_type) | |||||
| if norm_type == inf: | |||||
| total_norm = max(p.grad.data.abs().max() for p in parameters) | |||||
| total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) | |||||
| # Take max across all GPUs. | |||||
| torch.distributed.all_reduce( | |||||
| total_norm_cuda, | |||||
| op=torch.distributed.ReduceOp.MAX, | |||||
| group=get_model_parallel_group()) | |||||
| total_norm = total_norm_cuda[0].item() | |||||
| else: | |||||
| total_norm = 0 | |||||
| for p in parameters: | |||||
| if p.model_parallel or (get_model_parallel_rank() == 0): | |||||
| param_norm = p.grad.data.norm(norm_type) | |||||
| total_norm += param_norm.item()**norm_type | |||||
| # Sum across all model parallel GPUs. | |||||
| total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) | |||||
| torch.distributed.all_reduce( | |||||
| total_norm_cuda, | |||||
| op=torch.distributed.ReduceOp.SUM, | |||||
| group=get_model_parallel_group()) | |||||
| total_norm = total_norm_cuda[0].item()**(1. / norm_type) | |||||
| clip_coef = max_norm / (total_norm + 1e-6) | |||||
| if clip_coef < 1: | |||||
| for p in parameters: | |||||
| p.grad.data.mul_(clip_coef) | |||||
| return total_norm | |||||
| @@ -0,0 +1,130 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Model and data parallel groups.""" | |||||
| import torch | |||||
| from .utils import ensure_divisibility | |||||
| # Model parallel group that the current rank belongs to. | |||||
| _MODEL_PARALLEL_GROUP = None | |||||
| # Data parallel group that the current rank belongs to. | |||||
| _DATA_PARALLEL_GROUP = None | |||||
| def initialize_model_parallel(model_parallel_size_): | |||||
| """ | |||||
| Initialize model data parallel groups. | |||||
| Arguments: | |||||
| model_parallel_size: number of GPUs used to parallelize model. | |||||
| Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we | |||||
| use 2 GPUs to parallelize the model. The present function will | |||||
| create 4 model parallel groups and 2 data parallel grous as: | |||||
| 4 model parallel groups: | |||||
| [g0, g1], [g2, g3], [g4, g5], [g6, g7] | |||||
| 2 data parallel groups: | |||||
| [g0, g2, g4, g6], [g1, g3, g5, g7] | |||||
| Note that for efficiency, the caller should make sure adjacent ranks | |||||
| are on the same DGX box. For example if we are using 2 DGX-1 boxes | |||||
| with a total of 16 GPUs, rank 0 to 7 belong to the first box and | |||||
| ranks 8 to 15 belong to the second box. | |||||
| """ | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> initializing model parallel with size {}'.format( | |||||
| model_parallel_size_)) | |||||
| # Get world size and rank. Ensure some consistencies. | |||||
| assert torch.distributed.is_initialized() | |||||
| world_size = torch.distributed.get_world_size() | |||||
| model_parallel_size = min(model_parallel_size_, world_size) | |||||
| ensure_divisibility(world_size, model_parallel_size) | |||||
| rank = torch.distributed.get_rank() | |||||
| # Build the data parallel groups. | |||||
| global _DATA_PARALLEL_GROUP | |||||
| assert _DATA_PARALLEL_GROUP is None, \ | |||||
| 'data parallel group is already initialized' | |||||
| for i in range(model_parallel_size): | |||||
| ranks = range(i, world_size, model_parallel_size) | |||||
| group = torch.distributed.new_group(ranks) | |||||
| if i == (rank % model_parallel_size): | |||||
| _DATA_PARALLEL_GROUP = group | |||||
| # Build the model parallel groups. | |||||
| global _MODEL_PARALLEL_GROUP | |||||
| assert _MODEL_PARALLEL_GROUP is None, \ | |||||
| 'model parallel group is already initialized' | |||||
| for i in range(world_size // model_parallel_size): | |||||
| ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) | |||||
| group = torch.distributed.new_group(ranks) | |||||
| if i == (rank // model_parallel_size): | |||||
| _MODEL_PARALLEL_GROUP = group | |||||
| def model_parallel_is_initialized(): | |||||
| """Check if model and data parallel groups are initialized.""" | |||||
| if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: | |||||
| return False | |||||
| return True | |||||
| def get_model_parallel_group(): | |||||
| """Get the model parallel group the caller rank belongs to.""" | |||||
| assert _MODEL_PARALLEL_GROUP is not None, \ | |||||
| 'model parallel group is not initialized' | |||||
| return _MODEL_PARALLEL_GROUP | |||||
| def get_data_parallel_group(): | |||||
| """Get the data parallel group the caller rank belongs to.""" | |||||
| assert _DATA_PARALLEL_GROUP is not None, \ | |||||
| 'data parallel group is not initialized' | |||||
| return _DATA_PARALLEL_GROUP | |||||
| def get_model_parallel_world_size(): | |||||
| """Return world size for the model parallel group.""" | |||||
| return torch.distributed.get_world_size(group=get_model_parallel_group()) | |||||
| def get_model_parallel_rank(): | |||||
| """Return my rank for the model parallel group.""" | |||||
| return torch.distributed.get_rank(group=get_model_parallel_group()) | |||||
| def get_model_parallel_src_rank(): | |||||
| """Calculate the global rank corresponding to a local rank zeor | |||||
| in the model parallel group.""" | |||||
| global_rank = torch.distributed.get_rank() | |||||
| local_world_size = get_model_parallel_world_size() | |||||
| return (global_rank // local_world_size) * local_world_size | |||||
| def get_data_parallel_world_size(): | |||||
| """Return world size for the data parallel group.""" | |||||
| return torch.distributed.get_world_size(group=get_data_parallel_group()) | |||||
| def get_data_parallel_rank(): | |||||
| """Return my rank for the data parallel group.""" | |||||
| return torch.distributed.get_rank(group=get_data_parallel_group()) | |||||
| def destroy_model_parallel(): | |||||
| """Set the groups to none.""" | |||||
| global _MODEL_PARALLEL_GROUP | |||||
| _MODEL_PARALLEL_GROUP = None | |||||
| global _DATA_PARALLEL_GROUP | |||||
| _DATA_PARALLEL_GROUP = None | |||||
| @@ -0,0 +1,357 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # Parts of the code here are adapted from PyTorch | |||||
| # repo: https://github.com/pytorch/pytorch | |||||
| import math | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| import torch.nn.init as init | |||||
| from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm | |||||
| from torch.nn.parameter import Parameter | |||||
| from .initialize import get_model_parallel_rank, get_model_parallel_world_size | |||||
| from .mappings import (copy_to_model_parallel_region, | |||||
| gather_from_model_parallel_region, | |||||
| reduce_from_model_parallel_region, | |||||
| scatter_to_model_parallel_region) | |||||
| from .random import get_cuda_rng_tracker | |||||
| from .utils import VocabUtility, divide, split_tensor_along_last_dim | |||||
| def _initialize_affine_weight(weight, | |||||
| output_size, | |||||
| input_size, | |||||
| per_partition_size, | |||||
| partition_dim, | |||||
| init_method, | |||||
| stride=1, | |||||
| return_master_weight=False): | |||||
| """Initialize affine weight for model parallel. | |||||
| Build the master weight on all processes and scatter | |||||
| the relevant chunk.""" | |||||
| # If we only use 1 process for model parallelism, bypass scatter. | |||||
| world_size = get_model_parallel_world_size() | |||||
| if world_size == 1: | |||||
| init_method(weight) | |||||
| if return_master_weight: | |||||
| return weight | |||||
| return None | |||||
| # Initialize master weight | |||||
| master_weight = torch.empty( | |||||
| output_size, input_size, dtype=weight.dtype, requires_grad=False) | |||||
| init_method(master_weight) | |||||
| # Split and copy | |||||
| per_partition_per_stride_size = divide(per_partition_size, stride) | |||||
| weight_list = torch.split( | |||||
| master_weight, per_partition_per_stride_size, dim=partition_dim) | |||||
| rank = get_model_parallel_rank() | |||||
| my_weight_list = weight_list[rank::world_size] | |||||
| with torch.no_grad(): | |||||
| torch.cat(my_weight_list, dim=partition_dim, out=weight) | |||||
| if return_master_weight: | |||||
| return master_weight | |||||
| return None | |||||
| class VocabParallelEmbedding(torch.nn.Module): | |||||
| """Embedding parallelized in the vocabulary dimension. | |||||
| This is mainly adapted from torch.nn.Embedding and all the default | |||||
| values are kept. | |||||
| Arguments: | |||||
| num_embeddings: vocabulary size. | |||||
| embedding_dim: size of hidden state. | |||||
| init_method: method to initialize weights. | |||||
| """ | |||||
| def __init__(self, | |||||
| num_embeddings, | |||||
| embedding_dim, | |||||
| init_method=init.xavier_normal_): | |||||
| super(VocabParallelEmbedding, self).__init__() | |||||
| # Keep the input dimensions. | |||||
| self.num_embeddings = num_embeddings | |||||
| self.embedding_dim = embedding_dim | |||||
| # Set the detauls for compatibility. | |||||
| self.padding_idx = None | |||||
| self.max_norm = None | |||||
| self.norm_type = 2. | |||||
| self.scale_grad_by_freq = False | |||||
| self.sparse = False | |||||
| self._weight = None | |||||
| # Divide the weight matrix along the vocaburaly dimension. | |||||
| self.vocab_start_index, self.vocab_end_index = \ | |||||
| VocabUtility.vocab_range_from_global_vocab_size( | |||||
| self.num_embeddings, get_model_parallel_rank(), | |||||
| get_model_parallel_world_size()) | |||||
| self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # noqa | |||||
| # Allocate weights. | |||||
| self.weight = Parameter( | |||||
| torch.Tensor(self.num_embeddings_per_partition, | |||||
| self.embedding_dim)) | |||||
| self.weight.model_parallel = True | |||||
| # And initialize. | |||||
| _initialize_affine_weight(self.weight, self.num_embeddings, | |||||
| self.embedding_dim, | |||||
| self.num_embeddings_per_partition, 0, | |||||
| init_method) | |||||
| def forward(self, input_): | |||||
| # Build the mask. | |||||
| input_mask = (input_ < self.vocab_start_index) | \ | |||||
| (input_ >= self.vocab_end_index) | |||||
| # Mask the input. | |||||
| masked_input = input_.clone() - self.vocab_start_index | |||||
| masked_input[input_mask] = 0 | |||||
| # Get the embeddings. | |||||
| output_parallel = F.embedding(masked_input, self.weight, | |||||
| self.padding_idx, self.max_norm, | |||||
| self.norm_type, self.scale_grad_by_freq, | |||||
| self.sparse) | |||||
| # Mask the output embedding. | |||||
| output_parallel[input_mask, :] = 0.0 | |||||
| # Reduce across all the model parallel GPUs. | |||||
| output = reduce_from_model_parallel_region(output_parallel) | |||||
| return output | |||||
| class ParallelEmbedding(torch.nn.Module): | |||||
| """Embedding parallelized in the embedding dimension. | |||||
| This is mainly adapted from torch.nn.Embedding and all the default | |||||
| values are kept. | |||||
| Arguments: | |||||
| num_embeddings: vocabulary size. | |||||
| embedding_dim: size of hidden state. | |||||
| init_method: method to initialize weights. | |||||
| """ | |||||
| def __init__(self, | |||||
| num_embeddings, | |||||
| embedding_dim, | |||||
| init_method=init.xavier_normal_, | |||||
| keep_master_weight_for_test=False): | |||||
| super(ParallelEmbedding, self).__init__() | |||||
| # Keep the input dimensions. | |||||
| self.num_embeddings = num_embeddings | |||||
| self.embedding_dim = embedding_dim | |||||
| # Set some detauls for compatibility. | |||||
| self.padding_idx = None | |||||
| self.max_norm = None | |||||
| self.norm_type = 2. | |||||
| self.scale_grad_by_freq = False | |||||
| self.sparse = False | |||||
| self._weight = None | |||||
| # Divide the weight matrix along the embedding dimension. | |||||
| world_size = get_model_parallel_world_size() | |||||
| self.embedding_dim_per_partition = divide(self.embedding_dim, | |||||
| world_size) | |||||
| # Allocate weights. | |||||
| self.weight = Parameter( | |||||
| torch.Tensor(self.num_embeddings, | |||||
| self.embedding_dim_per_partition)) | |||||
| self.weight.model_parallel = True | |||||
| # And initialize. | |||||
| _initialize_affine_weight( | |||||
| self.weight, | |||||
| self.num_embeddings, | |||||
| self.embedding_dim, | |||||
| self.embedding_dim_per_partition, | |||||
| 1, | |||||
| init_method, | |||||
| stride=1, | |||||
| return_master_weight=False) | |||||
| def forward(self, input_): | |||||
| input_parallel = copy_to_model_parallel_region(input_) | |||||
| output_parallel = F.embedding(input_parallel, self.weight, | |||||
| self.padding_idx, self.max_norm, | |||||
| self.norm_type, self.scale_grad_by_freq, | |||||
| self.sparse) | |||||
| output = gather_from_model_parallel_region(output_parallel) | |||||
| return output | |||||
| class ColumnParallelLinear(torch.nn.Module): | |||||
| """Linear layer with column parallelism. | |||||
| The linear layer is defined as Y = XA + b. A is parallelized along | |||||
| its second dimension as A = [A_1, ..., A_p]. | |||||
| Arguments: | |||||
| input_size: first dimension of matrix A. | |||||
| output_size: second dimension of matrix A. | |||||
| bias: If true, add bias | |||||
| gather_output: If true, call all-gether on output and make Y avaiable | |||||
| to all GPUs, otherwise, every GPU will have its output | |||||
| which is Y_i = XA_i | |||||
| init_method: method to initialize weights. Note that bias is always set | |||||
| to zero. | |||||
| stride: For the strided linear layers. | |||||
| keep_master_weight_for_test: This was added for testing and should be | |||||
| set to False. It returns the master weights | |||||
| used for initialization. | |||||
| """ | |||||
| def __init__(self, | |||||
| input_size, | |||||
| output_size, | |||||
| bias=True, | |||||
| gather_output=True, | |||||
| init_method=init.xavier_normal_, | |||||
| stride=1, | |||||
| keep_master_weight_for_test=False): | |||||
| super(ColumnParallelLinear, self).__init__() | |||||
| # Keep input parameters | |||||
| self.input_size = input_size | |||||
| self.output_size = output_size | |||||
| self.gather_output = gather_output | |||||
| # Divide the weight matrix along the last dimension. | |||||
| world_size = get_model_parallel_world_size() | |||||
| self.output_size_per_partition = divide(output_size, world_size) | |||||
| # Parameters. | |||||
| # Note: torch.nn.functional.linear performs XA^T + b and as a result | |||||
| # we allocate the transpose. | |||||
| self.weight = Parameter( | |||||
| torch.Tensor(self.output_size_per_partition, self.input_size)) | |||||
| self.weight.model_parallel = True | |||||
| if bias: | |||||
| self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) | |||||
| self.bias.model_parallel = True | |||||
| # Always initialize bias to zero. | |||||
| with torch.no_grad(): | |||||
| self.bias.zero_() | |||||
| else: | |||||
| self.register_parameter('bias', None) | |||||
| # Initialize weight. | |||||
| self.master_weight = _initialize_affine_weight( | |||||
| self.weight, | |||||
| self.output_size, | |||||
| self.input_size, | |||||
| self.output_size_per_partition, | |||||
| 0, | |||||
| init_method, | |||||
| stride=stride, | |||||
| return_master_weight=keep_master_weight_for_test) | |||||
| def forward(self, input_): | |||||
| # Set up backprop all-reduce. | |||||
| input_parallel = copy_to_model_parallel_region(input_) | |||||
| # Matrix multiply. | |||||
| output_parallel = F.linear(input_parallel, self.weight, self.bias) | |||||
| if self.gather_output: | |||||
| # All-gather across the partitions. | |||||
| output = gather_from_model_parallel_region(output_parallel) | |||||
| else: | |||||
| output = output_parallel | |||||
| return output | |||||
| class RowParallelLinear(torch.nn.Module): | |||||
| """Linear layer with row parallelism. | |||||
| The linear layer is defined as Y = XA + b. A is parallelized along | |||||
| its first dimension and X along its second dimension as: | |||||
| - - | |||||
| | A_1 | | |||||
| | . | | |||||
| A = | . | X = [X_1, ..., X_p] | |||||
| | . | | |||||
| | A_p | | |||||
| - - | |||||
| Arguments: | |||||
| input_size: first dimension of matrix A. | |||||
| output_size: second dimension of matrix A. | |||||
| bias: If true, add bias. Note that bias is not parallelized. | |||||
| input_is_parallel: If true, we assume that the input is already | |||||
| split across the GPUs and we do not split | |||||
| again. | |||||
| init_method: method to initialize weights. Note that bias is always set | |||||
| to zero. | |||||
| stride: For the strided linear layers. | |||||
| keep_master_weight_for_test: This was added for testing and should be | |||||
| set to False. It returns the master weights | |||||
| used for initialization. | |||||
| """ | |||||
| def __init__(self, | |||||
| input_size, | |||||
| output_size, | |||||
| bias=True, | |||||
| input_is_parallel=False, | |||||
| init_method=init.xavier_normal_, | |||||
| stride=1, | |||||
| keep_master_weight_for_test=False): | |||||
| super(RowParallelLinear, self).__init__() | |||||
| # Keep input parameters | |||||
| self.input_size = input_size | |||||
| self.output_size = output_size | |||||
| self.input_is_parallel = input_is_parallel | |||||
| # Divide the weight matrix along the last dimension. | |||||
| world_size = get_model_parallel_world_size() | |||||
| self.input_size_per_partition = divide(input_size, world_size) | |||||
| # Parameters. | |||||
| # Note: torch.nn.functional.linear performs XA^T + b and as a result | |||||
| # we allocate the transpose. | |||||
| self.weight = Parameter( | |||||
| torch.Tensor(self.output_size, self.input_size_per_partition)) | |||||
| self.weight.model_parallel = True | |||||
| if bias: | |||||
| self.bias = Parameter(torch.Tensor(self.output_size)) | |||||
| # Always initialize bias to zero. | |||||
| with torch.no_grad(): | |||||
| self.bias.zero_() | |||||
| else: | |||||
| self.register_parameter('bias', None) | |||||
| # Initialize weight. | |||||
| self.master_weight = _initialize_affine_weight( | |||||
| self.weight, | |||||
| self.output_size, | |||||
| self.input_size, | |||||
| self.input_size_per_partition, | |||||
| 1, | |||||
| init_method, | |||||
| stride=stride, | |||||
| return_master_weight=keep_master_weight_for_test) | |||||
| def forward(self, input_): | |||||
| # Set up backprop all-reduce. | |||||
| if self.input_is_parallel: | |||||
| input_parallel = input_ | |||||
| else: | |||||
| input_parallel = scatter_to_model_parallel_region(input_) | |||||
| # Matrix multiply. | |||||
| output_parallel = F.linear(input_parallel, self.weight) | |||||
| # All-reduce across all the partitions. | |||||
| output_ = reduce_from_model_parallel_region(output_parallel) | |||||
| if self.bias is not None: | |||||
| output = output_ + self.bias | |||||
| else: | |||||
| output = output_ | |||||
| return output | |||||
| @@ -0,0 +1,144 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import torch | |||||
| from .initialize import get_model_parallel_group | |||||
| from .utils import split_tensor_along_last_dim | |||||
| def _reduce(input_): | |||||
| """All-reduce the the input tensor across model parallel group.""" | |||||
| group = get_model_parallel_group() | |||||
| # Bypass the function if we are using only 1 GPU. | |||||
| if torch.distributed.get_world_size(group=group) == 1: | |||||
| return input_ | |||||
| # All-reduce. | |||||
| torch.distributed.all_reduce(input_, group=group) | |||||
| return input_ | |||||
| def _split(input_): | |||||
| """Split the tensor along its last dimension and keep the | |||||
| corresponding slice.""" | |||||
| group = get_model_parallel_group() | |||||
| # Bypass the function if we are using only 1 GPU. | |||||
| if torch.distributed.get_world_size(group=group) == 1: | |||||
| return input_ | |||||
| # Split along last dimension. | |||||
| world_size = torch.distributed.get_world_size(group=group) | |||||
| input_list = split_tensor_along_last_dim(input_, world_size) | |||||
| # Note: torch.split does not create contiguous tensors by default. | |||||
| rank = torch.distributed.get_rank(group=group) | |||||
| output = input_list[rank].contiguous() | |||||
| return output | |||||
| def _gather(input_): | |||||
| """Gather tensors and concatinate along the last dimension.""" | |||||
| group = get_model_parallel_group() | |||||
| # Bypass the function if we are using only 1 GPU. | |||||
| if torch.distributed.get_world_size(group=group) == 1: | |||||
| return input_ | |||||
| # Size and dimension. | |||||
| last_dim = input_.dim() - 1 | |||||
| rank = torch.distributed.get_rank(group=group) | |||||
| world_size = torch.distributed.get_world_size(group=group) | |||||
| tensor_list = [torch.empty_like(input_) for _ in range(world_size)] | |||||
| tensor_list[rank] = input_ | |||||
| torch.distributed.all_gather(tensor_list, input_, group=group) | |||||
| # Note: torch.cat already creates a contiguous tensor. | |||||
| output = torch.cat(tensor_list, dim=last_dim).contiguous() | |||||
| return output | |||||
| class _CopyToModelParallelRegion(torch.autograd.Function): | |||||
| """Pass the input to the model parallel region.""" | |||||
| @staticmethod | |||||
| def forward(ctx, input_): | |||||
| return input_ | |||||
| @staticmethod | |||||
| def backward(ctx, grad_output): | |||||
| return _reduce(grad_output) | |||||
| class _ReduceFromModelParallelRegion(torch.autograd.Function): | |||||
| """All-redcue the input from the model parallel region.""" | |||||
| @staticmethod | |||||
| def forward(ctx, input_): | |||||
| return _reduce(input_) | |||||
| @staticmethod | |||||
| def backward(ctx, grad_output): | |||||
| return grad_output | |||||
| class _ScatterToModelParallelRegion(torch.autograd.Function): | |||||
| """Split the input and keep only the corresponding chuck to the rank.""" | |||||
| @staticmethod | |||||
| def forward(ctx, input_): | |||||
| return _split(input_) | |||||
| @staticmethod | |||||
| def backward(ctx, grad_output): | |||||
| return _gather(grad_output) | |||||
| class _GatherFromModelParallelRegion(torch.autograd.Function): | |||||
| """Gather the input from model parallel region and concatinate.""" | |||||
| @staticmethod | |||||
| def forward(ctx, input_): | |||||
| return _gather(input_) | |||||
| @staticmethod | |||||
| def backward(ctx, grad_output): | |||||
| return _split(grad_output) | |||||
| # ----------------- | |||||
| # Helper functions. | |||||
| # ----------------- | |||||
| def copy_to_model_parallel_region(input_): | |||||
| return _CopyToModelParallelRegion.apply(input_) | |||||
| def reduce_from_model_parallel_region(input_): | |||||
| return _ReduceFromModelParallelRegion.apply(input_) | |||||
| def scatter_to_model_parallel_region(input_): | |||||
| return _ScatterToModelParallelRegion.apply(input_) | |||||
| def gather_from_model_parallel_region(input_): | |||||
| return _GatherFromModelParallelRegion.apply(input_) | |||||
| @@ -0,0 +1,408 @@ | |||||
| # Modified by Samyam Rajbhandari | |||||
| # Used to partition the activations stored for backward propagation | |||||
| # Therefore reduces the memory consumption | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # Parts of the code here are adapted from PyTorch | |||||
| # repo: https://github.com/pytorch/pytorch | |||||
| import contextlib | |||||
| import torch | |||||
| import torch.distributed as dist | |||||
| from torch import _C | |||||
| from torch.cuda import _lazy_call | |||||
| from torch.cuda import device as device_ctx_manager | |||||
| from .initialize import (get_data_parallel_rank, get_model_parallel_group, | |||||
| get_model_parallel_rank, | |||||
| get_model_parallel_world_size) | |||||
| # from torch.utils.checkpoint import detach_variable | |||||
| PARTITION_ACTIVATIONS = False | |||||
| PA_CORRECTNESS_TEST = False | |||||
| def see_memory_usage(message, force=False): | |||||
| if not force: | |||||
| return | |||||
| dist.barrier() | |||||
| if dist.get_rank() == 0: | |||||
| print(message) | |||||
| print('Memory Allocated ', | |||||
| torch.cuda.memory_allocated() / (1024 * 1024 * 1024), | |||||
| 'GigaBytes') | |||||
| print('Max Memory Allocated ', | |||||
| torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), | |||||
| 'GigaBytes') | |||||
| print('Cache Allocated ', | |||||
| torch.cuda.memory_cached() / (1024 * 1024 * 1024), 'GigaBytes') | |||||
| print('Max cache Allocated ', | |||||
| torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), | |||||
| 'GigaBytes') | |||||
| print(' ') | |||||
| # input("Press Any Key To Continue ..") | |||||
| mp_rank = None # get_model_parallel_rank() | |||||
| mp_size = None # get_model_parallel_world_size() | |||||
| mp_group = None # get_model_parallel_group() | |||||
| # Default name for the model parallel rng tracker. | |||||
| _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' | |||||
| transport_stream = None | |||||
| cuda_device = None | |||||
| def detach_variable(inputs, device=None): | |||||
| if isinstance(inputs, tuple): | |||||
| out = [] | |||||
| for inp in inputs: | |||||
| if not isinstance(inp, torch.Tensor): | |||||
| out.append(inp) | |||||
| continue | |||||
| requires_grad = inp.requires_grad | |||||
| if device is not None: | |||||
| x = inp.to(device=device) | |||||
| else: | |||||
| x = inp | |||||
| x = x.detach() | |||||
| x.requires_grad = requires_grad | |||||
| out.append(x) | |||||
| return tuple(out) | |||||
| else: | |||||
| raise RuntimeError( | |||||
| 'Only tuple of tensors is supported. Got Unsupported input type: ', | |||||
| type(inputs).__name__) | |||||
| def _set_cuda_rng_state(new_state, device=-1): | |||||
| """Sets the random number generator state of the current GPU. | |||||
| Argumentss: | |||||
| new_state (torch.ByteTensor): The desired state | |||||
| This function is adapted from PyTorch repo (torch.cuda.set_rng_state) | |||||
| with a single change: the input state is not cloned. Cloning caused | |||||
| major performance issues for +4 GPU cases. | |||||
| """ | |||||
| if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): | |||||
| # older PyTorch | |||||
| def cb(): | |||||
| with device_ctx_manager(device): | |||||
| _C._cuda_setRNGState(new_state) | |||||
| else: | |||||
| # newer PyTorch | |||||
| if device == -1: | |||||
| device = torch.device('cuda') | |||||
| elif isinstance(device, str): | |||||
| device = torch.device(device) | |||||
| elif isinstance(device, int): | |||||
| device = torch.device('cuda', device) | |||||
| def cb(): | |||||
| idx = device.index | |||||
| if idx is None: | |||||
| idx = torch.cuda.current_device() | |||||
| default_generator = torch.cuda.default_generators[idx] | |||||
| default_generator.set_state(new_state) | |||||
| _lazy_call(cb) | |||||
| class CudaRNGStatesTracker: | |||||
| """Tracker for the cuda RNG states. | |||||
| Using the `add` method, a cuda rng state is initialized based on | |||||
| the input `seed` and is assigned to `name`. Later, by forking the | |||||
| rng state, we can perform operations and return to our starting | |||||
| cuda state. | |||||
| """ | |||||
| def __init__(self): | |||||
| # Map from a string name to the cuda rng state. | |||||
| self.states_ = {} | |||||
| # Seeds are just for book keeping and ensure no seed is set twice. | |||||
| self.seeds_ = set() | |||||
| def reset(self): | |||||
| """Set to the initial state (no tracker).""" | |||||
| self.states_ = {} | |||||
| self.seeds_ = set() | |||||
| def get_states(self): | |||||
| """Get rng states. Copy the dictionary so we have direct | |||||
| pointers to the states, not just a pointer to the dictionary.""" | |||||
| states = {} | |||||
| for name in self.states_: | |||||
| states[name] = self.states_[name] | |||||
| return states | |||||
| def set_states(self, states): | |||||
| """Set the rng states. For efficiency purposes, we do not check | |||||
| the size of seed for compatibility.""" | |||||
| self.states_ = states | |||||
| def add(self, name, seed): | |||||
| """Track the rng state.""" | |||||
| # Check seed is not already used. | |||||
| if seed in self.seeds_: | |||||
| raise Exception('seed {} already exists'.format(seed)) | |||||
| self.seeds_.add(seed) | |||||
| # Check that state is not already defined. | |||||
| if name in self.states_: | |||||
| raise Exception('cuda rng state {} already exists'.format(name)) | |||||
| # Get the current rng state. | |||||
| orig_rng_state = torch.cuda.get_rng_state() | |||||
| # Set the new state and store it. | |||||
| torch.cuda.manual_seed(seed) | |||||
| self.states_[name] = torch.cuda.get_rng_state() | |||||
| # Reset rng state to what it was. | |||||
| _set_cuda_rng_state(orig_rng_state) | |||||
| @contextlib.contextmanager | |||||
| def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): | |||||
| """Fork the cuda rng state, perform operations, and exit with | |||||
| the original state.""" | |||||
| # Check if we have added the state | |||||
| if name not in self.states_: | |||||
| raise Exception('cuda rng state {} is not added'.format(name)) | |||||
| # Store current rng state. | |||||
| orig_cuda_rng_state = torch.cuda.get_rng_state() | |||||
| # Set rng state to the desired one | |||||
| _set_cuda_rng_state(self.states_[name]) | |||||
| # Do the stuff we wanted to do. | |||||
| try: | |||||
| yield | |||||
| finally: | |||||
| # Update the current rng state for later use. | |||||
| self.states_[name] = torch.cuda.get_rng_state() | |||||
| # And set the state to the original state we started with. | |||||
| _set_cuda_rng_state(orig_cuda_rng_state) | |||||
| # RNG tracker object. | |||||
| _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() | |||||
| def get_cuda_rng_tracker(): | |||||
| """Get cuda rng tracker.""" | |||||
| return _CUDA_RNG_STATE_TRACKER | |||||
| def model_parallel_cuda_manual_seed(seed): | |||||
| """Initialize model parallel cuda seed. | |||||
| This function should be called after the model parallel is | |||||
| initialized. Also, no torch.cuda.manual_seed should be called | |||||
| after this function. Basically, this is replacement for that | |||||
| function. | |||||
| Two set of RNG states are tracked: | |||||
| default state: This is for data parallelism and is the same among a | |||||
| set of model parallel GPUs but different across | |||||
| different model paralle groups. This is used for | |||||
| example for dropout in the non-model-parallel regions. | |||||
| model-parallel state: This state is different among a set of model | |||||
| parallel GPUs, but the same across data parallel | |||||
| groups. This is used for example for dropout in | |||||
| model parallel regions. | |||||
| """ | |||||
| # 2718 is just for fun and any POSITIVE value will work. | |||||
| offset = seed + 2718 | |||||
| model_parallel_seed = offset + get_model_parallel_rank() | |||||
| # Data parallel gets the original sedd. | |||||
| data_parallel_seed = seed | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print( | |||||
| '> initializing model parallel cuda seeds on global rank {}, ' | |||||
| 'model parallel rank {}, and data parallel rank {} with ' | |||||
| 'model parallel seed: {} and data parallel seed: {}'.format( | |||||
| torch.distributed.get_rank(), get_model_parallel_rank(), | |||||
| get_data_parallel_rank(), model_parallel_seed, | |||||
| data_parallel_seed), | |||||
| flush=True) | |||||
| _CUDA_RNG_STATE_TRACKER.reset() | |||||
| # Set the default state. | |||||
| torch.cuda.manual_seed(data_parallel_seed) | |||||
| # and model parallel state. | |||||
| _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, | |||||
| model_parallel_seed) | |||||
| def get_partition_start(item): | |||||
| global mp_rank, mp_size, mp_group | |||||
| partition_size = get_partition_size(item) | |||||
| start = partition_size * mp_rank | |||||
| return int(start) | |||||
| def get_partition_size(item): | |||||
| global mp_rank, mp_size, mp_group | |||||
| size = item.numel() | |||||
| partition_size = size / mp_size | |||||
| return int(partition_size) | |||||
| def get_full_inputs(tensors): | |||||
| inputs = [] | |||||
| for i in range(int(len(tensors) / 2) - 1): | |||||
| item = tensors[2 * i] | |||||
| size = tensors[2 * i + 1] | |||||
| partition_size = item.numel() | |||||
| tensor_size = partition_size * mp_size | |||||
| flat_tensor = torch.zeros([tensor_size], | |||||
| dtype=item.dtype, | |||||
| device=item.device) | |||||
| partitions = [] | |||||
| for i in range(mp_size): | |||||
| part_i = flat_tensor.narrow(0, partition_size * i, partition_size) | |||||
| if i == mp_rank: | |||||
| part_i.copy_(item) | |||||
| partitions.append(part_i) | |||||
| dist.all_gather(partitions, partitions[mp_rank], group=mp_group) | |||||
| input_tensor = flat_tensor.view(list(size.numpy())) | |||||
| item.data = input_tensor.data | |||||
| inputs.append(item) | |||||
| inputs.append(tensors[-2]) | |||||
| return tuple(inputs) | |||||
| class CheckpointFunction(torch.autograd.Function): | |||||
| """This function is adapted from torch.utils.checkpoint with | |||||
| two main changes: | |||||
| 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` | |||||
| 2) the states in the model parallel tracker are also properly | |||||
| tracked/set/reset. | |||||
| """ | |||||
| @staticmethod | |||||
| def forward(ctx, run_function, *args): | |||||
| ctx.run_function = run_function | |||||
| global mp_rank, mp_size, mp_group | |||||
| if mp_rank is None: | |||||
| mp_rank = get_model_parallel_rank() | |||||
| mp_size = get_model_parallel_world_size() | |||||
| mp_group = get_model_parallel_group() | |||||
| global cuda_device, transport_stream, PARTITION_ACTIVATIONS | |||||
| if cuda_device is None: | |||||
| if dist.get_rank() == 0: | |||||
| print( | |||||
| f'Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}' | |||||
| ) | |||||
| cuda_device = torch.cuda.current_device() | |||||
| # The transport stream is used to overlap the allgather communication for the activations | |||||
| # with the computation in the backward pass | |||||
| transport_stream = torch.cuda.Stream(device=cuda_device) | |||||
| if PARTITION_ACTIVATIONS: | |||||
| inputs = [ | |||||
| item.detach().contiguous().view(-1).narrow( | |||||
| 0, get_partition_start(item), | |||||
| get_partition_size(item)).clone() for item in args[:-1] | |||||
| ] | |||||
| inputs.append(args[-1]) | |||||
| # just in case something funky is happening such as reuse of inputs | |||||
| inputs_cuda = [item.to(cuda_device) for item in args] | |||||
| # Copy the rng states. | |||||
| ctx.fwd_cpu_rng_state = torch.get_rng_state() | |||||
| ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() | |||||
| ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() | |||||
| # ctx.save_for_backward(*args) | |||||
| with torch.no_grad(): | |||||
| outputs = run_function(*inputs_cuda) | |||||
| del inputs_cuda | |||||
| if PARTITION_ACTIVATIONS: | |||||
| new_args = [] | |||||
| for arg, inp in zip(args, inputs): | |||||
| size = torch.tensor(arg.size()) | |||||
| arg.data = inp.data | |||||
| new_args.append(arg) | |||||
| new_args.append(size) | |||||
| ctx.save_for_backward(*new_args) | |||||
| else: | |||||
| ctx.save_for_backward(*args) | |||||
| return outputs | |||||
| @staticmethod | |||||
| def backward(ctx, *args): | |||||
| if not torch.autograd._is_checkpoint_valid(): | |||||
| raise RuntimeError('Checkpointing is not compatible with .grad(), ' | |||||
| 'please use .backward() if possible') | |||||
| global cuda_device, transport_stream, PARTITION_ACTIVATIONS | |||||
| if PARTITION_ACTIVATIONS: | |||||
| with torch.cuda.stream(transport_stream): | |||||
| inputs = get_full_inputs(ctx.saved_tensors) | |||||
| detached_inputs = detach_variable(inputs) | |||||
| else: | |||||
| inputs = ctx.saved_tensors | |||||
| detached_inputs = detach_variable(inputs) | |||||
| # Store the current states. | |||||
| bwd_cpu_rng_state = torch.get_rng_state() | |||||
| bwd_cuda_rng_state = torch.cuda.get_rng_state() | |||||
| bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() | |||||
| # Set the states to what it used to be before the forward pass. | |||||
| torch.set_rng_state(ctx.fwd_cpu_rng_state) | |||||
| _set_cuda_rng_state(ctx.fwd_cuda_rng_state) | |||||
| get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) | |||||
| if PARTITION_ACTIVATIONS: | |||||
| current_stream = torch.cuda.current_stream() | |||||
| current_stream.wait_stream(transport_stream) | |||||
| with torch.enable_grad(): | |||||
| outputs = ctx.run_function(*detached_inputs) | |||||
| # Set the states back to what it was at the start of this function. | |||||
| torch.set_rng_state(bwd_cpu_rng_state) | |||||
| _set_cuda_rng_state(bwd_cuda_rng_state) | |||||
| get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) | |||||
| if isinstance(outputs, torch.Tensor): | |||||
| outputs = (outputs, ) | |||||
| torch.autograd.backward(outputs, args) | |||||
| return (None, ) + tuple(inp.grad for inp in detached_inputs) | |||||
| def checkpoint(function, *args): | |||||
| """Checkpoint a model or part of the model. | |||||
| This has been directly copied from torch.utils.checkpoint.""" | |||||
| return CheckpointFunction.apply(function, *args) | |||||
| def partition_activations_in_checkpoint(partition_activation): | |||||
| global PARTITION_ACTIVATIONS | |||||
| PARTITION_ACTIVATIONS = partition_activation | |||||
| if dist.get_rank() == 0: | |||||
| print( | |||||
| f'**************Partition Activations {PARTITION_ACTIVATIONS}************' | |||||
| ) | |||||
| @@ -0,0 +1,86 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import argparse | |||||
| import os | |||||
| import random | |||||
| import mpu | |||||
| import numpy | |||||
| import torch | |||||
| class IdentityLayer(torch.nn.Module): | |||||
| def __init__(self, size, scale=1.0): | |||||
| super(IdentityLayer, self).__init__() | |||||
| self.weight = torch.nn.Parameter(scale * torch.randn(size)) | |||||
| def forward(self): | |||||
| return self.weight | |||||
| def set_random_seed(seed): | |||||
| """Set random seed for reproducability.""" | |||||
| random.seed(seed) | |||||
| numpy.random.seed(seed) | |||||
| torch.manual_seed(seed) | |||||
| mpu.model_parallel_cuda_manual_seed(seed) | |||||
| def initialize_distributed(backend='nccl'): | |||||
| """Initialize torch.distributed.""" | |||||
| # Get local rank in case it is provided. | |||||
| parser = argparse.ArgumentParser() | |||||
| parser.add_argument( | |||||
| '--local_rank', | |||||
| type=int, | |||||
| default=None, | |||||
| help='local rank passed from distributed launcher') | |||||
| args = parser.parse_args() | |||||
| local_rank = args.local_rank | |||||
| # Get rank and world size. | |||||
| rank = int(os.getenv('RANK', '0')) | |||||
| world_size = int(os.getenv('WORLD_SIZE', '1')) | |||||
| print('> initializing torch.distributed with local rank: {}, ' | |||||
| 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) | |||||
| # Set the device id. | |||||
| device = rank % torch.cuda.device_count() | |||||
| if local_rank is not None: | |||||
| device = local_rank | |||||
| torch.cuda.set_device(device) | |||||
| # Call the init process. | |||||
| init_method = 'tcp://' | |||||
| master_ip = os.getenv('MASTER_ADDR', 'localhost') | |||||
| master_port = os.getenv('MASTER_PORT', '6000') | |||||
| init_method += master_ip + ':' + master_port | |||||
| torch.distributed.init_process_group( | |||||
| backend=backend, | |||||
| world_size=world_size, | |||||
| rank=rank, | |||||
| init_method=init_method) | |||||
| def print_separator(message): | |||||
| torch.distributed.barrier() | |||||
| filler_len = (78 - len(message)) // 2 | |||||
| filler = '-' * filler_len | |||||
| string = '\n' + filler + ' {} '.format(message) + filler | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print(string, flush=True) | |||||
| torch.distributed.barrier() | |||||
| @@ -0,0 +1,106 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import random | |||||
| import sys | |||||
| import mpu | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from commons import (IdentityLayer, initialize_distributed, print_separator, | |||||
| set_random_seed) | |||||
| from mpu.cross_entropy import vocab_parallel_cross_entropy | |||||
| sys.path.append('../..') | |||||
| def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, | |||||
| seed): | |||||
| set_random_seed(seed) | |||||
| identity = IdentityLayer((batch_size, seq_length, vocab_size), | |||||
| scale=logits_scale).cuda() | |||||
| logits = identity() | |||||
| target = torch.cuda.LongTensor(size=(batch_size, | |||||
| seq_length)).random_(0, vocab_size) | |||||
| loss = F.cross_entropy( | |||||
| logits.view(-1, | |||||
| logits.size()[-1]), target.view(-1), | |||||
| reduction='none').view_as(target).mean() | |||||
| loss.backward() | |||||
| return loss, identity.weight.grad | |||||
| def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): | |||||
| set_random_seed(seed) | |||||
| identity = IdentityLayer((batch_size, seq_length, vocab_size), | |||||
| scale=logits_scale).cuda() | |||||
| logits = identity() | |||||
| logits_parallel = mpu.scatter_to_model_parallel_region(logits) | |||||
| target = torch.cuda.LongTensor(size=(batch_size, | |||||
| seq_length)).random_(0, vocab_size) | |||||
| loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() | |||||
| loss.backward() | |||||
| return loss, identity.weight.grad | |||||
| def test_cross_entropy(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing cross entropy with model parallel size {} ...'.format( | |||||
| model_parallel_size)) | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| batch_size = 13 | |||||
| seq_length = 17 | |||||
| vocab_size_per_partition = 11 | |||||
| logits_scale = 1000.0 | |||||
| vocab_size = vocab_size_per_partition * model_parallel_size | |||||
| seed = 1234 | |||||
| loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, | |||||
| vocab_size, logits_scale, | |||||
| seed) | |||||
| loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size, | |||||
| logits_scale, seed) | |||||
| error = loss_torch.sub_(loss_mpu).abs().max() | |||||
| print(' max error in loss on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| error = grad_torch.sub_(grad_mpu).abs().max() | |||||
| print(' max error in grad on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('>> passed the test :-)') | |||||
| if __name__ == '__main__': | |||||
| initialize_distributed() | |||||
| world_size = torch.distributed.get_world_size() | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| print_separator('test cross entropy') | |||||
| test_cross_entropy(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| @@ -0,0 +1,91 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import functools | |||||
| import operator | |||||
| import sys | |||||
| import mpu | |||||
| import torch | |||||
| from commons import initialize_distributed, print_separator | |||||
| from mpu import data as data_utils | |||||
| sys.path.append('../..') | |||||
| def test_boradcast_data(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print( | |||||
| '> testing boradcast_data with model parallel size {} ...'.format( | |||||
| model_parallel_size)) | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| torch.manual_seed(1234 + mpu.get_data_parallel_rank()) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| key_size_t = { | |||||
| 'key1': [7, 11], | |||||
| 'key2': [8, 2, 1], | |||||
| 'key3': [13], | |||||
| 'key4': [5, 1, 2], | |||||
| 'key5': [5, 12] | |||||
| } | |||||
| keys = list(key_size_t.keys()) | |||||
| data = {} | |||||
| data_t = {} | |||||
| for key in key_size_t: | |||||
| data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) | |||||
| data_t[key] = data[key].clone() | |||||
| data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) | |||||
| data_t['keyX'] = data['keyX'].clone() | |||||
| if mpu.get_model_parallel_rank() != 0: | |||||
| data = None | |||||
| data_utils._check_data_types(keys, data_t, torch.int64) | |||||
| key_size, key_numel, \ | |||||
| total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) | |||||
| for key in keys: | |||||
| assert key_size[key] == key_size_t[key] | |||||
| total_numel_t = 0 | |||||
| for key in keys: | |||||
| target_size = functools.reduce(operator.mul, key_size_t[key], 1) | |||||
| assert key_numel[key] == target_size | |||||
| total_numel_t += target_size | |||||
| assert total_numel == total_numel_t | |||||
| data_b = data_utils.broadcast_data(keys, data, torch.int64) | |||||
| for key in keys: | |||||
| tensor = data_t[key].cuda() | |||||
| assert data_b[key].sub(tensor).abs().max() == 0 | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('>> passed the test :-)') | |||||
| if __name__ == '__main__': | |||||
| initialize_distributed() | |||||
| world_size = torch.distributed.get_world_size() | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| print_separator('test test boradcast data') | |||||
| test_boradcast_data(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| @@ -0,0 +1,95 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import sys | |||||
| import mpu | |||||
| import torch | |||||
| from commons import initialize_distributed, print_separator | |||||
| sys.path.append('../..') | |||||
| def test_initialize_model_parallel(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing initialize_model_parallel with size {} ...'.format( | |||||
| model_parallel_size)) | |||||
| model_parallel_size_ = min(model_parallel_size, | |||||
| torch.distributed.get_world_size()) | |||||
| assert not mpu.model_parallel_is_initialized() | |||||
| mpu.initialize_model_parallel(model_parallel_size_) | |||||
| assert mpu.model_parallel_is_initialized() | |||||
| # Checks. | |||||
| def check(group, world_size, rank): | |||||
| assert world_size == torch.distributed.get_world_size(group=group) | |||||
| assert rank == torch.distributed.get_rank(group=group) | |||||
| # Model parallel. | |||||
| world_size = model_parallel_size_ | |||||
| rank = torch.distributed.get_rank() % model_parallel_size_ | |||||
| assert world_size == mpu.get_model_parallel_world_size() | |||||
| assert rank == mpu.get_model_parallel_rank() | |||||
| check(mpu.get_model_parallel_group(), world_size, rank) | |||||
| # Data parallel. | |||||
| world_size = torch.distributed.get_world_size() // model_parallel_size_ | |||||
| rank = torch.distributed.get_rank() // model_parallel_size | |||||
| assert world_size == mpu.get_data_parallel_world_size() | |||||
| assert rank == mpu.get_data_parallel_rank() | |||||
| check(mpu.get_data_parallel_group(), world_size, rank) | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('>> passed the test :-)') | |||||
| def test_get_model_parallel_src_rank(model_parallel_size_): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing get_model_parallel_src_rank with size {} ...'.format( | |||||
| model_parallel_size_)) | |||||
| model_parallel_size = min(model_parallel_size_, | |||||
| torch.distributed.get_world_size()) | |||||
| assert not mpu.model_parallel_is_initialized() | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| assert mpu.model_parallel_is_initialized() | |||||
| # Checks | |||||
| src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() | |||||
| assert mpu.get_model_parallel_src_rank() == src_rank | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('>> passed the test :-)') | |||||
| if __name__ == '__main__': | |||||
| initialize_distributed() | |||||
| world_size = torch.distributed.get_world_size() | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| print_separator('test initialize model parallel') | |||||
| test_initialize_model_parallel(model_parallel_size) | |||||
| print_separator('test model parallel source rank') | |||||
| test_get_model_parallel_src_rank(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| @@ -0,0 +1,533 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import random | |||||
| import sys | |||||
| import mpu | |||||
| import torch | |||||
| import torch.nn.init as init | |||||
| from commons import initialize_distributed, print_separator, set_random_seed | |||||
| from mpu import layers | |||||
| from torch.nn.parameter import Parameter | |||||
| sys.path.append('../..') | |||||
| def test_parallel_embedding(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing parallel embedding with model parallel size {} ...'. | |||||
| format(model_parallel_size)) | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| batch_size = 17 | |||||
| seq_length = 23 | |||||
| vocab_size = 48 | |||||
| hidden_size = 16 | |||||
| seed = 1236 | |||||
| set_random_seed(123) | |||||
| input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( | |||||
| 0, vocab_size).cuda() | |||||
| loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() | |||||
| set_random_seed(seed) | |||||
| embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() | |||||
| output = embedding_original(input_data) | |||||
| loss_original = torch.mul(output, loss_weight).sum() | |||||
| loss_original.backward() | |||||
| set_random_seed(seed) | |||||
| embedding_parallel = layers.ParallelEmbedding( | |||||
| vocab_size, hidden_size, init_method=init.normal_).cuda() | |||||
| output = embedding_parallel(input_data) | |||||
| loss_parallel = torch.mul(output, loss_weight).sum() | |||||
| loss_parallel.backward() | |||||
| set_random_seed(seed) | |||||
| embedding_vocab_parallel = layers.VocabParallelEmbedding( | |||||
| vocab_size, hidden_size, init_method=init.normal_).cuda() | |||||
| output = embedding_vocab_parallel(input_data) | |||||
| loss_vocab_parallel = torch.mul(output, loss_weight).sum() | |||||
| loss_vocab_parallel.backward() | |||||
| torch.distributed.barrier() | |||||
| error = loss_parallel.sub(loss_original).abs() | |||||
| print(' error in loss (parallel) on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-12, 'error: {}'.format(error) | |||||
| torch.distributed.barrier() | |||||
| error = loss_vocab_parallel.sub(loss_original).abs() | |||||
| print(' error in loss (vocab parallel) on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-12, 'error: {}'.format(error) | |||||
| weight_grad_orig = torch.split(embedding_original.weight.grad, | |||||
| hidden_size // model_parallel_size, | |||||
| 1)[mpu.get_model_parallel_rank()] | |||||
| error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() | |||||
| print(' error in grad (parallel) on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-12, 'error: {}'.format(error) | |||||
| weight_grad_orig = torch.split(embedding_original.weight.grad, | |||||
| vocab_size // model_parallel_size, | |||||
| 0)[mpu.get_model_parallel_rank()] | |||||
| error = embedding_vocab_parallel.weight.grad.sub( | |||||
| weight_grad_orig).abs().max() | |||||
| print(' error in grad (vocab parallel) on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-12, 'error: {}'.format(error) | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('>> passed the test :-)') | |||||
| def test_initialize_affine_weight(model_parallel_size): | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing initialize_affine_weight with model parallel ' | |||||
| 'size: {}'.format(model_parallel_size)) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| seed = 12345 | |||||
| input_size_coeff = 13 | |||||
| input_size = input_size_coeff * model_parallel_size | |||||
| output_size_coeff = 17 | |||||
| output_size = output_size_coeff * model_parallel_size | |||||
| # --------------- | |||||
| # Column parallel | |||||
| # --------------- | |||||
| weight = torch.empty(output_size_coeff, input_size) | |||||
| set_random_seed(seed) | |||||
| layers._initialize_affine_weight(weight, output_size, input_size, | |||||
| output_size_coeff, 0, | |||||
| torch.nn.init.normal_) | |||||
| # Target. | |||||
| set_random_seed(seed) | |||||
| master_weight = torch.empty(output_size, input_size) | |||||
| torch.nn.init.normal_(master_weight) | |||||
| rank = mpu.get_model_parallel_rank() | |||||
| my_weight = torch.split( | |||||
| master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() | |||||
| # Compare. | |||||
| error = weight.sub(my_weight).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' column parallel max error (should be zero) on global rank ' | |||||
| '{}: {}'.format(torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| # ------------ | |||||
| # Row parallel | |||||
| # ------------ | |||||
| weight = torch.empty(output_size, input_size_coeff) | |||||
| set_random_seed(seed) | |||||
| mpu.layers._initialize_affine_weight(weight, output_size, input_size, | |||||
| input_size_coeff, 1, | |||||
| torch.nn.init.normal_) | |||||
| # Target. | |||||
| set_random_seed(seed) | |||||
| master_weight = torch.empty(output_size, input_size) | |||||
| torch.nn.init.normal_(master_weight) | |||||
| rank = mpu.get_model_parallel_rank() | |||||
| my_weight = torch.split( | |||||
| master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() | |||||
| # Compare. | |||||
| error = weight.sub(my_weight).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' row parallel max error (should be zero) on global rank ' | |||||
| '{}: {}'.format(torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print(' >> passed the test :-)') | |||||
| class IdentityLayer2D(torch.nn.Module): | |||||
| def __init__(self, m, n): | |||||
| super(IdentityLayer2D, self).__init__() | |||||
| self.weight = Parameter(torch.Tensor(m, n)) | |||||
| torch.nn.init.xavier_normal_(self.weight) | |||||
| def forward(self): | |||||
| return self.weight | |||||
| def test_column_parallel_linear(model_parallel_size): | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing ColumnParallelLinear with model parallel ' | |||||
| 'size: {}'.format(model_parallel_size)) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| seed = 12345 | |||||
| set_random_seed(seed) | |||||
| input_size_coeff = 13 | |||||
| input_size = input_size_coeff * model_parallel_size | |||||
| output_size_coeff = 17 | |||||
| output_size = output_size_coeff * model_parallel_size | |||||
| batch_size = 7 | |||||
| # Network | |||||
| identity_layer = IdentityLayer2D(batch_size, input_size).cuda() | |||||
| linear_layer = mpu.ColumnParallelLinear( | |||||
| input_size, output_size, keep_master_weight_for_test=True).cuda() | |||||
| loss_weight = torch.randn([batch_size, output_size]).cuda() | |||||
| # Forward | |||||
| input_ = identity_layer() | |||||
| output = linear_layer(input_) | |||||
| loss = torch.mul(output, loss_weight).sum() | |||||
| # Backward | |||||
| loss.backward() | |||||
| # Values. | |||||
| dLdY = loss_weight | |||||
| X = identity_layer.weight | |||||
| A = linear_layer.master_weight.cuda() | |||||
| dLdA = torch.matmul(dLdY.t(), X) | |||||
| dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) | |||||
| dLdX = torch.matmul(dLdY, A) | |||||
| rank = mpu.get_model_parallel_rank() | |||||
| my_dLdA = torch.split( | |||||
| dLdA, output_size_coeff, dim=0)[rank].contiguous().clone() | |||||
| error = my_dLdA.sub(linear_layer.weight.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' error in dLdA on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| my_dLdb = torch.split( | |||||
| dLdb, output_size_coeff, dim=0)[rank].contiguous().clone() | |||||
| error = my_dLdb.sub(linear_layer.bias.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' error in dLdb on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| error = dLdX.sub(identity_layer.weight.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' error in dLdX on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print(' >> passed the test :-)') | |||||
| def test_row_parallel_linear(model_parallel_size): | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing RowParallelLinear with model parallel ' | |||||
| 'size: {}'.format(model_parallel_size)) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| seed = 12345 | |||||
| set_random_seed(seed) | |||||
| input_size_coeff = 13 | |||||
| input_size = input_size_coeff * model_parallel_size | |||||
| output_size_coeff = 17 | |||||
| output_size = output_size_coeff * model_parallel_size | |||||
| batch_size = 7 | |||||
| # Network | |||||
| identity_layer = IdentityLayer2D(batch_size, input_size).cuda() | |||||
| linear_layer = mpu.RowParallelLinear( | |||||
| input_size, output_size, keep_master_weight_for_test=True).cuda() | |||||
| loss_weight = torch.randn([batch_size, output_size]).cuda() | |||||
| # Forward | |||||
| input_ = identity_layer() | |||||
| output = linear_layer(input_) | |||||
| loss = torch.mul(output, loss_weight).sum() | |||||
| # Backward | |||||
| loss.backward() | |||||
| # Values. | |||||
| dLdY = loss_weight | |||||
| X = identity_layer.weight | |||||
| A = linear_layer.master_weight.cuda() | |||||
| dLdA = torch.matmul(dLdY.t(), X) | |||||
| dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) | |||||
| dLdX = torch.matmul(dLdY, A) | |||||
| rank = mpu.get_model_parallel_rank() | |||||
| my_dLdA = torch.split( | |||||
| dLdA, input_size_coeff, dim=1)[rank].contiguous().clone() | |||||
| error = my_dLdA.sub(linear_layer.weight.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' error in dLdA on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| error = dLdb.sub(linear_layer.bias.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' error in dLdb on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| error = dLdX.sub(identity_layer.weight.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' error in dLdX on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print(' >> passed the test :-)') | |||||
| class IdentityLayer3D(torch.nn.Module): | |||||
| def __init__(self, m, n, k): | |||||
| super(IdentityLayer3D, self).__init__() | |||||
| self.weight = Parameter(torch.Tensor(m, n, k)) | |||||
| torch.nn.init.xavier_normal_(self.weight) | |||||
| def forward(self): | |||||
| return self.weight | |||||
| def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, | |||||
| hidden_size_per_att_head, dropout_prob, batch_size, | |||||
| sequence_length): | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| seed = 12345 | |||||
| set_random_seed(seed) | |||||
| num_att_heads = num_att_heads_per_partition * torch.distributed.get_world_size( | |||||
| ) # noqa | |||||
| hidden_size = hidden_size_per_att_head * num_att_heads | |||||
| # Network | |||||
| identity_layer = IdentityLayer3D(batch_size, sequence_length, | |||||
| hidden_size).cuda() | |||||
| attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, | |||||
| dropout_prob).cuda() | |||||
| loss_weight = torch.randn([batch_size, sequence_length, | |||||
| hidden_size]).cuda() | |||||
| attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() | |||||
| # Forward | |||||
| input_ = identity_layer() | |||||
| output = attention_layer(input_, attention_mask) | |||||
| loss = torch.mul(output, loss_weight).sum() | |||||
| # Backward | |||||
| loss.backward() | |||||
| rank = mpu.get_model_parallel_rank() | |||||
| mpu.destroy_model_parallel() | |||||
| return rank, hidden_size, model_parallel_size, loss, \ | |||||
| attention_layer, identity_layer | |||||
| def test_parallel_self_attention(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing ParallelSelfAttention with model parallel ' | |||||
| 'size: {}'.format(model_parallel_size)) | |||||
| num_att_heads_per_partition = 3 | |||||
| hidden_size_per_att_head = 7 | |||||
| dropout_prob = 0.0 # has to be zero | |||||
| batch_size = 5 | |||||
| sequence_length = 13 | |||||
| rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ | |||||
| attention_layer_1, identity_layer_1 = parallel_self_attention( | |||||
| 1, num_att_heads_per_partition, | |||||
| hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) | |||||
| rank, hidden_size, model_parallel_size, loss, \ | |||||
| attention_layer, identity_layer = parallel_self_attention( | |||||
| model_parallel_size, num_att_heads_per_partition, | |||||
| hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) | |||||
| assert hideen_size_1 == hidden_size | |||||
| error = loss_1.sub(loss).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' loss error on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 5.0e-6 | |||||
| my_lin_grad_list = torch.split( | |||||
| attention_layer_1.query_key_value.weight.grad, | |||||
| hidden_size // model_parallel_size, 0)[rank::model_parallel_size] | |||||
| my_lin_grad = torch.cat(my_lin_grad_list, dim=0) | |||||
| error = my_lin_grad.sub( | |||||
| attention_layer.query_key_value.weight.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' weight gradient error on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 5.0e-6 | |||||
| error = identity_layer_1.weight.grad.sub( | |||||
| identity_layer.weight.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' input gradient error on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 5.0e-6 | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print(' >> passed the test :-)') | |||||
| def parallel_transformer(model_parallel_size, num_att_heads_per_partition, | |||||
| hidden_size_per_att_head, batch_size, | |||||
| sequence_length): | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| seed = 12345 | |||||
| set_random_seed(seed) | |||||
| num_att_heads = num_att_heads_per_partition * torch.distributed.get_world_size( | |||||
| ) | |||||
| hidden_size = hidden_size_per_att_head * num_att_heads | |||||
| intermediate_size = 4 * hidden_size | |||||
| # Network | |||||
| identity_layer = IdentityLayer3D(batch_size, sequence_length, | |||||
| hidden_size).cuda() | |||||
| transformer_layer = mpu.BertParallelTransformerLayer( | |||||
| hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, | |||||
| torch.nn.functional.relu, 1.0e-5).cuda() | |||||
| loss_weight = torch.randn([batch_size, sequence_length, | |||||
| hidden_size]).cuda() | |||||
| attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() | |||||
| # Forward | |||||
| input_ = identity_layer() | |||||
| output = transformer_layer(input_, attention_mask) | |||||
| loss = torch.mul(output, loss_weight).sum() | |||||
| # Backward | |||||
| loss.backward() | |||||
| rank = mpu.get_model_parallel_rank() | |||||
| mpu.destroy_model_parallel() | |||||
| return rank, hidden_size, model_parallel_size, loss, \ | |||||
| transformer_layer, identity_layer | |||||
| def test_parallel_transformer_layer(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing ParallelTransformerLayer with model parallel ' | |||||
| 'size: {}'.format(model_parallel_size)) | |||||
| num_att_heads_per_partition = 3 | |||||
| hidden_size_per_att_head = 7 | |||||
| batch_size = 5 | |||||
| sequence_length = 13 | |||||
| rank_1, hidden_size_1, model_parallel_size_1, loss_1, \ | |||||
| transformer_layer_1, identity_layer_1 = parallel_transformer( | |||||
| 1, num_att_heads_per_partition, | |||||
| hidden_size_per_att_head, batch_size, sequence_length) | |||||
| rank, hidden_size, model_parallel_size, loss, \ | |||||
| transformer_layer, identity_layer = parallel_transformer( | |||||
| model_parallel_size, num_att_heads_per_partition, | |||||
| hidden_size_per_att_head, batch_size, sequence_length) | |||||
| error = loss_1.sub(loss).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' loss error on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 5.0e-5, 'error: {}'.format(error) | |||||
| error = identity_layer_1.weight.grad.sub( | |||||
| identity_layer.weight.grad).abs().max() | |||||
| torch.distributed.barrier() | |||||
| print(' input gradient error on global rank {}: {}'.format( | |||||
| torch.distributed.get_rank(), error)) | |||||
| assert error < 5.0e-5, 'error: {}'.format(error) | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print(' >> passed the test :-)') | |||||
| if __name__ == '__main__': | |||||
| torch.backends.cudnn.deterministic = True | |||||
| torch.backends.cudnn.benchmark = False | |||||
| initialize_distributed() | |||||
| world_size = torch.distributed.get_world_size() | |||||
| print_separator('test initialize affine weight') | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| test_initialize_affine_weight(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| print_separator('test parallel embedding') | |||||
| test_parallel_embedding(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| print_separator('test column-parallel linear') | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| test_column_parallel_linear(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| print_separator('test row-parallel linear') | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| test_row_parallel_linear(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| print_separator('test parallel self-attention') | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| test_parallel_self_attention(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| print_separator('test parallel transformer') | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| test_parallel_transformer_layer(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| @@ -0,0 +1,206 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import sys | |||||
| import mpu | |||||
| import torch | |||||
| from commons import initialize_distributed, print_separator | |||||
| sys.path.append('../..') | |||||
| def test_set_cuda_rng_state(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing set_rng_state with size {} ...'.format( | |||||
| model_parallel_size)) | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| size = 123 | |||||
| seed = 1234 | |||||
| torch.cuda.manual_seed(seed) | |||||
| tensor = torch.cuda.FloatTensor(size) | |||||
| # Get the state | |||||
| rng_state = torch.cuda.get_rng_state() | |||||
| rng_state_copy = rng_state.clone() | |||||
| # Do some stuff. | |||||
| for _ in range(5): | |||||
| torch.randn(size, out=tensor) | |||||
| result_1 = tensor.clone() | |||||
| assert rng_state.sub(rng_state_copy).max() == 0 | |||||
| assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 | |||||
| # State should be different. | |||||
| new_rng_state = torch.cuda.get_rng_state() | |||||
| max_diff = new_rng_state.sub(rng_state).max() | |||||
| print( | |||||
| ' max diff in rng state (should be non-zero) on global rank {}: {}'. | |||||
| format(torch.distributed.get_rank(), max_diff)) | |||||
| assert max_diff > 0 | |||||
| # Reset the rng state and do the same stuff. | |||||
| mpu.random._set_cuda_rng_state(rng_state) | |||||
| for _ in range(5): | |||||
| torch.randn(size, out=tensor) | |||||
| mpu.random._set_cuda_rng_state(rng_state) | |||||
| for _ in range(5): | |||||
| torch.randn(size, out=tensor) | |||||
| result_2 = tensor.clone() | |||||
| # Results should be the same | |||||
| error = result_2.sub(result_1).abs().max() | |||||
| print(' max error in generated tensors (should be zero) on ' | |||||
| 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| # Input state should have remained intact. | |||||
| error = rng_state.sub(rng_state_copy).max() | |||||
| print(' max error in rng state (should be zero) on global rank {}: {}'. | |||||
| format(torch.distributed.get_rank(), error)) | |||||
| assert error == 0 | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('>> passed the test :-)') | |||||
| def test_cuda_rng_tracker(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing cuda rng tracker with size {} ...'.format( | |||||
| model_parallel_size)) | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| seed_1 = 1234 | |||||
| seed_2 = 4321 | |||||
| size = [12, 21] | |||||
| tensor = torch.cuda.FloatTensor(size) | |||||
| # Set to seed_1 and generate two tensors. | |||||
| torch.cuda.manual_seed(seed_1) | |||||
| torch.randn(size, out=tensor) | |||||
| target_11 = tensor.clone() | |||||
| torch.randn(size, out=tensor) | |||||
| target_12 = tensor.clone() | |||||
| # Set to seed_2 and generate two tensors. | |||||
| torch.cuda.manual_seed(seed_2) | |||||
| torch.randn(size, out=tensor) | |||||
| target_21 = tensor.clone() | |||||
| torch.randn(size, out=tensor) | |||||
| target_22 = tensor.clone() | |||||
| # Now if we interleave seed_1 and seed_2, | |||||
| # we should still get the same tensors | |||||
| torch.cuda.manual_seed(seed_1) | |||||
| mpu.get_cuda_rng_tracker().add('test', seed_2) | |||||
| torch.randn(size, out=tensor) | |||||
| result_11 = tensor.clone() | |||||
| with mpu.get_cuda_rng_tracker().fork('test'): | |||||
| torch.randn(size, out=tensor) | |||||
| result_21 = tensor.clone() | |||||
| torch.randn(size, out=tensor) | |||||
| result_12 = tensor.clone() | |||||
| with mpu.get_cuda_rng_tracker().fork('test'): | |||||
| torch.randn(size, out=tensor) | |||||
| result_22 = tensor.clone() | |||||
| diff = result_11.sub(result_21).abs().max() | |||||
| diff = min(diff, result_12.sub(result_22).abs().max()) | |||||
| print(' max diff in generated tensors (should be non-zero) on ' | |||||
| 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) | |||||
| assert diff > 1.0e-6 | |||||
| error = max( | |||||
| result_11.sub(target_11).abs().max(), | |||||
| result_12.sub(target_12).abs().max()) | |||||
| error = max(error, result_21.sub(target_21).abs().max()) | |||||
| error = max(error, result_22.sub(target_22).abs().max()) | |||||
| print(' max error in generated tensors (should be zero) on ' | |||||
| 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) | |||||
| assert error < 1.0e-6 | |||||
| # Reset the tracker | |||||
| mpu.get_cuda_rng_tracker().reset() | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('>> passed the test :-)') | |||||
| def test_model_parallel_cuda_manual_seed(model_parallel_size): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('> testing model parallel cuda manual seed with size {} ...'. | |||||
| format(model_parallel_size)) | |||||
| mpu.initialize_model_parallel(model_parallel_size) | |||||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||||
| mpu.model_parallel_cuda_manual_seed(12345) | |||||
| assert torch.cuda.initial_seed() == 12345 | |||||
| with mpu.get_cuda_rng_tracker().fork(): | |||||
| assert torch.cuda.initial_seed() == (12345 + 2718 | |||||
| + mpu.get_model_parallel_rank()) | |||||
| # Reset the tracker | |||||
| mpu.get_cuda_rng_tracker().reset() | |||||
| # Reset groups | |||||
| mpu.destroy_model_parallel() | |||||
| torch.distributed.barrier() | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print('>> passed the test :-)') | |||||
| if __name__ == '__main__': | |||||
| initialize_distributed() | |||||
| world_size = torch.distributed.get_world_size() | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| print_separator('test set rng state') | |||||
| test_set_cuda_rng_state(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| print_separator('test cuda rng tracker') | |||||
| test_cuda_rng_tracker(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| model_parallel_size = 1 | |||||
| while model_parallel_size <= world_size: | |||||
| print_separator('test model parallel cuda manual seed') | |||||
| test_model_parallel_cuda_manual_seed(model_parallel_size) | |||||
| model_parallel_size *= 2 | |||||
| @@ -0,0 +1,70 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import torch | |||||
| def ensure_divisibility(numerator, denominator): | |||||
| """Ensure that numerator is divisible by the denominator.""" | |||||
| assert numerator % denominator == 0, '{} is not divisible by {}'.format( | |||||
| numerator, denominator) | |||||
| def divide(numerator, denominator): | |||||
| """Ensure that numerator is divisible by the denominator and return | |||||
| the division value.""" | |||||
| ensure_divisibility(numerator, denominator) | |||||
| return numerator // denominator | |||||
| def split_tensor_along_last_dim(tensor, | |||||
| num_partitions, | |||||
| contiguous_split_chunks=False): | |||||
| """Split a tensor along its last dimension. | |||||
| Arguments: | |||||
| tensor: input tensor. | |||||
| num_partitions: number of partitions to split the tensor | |||||
| contiguous_split_chunks: If True, make each chunk contiguous | |||||
| in memory. | |||||
| """ | |||||
| # Get the size and dimension. | |||||
| last_dim = tensor.dim() - 1 | |||||
| last_dim_size = divide(tensor.size()[last_dim], num_partitions) | |||||
| # Split. | |||||
| tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) | |||||
| # Note: torch.split does not create contiguous tensors by default. | |||||
| if contiguous_split_chunks: | |||||
| return tuple(chunk.contiguous() for chunk in tensor_list) | |||||
| return tensor_list | |||||
| class VocabUtility: | |||||
| """Split the vocabulary into `world_size` chunks amd return the | |||||
| first and last index of the vocabulary belonging to the `rank` | |||||
| partition: Note that indecies in [fist, last)""" | |||||
| @staticmethod | |||||
| def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, | |||||
| rank, world_size): | |||||
| index_f = rank * per_partition_vocab_size | |||||
| index_l = index_f + per_partition_vocab_size | |||||
| return index_f, index_l | |||||
| @staticmethod | |||||
| def vocab_range_from_global_vocab_size(global_vocab_size, rank, | |||||
| world_size): | |||||
| per_partition_vocab_size = divide(global_vocab_size, world_size) | |||||
| return VocabUtility.vocab_range_from_per_partition_vocab_size( | |||||
| per_partition_vocab_size, rank, world_size) | |||||
| @@ -0,0 +1,61 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import glob | |||||
| import os | |||||
| import statistics | |||||
| import sys | |||||
| import json | |||||
| path_pattern = sys.argv[1] | |||||
| target_type = sys.argv[2] | |||||
| best_value, best_result, best_name = None, None, None | |||||
| mean_result = {} | |||||
| print(path_pattern) | |||||
| for dir_path in glob.glob(path_pattern, recursive=True): | |||||
| entry = os.path.basename(dir_path) | |||||
| valid_result = None | |||||
| test_found = os.path.exists(os.path.join(dir_path, 'test_results.json')) | |||||
| valid_path = os.path.join(dir_path, 'results.json') | |||||
| if os.path.exists(valid_path): | |||||
| print(entry) | |||||
| with open(valid_path) as file: | |||||
| valid_result = json.load(file) | |||||
| else: | |||||
| print(f'{entry} no validation results') | |||||
| continue | |||||
| if not test_found: | |||||
| print(f'{entry} not tested yet') | |||||
| if target_type == 'max': | |||||
| metric = sys.argv[3] | |||||
| metric_value = valid_result[metric] | |||||
| if best_value is None or metric_value > best_value: | |||||
| best_value = metric_value | |||||
| best_result = valid_result | |||||
| best_name = entry | |||||
| elif target_type == 'mean' or target_type == 'median': | |||||
| if mean_result: | |||||
| for metric, value in valid_result.items(): | |||||
| if metric not in ['type', 'epoch']: | |||||
| mean_result[metric].append(value) | |||||
| else: | |||||
| mean_result = { | |||||
| metric: [value] | |||||
| for metric, value in valid_result.items() | |||||
| if metric not in ['type', 'epoch'] | |||||
| } | |||||
| if target_type == 'max': | |||||
| print(f'Best result found at {best_name}: {best_result}') | |||||
| elif target_type == 'mean': | |||||
| mean_result = { | |||||
| metric: sum(value) / len(value) | |||||
| for metric, value in mean_result.items() | |||||
| } | |||||
| print(f'Mean result {mean_result}') | |||||
| elif target_type == 'median': | |||||
| mean_result = { | |||||
| metric: statistics.median(value) | |||||
| for metric, value in mean_result.items() | |||||
| } | |||||
| print(f'Mean result {mean_result}') | |||||
| @@ -0,0 +1,22 @@ | |||||
| boto3 | |||||
| botocore | |||||
| deepspeed | |||||
| fasttext | |||||
| filelock | |||||
| ftfy | |||||
| langdetect | |||||
| lsh | |||||
| matplotlib | |||||
| mpi4py | |||||
| nltk | |||||
| pandas | |||||
| regex | |||||
| requests | |||||
| rouge_score | |||||
| scikit_learn | |||||
| scipy | |||||
| sentencepiece | |||||
| termcolor | |||||
| tldextract | |||||
| tqdm | |||||
| transformers | |||||
| @@ -0,0 +1,10 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import sys | |||||
| if sys.argv[1] == 'block': | |||||
| from test.test_block import main | |||||
| main() | |||||
| elif sys.argv[1] == 'rel_shift': | |||||
| from test.test_rel_shift import main | |||||
| main() | |||||
| @@ -0,0 +1,389 @@ | |||||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """ Tasks data utility.""" | |||||
| import copy | |||||
| import pickle | |||||
| import re | |||||
| from typing import Dict, List, Optional | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.utils.data | |||||
| from torch.utils.data.dataloader import default_collate | |||||
| from modelscope.models.nlp.mglm import mpu | |||||
| def clean_text(text): | |||||
| """Remove new lines and multiple spaces and adjust end of sentence dot.""" | |||||
| text = text.replace('\n', ' ') | |||||
| text = re.sub(r'\s+', ' ', text) | |||||
| for _ in range(3): | |||||
| text = text.replace(' . ', '. ') | |||||
| return text | |||||
| class InputExample(object): | |||||
| """A raw input example consisting of one or two segments of text and a label""" | |||||
| def __init__(self, | |||||
| guid, | |||||
| text_a, | |||||
| text_b=None, | |||||
| label=None, | |||||
| logits=None, | |||||
| meta: Optional[Dict] = None, | |||||
| idx=-1, | |||||
| num_choices=1): | |||||
| """ | |||||
| Create a new InputExample. | |||||
| :param guid: a unique textual identifier | |||||
| :param text_a: the sequence of text | |||||
| :param text_b: an optional, second sequence of text | |||||
| :param label: an optional label | |||||
| :param logits: an optional list of per-class logits | |||||
| :param meta: an optional dictionary to store arbitrary meta information | |||||
| :param idx: an optional numeric index | |||||
| """ | |||||
| self.guid = guid | |||||
| self.text_a = text_a | |||||
| self.text_b = text_b | |||||
| self.label = label | |||||
| self.logits = logits | |||||
| self.idx = idx | |||||
| self.num_choices = num_choices | |||||
| self.meta = meta if meta else {} | |||||
| def __repr__(self): | |||||
| return str(self.to_json_string()) | |||||
| def to_dict(self): | |||||
| """Serialize this instance to a Python dictionary.""" | |||||
| output = copy.deepcopy(self.__dict__) | |||||
| return output | |||||
| def to_json_string(self): | |||||
| """Serialize this instance to a JSON string.""" | |||||
| return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' | |||||
| @staticmethod | |||||
| def load_examples(path: str) -> List['InputExample']: | |||||
| """Load a set of input examples from a file""" | |||||
| with open(path, 'rb') as fh: | |||||
| return pickle.load(fh) | |||||
| @staticmethod | |||||
| def save_examples(examples: List['InputExample'], path: str) -> None: | |||||
| """Save a set of input examples to a file""" | |||||
| with open(path, 'wb') as fh: | |||||
| pickle.dump(examples, fh) | |||||
| def num_special_tokens_to_add(text_a_ids, | |||||
| text_b_ids, | |||||
| answer_ids, | |||||
| add_cls, | |||||
| add_sep, | |||||
| add_piece, | |||||
| add_eos=True): | |||||
| num_tokens = 0 | |||||
| if add_cls: | |||||
| num_tokens += 1 | |||||
| if text_b_ids and add_sep: | |||||
| num_tokens += 1 | |||||
| if add_eos: | |||||
| num_tokens += 1 | |||||
| if not answer_ids and add_piece: | |||||
| num_tokens += 1 | |||||
| return num_tokens | |||||
| def build_input_from_ids(text_a_ids, | |||||
| text_b_ids, | |||||
| answer_ids, | |||||
| max_seq_length, | |||||
| tokenizer, | |||||
| args=None, | |||||
| add_cls=True, | |||||
| add_sep=False, | |||||
| add_piece=False, | |||||
| add_eos=True, | |||||
| mask_id=None): | |||||
| if mask_id is None: | |||||
| mask_id = tokenizer.get_command('MASK').Id | |||||
| eos_id = tokenizer.get_command('eos').Id | |||||
| cls_id = tokenizer.get_command('ENC').Id | |||||
| sep_id = tokenizer.get_command('sep').Id | |||||
| ids = [] | |||||
| types = [] | |||||
| paddings = [] | |||||
| # CLS | |||||
| if add_cls: | |||||
| ids.append(cls_id) | |||||
| types.append(0) | |||||
| paddings.append(1) | |||||
| # A | |||||
| len_text_a = len(text_a_ids) | |||||
| ids.extend(text_a_ids) | |||||
| types.extend([0] * len_text_a) | |||||
| paddings.extend([1] * len_text_a) | |||||
| # B | |||||
| if text_b_ids is not None: | |||||
| # SEP | |||||
| if add_sep: | |||||
| ids.append(sep_id) | |||||
| types.append(0) | |||||
| paddings.append(1) | |||||
| len_text_b = len(text_b_ids) | |||||
| ids.extend(text_b_ids) | |||||
| types.extend([1] * len_text_b) | |||||
| paddings.extend([1] * len_text_b) | |||||
| eos_length = 1 if add_eos else 0 | |||||
| # Cap the size. | |||||
| if len(ids) >= max_seq_length - eos_length: | |||||
| max_seq_length_m1 = max_seq_length - 1 | |||||
| ids = ids[0:max_seq_length_m1] | |||||
| types = types[0:max_seq_length_m1] | |||||
| paddings = paddings[0:max_seq_length_m1] | |||||
| end_type = 0 if text_b_ids is None else 1 | |||||
| if add_eos: | |||||
| ids.append(eos_id) | |||||
| types.append(end_type) | |||||
| paddings.append(1) | |||||
| sep = len(ids) | |||||
| target_ids = [0] * len(ids) | |||||
| loss_masks = [0] * len(ids) | |||||
| position_ids = list(range(len(ids))) | |||||
| block_position_ids = [0] * len(ids) | |||||
| # Piece | |||||
| if add_piece or answer_ids is not None: | |||||
| sop_id = tokenizer.get_command('sop').Id | |||||
| mask_position = ids.index( | |||||
| mask_id | |||||
| ) if not args.sentinel_token else args.max_position_embeddings | |||||
| ids.append(sop_id) | |||||
| types.append(end_type) | |||||
| paddings.append(1) | |||||
| position_ids.append(mask_position) | |||||
| block_position_ids.append(1) | |||||
| if answer_ids is not None: | |||||
| len_answer = len(answer_ids) | |||||
| ids.extend(answer_ids[:-1]) | |||||
| types.extend([end_type] * (len_answer - 1)) | |||||
| paddings.extend([1] * (len_answer - 1)) | |||||
| position_ids.extend([mask_position] * (len_answer - 1)) | |||||
| if not args.no_block_position: | |||||
| block_position_ids.extend(range(2, len(answer_ids) + 1)) | |||||
| else: | |||||
| block_position_ids.extend([1] * (len(answer_ids) - 1)) | |||||
| target_ids.extend(answer_ids) | |||||
| loss_masks.extend([1] * len(answer_ids)) | |||||
| else: | |||||
| target_ids.append(0) | |||||
| loss_masks.append(1) | |||||
| # Padding. | |||||
| padding_length = max_seq_length - len(ids) | |||||
| if padding_length > 0: | |||||
| ids.extend([eos_id] * padding_length) | |||||
| types.extend([eos_id] * padding_length) | |||||
| paddings.extend([0] * padding_length) | |||||
| position_ids.extend([0] * padding_length) | |||||
| block_position_ids.extend([0] * padding_length) | |||||
| target_ids.extend([0] * padding_length) | |||||
| loss_masks.extend([0] * padding_length) | |||||
| if not args.masked_lm: | |||||
| position_ids = [position_ids, block_position_ids] | |||||
| return ids, types, paddings, position_ids, sep, target_ids, loss_masks | |||||
| def build_decoder_input(enc_ids, answer_ids, max_seq_length, | |||||
| max_dec_seq_length, tokenizer): | |||||
| mask_id = tokenizer.get_command('MASK').Id | |||||
| eos_id = tokenizer.get_command('eos').Id | |||||
| sop_id = tokenizer.get_command('sop').Id | |||||
| enc_len = len(enc_ids) # noqa | |||||
| masks = [] | |||||
| # TODO: it probably takes too much memory | |||||
| # for i in range(max_dec_seq_length): | |||||
| # m = [1]*enc_len + [0]*(max_seq_length - enc_len) + [1]*(i+1) + [0]*(max_dec_seq_length-1-i) | |||||
| # masks.append(m) | |||||
| mask_position = enc_ids.index(mask_id) | |||||
| len_answer = len(answer_ids) | |||||
| ids = [sop_id] + answer_ids[:-1] | |||||
| types = [0] * len_answer # not used | |||||
| paddings = [1] * len_answer | |||||
| position_ids = [mask_position] * len_answer | |||||
| block_position_ids = list(range(1, len_answer + 1)) | |||||
| target_ids = answer_ids | |||||
| loss_masks = [1] * len_answer | |||||
| # Padding. | |||||
| padding_length = max_dec_seq_length - len(ids) | |||||
| if padding_length > 0: | |||||
| ids.extend([eos_id] * padding_length) | |||||
| types.extend([0] * padding_length) | |||||
| paddings.extend([0] * padding_length) | |||||
| position_ids.extend([0] * padding_length) | |||||
| block_position_ids.extend([0] * padding_length) | |||||
| target_ids.extend([0] * padding_length) | |||||
| loss_masks.extend([0] * padding_length) | |||||
| position_ids = [position_ids, block_position_ids] | |||||
| return ids, types, paddings, position_ids, masks, target_ids, loss_masks | |||||
| def build_sample(ids, | |||||
| types=None, | |||||
| paddings=None, | |||||
| positions=None, | |||||
| masks=None, | |||||
| label=None, | |||||
| unique_id=None, | |||||
| target=None, | |||||
| logit_mask=None, | |||||
| segment_ids=None, | |||||
| prompt_ids=None): | |||||
| """Convert to numpy and return a sample consumed by the batch producer.""" | |||||
| ids_np = np.array(ids, dtype=np.int64) | |||||
| sample = {'text': ids_np, 'label': int(label)} | |||||
| if types is not None: | |||||
| types_np = np.array(types, dtype=np.int64) | |||||
| sample['types'] = types_np | |||||
| if paddings is not None: | |||||
| paddings_np = np.array(paddings, dtype=np.int64) | |||||
| sample['padding_mask'] = paddings_np | |||||
| if positions is not None: | |||||
| positions_np = np.array(positions, dtype=np.int64) | |||||
| sample['position'] = positions_np | |||||
| if masks is not None: | |||||
| masks_np = np.array(masks, dtype=np.int64) | |||||
| sample['mask'] = masks_np | |||||
| if target is not None: | |||||
| target_np = np.array(target, dtype=np.int64) | |||||
| sample['target'] = target_np | |||||
| if logit_mask is not None: | |||||
| logit_mask_np = np.array(logit_mask, dtype=np.int64) | |||||
| sample['logit_mask'] = logit_mask_np | |||||
| if segment_ids is not None: | |||||
| segment_ids = np.array(segment_ids, dtype=np.int64) | |||||
| sample['segment_id'] = segment_ids | |||||
| if prompt_ids is not None: | |||||
| prompt_ids = np.array(prompt_ids, dtype=np.int64) | |||||
| sample['prompt_pos'] = prompt_ids | |||||
| if unique_id is not None: | |||||
| sample['uid'] = unique_id | |||||
| return sample | |||||
| def build_decoder_sample(sample, dec_ids, dec_position, dec_masks, dec_target, | |||||
| dec_logit_mask): | |||||
| sample['dec_text'] = np.array(dec_ids) | |||||
| sample['dec_position'] = np.array(dec_position) | |||||
| sample['dec_mask'] = np.array(dec_masks) | |||||
| sample['dec_target'] = np.array(dec_target) | |||||
| sample['dec_logit_mask'] = np.array(dec_logit_mask) | |||||
| return sample | |||||
| def my_collate(batch): | |||||
| new_batch = [{key: value | |||||
| for key, value in sample.items() if key != 'uid'} | |||||
| for sample in batch] | |||||
| text_list = [sample['text'] for sample in batch] | |||||
| def pad_choice_dim(data, choice_num): | |||||
| if len(data) < choice_num: | |||||
| data = np.concatenate([data] | |||||
| + [data[0:1]] * (choice_num - len(data))) | |||||
| return data | |||||
| if len(text_list[0].shape) == 2: | |||||
| choice_nums = list(map(len, text_list)) | |||||
| max_choice_num = max(choice_nums) | |||||
| for i, sample in enumerate(new_batch): | |||||
| for key, value in sample.items(): | |||||
| if key != 'label': | |||||
| sample[key] = pad_choice_dim(value, max_choice_num) | |||||
| else: | |||||
| sample[key] = value | |||||
| sample['loss_mask'] = np.array( | |||||
| [1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]), | |||||
| dtype=np.int64) | |||||
| if 'dec_text' in new_batch[0]: | |||||
| choice_nums = [len(sample['dec_text']) for sample in new_batch] | |||||
| if choice_nums.count(choice_nums[0]) != len(choice_nums): | |||||
| max_choice_num = max(choice_nums) | |||||
| for i, sample in enumerate(new_batch): | |||||
| for key, value in sample.items(): | |||||
| if key.startswith('dec_'): | |||||
| sample[key] = pad_choice_dim(value, max_choice_num) | |||||
| sample['loss_mask'] = np.array( | |||||
| [1] * choice_nums[i] + [0] * # noqa | |||||
| (max_choice_num - choice_nums[i]), | |||||
| dtype=np.int64) | |||||
| new_batch = default_collate(new_batch) | |||||
| if 'uid' in batch[0]: | |||||
| uid_list = [sample['uid'] for sample in batch] | |||||
| new_batch['uid'] = uid_list | |||||
| return new_batch | |||||
| class FakeDataloader: | |||||
| def __init__(self, num_iters): | |||||
| self.num_iters = num_iters | |||||
| def __iter__(self): | |||||
| if self.num_iters is not None: | |||||
| for _ in range(self.num_iters): | |||||
| yield None | |||||
| else: | |||||
| while True: | |||||
| yield None | |||||
| def build_data_loader(dataset, | |||||
| batch_size, | |||||
| num_workers, | |||||
| drop_last, | |||||
| shuffle=True, | |||||
| only_rank0=False): | |||||
| """Data loader. Note that batch-size is the local (per GPU) batch-size.""" | |||||
| # Sampler. | |||||
| if only_rank0: | |||||
| rank, world_size = 0, 1 | |||||
| else: | |||||
| world_size = mpu.get_data_parallel_world_size() | |||||
| rank = mpu.get_data_parallel_rank() | |||||
| sampler = torch.utils.data.distributed.DistributedSampler( | |||||
| dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | |||||
| # Data loader. Note that batch size is the per GPU batch size. | |||||
| data_loader = torch.utils.data.DataLoader( | |||||
| dataset, | |||||
| batch_size=batch_size, | |||||
| sampler=sampler, | |||||
| shuffle=False, | |||||
| num_workers=num_workers, | |||||
| drop_last=drop_last, | |||||
| pin_memory=True, | |||||
| collate_fn=my_collate) | |||||
| return data_loader | |||||
| @@ -0,0 +1,249 @@ | |||||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Evaluation utilities.""" | |||||
| import datetime | |||||
| import os | |||||
| import random | |||||
| import time | |||||
| from collections import OrderedDict | |||||
| from typing import List | |||||
| import mpu | |||||
| import torch | |||||
| from finetune_glm import process_batch | |||||
| from sklearn.metrics import f1_score | |||||
| from tasks.data_utils import InputExample, build_data_loader | |||||
| from utils import debug_finetune_data, get_spare_port, print_rank_0 | |||||
| def accuracy_metric(predictions, labels, examples): | |||||
| count = 0 | |||||
| num_predictions = max(len(predictions), 1) | |||||
| assert len(predictions) == len(labels) | |||||
| for prediction, label in zip(predictions, labels): | |||||
| count += prediction == label | |||||
| return count * 100.0 / num_predictions | |||||
| def f1_metric(predictions, labels, examples): | |||||
| return f1_score(labels, predictions) | |||||
| def f1_macro_metric(predictions, labels, examples): | |||||
| return f1_score(labels, predictions, average='macro') | |||||
| global_tokenizer = None | |||||
| def accuracy_func_provider(single_dataset_provider, | |||||
| metric_dict, | |||||
| args, | |||||
| is_test=False, | |||||
| eval_func=None, | |||||
| output_func=None, | |||||
| only_rank0=True, | |||||
| tokenizer=None): | |||||
| """Provide function that calculates accuracies.""" | |||||
| # Build dataloaders. | |||||
| global global_tokenizer | |||||
| global_tokenizer = tokenizer | |||||
| if only_rank0 and torch.distributed.is_initialized( | |||||
| ) and torch.distributed.get_rank() != 0: | |||||
| return None | |||||
| if is_test and not args.eval_valid: | |||||
| datapaths = args.test_data if args.test_data is not None else ['test'] | |||||
| else: | |||||
| datapaths = args.valid_data if args.valid_data is not None else ['dev'] | |||||
| if eval_func is None: | |||||
| eval_func = multichoice_evaluate | |||||
| dataloaders = [] | |||||
| eval_batch_size = args.eval_batch_size if args.eval_batch_size else args.batch_size | |||||
| for datapath in datapaths: | |||||
| dataset = single_dataset_provider(datapath) | |||||
| dataloader = build_data_loader( | |||||
| dataset, | |||||
| eval_batch_size, | |||||
| num_workers=args.num_workers, | |||||
| drop_last=False, | |||||
| shuffle=False, | |||||
| only_rank0=only_rank0) | |||||
| dataloaders.append((dataset.dataset_name, dataloader)) | |||||
| def metrics_func(model, | |||||
| epoch, | |||||
| output_predictions=False, | |||||
| summary_writer=None): | |||||
| print_rank_0('calculating metrics ...') | |||||
| score_dict = OrderedDict([(key, 0.0) for key in metric_dict | |||||
| ]) if isinstance(metric_dict, dict) else { | |||||
| metric_dict: 0.0 | |||||
| } # noqa | |||||
| total = 0 | |||||
| for name, dataloader in dataloaders: | |||||
| example_dict = None | |||||
| if hasattr(dataloader.dataset, 'examples'): | |||||
| example_dict = dataloader.dataset.examples | |||||
| start_time = time.time() | |||||
| predictions, labels, examples = eval_func(model, dataloader, | |||||
| example_dict, args) | |||||
| elapsed_time = time.time() - start_time | |||||
| if output_predictions and torch.distributed.get_rank() == 0: | |||||
| filename = os.path.join(args.log_dir, name + '.jsonl') | |||||
| output_func(predictions, examples, filename) | |||||
| total_count = len(predictions) | |||||
| single_dict = { | |||||
| key: metric(predictions, labels, examples) | |||||
| for key, metric in metric_dict.items() | |||||
| } | |||||
| output_str = ' > |epoch: {}| metrics for {}: total {}'.format( | |||||
| epoch, name, total_count) | |||||
| for key, value in single_dict.items(): | |||||
| output_str += ' {} = {:.4f} %'.format(key, value) | |||||
| if summary_writer is not None and epoch >= 0 and not is_test and len( | |||||
| dataloaders) > 1: | |||||
| summary_writer.add_scalar(f'Train/valid_{name}_{key}', | |||||
| value, epoch) | |||||
| output_str += ' elapsed time (sec): {:.3f}'.format(elapsed_time) | |||||
| if len(dataloaders) > 1: | |||||
| print_rank_0(output_str) | |||||
| for key in score_dict: | |||||
| score_dict[key] += single_dict[key] * total_count | |||||
| total += total_count | |||||
| score_dict = { | |||||
| key: score / float(total) | |||||
| for key, score in score_dict.items() | |||||
| } | |||||
| output_str = ' >> |epoch: {}| overall: total = {}'.format(epoch, total) | |||||
| for key, score in score_dict.items(): | |||||
| output_str += ' {} = {:.4f}'.format(key, score) | |||||
| if summary_writer is not None and epoch >= 0 and not is_test: | |||||
| summary_writer.add_scalar(f'Train/valid_{key}', score, epoch) | |||||
| print_rank_0(output_str) | |||||
| return score_dict | |||||
| return metrics_func | |||||
| segment_length = 10 | |||||
| def multichoice_evaluate(model, dataloader, example_dict, args): | |||||
| """Calculate correct over total answers and return prediction if the | |||||
| `output_predictions` is true.""" | |||||
| model.eval() | |||||
| port = get_spare_port(args) | |||||
| print_rank_0(f'Using port {port}') | |||||
| store = torch.distributed.TCPStore(args.master_ip, port, | |||||
| torch.distributed.get_world_size(), | |||||
| torch.distributed.get_rank() == 0, | |||||
| datetime.timedelta(seconds=30)) | |||||
| # file_path = os.path.join("/cache", args.experiment_name + "_store") | |||||
| # print_rank_0(f"Using file store at {file_path}") | |||||
| # store = torch.distributed.FileStore(file_path, torch.distributed.get_world_size()) | |||||
| with torch.no_grad(): | |||||
| # For all the batches in the dataset. | |||||
| for _, batch in enumerate(dataloader): | |||||
| # Run the model forward. | |||||
| data = process_batch(batch, args) | |||||
| if args.pretrained_bert: | |||||
| tokens, types, labels_, attention_mask = data['text'], data[ | |||||
| 'types'], data['label'], data['padding_mask'] | |||||
| inputs = [tokens, types, attention_mask] | |||||
| elif args.cloze_eval: | |||||
| tokens, labels_, position_ids = data['text'], data[ | |||||
| 'label'], data['position'] | |||||
| attention_mask, target_ids, logit_mask = data['mask'], data[ | |||||
| 'target'], data['logit_mask'] | |||||
| if not args.fast_decode: | |||||
| inputs = [ | |||||
| tokens, position_ids, attention_mask, target_ids, | |||||
| logit_mask | |||||
| ] | |||||
| if args.continuous_prompt: | |||||
| prompt_pos = data['prompt_pos'] | |||||
| inputs.append(prompt_pos) | |||||
| else: | |||||
| dec_input_ids, dec_position_ids, dec_attention_mask = data[ | |||||
| 'dec_text'], data['dec_position'], data['dec_mask'] | |||||
| dec_target_ids, dec_logit_mask = data['dec_target'], data[ | |||||
| 'dec_logit_mask'] | |||||
| inputs = [ | |||||
| tokens, position_ids, attention_mask, dec_input_ids, | |||||
| dec_position_ids, dec_attention_mask, dec_target_ids, | |||||
| dec_logit_mask | |||||
| ] | |||||
| else: | |||||
| tokens, labels_, position_ids, attention_mask = data[ | |||||
| 'text'], data['label'], data['position'], data['mask'] | |||||
| inputs = [tokens, position_ids, attention_mask] | |||||
| if len(inputs[0].shape | |||||
| ) == 3 and inputs[0].size(1) > segment_length: | |||||
| logit_list = [] | |||||
| for i in range((inputs[0].size(1) - 1) // segment_length + 1): | |||||
| input_batch = [ | |||||
| arg[:, i * segment_length:(i + 1) * segment_length] | |||||
| for arg in inputs | |||||
| ] | |||||
| if args.pretrained_bert: | |||||
| logits = model(*input_batch) | |||||
| else: | |||||
| logits, *mems = model(*input_batch) | |||||
| logit_list.append(logits) | |||||
| logits = torch.cat(logit_list, dim=1) | |||||
| elif args.cloze_eval and args.fast_decode: | |||||
| logit_list = [] | |||||
| num_choices = inputs[3].size(1) | |||||
| for i in range((num_choices - 1) // segment_length + 1): | |||||
| input_batch = inputs[:3] + [ | |||||
| arg[:, i * segment_length:(i + 1) * segment_length] | |||||
| for arg in inputs[3:] | |||||
| ] | |||||
| logits, *mems = model(*input_batch) | |||||
| logit_list.append(logits) | |||||
| logits = torch.cat(logit_list, dim=1) | |||||
| else: | |||||
| if args.pretrained_bert: | |||||
| logits = model(*inputs) | |||||
| else: | |||||
| logits, *mems = model(*inputs) | |||||
| if 'segment_id' in data: | |||||
| from torch_scatter import scatter_sum | |||||
| if 'loss_mask' in data: | |||||
| logits = logits * data['loss_mask'] | |||||
| logits = scatter_sum(logits, data['segment_id'], dim=1) | |||||
| elif 'loss_mask' in data: | |||||
| loss_mask = data['loss_mask'] | |||||
| logits = logits * loss_mask - 10000.0 * (1.0 - loss_mask) | |||||
| uid_list = batch['uid'] | |||||
| if isinstance(uid_list, torch.Tensor): | |||||
| uid_list = uid_list.cpu().numpy().tolist() | |||||
| predicted = torch.argmax(logits, dim=-1).tolist() | |||||
| labels = labels_.tolist() | |||||
| if args.task.lower() == 'wsc': | |||||
| predicted = [1 if pred == 0 else 0 for pred in predicted] | |||||
| if mpu.get_model_parallel_rank() == 0: | |||||
| for uid, prediction, label in zip(uid_list, predicted, labels): | |||||
| store.set(uid, str((prediction, label))) | |||||
| model.train() | |||||
| torch.distributed.barrier() | |||||
| predictions, labels, examples = [], [], [] | |||||
| for uid, example in example_dict.items(): | |||||
| prediction, label = eval(store.get(uid)) | |||||
| predictions.append(prediction) | |||||
| labels.append(label) | |||||
| examples.append(example) | |||||
| torch.distributed.barrier() | |||||
| return predictions, labels, examples | |||||
| @@ -0,0 +1,249 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import math | |||||
| from bisect import bisect_right | |||||
| from itertools import accumulate | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from tasks.data_utils import build_input_from_ids, num_special_tokens_to_add | |||||
| from tasks.language_model.detokenizer import get_detokenizer | |||||
| from utils import print_rank_0 | |||||
| class LMDataset(torch.utils.data.Dataset): | |||||
| def __init__(self, args, documents, tokenizer, num_original_tokens, | |||||
| num_tokenized_tokens): | |||||
| self.args = args | |||||
| self.documents = documents | |||||
| self.max_seq_len = args.seq_length - 1 | |||||
| self.tokenizer = tokenizer | |||||
| self.overalapping_eval = args.overlapping_eval | |||||
| if self.overalapping_eval is None: | |||||
| self.overalapping_eval = self.max_seq_len | |||||
| self.overalapping_eval = max(1, self.overalapping_eval) | |||||
| self.num_original_tokens = num_original_tokens | |||||
| self.num_tokenized_tokens = num_tokenized_tokens | |||||
| # remove first sequence tokens | |||||
| targets = [ | |||||
| max(len(tokens) - self.max_seq_len, 0) for tokens in self.documents | |||||
| ] | |||||
| self.num_sequences = [ | |||||
| max(math.ceil(target / self.overalapping_eval) + 1, 1) | |||||
| for target in targets | |||||
| ] | |||||
| self.weights = list(accumulate(self.num_sequences)) | |||||
| self.left_weights = [0] + self.weights[:-1] | |||||
| self.unidirectional = args.unidirectional | |||||
| self.block_lm = args.block_lm | |||||
| mask_token = 'gMASK' if args.task_mask else 'MASK' | |||||
| self.mask_id = self.tokenizer.get_command(mask_token).Id | |||||
| def __len__(self): | |||||
| return sum(self.num_sequences) | |||||
| def __getitem__(self, idx): | |||||
| document_idx = bisect_right(self.weights, idx) | |||||
| idx = idx - self.left_weights[document_idx] | |||||
| start_idx = idx * self.overalapping_eval | |||||
| end_idx = start_idx + self.max_seq_len | |||||
| tokens = self.documents[document_idx][start_idx:end_idx] | |||||
| if self.block_lm: | |||||
| if idx == 0 or self.unidirectional: | |||||
| prompt, text = tokens[:1], tokens[1:] | |||||
| else: | |||||
| prompt_length = self.max_seq_len - self.overalapping_eval | |||||
| prompt, text = tokens[:prompt_length], tokens[prompt_length:] | |||||
| prompt = prompt + [self.mask_id] | |||||
| num_special_tokens = num_special_tokens_to_add( | |||||
| prompt, | |||||
| None, | |||||
| text, | |||||
| add_cls=True, | |||||
| add_sep=False, | |||||
| add_piece=True, | |||||
| add_eos=False) | |||||
| data = build_input_from_ids( | |||||
| prompt, | |||||
| None, | |||||
| text, | |||||
| self.max_seq_len + num_special_tokens + 1, | |||||
| self.tokenizer, | |||||
| args=self.args, | |||||
| add_cls=True, | |||||
| add_sep=False, | |||||
| add_piece=True, | |||||
| add_eos=False, | |||||
| mask_id=self.mask_id) | |||||
| ids, types, paddings, position_ids, sep, target_ids, loss_masks = data | |||||
| if idx != 0 and self.unidirectional: | |||||
| loss_masks = np.array(loss_masks, dtype=np.int64) | |||||
| loss_masks[:-self.overalapping_eval] = 0 | |||||
| return { | |||||
| 'text': np.array(ids, dtype=np.int64), | |||||
| 'target': np.array(target_ids, dtype=np.int64), | |||||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||||
| 'loss_mask': np.array(loss_masks, dtype=np.int64), | |||||
| 'position_id': np.array(position_ids, dtype=np.int64) | |||||
| } | |||||
| else: | |||||
| loss_masks = [1] * len(tokens) | |||||
| if len(tokens) < self.max_seq_len: | |||||
| tokens = tokens + [0] * (self.max_seq_len - len(tokens)) | |||||
| loss_masks = loss_masks + [0] * ( | |||||
| self.max_seq_len - len(loss_masks)) | |||||
| if idx != 0: | |||||
| loss_masks = np.array(loss_masks, dtype=np.int64) | |||||
| loss_masks[:-self.overalapping_eval] = 0 | |||||
| return { | |||||
| 'text': np.array(tokens, dtype=np.int64), | |||||
| 'loss_mask': np.array(loss_masks, dtype=np.int64) | |||||
| } | |||||
| class LambadaDataset(torch.utils.data.Dataset): | |||||
| def __init__(self, args, tokenizer, strict=True): | |||||
| data_path = args.valid_data[0] | |||||
| print_rank_0( | |||||
| '> building lambada dataset from {} ...'.format(data_path)) | |||||
| self.args = args | |||||
| self.max_seq_length = args.seq_length | |||||
| self.tokenizer = tokenizer | |||||
| self.pad_idx = tokenizer.get_command('pad').Id | |||||
| self.strict = strict | |||||
| self.block_lm = args.block_lm | |||||
| self.unidirectional = args.unidirectional | |||||
| mask_token = 'gMASK' if args.task_mask else 'MASK' | |||||
| self.mask_id = self.tokenizer.get_command(mask_token).Id | |||||
| self.tokens = [] | |||||
| self.labels = [] | |||||
| with open(data_path, 'r') as f: | |||||
| for line in f.readlines(): | |||||
| text = json.loads(line)['text'] | |||||
| tokens, labels = self.get_tokens(text) | |||||
| self.tokens.append(tokens) | |||||
| self.labels.append(labels) | |||||
| def get_tokens(self, text): | |||||
| if not self.strict: | |||||
| tokens = self.tokenizer.EncodeAsIds(text).tokenization | |||||
| return tokens[:-1], [tokens[-1]] | |||||
| last_token = text.split()[-1] | |||||
| start_idx = text.rfind(last_token) | |||||
| beginning_tokens = self.tokenizer.EncodeAsIds( | |||||
| text[:start_idx].strip()).tokenization | |||||
| last_token = self.tokenizer.EncodeAsIds(' ' + last_token).tokenization | |||||
| return beginning_tokens, last_token | |||||
| def __len__(self): | |||||
| return len(self.tokens) | |||||
| def __getitem__(self, idx): | |||||
| tokens, answer = self.tokens[idx], self.labels[idx] | |||||
| if self.block_lm: | |||||
| if self.unidirectional: | |||||
| tokens, answer_tokens = tokens[:1], tokens[1:] + answer | |||||
| else: | |||||
| answer_tokens = answer | |||||
| tokens = tokens + [self.mask_id] | |||||
| num_special_tokens = num_special_tokens_to_add( | |||||
| tokens, | |||||
| None, | |||||
| answer_tokens, | |||||
| add_cls=True, | |||||
| add_sep=False, | |||||
| add_piece=True) | |||||
| left_shift = len(tokens) + len( | |||||
| answer_tokens) + num_special_tokens - self.max_seq_length | |||||
| if left_shift > 0: | |||||
| tokens = tokens[left_shift:] | |||||
| data = build_input_from_ids( | |||||
| tokens, | |||||
| None, | |||||
| answer_tokens, | |||||
| self.max_seq_length, | |||||
| self.tokenizer, | |||||
| args=self.args, | |||||
| add_cls=True, | |||||
| add_sep=False, | |||||
| add_piece=True, | |||||
| mask_id=self.mask_id) | |||||
| ids, types, paddings, position_ids, sep, target_ids, loss_masks = data | |||||
| if self.unidirectional: | |||||
| loss_masks = np.array(loss_masks, dtype=np.int64) | |||||
| last_index = len(loss_masks) | |||||
| while loss_masks[last_index - 1] == 0: | |||||
| last_index -= 1 | |||||
| loss_masks[:last_index - len(answer)] = 0 | |||||
| return { | |||||
| 'text': np.array(ids, dtype=np.int64), | |||||
| 'target': np.array(target_ids, dtype=np.int64), | |||||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||||
| 'loss_mask': np.array(loss_masks, dtype=np.int64), | |||||
| 'position_id': np.array(position_ids, dtype=np.int64) | |||||
| } | |||||
| else: | |||||
| left_shift = len(tokens) - self.max_seq_length | |||||
| if left_shift > 0: | |||||
| tokens = tokens[left_shift:] | |||||
| ids = tokens + answer | |||||
| if len(ids) < self.max_seq_length: | |||||
| ids = ids + [0] * (self.max_seq_length - len(ids)) | |||||
| loss_masks = [0] * len(tokens) + [1] * len(answer) | |||||
| if len(loss_masks) < self.max_seq_length: | |||||
| loss_masks = loss_masks + [0] * ( | |||||
| self.max_seq_length - len(loss_masks)) | |||||
| return { | |||||
| 'text': np.array(ids, dtype=np.int64), | |||||
| 'loss_mask': np.array(loss_masks, dtype=np.int64) | |||||
| } | |||||
| def build_lambada_dataset(tokenizer, args): | |||||
| """Build lambada dataset.""" | |||||
| assert len(args.valid_data) == 1 | |||||
| val_dataset = LambadaDataset(args, tokenizer, strict=True) | |||||
| print_rank_0(' > found {} samples, {} label tokens.'.format( | |||||
| len(val_dataset), sum(map(len, val_dataset.labels)))) | |||||
| return val_dataset | |||||
| def build_lm_dataset(tokenizer, args): | |||||
| documents = [] | |||||
| num_tokens, num_original_tokens = 0, 0 | |||||
| with open(args.valid_data[0], encoding='utf-8') as file: | |||||
| for line in file: | |||||
| tokens = tokenizer.EncodeAsIds(line.strip()).tokenization | |||||
| num_tokens += len(tokens) | |||||
| num_original_tokens += len(line.strip().split(' ')) | |||||
| documents.append(tokens) | |||||
| val_dataset = LMDataset(args, documents, tokenizer, num_original_tokens, | |||||
| num_tokens) | |||||
| print_rank_0( | |||||
| ' > number of document: {}, number of original tokens {}, number of detokenized tokens: {}' | |||||
| .format(len(documents), num_original_tokens, num_tokens)) | |||||
| return val_dataset | |||||
| def build_wikitext103_dataset(tokenizer, args): | |||||
| """""" | |||||
| assert len(args.valid_data) == 1 | |||||
| with open(args.valid_data[0], 'rb') as reader: | |||||
| entire_data = reader.read().decode('utf-8') | |||||
| num_original_tokens = len(entire_data.strip().split(' ')) | |||||
| entire_data = get_detokenizer('wikitext')(entire_data) | |||||
| print_rank_0(entire_data[:1024]) | |||||
| tokenized_data = tokenizer.EncodeAsIds(entire_data).tokenization | |||||
| num_tokenized_tokens = len(tokenized_data) | |||||
| val_dataset = LMDataset(args, [tokenized_data], tokenizer, | |||||
| num_original_tokens, num_tokenized_tokens) | |||||
| print_rank_0(' > number of original tokens: {}, number of detokenized ' | |||||
| 'tokens: {}'.format(num_original_tokens, | |||||
| num_tokenized_tokens)) | |||||
| return val_dataset | |||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import re | |||||
| def ptb_detokenizer(string): | |||||
| string = string.replace(" '", "'") | |||||
| string = string.replace(' \n', '\n') | |||||
| string = string.replace('\n ', '\n') | |||||
| string = string.replace(" n't", "n't") | |||||
| string = string.replace(' N ', '1 ') | |||||
| string = string.replace('$ 1', '$1') | |||||
| string = string.replace('# 1', '#1') | |||||
| return string | |||||
| def wikitext_detokenizer(string): | |||||
| # contractions | |||||
| string = string.replace("s '", "s'") | |||||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||||
| # number separators | |||||
| string = string.replace(' @-@ ', '-') | |||||
| string = string.replace(' @,@ ', ',') | |||||
| string = string.replace(' @.@ ', '.') | |||||
| # punctuation | |||||
| string = string.replace(' : ', ': ') | |||||
| string = string.replace(' ; ', '; ') | |||||
| string = string.replace(' . ', '. ') | |||||
| string = string.replace(' ! ', '! ') | |||||
| string = string.replace(' ? ', '? ') | |||||
| string = string.replace(' , ', ', ') | |||||
| # double brackets | |||||
| string = re.sub(r'\(\s*([^\)]*?)\s*\)', r'(\1)', string) | |||||
| string = re.sub(r'\[\s*([^\]]*?)\s*\]', r'[\1]', string) | |||||
| string = re.sub(r'{\s*([^}]*?)\s*}', r'{\1}', string) | |||||
| string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) | |||||
| string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) | |||||
| # miscellaneous | |||||
| string = string.replace('= = = =', '====') | |||||
| string = string.replace('= = =', '===') | |||||
| string = string.replace('= =', '==') | |||||
| string = string.replace(' ' + chr(176) + ' ', chr(176)) | |||||
| string = string.replace(' \n', '\n') | |||||
| string = string.replace('\n ', '\n') | |||||
| string = string.replace(' N ', ' 1 ') | |||||
| string = string.replace(" 's", "'s") | |||||
| return string | |||||
| def lambada_detokenizer(string): | |||||
| return string | |||||
| def get_detokenizer(dataset): | |||||
| return DETOKENIZERS[dataset] | |||||
| DETOKENIZERS = { | |||||
| 'ptb': ptb_detokenizer, | |||||
| 'wikitext': wikitext_detokenizer, | |||||
| 'lambada': lambada_detokenizer, | |||||
| } | |||||
| @@ -0,0 +1,254 @@ | |||||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """GPT2 zero-shot evaluation.""" | |||||
| import functools | |||||
| import math | |||||
| import mpu | |||||
| import torch | |||||
| from finetune_glm import finetune | |||||
| from pretrain_glm import get_batch | |||||
| from tasks.data_utils import build_data_loader | |||||
| from tasks.language_model.dataset import (build_lambada_dataset, | |||||
| build_lm_dataset, | |||||
| build_wikitext103_dataset) | |||||
| from utils import print_rank_0 | |||||
| global_tokenizer = None | |||||
| def lm_forward_step(data, model, args, timers, mems, eval_metric=None): | |||||
| """Forward step.""" | |||||
| # Get the batch. | |||||
| if timers is not None: | |||||
| timers('batch generator').start() | |||||
| if 'mask' in data: | |||||
| data['attention_mask'] = data.pop('mask') | |||||
| tokens, labels, loss_mask, attention_mask, position_ids = get_batch( | |||||
| data, args) | |||||
| if timers is not None: | |||||
| timers('batch generator').stop() | |||||
| def print_masked_text(batch_id): | |||||
| block_position_ids = position_ids[:, 1] | |||||
| position_ids_ = position_ids[:, 0] | |||||
| output_tokens = [] | |||||
| sep = attention_mask[batch_id].item() | |||||
| for i, token in enumerate(tokens[batch_id, :sep].tolist()): | |||||
| if global_tokenizer is not None: | |||||
| token = global_tokenizer.IdToToken(token) | |||||
| if token.startswith('[MASK'): | |||||
| token = f'[{position_ids_[batch_id, i].item()}, {token}]' | |||||
| if token.startswith('##') and len( | |||||
| output_tokens) > 0 and not output_tokens[-1].endswith( | |||||
| ']'): | |||||
| output_tokens[-1] += token[2:] | |||||
| else: | |||||
| output_tokens.append(token) | |||||
| else: | |||||
| output_tokens.append(str(token)) | |||||
| print(' '.join(output_tokens)) | |||||
| last_index = None | |||||
| for i in range(sep, tokens.size(1)): | |||||
| if global_tokenizer.IdToToken( | |||||
| tokens[batch_id, i].item()).startswith('<|startofpiece'): | |||||
| if last_index is not None: | |||||
| print( | |||||
| global_tokenizer.DecodeIds( | |||||
| tokens[batch_id, last_index:i].tolist()), '|', | |||||
| global_tokenizer.DecodeIds( | |||||
| labels[batch_id, last_index:i].tolist())), | |||||
| print(position_ids_[batch_id, last_index:i].tolist(), | |||||
| block_position_ids[batch_id, last_index:i].tolist()) | |||||
| last_index = i | |||||
| if last_index is not None: | |||||
| print( | |||||
| global_tokenizer.DecodeIds(tokens[batch_id, | |||||
| last_index:].tolist()), '|', | |||||
| global_tokenizer.DecodeIds(labels[batch_id, | |||||
| last_index:].tolist())) | |||||
| print(position_ids_[batch_id, last_index:].tolist(), | |||||
| block_position_ids[batch_id, last_index:].tolist()) | |||||
| # Forward model. | |||||
| if args.continuous_prompt: | |||||
| prompt_pos = data['prompt_pos'].long().cuda() | |||||
| logits, *mems = model( | |||||
| tokens, position_ids, attention_mask, *mems, prompt_pos=prompt_pos) | |||||
| else: | |||||
| logits, *mems = model(tokens, position_ids, attention_mask, *mems) | |||||
| if eval_metric is None or eval_metric == 'loss': | |||||
| losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), | |||||
| labels) | |||||
| loss_mask = loss_mask.view(-1) | |||||
| # The loss is not normalized for fair comparison | |||||
| loss = torch.sum(losses.view(-1) * loss_mask) | |||||
| if eval_metric is None: | |||||
| loss = loss / loss_mask.sum() | |||||
| return loss, mems, 'bert' | |||||
| elif eval_metric == 'accuracy' or eval_metric == 'classify': | |||||
| logits = mpu.gather_from_model_parallel_region(logits) | |||||
| outputs = torch.argmax(logits, -1) | |||||
| correct = (outputs == labels).float() | |||||
| correct[(1 - loss_mask).bool()] = 1 | |||||
| correct = correct.prod(-1) | |||||
| if eval_metric == 'accuracy': | |||||
| correct = correct.sum() | |||||
| return correct, mems, 'bert' | |||||
| else: | |||||
| raise NotImplementedError( | |||||
| 'Metric {} not implemented'.format(eval_metric)) | |||||
| def classify_evaluate(model, dataloader, example_dict, args): | |||||
| """Evaluation.""" | |||||
| # Turn on evaluation mode which disables dropout. | |||||
| model.eval() | |||||
| predictions, labels, examples = [], [], [] | |||||
| with torch.no_grad(): | |||||
| # For all the batches in the dataset. | |||||
| for iteration, batch in enumerate(dataloader): | |||||
| # Forward evaluation. | |||||
| output, _, _ = lm_forward_step( | |||||
| batch, model, args, None, [], eval_metric='classify') | |||||
| uid_list = batch['uid'] | |||||
| example_batch = [example_dict[uid] for uid in uid_list] | |||||
| predictions.extend(output.long().tolist()) | |||||
| label = batch['label'].tolist() | |||||
| labels.extend(label) | |||||
| examples.extend(example_batch) | |||||
| return predictions, labels, examples | |||||
| def evaluate(model, dataloader, eval_metric, args): | |||||
| """Evaluation.""" | |||||
| # Turn on evaluation mode which disables dropout. | |||||
| model.eval() | |||||
| total_output, total_count = 0.0, 0 | |||||
| total_tokens = 0 | |||||
| with torch.no_grad(): | |||||
| # For all the batches in the dataset. | |||||
| for iteration, batch in enumerate(dataloader): | |||||
| if (iteration + 1) % args.log_interval == 0: | |||||
| print_rank_0('> working on iteration: {}'.format(iteration)) | |||||
| # Forward evaluation. | |||||
| output, _, _ = lm_forward_step( | |||||
| batch, model, args, None, [], eval_metric=eval_metric) | |||||
| count = batch['text'].size(0) | |||||
| count = torch.cuda.LongTensor([count]) | |||||
| # Reduce across processes. | |||||
| torch.distributed.all_reduce( | |||||
| output, group=mpu.get_data_parallel_group()) | |||||
| torch.distributed.all_reduce( | |||||
| count, group=mpu.get_data_parallel_group()) | |||||
| total_output += output.item() | |||||
| total_count += count.item() | |||||
| total_tokens += batch['loss_mask'].sum().item() | |||||
| totals = torch.cuda.FloatTensor([total_output, total_tokens]) | |||||
| torch.distributed.all_reduce(totals, group=mpu.get_data_parallel_group()) | |||||
| total_output, total_tokens = totals.tolist() | |||||
| print(total_tokens) | |||||
| return {eval_metric: total_output}, total_count | |||||
| def evaluate_and_print_results(data_loader, model, eval_metric, args): | |||||
| """Evaluate and print results on screen.""" | |||||
| # Evaluate and get results. | |||||
| output, _ = evaluate(model, data_loader, eval_metric, args) | |||||
| string = '' | |||||
| if eval_metric == 'loss': | |||||
| output = output['loss'] | |||||
| num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens | |||||
| num_original_tokens = data_loader.dataset.num_original_tokens | |||||
| val_loss = output / (num_tokenized_tokens - 1) | |||||
| ppl = math.exp(min(20, val_loss)) | |||||
| token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) | |||||
| adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) | |||||
| string += 'avg loss: {:.4E} | '.format(val_loss) | |||||
| string += 'ppl: {:.4E} | '.format(ppl) | |||||
| string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) | |||||
| string += 'token ratio: {} |'.format(token_ratio) | |||||
| score_dict = { | |||||
| 'avg loss': val_loss, | |||||
| 'ppl': ppl, | |||||
| 'adjusted ppl': adjusted_ppl | |||||
| } | |||||
| elif eval_metric == 'accuracy': | |||||
| output = output['accuracy'] | |||||
| num_examples = len(data_loader.dataset) | |||||
| acc = output / num_examples * 100 | |||||
| string += 'number correct: {} | '.format(output) | |||||
| string += 'total examples: {} | '.format(num_examples) | |||||
| string += 'avg accuracy: {:.2f}'.format(acc) | |||||
| score_dict = {'accuracy': acc} | |||||
| else: | |||||
| raise NotImplementedError('evaluation method for {} metric is not ' | |||||
| 'implemented yet.'.format(eval_metric)) | |||||
| length = len(string) + 1 | |||||
| print_rank_0('-' * length) | |||||
| print_rank_0(string) | |||||
| print_rank_0('-' * length) | |||||
| return score_dict | |||||
| def metrics_func_provider(args, tokenizer, is_test): | |||||
| """Privde metrics callback function.""" | |||||
| if args.task.lower() == 'lambda': | |||||
| eval_metric = 'accuracy' | |||||
| dataset = build_lambada_dataset(tokenizer, args) | |||||
| elif args.task == 'wikitext': | |||||
| eval_metric = 'loss' | |||||
| dataset = build_wikitext103_dataset(tokenizer, args) | |||||
| elif args.task == 'language_model': | |||||
| eval_metric = 'loss' | |||||
| dataset = build_lm_dataset(tokenizer, args) | |||||
| else: | |||||
| raise NotImplementedError('{} task is not implemented.'.format( | |||||
| args.task)) | |||||
| # Data stuff | |||||
| dataloader = build_data_loader( | |||||
| dataset, | |||||
| args.eval_batch_size, | |||||
| args.num_workers, | |||||
| drop_last=False, | |||||
| shuffle=False) | |||||
| def metrics_func(model, | |||||
| epoch, | |||||
| output_predictions=False, | |||||
| summary_writer=None): | |||||
| return evaluate_and_print_results( | |||||
| dataloader, model, eval_metric=eval_metric, args=args) | |||||
| global global_tokenizer | |||||
| global_tokenizer = tokenizer | |||||
| return metrics_func | |||||
| def main(args): | |||||
| """Main program.""" | |||||
| finetune( | |||||
| args, | |||||
| None, {}, | |||||
| end_of_epoch_callback_provider=metrics_func_provider, | |||||
| forward_step=lm_forward_step) | |||||
| @@ -0,0 +1,667 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import os | |||||
| import random | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| import torch.utils.data | |||||
| from data_utils.corpora import punctuation_standardization | |||||
| from tasks.data_utils import InputExample | |||||
| from tqdm import tqdm | |||||
| from utils import print_rank_0 | |||||
| def gigaword_detokenize(string, is_target=False): | |||||
| _tok_dict = { | |||||
| '(': '-lrb-', | |||||
| ')': '-rrb-', | |||||
| '[': '-lsb-', | |||||
| ']': '-rsb-', | |||||
| '{': '-lcb-', | |||||
| '}': '-rcb-', | |||||
| '&': '&', | |||||
| '<': '<', | |||||
| '>': '>' | |||||
| } | |||||
| string = string.replace('UNK', '[UNK]') | |||||
| string = string.replace('<unk>', '[UNK]') | |||||
| for key, value in _tok_dict.items(): | |||||
| string = string.replace(value, key) | |||||
| # string = string.replace("''", "\"") | |||||
| # string = string.replace("``", "\"") | |||||
| # string = string.replace("`", "'") | |||||
| # string = string.replace(" n't", "n't") | |||||
| # string = string.replace(" 's", "'s") | |||||
| # string = string.replace(" 'd", "'d") | |||||
| # string = string.replace(" 'll", "'ll") | |||||
| return string | |||||
| def cnndm_detokenize(string, is_target=False): | |||||
| _tok_dict = { | |||||
| '(': '-LRB-', | |||||
| ')': '-RRB-', | |||||
| '[': '-LSB-', | |||||
| ']': '-RSB-', | |||||
| '{': '-LCB-', | |||||
| '}': '-RCB-' | |||||
| } | |||||
| if not is_target: | |||||
| string = string.replace('<S_SEP>', '') | |||||
| else: | |||||
| string = string.replace('<S_SEP>', '[SEP]') | |||||
| for key, value in _tok_dict.items(): | |||||
| string = string.replace(value, key) | |||||
| string = string.replace("''", "\"") | |||||
| string = string.replace('``', "\"") | |||||
| string = string.replace('`', "'") | |||||
| string = string.replace(" n't", "n't") | |||||
| string = string.replace(" 's", "'s") | |||||
| string = string.replace(" 'd", "'d") | |||||
| string = string.replace(" 'll", "'ll") | |||||
| return string | |||||
| def blanklm_detokenize(string, is_target=False): | |||||
| string = string.replace('_UNK', '[UNK]') | |||||
| string = string.replace('<blank>', '[MASK]') | |||||
| return string | |||||
| class SummmaryProcessor: | |||||
| def __init__(self, task, data_dir, tokenizer): | |||||
| self.task = task | |||||
| self.data_dir = data_dir | |||||
| self.tokenizer = tokenizer | |||||
| def create_examples(self, split): | |||||
| if split == 'train': | |||||
| filename = 'train' | |||||
| elif split == 'dev': | |||||
| filename = 'val' | |||||
| elif split == 'test': | |||||
| filename = 'test' | |||||
| else: | |||||
| raise NotImplementedError(split) | |||||
| print_rank_0( | |||||
| f'Creating {self.task}-{split} dataset from {self.data_dir}') | |||||
| if self.task == 'gigaword': | |||||
| detokenizer = gigaword_detokenize | |||||
| elif self.task == 'cnn_dm': | |||||
| detokenizer = cnndm_detokenize | |||||
| else: | |||||
| detokenizer = None | |||||
| source_texts, target_texts = [], [] | |||||
| with open( | |||||
| os.path.join(self.data_dir, f'{filename}.source'), | |||||
| encoding='utf-8') as file: | |||||
| for line in file: | |||||
| line = line.strip() | |||||
| line = punctuation_standardization(line) | |||||
| line = detokenizer(line) if detokenizer else line | |||||
| source_texts.append(line) | |||||
| with open( | |||||
| os.path.join(self.data_dir, f'{filename}.target'), | |||||
| encoding='utf-8') as file: | |||||
| for line in file: | |||||
| line = line.strip() | |||||
| line = punctuation_standardization(line) | |||||
| line = detokenizer( | |||||
| line, is_target=True) if detokenizer else line | |||||
| target_texts.append(line) | |||||
| assert len(source_texts) == len(target_texts) | |||||
| example_list = [] | |||||
| for idx, (source_text, | |||||
| target_text) in enumerate(zip(source_texts, target_texts)): | |||||
| if (idx + 1) % 20000 == 0: | |||||
| print_rank_0(f'Complete {idx + 1} examples') | |||||
| guid = '%s-%s' % (split, idx) | |||||
| meta = { | |||||
| 'ref': | |||||
| self.tokenizer.DecodeIds( | |||||
| self.tokenizer.EncodeAsIds(target_text).tokenization) | |||||
| } | |||||
| example = InputExample( | |||||
| guid=guid, text_a=source_text, text_b=target_text, meta=meta) | |||||
| if idx < 10: | |||||
| print_rank_0( | |||||
| (source_text.encode('utf-8'), target_text.encode('utf-8'), | |||||
| meta['ref'].encode('utf-8'))) | |||||
| example_list.append(example) | |||||
| return example_list | |||||
| class SQuADProcessor: | |||||
| def __init__(self, data_dir, tokenizer): | |||||
| self.data_dir = data_dir | |||||
| self.tokenizer = tokenizer | |||||
| def create_examples(self, split): | |||||
| if split == 'train': | |||||
| filename = 'train.json' | |||||
| elif split == 'dev': | |||||
| filename = 'dev.json' | |||||
| elif split == 'test': | |||||
| filename = 'test.json' | |||||
| else: | |||||
| raise NotImplementedError(split) | |||||
| print_rank_0(f'Creating SQuAD-{split} dataset from {self.data_dir}') | |||||
| example_list = [] | |||||
| idx = 0 | |||||
| with open( | |||||
| os.path.join(self.data_dir, filename), | |||||
| encoding='utf-8') as file: | |||||
| dataset = json.load(file) | |||||
| for paragraphs in dataset: | |||||
| for paragraph in paragraphs['paragraphs']: | |||||
| context = paragraph['context'] | |||||
| for qa in paragraph['qas']: | |||||
| question = qa['question'] | |||||
| answers = {answer['text'] for answer in qa['answers']} | |||||
| answer_starts = { | |||||
| answer['text']: answer['answer_start'] | |||||
| for answer in qa['answers'] | |||||
| } | |||||
| for answer in answers: | |||||
| guid = '%s-%s' % (split, idx) | |||||
| meta = { | |||||
| 'answer_start': | |||||
| answer_starts[answer], | |||||
| 'answer': | |||||
| answer, | |||||
| 'question': | |||||
| question, | |||||
| 'ref': | |||||
| self.tokenizer.DecodeIds( | |||||
| self.tokenizer.EncodeAsIds( | |||||
| question).tokenization) | |||||
| } | |||||
| example = InputExample( | |||||
| guid=guid, text_a=context, meta=meta) | |||||
| if idx < 10: | |||||
| print_rank_0((context.encode('utf-8'), | |||||
| answer.encode('utf-8'), | |||||
| meta['ref'].encode('utf-8'))) | |||||
| example_list.append(example) | |||||
| idx += 1 | |||||
| print_rank_0(f'Creating {len(example_list)} examples for {split}') | |||||
| return example_list | |||||
| class XSumProcessor: | |||||
| def __init__(self, data_dir, tokenizer): | |||||
| self.data_dir = data_dir | |||||
| self.tokenizer = tokenizer | |||||
| def create_examples(self, split): | |||||
| if split == 'train': | |||||
| key = 'train' | |||||
| elif split == 'dev': | |||||
| key = 'validation' | |||||
| elif split == 'test': | |||||
| key = 'test' | |||||
| else: | |||||
| raise NotImplementedError(split) | |||||
| print_rank_0(f'Creating XSUM-{split} dataset from {self.data_dir}') | |||||
| with open( | |||||
| os.path.join( | |||||
| self.data_dir, | |||||
| 'XSum-TRAINING-DEV-TEST-SPLIT-90-5-5.json')) as file: | |||||
| id_list = json.load(file) | |||||
| id_list = id_list[key] | |||||
| source_texts, target_texts = [], [] | |||||
| for i, idx in enumerate(id_list): | |||||
| with open(os.path.join(self.data_dir, f'{idx}.summary')) as file: | |||||
| key, sentences = None, [] | |||||
| source_text, target_text = None, None | |||||
| for line in file: | |||||
| line = line.strip() | |||||
| if line.startswith('[SN]'): | |||||
| if key is not None: | |||||
| if key == 'RESTBODY': | |||||
| source_text = ' '.join(sentences) | |||||
| elif key == 'FIRST-SENTENCE': | |||||
| target_text = ' '.join(sentences) | |||||
| key = line[4:-4] | |||||
| sentences = [] | |||||
| elif line: | |||||
| sentences.append(line) | |||||
| if key is not None: | |||||
| if key == 'RESTBODY': | |||||
| source_text = ' '.join(sentences) | |||||
| elif key == 'FIRST-SENTENCE': | |||||
| target_text = ' '.join(sentences) | |||||
| source_texts.append(source_text) | |||||
| target_texts.append(target_text) | |||||
| if (i + 1) % 1000 == 0: | |||||
| print_rank_0(f'Complete {i + 1} examples') | |||||
| assert len(source_texts) == len(target_texts) | |||||
| example_list = [] | |||||
| for idx, (source_text, | |||||
| target_text) in enumerate(zip(source_texts, target_texts)): | |||||
| if (idx + 1) % 20000 == 0: | |||||
| print_rank_0(f'Complete {idx + 1} examples') | |||||
| guid = '%s-%s' % (split, idx) | |||||
| meta = { | |||||
| 'ref': | |||||
| self.tokenizer.DecodeIds( | |||||
| self.tokenizer.EncodeAsIds(target_text).tokenization) | |||||
| } | |||||
| example = InputExample( | |||||
| guid=guid, text_a=source_text, text_b=target_text, meta=meta) | |||||
| if idx < 10: | |||||
| print_rank_0( | |||||
| (source_text.encode('utf-8'), target_text.encode('utf-8'), | |||||
| meta['ref'].encode('utf-8'))) | |||||
| example_list.append(example) | |||||
| return example_list | |||||
| class Seq2SeqDataset(torch.utils.data.Dataset): | |||||
| def __init__(self, args, split, tokenizer): | |||||
| self.args = args | |||||
| self.task, self.data_dir = args.task.lower(), args.data_dir | |||||
| self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length | |||||
| self.split = split | |||||
| self.tokenizer = tokenizer | |||||
| self.dataset_name = split | |||||
| if self.task in ['gigaword', 'cnn_dm', 'cnn_dm_original']: | |||||
| self.processor = SummmaryProcessor(self.task, self.data_dir, | |||||
| tokenizer) | |||||
| elif self.task in ['xsum']: | |||||
| self.processor = XSumProcessor(self.data_dir, tokenizer) | |||||
| elif self.task in ['squad_generation']: | |||||
| self.processor = SQuADProcessor(self.data_dir, tokenizer) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| example_list = self.processor.create_examples(split) | |||||
| self.example_list = example_list | |||||
| self.examples = {example.guid: example for example in example_list} | |||||
| print_rank_0(f'Return {len(self.examples)} {split} examples') | |||||
| def __len__(self): | |||||
| return len(self.example_list) | |||||
| def __getitem__(self, idx): | |||||
| example = self.example_list[idx] | |||||
| cls_id = self.tokenizer.get_command('ENC').Id | |||||
| mask_token = 'sMASK' if self.args.task_mask else 'MASK' | |||||
| mask_id = self.tokenizer.get_command(mask_token).Id | |||||
| pad_id = self.tokenizer.get_command('pad').Id | |||||
| sop_id = self.tokenizer.get_command('sop').Id | |||||
| eop_id = self.tokenizer.get_command('eop').Id | |||||
| if self.task in ['gigaword', 'cnn_dm', 'cnn_dm_original', 'xsum']: | |||||
| source_text, target_text = example.text_a, example.text_b | |||||
| source_tokens = self.tokenizer.EncodeAsIds( | |||||
| ' ' + source_text).tokenization | |||||
| prompt = [cls_id, mask_id | |||||
| ] + self.tokenizer.EncodeAsIds(' Content:').tokenization | |||||
| if len(source_tokens) > self.max_src_length - len(prompt): | |||||
| source_tokens = source_tokens[:self.max_src_length | |||||
| - len(prompt)] | |||||
| source_tokens = prompt + source_tokens | |||||
| elif self.task == 'squad_generation': | |||||
| source_text = example.text_a | |||||
| target_text, answer = example.meta['question'], example.meta[ | |||||
| 'answer'] | |||||
| source_tokens = self.tokenizer.EncodeAsIds( | |||||
| source_text.rstrip() + ' Question:').tokenization | |||||
| answer_tokens = self.tokenizer.EncodeAsIds(' Answer: ' | |||||
| + answer).tokenization | |||||
| if len(source_tokens | |||||
| ) > self.max_src_length - len(answer_tokens) - 2: | |||||
| max_src_length = self.max_src_length - len(answer_tokens) - 2 | |||||
| answer_pattern = self.tokenizer.EncodeAsIds( | |||||
| ' ' + answer).tokenization | |||||
| def sub_finder(mylist, pattern): | |||||
| matches = [] | |||||
| for i in range(len(mylist)): | |||||
| if mylist[i] == pattern[0] and mylist[ | |||||
| i:i + len(pattern)] == pattern: | |||||
| matches.append(i) | |||||
| return matches | |||||
| answer_indices = sub_finder(source_tokens, answer_pattern) | |||||
| if len(answer_indices) == 0: | |||||
| print(f'Answer {answer} not exists in the source text') | |||||
| source_tokens = source_tokens[:max_src_length] | |||||
| else: | |||||
| start_index = max(answer_indices[0] - max_src_length // 2, | |||||
| 0) | |||||
| source_tokens = source_tokens[start_index:start_index | |||||
| + max_src_length] | |||||
| source_tokens = [cls_id] + source_tokens + [mask_id | |||||
| ] + answer_tokens | |||||
| else: | |||||
| raise NotImplementedError | |||||
| if len(source_tokens) < self.max_src_length: | |||||
| source_tokens = source_tokens + [pad_id] * ( | |||||
| self.max_src_length - len(source_tokens)) | |||||
| sep = len(source_tokens) | |||||
| position_ids = list(range(len(source_tokens))) | |||||
| block_position_ids = [0] * len(source_tokens) | |||||
| mask_pos = source_tokens.index(mask_id) | |||||
| if self.split == 'train': | |||||
| target_tokens = self.tokenizer.EncodeAsIds( | |||||
| ' ' + target_text).tokenization | |||||
| target_tokens = target_tokens + [eop_id] | |||||
| if len(target_tokens) > self.max_tgt_length: | |||||
| target_tokens = target_tokens[:self.max_tgt_length] | |||||
| loss_mask = [1] * len(target_tokens) | |||||
| if len(target_tokens) < self.max_tgt_length: | |||||
| loss_mask += [0] * (self.max_tgt_length - len(target_tokens)) | |||||
| target_tokens += [pad_id] * ( | |||||
| self.max_tgt_length - len(target_tokens)) | |||||
| tokens = source_tokens + [sop_id] + target_tokens[:-1] | |||||
| loss_mask = [0] * len(source_tokens) + loss_mask | |||||
| target_ids = [0] * len(source_tokens) + target_tokens | |||||
| position_ids += [mask_pos] * len(target_tokens) | |||||
| if self.args.no_block_position: | |||||
| block_position_ids += [1] * len(target_tokens) | |||||
| else: | |||||
| block_position_ids += list(range(1, len(target_tokens) + 1)) | |||||
| position_ids = [position_ids, block_position_ids] | |||||
| sample = { | |||||
| 'text': np.array(tokens, dtype=np.int64), | |||||
| 'target': np.array(target_ids, dtype=np.int64), | |||||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||||
| 'loss_mask': np.array(loss_mask, dtype=np.int64), | |||||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||||
| 'uid': example.guid | |||||
| } | |||||
| else: | |||||
| tokens = source_tokens + [sop_id] | |||||
| position_ids = position_ids + [mask_pos] | |||||
| block_position_ids = block_position_ids + [1] | |||||
| position_ids = [position_ids, block_position_ids] | |||||
| sample = { | |||||
| 'text': np.array(tokens, dtype=np.int64), | |||||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||||
| 'uid': example.guid | |||||
| } | |||||
| return sample | |||||
| class ExtractionDataset(torch.utils.data.Dataset): | |||||
| def __init__(self, args, split, tokenizer): | |||||
| self.args = args | |||||
| task, data_dir = args.task.lower(), args.data_dir | |||||
| self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length | |||||
| self.split = split | |||||
| self.tokenizer = tokenizer | |||||
| if split == 'train': | |||||
| filename = 'train' | |||||
| elif split == 'dev': | |||||
| filename = 'valid' | |||||
| elif split == 'test': | |||||
| filename = 'test' | |||||
| else: | |||||
| raise NotImplementedError(split) | |||||
| print_rank_0(f'Creating {task}-{split} dataset from {data_dir}') | |||||
| self.dataset_name = split | |||||
| source_texts, target_texts = [], [] | |||||
| with open( | |||||
| os.path.join(data_dir, f'{filename}.source'), | |||||
| encoding='utf-8') as file: | |||||
| for line in file: | |||||
| line = line.strip() | |||||
| source_texts.append(line) | |||||
| with open( | |||||
| os.path.join(data_dir, f'{filename}.target'), | |||||
| encoding='utf-8') as file: | |||||
| for line in file: | |||||
| line = line.strip() | |||||
| target_texts.append(line) | |||||
| self.examples, self.example_list = {}, [] | |||||
| for idx, (source_text, | |||||
| target_text) in enumerate(zip(source_texts, target_texts)): | |||||
| if (idx + 1) % 20000 == 0: | |||||
| print_rank_0(f'Complete {idx + 1} examples') | |||||
| guid = '%s-%s' % (split, idx) | |||||
| meta = {'ref': target_text} | |||||
| example = InputExample( | |||||
| guid=guid, text_a=source_text, text_b=target_text, meta=meta) | |||||
| self.examples[guid] = example | |||||
| self.example_list.append(example) | |||||
| print_rank_0(f'Return {len(self.examples)} {split} examples') | |||||
| def __len__(self): | |||||
| return len(self.example_list) | |||||
| def __getitem__(self, idx): | |||||
| example = self.example_list[idx] | |||||
| source_text, target_text = example.text_a, example.text_b | |||||
| mask_token = 'MASK' | |||||
| mask_id = self.tokenizer.get_command(mask_token).Id | |||||
| sop_id = self.tokenizer.get_command('sop').Id | |||||
| eop_id = self.tokenizer.get_command('eop').Id | |||||
| pad_id = self.tokenizer.get_command('pad').Id | |||||
| def pad_to(text, max_len, pad_id): | |||||
| if len(text) > max_len: | |||||
| text = text[:max_len] | |||||
| else: | |||||
| text = text + [pad_id] * (max_len - len(text)) | |||||
| return text | |||||
| source_tokens = self.tokenizer.EncodeAsIds(source_text).tokenization | |||||
| masked_tgt = target_text.split('|') | |||||
| source_tokens = pad_to(source_tokens, self.max_src_length, pad_id) | |||||
| sep = len(source_tokens) | |||||
| position_ids = list(range(len(source_tokens))) | |||||
| block_position_ids = [0] * len(source_tokens) | |||||
| if self.split == 'train': | |||||
| mask_positions = [ | |||||
| i for i, x in enumerate(source_tokens) if x == mask_id | |||||
| ] | |||||
| assert len(mask_positions) <= len(masked_tgt) | |||||
| tokens = source_tokens | |||||
| target_ids = [0] * len(source_tokens) | |||||
| loss_mask = [0] * len(source_tokens) | |||||
| for i, mask_pos in enumerate(mask_positions): | |||||
| tgt_text = masked_tgt[i] | |||||
| tgt_tokens = self.tokenizer.EncodeAsIds( | |||||
| ' ' + tgt_text).tokenization | |||||
| tokens += [sop_id] + tgt_tokens | |||||
| target_ids += tgt_tokens + [eop_id] | |||||
| loss_mask += [1] * (len(tgt_tokens) + 1) | |||||
| position_ids += [mask_pos] * (len(tgt_tokens) + 1) | |||||
| block_position_ids += [ | |||||
| i + 1 for i in range(len(tgt_tokens) + 1) | |||||
| ] | |||||
| tokens = pad_to(tokens, self.max_src_length + self.max_tgt_length, | |||||
| pad_id) | |||||
| target_ids = pad_to(target_ids, | |||||
| self.max_src_length + self.max_tgt_length, | |||||
| pad_id) | |||||
| loss_mask = pad_to(loss_mask, | |||||
| self.max_src_length + self.max_tgt_length, 0) | |||||
| position_ids = pad_to(position_ids, | |||||
| self.max_src_length + self.max_tgt_length, 0) | |||||
| block_position_ids = pad_to( | |||||
| block_position_ids, self.max_src_length + self.max_tgt_length, | |||||
| 0) | |||||
| position_ids = [position_ids, block_position_ids] | |||||
| sample = { | |||||
| 'text': np.array(tokens, dtype=np.int64), | |||||
| 'target': np.array(target_ids, dtype=np.int64), | |||||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||||
| 'loss_mask': np.array(loss_mask, dtype=np.int64), | |||||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||||
| 'uid': example.guid | |||||
| } | |||||
| else: | |||||
| tokens = source_tokens + [sop_id] | |||||
| mask_pos = source_tokens.index(mask_id) | |||||
| position_ids = position_ids + [mask_pos] | |||||
| block_position_ids = block_position_ids + [1] | |||||
| position_ids = [position_ids, block_position_ids] | |||||
| sample = { | |||||
| 'text': np.array(tokens, dtype=np.int64), | |||||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||||
| 'uid': example.guid | |||||
| } | |||||
| return sample | |||||
| class BlankLMDataset(torch.utils.data.Dataset): | |||||
| def __init__(self, args, split, tokenizer): | |||||
| self.args = args | |||||
| task, data_dir = args.task.lower(), args.data_dir | |||||
| self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length | |||||
| self.split = split | |||||
| assert args.tokenizer_type == 'BertWordPieceTokenizer' | |||||
| self.tokenizer = tokenizer | |||||
| if split == 'train': | |||||
| filename = 'train' | |||||
| elif split == 'dev': | |||||
| filename = 'valid' | |||||
| elif split == 'test': | |||||
| filename = 'test' | |||||
| else: | |||||
| raise NotImplementedError(split) | |||||
| print_rank_0(f'Creating {task}-{split} dataset from {data_dir}') | |||||
| self.dataset_name = split | |||||
| detokenizer = blanklm_detokenize | |||||
| source_texts, target_texts = [], [] | |||||
| with open( | |||||
| os.path.join(data_dir, f'{filename}.txt'), | |||||
| encoding='utf-8') as file: | |||||
| for line in file: | |||||
| line = line.strip() | |||||
| line = detokenizer(line) if detokenizer else line | |||||
| target_texts.append(line) | |||||
| if split == 'test': | |||||
| with open( | |||||
| os.path.join( | |||||
| data_dir, | |||||
| f'blank/test.maskratio{args.blank_maskratio:.1f}.blank' | |||||
| ), | |||||
| encoding='utf-8') as file: | |||||
| for line in file: | |||||
| line = line.strip() | |||||
| line = detokenizer(line) if detokenizer else line | |||||
| source_texts.append(line) | |||||
| else: | |||||
| source_texts = target_texts | |||||
| self.examples, self.example_list = {}, [] | |||||
| for idx, (source_text, | |||||
| target_text) in enumerate(zip(source_texts, target_texts)): | |||||
| # if idx > 10000: | |||||
| # break | |||||
| if (idx + 1) % 20000 == 0: | |||||
| print_rank_0(f'Complete {idx + 1} examples') | |||||
| guid = '%s-%s' % (split, idx) | |||||
| meta = {'ref': target_text} | |||||
| example = InputExample( | |||||
| guid=guid, text_a=source_text, text_b=target_text, meta=meta) | |||||
| self.examples[guid] = example | |||||
| self.example_list.append(example) | |||||
| print_rank_0(f'Return {len(self.examples)} {split} examples') | |||||
| self.random = random.Random(args.seed) | |||||
| def __len__(self): | |||||
| return len(self.example_list) | |||||
| def __getitem__(self, idx): | |||||
| example = self.example_list[idx] | |||||
| source_text, target_text = example.text_a, example.text_b # noqa | |||||
| mask_token = 'gMASK' if self.args.task_mask else 'MASK' | |||||
| mask_id = self.tokenizer.get_command(mask_token).Id | |||||
| sop_id = self.tokenizer.get_command('sop').Id | |||||
| eop_id = self.tokenizer.get_command('eop').Id | |||||
| pad_id = self.tokenizer.get_command('pad').Id | |||||
| if self.split in ['train', 'dev']: | |||||
| masked_src, masked_tgt = self.mask_text(source_text) | |||||
| source_text = masked_src | |||||
| def pad_to(text, max_len, pad_id): | |||||
| if len(text) > max_len: | |||||
| text = text[:max_len] | |||||
| else: | |||||
| text = text + [pad_id] * (max_len - len(text)) | |||||
| return text | |||||
| source_tokens = self.tokenizer.EncodeAsIds(' ' | |||||
| + source_text).tokenization | |||||
| source_tokens = pad_to(source_tokens, self.max_src_length, pad_id) | |||||
| sep = len(source_tokens) | |||||
| position_ids = list(range(len(source_tokens))) | |||||
| block_position_ids = [0] * len(source_tokens) | |||||
| if self.split in ['train', 'dev']: | |||||
| mask_positions = [ | |||||
| i for i, x in enumerate(source_tokens) if x == mask_id | |||||
| ] | |||||
| assert len(mask_positions) <= len(masked_tgt) | |||||
| tokens = source_tokens | |||||
| target_ids = [0] * len(source_tokens) | |||||
| loss_mask = [0] * len(source_tokens) | |||||
| for i, mask_pos in enumerate(mask_positions): | |||||
| tgt_text = masked_tgt[i] | |||||
| tgt_tokens = self.tokenizer.EncodeAsIds( | |||||
| ' ' + tgt_text).tokenization | |||||
| tokens += [sop_id] + tgt_tokens | |||||
| target_ids += tgt_tokens + [eop_id] | |||||
| loss_mask += [1] * (len(tgt_tokens) + 1) | |||||
| position_ids += [mask_pos] * (len(tgt_tokens) + 1) | |||||
| block_position_ids += [ | |||||
| i + 1 for i in range(len(tgt_tokens) + 1) | |||||
| ] | |||||
| max_length = self.max_src_length + int( | |||||
| self.max_src_length * self.args.blank_maskratio) | |||||
| tokens = pad_to(tokens, max_length, pad_id) | |||||
| target_ids = pad_to(target_ids, max_length, pad_id) | |||||
| loss_mask = pad_to(loss_mask, max_length, 0) | |||||
| position_ids = pad_to(position_ids, max_length, 0) | |||||
| block_position_ids = pad_to(block_position_ids, max_length, 0) | |||||
| position_ids = [position_ids, block_position_ids] | |||||
| sample = { | |||||
| 'text': np.array(tokens, dtype=np.int64), | |||||
| 'target': np.array(target_ids, dtype=np.int64), | |||||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||||
| 'loss_mask': np.array(loss_mask, dtype=np.int64), | |||||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||||
| 'uid': example.guid | |||||
| } | |||||
| else: | |||||
| tokens = source_tokens + [sop_id] | |||||
| mask_pos = source_tokens.index(mask_id) | |||||
| position_ids = position_ids + [mask_pos] | |||||
| block_position_ids = block_position_ids + [1] | |||||
| position_ids = [position_ids, block_position_ids] | |||||
| sample = { | |||||
| 'text': np.array(tokens, dtype=np.int64), | |||||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||||
| 'uid': example.guid | |||||
| } | |||||
| return sample | |||||
| def mask_text(self, text): | |||||
| tokens = text.split() | |||||
| mask_ratio = self.args.blank_maskratio | |||||
| n = len(tokens) | |||||
| indices = sorted(self.random.sample(range(n), int(n * mask_ratio))) | |||||
| masked_src, masked_tgt = '', [] | |||||
| for i, idx in enumerate(indices): | |||||
| if i == 0 or idx != indices[i - 1] + 1: | |||||
| masked_tgt.append('') | |||||
| masked_tgt[-1] += ' ' + tokens[idx] | |||||
| tokens[idx] = '[MASK]' | |||||
| for i, token in enumerate(tokens): | |||||
| if i != 0 and token == '[MASK]' and tokens[i - 1] == '[MASK]': | |||||
| continue | |||||
| masked_src += ' ' + token | |||||
| return masked_src, masked_tgt | |||||
| @@ -0,0 +1,538 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import datetime | |||||
| import random | |||||
| import string | |||||
| import mpu | |||||
| import torch | |||||
| import torch.nn.functional as F | |||||
| from generation_utils import (BeamSearchScorer, LogitsProcessorList, | |||||
| MinLengthLogitsProcessor, | |||||
| NoRepeatNGramLogitsProcessor) | |||||
| from rouge_score import rouge_scorer | |||||
| from utils import print_rank_0 | |||||
| def _is_digit(w): | |||||
| for ch in w: | |||||
| if not (ch.isdigit() or ch == ','): | |||||
| return False | |||||
| return True | |||||
| gigaword_tok_dict = { | |||||
| '(': '-lrb-', | |||||
| ')': '-rrb-', | |||||
| '[': '-lsb-', | |||||
| ']': '-rsb-', | |||||
| '{': '-lcb-', | |||||
| '}': '-rcb-', | |||||
| '[UNK]': 'UNK', | |||||
| '&': '&', | |||||
| '<': '<', | |||||
| '>': '>' | |||||
| } | |||||
| cnndm_tok_dict = { | |||||
| '(': '-LRB-', | |||||
| ')': '-RRB-', | |||||
| '[': '-LSB-', | |||||
| ']': '-RSB-', | |||||
| '{': '-LCB-', | |||||
| '}': '-RCB-' | |||||
| } | |||||
| def fix_tokenization(text, dataset): | |||||
| if dataset == 'cnn_dm_org': | |||||
| return text | |||||
| if dataset == 'gigaword': | |||||
| text = text.replace('[UNK]', 'UNK') | |||||
| return text | |||||
| input_tokens = text.split() | |||||
| output_tokens = [] | |||||
| has_left_quote = False | |||||
| has_left_single_quote = False | |||||
| i = 0 | |||||
| prev_dash = False | |||||
| while i < len(input_tokens): | |||||
| tok = input_tokens[i] | |||||
| flag_prev_dash = False | |||||
| if tok == "\"": | |||||
| if has_left_quote: | |||||
| output_tokens.append("''") | |||||
| else: | |||||
| output_tokens.append('``') | |||||
| has_left_quote = not has_left_quote | |||||
| i += 1 | |||||
| elif tok == "'" and len( | |||||
| output_tokens) > 0 and output_tokens[-1].endswith( | |||||
| 'n') and i < len(input_tokens) - 1 and input_tokens[ | |||||
| i + 1] == 't': # noqa | |||||
| output_tokens[-1] = output_tokens[-1][:-1] | |||||
| output_tokens.append("n't") | |||||
| i += 2 | |||||
| elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[ | |||||
| i + 1] in ('s', 'd', 'll'): | |||||
| output_tokens.append("'" + input_tokens[i + 1]) | |||||
| i += 2 | |||||
| elif tok == "'": | |||||
| if has_left_single_quote: | |||||
| output_tokens.append("'") | |||||
| else: | |||||
| output_tokens.append('`') | |||||
| has_left_single_quote = not has_left_single_quote | |||||
| i += 1 | |||||
| elif tok == '.' and i < len(input_tokens) - 2 and input_tokens[ | |||||
| i + 1] == '.' and input_tokens[i + 2] == '.': | |||||
| output_tokens.append('...') | |||||
| i += 3 | |||||
| elif tok == ',' and len(output_tokens) > 0 and _is_digit( | |||||
| output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit( | |||||
| input_tokens[i + 1]): | |||||
| # $ 3 , 000 -> $ 3,000 | |||||
| output_tokens[-1] += ',' + input_tokens[i + 1] | |||||
| i += 2 | |||||
| elif tok == '.' and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and \ | |||||
| input_tokens[i + 1].isdigit(): | |||||
| # 3 . 03 -> $ 3.03 | |||||
| output_tokens[-1] += '.' + input_tokens[i + 1] | |||||
| i += 2 | |||||
| elif tok == '.' and len(output_tokens) > 0 and len( | |||||
| output_tokens[-1]) == 1 and output_tokens[-1].isalpha( # noqa | |||||
| ) and i < len(input_tokens) - 2 and len( # noqa | |||||
| input_tokens[i + 1]) == 1 and input_tokens[ | |||||
| i + 1].isalpha( # noqa | |||||
| ) and input_tokens[i + 2] == '.': # noqa | |||||
| # U . N . -> U.N. | |||||
| k = i + 3 | |||||
| while k + 2 < len(input_tokens): | |||||
| if len(input_tokens[k + 1]) == 1 and input_tokens[ | |||||
| k + 1].isalpha() and input_tokens[k + 2] == '.': | |||||
| k += 2 | |||||
| else: | |||||
| break | |||||
| output_tokens[-1] += ''.join(input_tokens[i:k]) | |||||
| i = k | |||||
| elif tok == '-': | |||||
| if i < len(input_tokens) - 1 and input_tokens[i + 1] == '-': | |||||
| output_tokens.append('--') | |||||
| i += 2 | |||||
| elif i == len(input_tokens) - 1 or i == 0: | |||||
| output_tokens.append('-') | |||||
| i += 1 | |||||
| elif output_tokens[-1] not in string.punctuation and input_tokens[ | |||||
| i + 1][0] not in string.punctuation: | |||||
| output_tokens[-1] += '-' | |||||
| i += 1 | |||||
| flag_prev_dash = True | |||||
| else: | |||||
| output_tokens.append('-') | |||||
| i += 1 | |||||
| elif prev_dash and len( | |||||
| output_tokens) > 0 and tok[0] not in string.punctuation: | |||||
| output_tokens[-1] += tok | |||||
| i += 1 | |||||
| else: | |||||
| output_tokens.append(tok) | |||||
| i += 1 | |||||
| prev_dash = flag_prev_dash | |||||
| return ' '.join(output_tokens) | |||||
| def count_tokens(tokens): | |||||
| counter = {} | |||||
| for t in tokens: | |||||
| if t in counter.keys(): | |||||
| counter[t] += 1 | |||||
| else: | |||||
| counter[t] = 1 | |||||
| return counter | |||||
| def get_f1(text_a, text_b): | |||||
| tokens_a = text_a.lower().split() | |||||
| tokens_b = text_b.lower().split() | |||||
| if len(tokens_a) == 0 or len(tokens_b) == 0: | |||||
| return 1 if len(tokens_a) == len(tokens_b) else 0 | |||||
| set_a = count_tokens(tokens_a) | |||||
| set_b = count_tokens(tokens_b) | |||||
| match = 0 | |||||
| for token in set_a.keys(): | |||||
| if token in set_b.keys(): | |||||
| match += min(set_a[token], set_b[token]) | |||||
| p = match / len(tokens_a) | |||||
| r = match / len(tokens_b) | |||||
| return 2.0 * p * r / (p + r + 1e-5) | |||||
| def remove_duplicate(l_list, duplicate_rate): | |||||
| tk_list = [l.lower().split() for l in l_list] # noqa | |||||
| r_list = [] | |||||
| history_set = set() | |||||
| for i, w_list in enumerate(tk_list): | |||||
| w_set = set(w_list) | |||||
| if len(w_set & history_set) / len(w_set) <= duplicate_rate: | |||||
| r_list.append(l_list[i]) | |||||
| history_set |= w_set | |||||
| return r_list | |||||
| def rouge_metric(predictions, | |||||
| labels, | |||||
| examples, | |||||
| metric='rouge-1', | |||||
| duplicate_rate=0.7, | |||||
| dataset='cnn_dm'): | |||||
| metric_dict = { | |||||
| 'rouge-1': 'rouge1', | |||||
| 'rouge-2': 'rouge2', | |||||
| 'rouge-l': 'rougeLsum' | |||||
| } | |||||
| refs = [example.meta['ref'] for example in examples] | |||||
| ref_list = [] | |||||
| for ref in refs: | |||||
| ref = ref.strip().split('[SEP]') | |||||
| ref = [fix_tokenization(sentence, dataset=dataset) for sentence in ref] | |||||
| ref = '\n'.join(ref) | |||||
| ref_list.append(ref) | |||||
| pred_list = [] | |||||
| for prediction in predictions: | |||||
| buf = [] | |||||
| for sentence in prediction.strip().split('[SEP]'): | |||||
| sentence = fix_tokenization(sentence, dataset=dataset) | |||||
| if any(get_f1(sentence, s) > 1.0 for s in buf): | |||||
| continue | |||||
| s_len = len(sentence.split()) | |||||
| if s_len <= 4: | |||||
| continue | |||||
| buf.append(sentence) | |||||
| if duplicate_rate and duplicate_rate < 1: | |||||
| buf = remove_duplicate(buf, duplicate_rate) | |||||
| line = '\n'.join(buf) | |||||
| pred_list.append(line) | |||||
| if torch.distributed.get_rank() == 0: | |||||
| import json | |||||
| with open('./results.json', 'w') as output: | |||||
| for ref, pred in zip(ref_list, pred_list): | |||||
| output.write(json.dumps({'ref': ref, 'pred': pred}) + '\n') | |||||
| scorer = rouge_scorer.RougeScorer([metric_dict[metric]], use_stemmer=True) | |||||
| scores = [ | |||||
| scorer.score(pred, ref) for pred, ref in zip(pred_list, ref_list) | |||||
| ] | |||||
| scores = [score[metric_dict[metric]].fmeasure for score in scores] | |||||
| scores = sum(scores) / len(scores) | |||||
| return scores | |||||
| def process_batch(batch, args): | |||||
| """Process batch and produce inputs for the model.""" | |||||
| tokens = batch['text'].long().cuda() | |||||
| attention_mask = batch['attention_mask'].long().cuda() | |||||
| position_ids = batch['position_id'].long().cuda() | |||||
| return tokens, attention_mask, position_ids | |||||
| class DecoderEvaluater: | |||||
| def __init__(self, args, tokenizer): | |||||
| self.tokenizer = tokenizer | |||||
| self.start_token = tokenizer.get_command('sop').Id | |||||
| self.end_token = tokenizer.get_command('eop').Id | |||||
| self.mask_token = tokenizer.get_command( | |||||
| 'sMASK').Id if args.task_mask else tokenizer.get_command('MASK').Id | |||||
| self.pad_token = tokenizer.get_command('pad').Id | |||||
| self.processors = LogitsProcessorList() | |||||
| if args.min_tgt_length > 0: | |||||
| processor = MinLengthLogitsProcessor(args.min_tgt_length, | |||||
| self.end_token) | |||||
| self.processors.append(processor) | |||||
| if args.no_repeat_ngram_size > 0: | |||||
| processor = NoRepeatNGramLogitsProcessor(args.no_repeat_ngram_size) | |||||
| self.processors.append(processor) | |||||
| def evaluate(self, model, dataloader, example_dict, args): | |||||
| """Calculate correct over total answers and return prediction if the | |||||
| `output_predictions` is true.""" | |||||
| model.eval() | |||||
| store = torch.distributed.TCPStore(args.master_ip, | |||||
| 18931 + random.randint(0, 10000), | |||||
| mpu.get_data_parallel_world_size(), | |||||
| torch.distributed.get_rank() == 0, | |||||
| datetime.timedelta(seconds=30)) | |||||
| print_rank_0('Distributed store created') | |||||
| with torch.no_grad(): | |||||
| # For all the batches in the dataset. | |||||
| for idx, data in enumerate(dataloader): | |||||
| tokens, attention_mask, position_ids = process_batch( | |||||
| data, args) | |||||
| batch_size = tokens.size(0) | |||||
| beam_scorer = BeamSearchScorer( | |||||
| batch_size=batch_size, | |||||
| max_length=args.out_seq_length, | |||||
| num_beams=args.num_beams, | |||||
| device=tokens.device, | |||||
| length_penalty=args.length_penalty, | |||||
| do_early_stopping=False, | |||||
| ) | |||||
| beam_scores = torch.zeros((batch_size, args.num_beams), | |||||
| dtype=torch.float, | |||||
| device=tokens.device) | |||||
| beam_scores[:, 1:] = -1e9 | |||||
| beam_scores = beam_scores.view((batch_size * args.num_beams, )) | |||||
| # Run the model forward. | |||||
| counter = 0 | |||||
| while counter < args.tgt_seq_length: | |||||
| if counter == 0: | |||||
| next_token_logits, *mems = model( | |||||
| tokens, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| return_memory=True) | |||||
| seq_length = next_token_logits.size(1) | |||||
| next_token_logits = next_token_logits[:, -1] | |||||
| next_token_logits = next_token_logits.unsqueeze( | |||||
| 1).repeat(1, args.num_beams, | |||||
| 1).view(batch_size * args.num_beams, -1) | |||||
| mems = [ | |||||
| mem.unsqueeze(1).repeat( | |||||
| 1, args.num_beams, 1, | |||||
| 1).view(batch_size * args.num_beams, | |||||
| seq_length, -1) for mem in mems | |||||
| ] | |||||
| position_ids = tokens.new_ones(batch_size, | |||||
| args.num_beams, 2, 1) | |||||
| for i, text in enumerate(tokens.tolist()): | |||||
| mask_pos = text.index(self.mask_token) | |||||
| position_ids[i, :, 0] = mask_pos | |||||
| position_ids = position_ids.reshape( | |||||
| batch_size * args.num_beams, 2, 1) | |||||
| tokens = tokens.new_zeros(batch_size * args.num_beams, | |||||
| 0) | |||||
| attention_mask = tokens.new_zeros( | |||||
| [batch_size * args.num_beams]) | |||||
| else: | |||||
| if not args.no_block_position: | |||||
| position_ids[:, 1] = counter + 1 | |||||
| last_token = tokens[:, -1:] | |||||
| next_token_logits, *mems = model( | |||||
| last_token, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| *mems, | |||||
| return_memory=True) | |||||
| next_token_logits = next_token_logits[:, -1] | |||||
| next_token_scores = F.log_softmax( | |||||
| next_token_logits, dim=-1) | |||||
| next_token_scores = self.processors( | |||||
| tokens, next_token_scores) | |||||
| next_token_scores = next_token_scores + beam_scores[:, None].expand_as( | |||||
| next_token_scores) | |||||
| vocab_size = next_token_scores.shape[-1] | |||||
| next_token_scores = next_token_scores.view( | |||||
| batch_size, args.num_beams * vocab_size) | |||||
| probs = F.softmax(next_token_scores, dim=-1) | |||||
| if args.select_topk: | |||||
| _, next_tokens = torch.topk( | |||||
| probs, k=2 * args.num_beams, dim=-1, largest=True) | |||||
| else: | |||||
| next_tokens = torch.multinomial( | |||||
| probs, num_samples=2 * args.num_beams) | |||||
| next_token_scores = torch.gather(next_token_scores, -1, | |||||
| next_tokens) | |||||
| next_token_scores, _indices = torch.sort( | |||||
| next_token_scores, descending=True, dim=1) | |||||
| next_tokens = torch.gather(next_tokens, -1, _indices) | |||||
| next_indices = next_tokens // vocab_size | |||||
| next_tokens = next_tokens % vocab_size | |||||
| # stateless | |||||
| beam_outputs = beam_scorer.process( | |||||
| tokens, | |||||
| next_token_scores, | |||||
| next_tokens, | |||||
| next_indices, | |||||
| eos_token_id=self.end_token, | |||||
| pad_token_id=self.pad_token) | |||||
| beam_scores = beam_outputs['next_beam_scores'] | |||||
| beam_next_tokens = beam_outputs['next_beam_tokens'] | |||||
| beam_idx = beam_outputs['next_beam_indices'] | |||||
| beam_next_tokens = beam_next_tokens.unsqueeze(-1) | |||||
| tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], | |||||
| dim=-1) | |||||
| mems = [mem[beam_idx] for mem in mems] if mems else [] | |||||
| if beam_scorer.is_done: | |||||
| break | |||||
| counter += 1 | |||||
| tokens, _ = beam_scorer.finalize( | |||||
| tokens, | |||||
| beam_scores, | |||||
| next_tokens, | |||||
| next_indices, | |||||
| eos_token_id=self.end_token, | |||||
| pad_token_id=self.pad_token) | |||||
| predictions = [] | |||||
| for text in tokens.tolist(): | |||||
| text = [ | |||||
| token for token in text | |||||
| if token not in [self.end_token, self.pad_token] | |||||
| ] | |||||
| text = self.tokenizer.DecodeIds(text) | |||||
| predictions.append(text) | |||||
| uid_list = data['uid'] | |||||
| if isinstance(uid_list, torch.Tensor): | |||||
| uid_list = uid_list.cpu().numpy().tolist() | |||||
| for uid, prediction in zip(uid_list, predictions): | |||||
| store.set(uid, prediction) | |||||
| if (idx + 1) % args.log_interval == 0: | |||||
| print_rank_0(f'Iteration {idx + 1} / {len(dataloader)}') | |||||
| model.train() | |||||
| torch.distributed.barrier() | |||||
| print_rank_0('Evaluation completed') | |||||
| predictions, examples = [], [] | |||||
| for uid, example in example_dict.items(): | |||||
| predictions.append(store.get(uid).decode('utf-8')) | |||||
| examples.append(example) | |||||
| torch.distributed.barrier() | |||||
| return predictions, [], examples | |||||
| def blanklm_fix_tokenization(text): | |||||
| text = text.replace('` `', '``') | |||||
| text = text.replace("\' \'", "\'\'") | |||||
| text = text.replace("n \' t", "n\'t") | |||||
| text = text.replace("\' s", "\'s") | |||||
| text = text.replace("\' m", "\'m") | |||||
| text = text.replace("\' re", "\'re") | |||||
| text = text.replace('. . .', '...') | |||||
| text = text.replace(' . .', ' ..') | |||||
| text = text.replace('- -', '--') | |||||
| text = text.replace('u . s .', 'u.s.') | |||||
| text = text.replace('u . k .', 'u.k.') | |||||
| text = text.replace('e . g .', 'e.g.') | |||||
| return text | |||||
| class BlankLMEvaluater(DecoderEvaluater): | |||||
| def evaluate(self, model, dataloader, example_dict, args): | |||||
| model.eval() | |||||
| store = torch.distributed.TCPStore(args.master_ip, | |||||
| 18931 + random.randint(0, 10000), | |||||
| mpu.get_data_parallel_world_size(), | |||||
| torch.distributed.get_rank() == 0, | |||||
| datetime.timedelta(seconds=30)) | |||||
| print_rank_0('Distributed store created') | |||||
| with torch.no_grad(): | |||||
| for idx, data in enumerate(dataloader): | |||||
| tokens, attention_mask, position_ids = process_batch( | |||||
| data, args) | |||||
| src_tokens = tokens | |||||
| batch_size = tokens.size(0) | |||||
| mask_positions = [] | |||||
| current_mask = [] | |||||
| for text in tokens.tolist(): | |||||
| mask_positions.append([ | |||||
| i for i, x in enumerate(text) if x == self.mask_token | |||||
| ]) | |||||
| current_mask.append(0) | |||||
| # print(self.tokenizer.DecodeIds(text)) | |||||
| # print(mask_positions[-1]) | |||||
| counter = 0 | |||||
| done = [False] * batch_size | |||||
| while counter < args.tgt_seq_length: | |||||
| if counter == 0: | |||||
| # print(tokens) | |||||
| # print(position_ids) | |||||
| next_token_logits, *mems = model( | |||||
| tokens, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| return_memory=True) | |||||
| next_token_logits = next_token_logits[:, -1] | |||||
| position_ids = tokens.new_ones(batch_size, 2, 1) | |||||
| for i, text in enumerate(tokens.tolist()): | |||||
| mask_pos = mask_positions[i][current_mask[i]] | |||||
| position_ids[i, 0] = mask_pos | |||||
| tokens = tokens.new_zeros(batch_size, 0) | |||||
| attention_mask = tokens.new_zeros(batch_size) | |||||
| else: | |||||
| position_ids[:, 1] = position_ids[:, 1] + 1 | |||||
| last_token = tokens[:, -1:] | |||||
| next_token_logits, *mems = model( | |||||
| last_token, | |||||
| position_ids, | |||||
| attention_mask, | |||||
| *mems, | |||||
| return_memory=True) | |||||
| next_token_logits = next_token_logits[:, -1] | |||||
| next_token_scores = F.log_softmax( | |||||
| next_token_logits, dim=-1) | |||||
| next_token_scores = self.processors( | |||||
| tokens, next_token_scores) | |||||
| next_tokens = next_token_scores.max(dim=-1)[1] | |||||
| # print(self.tokenizer.DecodeIds(next_tokens.tolist())) | |||||
| for i, next_token in enumerate(next_tokens.tolist()): | |||||
| if next_token == self.end_token: | |||||
| if current_mask[i] + 1 < len(mask_positions[i]): | |||||
| current_mask[i] += 1 | |||||
| next_tokens[i] = self.start_token | |||||
| position_ids[i, 0] = mask_positions[i][ | |||||
| current_mask[i]] | |||||
| position_ids[i, 1] = 0 | |||||
| else: | |||||
| done[i] = True | |||||
| if done[i]: | |||||
| next_tokens[i] = self.pad_token | |||||
| if all(done): | |||||
| break | |||||
| tokens = torch.cat( | |||||
| [tokens, next_tokens.unsqueeze(-1)], dim=-1) | |||||
| counter += 1 | |||||
| predictions = [] | |||||
| for i, text in enumerate(tokens.tolist()): | |||||
| text = [ | |||||
| token for token in text | |||||
| if token not in [self.end_token, self.pad_token] | |||||
| ] | |||||
| blanks = [[]] | |||||
| for token in text: | |||||
| if token == self.start_token: | |||||
| blanks.append([]) | |||||
| else: | |||||
| blanks[-1].append(token) | |||||
| output_tokens = [] | |||||
| current_blank = 0 | |||||
| for token in src_tokens[i].tolist(): | |||||
| if token == self.mask_token: | |||||
| if current_blank < len(blanks): | |||||
| output_tokens += blanks[current_blank] | |||||
| current_blank += 1 | |||||
| else: | |||||
| if token not in [self.pad_token]: | |||||
| output_tokens.append(token) | |||||
| text = self.tokenizer.DecodeIds(output_tokens[:-1]) | |||||
| text = blanklm_fix_tokenization(text) | |||||
| predictions.append(text) | |||||
| # print(text) | |||||
| uid_list = data['uid'] | |||||
| if isinstance(uid_list, torch.Tensor): | |||||
| uid_list = uid_list.cpu().numpy().tolist() | |||||
| for uid, prediction in zip(uid_list, predictions): | |||||
| store.set(uid, prediction) | |||||
| if (idx + 1) % args.log_interval == 0: | |||||
| print_rank_0(f'Iteration {idx + 1} / {len(dataloader)}') | |||||
| model.train() | |||||
| torch.distributed.barrier() | |||||
| print_rank_0('Evaluation completed') | |||||
| predictions, examples = [], [] | |||||
| for uid, example in example_dict.items(): | |||||
| predictions.append(store.get(uid).decode('utf-8')) | |||||
| examples.append(example) | |||||
| torch.distributed.barrier() | |||||
| return predictions, [], examples | |||||
| @@ -0,0 +1,151 @@ | |||||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Race.""" | |||||
| import functools | |||||
| from collections import OrderedDict | |||||
| import mpu | |||||
| import torch | |||||
| from finetune_glm import finetune | |||||
| from pretrain_glm import get_batch | |||||
| from tasks.eval_utils import accuracy_func_provider | |||||
| from tasks.seq2seq.dataset import (BlankLMDataset, ExtractionDataset, | |||||
| Seq2SeqDataset) | |||||
| from tasks.seq2seq.evaluate import (BlankLMEvaluater, DecoderEvaluater, | |||||
| rouge_metric) | |||||
| global_tokenizer = None | |||||
| def seq2seq_forward_step(data, model, args, timers, mems): | |||||
| """Forward step.""" | |||||
| # Get the batch. | |||||
| if timers is not None: | |||||
| timers('batch generator').start() | |||||
| tokens, labels, loss_mask, attention_mask, position_ids = get_batch( | |||||
| data, args) | |||||
| if timers is not None: | |||||
| timers('batch generator').stop() | |||||
| # Forward model. | |||||
| logits, *mems = model(tokens, position_ids, attention_mask, *mems) | |||||
| # logits, loss_mask = logits[:, args.src_seq_length:], loss_mask[:, args.src_seq_length:] | |||||
| # target_ids = target_ids[:, args.src_seq_length:] | |||||
| losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), | |||||
| labels) | |||||
| if args.label_smoothing > 0.0: | |||||
| epsilon = args.label_smoothing | |||||
| smooth_loss = -torch.nn.functional.log_softmax( | |||||
| logits, dim=-1).mean(dim=-1) | |||||
| losses = (1 - epsilon) * losses + epsilon * smooth_loss | |||||
| loss_mask = loss_mask.reshape(-1) | |||||
| # The loss is not normalized for fair comparison | |||||
| loss = torch.sum(losses.reshape(-1) * loss_mask) / loss_mask.sum() | |||||
| return loss, mems, 'bert' | |||||
| def train_valid_datasets_provider(args, tokenizer): | |||||
| """Provide train and validation datasets.""" | |||||
| if args.task.lower() == 'blank': | |||||
| train_dataset = BlankLMDataset( | |||||
| args, split='train', tokenizer=tokenizer) | |||||
| valid_dataset = None | |||||
| elif args.task.lower() == 'extraction': | |||||
| train_dataset = ExtractionDataset( | |||||
| args, split='train', tokenizer=tokenizer) | |||||
| valid_dataset = None | |||||
| else: | |||||
| train_dataset = Seq2SeqDataset( | |||||
| args, split='train', tokenizer=tokenizer) | |||||
| valid_dataset = None | |||||
| global global_tokenizer | |||||
| global_tokenizer = tokenizer | |||||
| return train_dataset, valid_dataset | |||||
| def metrics_func_provider(args, tokenizer, is_test): | |||||
| """Provide metrics callback function.""" | |||||
| def single_dataset_provider(split): | |||||
| if args.task.lower() == 'blank': | |||||
| return BlankLMDataset(args, split=split, tokenizer=tokenizer) | |||||
| elif args.task.lower() == 'extraction': | |||||
| return ExtractionDataset(args, split=split, tokenizer=tokenizer) | |||||
| else: | |||||
| return Seq2SeqDataset(args, split=split, tokenizer=tokenizer) | |||||
| if args.task.lower() in ['blank', 'extraction']: | |||||
| evaluater = BlankLMEvaluater(args, tokenizer) | |||||
| eval_func = evaluater.evaluate | |||||
| metric_dict = {} | |||||
| else: | |||||
| evaluater = DecoderEvaluater(args, tokenizer) | |||||
| eval_func = evaluater.evaluate | |||||
| if args.tokenizer_type == 'BertWordPieceTokenizer': | |||||
| dataset = 'cnn_dm' | |||||
| elif args.task.lower() == 'gigaword': | |||||
| dataset = 'gigaword' | |||||
| else: | |||||
| dataset = 'cnn_dm_org' | |||||
| metric_dict = OrderedDict({ | |||||
| 'rouge-1': | |||||
| functools.partial(rouge_metric, metric='rouge-1', dataset=dataset), | |||||
| 'rouge-2': | |||||
| functools.partial(rouge_metric, metric='rouge-2', dataset=dataset), | |||||
| 'rouge-l': | |||||
| functools.partial(rouge_metric, metric='rouge-l', dataset=dataset) | |||||
| }) | |||||
| def output_func(predictions, examples, output_file): | |||||
| with open(output_file + '.hyps', 'w', encoding='utf-8') as output: | |||||
| for prediction in predictions: | |||||
| output.write(prediction) | |||||
| output.write('\n') | |||||
| with open(output_file + '.refs', 'w', encoding='utf-8') as output: | |||||
| for example in examples: | |||||
| output.write(example.meta['ref']) | |||||
| output.write('\n') | |||||
| if args.task.lower() == 'squad_generation': | |||||
| with open( | |||||
| output_file + '.source', 'w', encoding='utf-8') as output: | |||||
| for example in examples: | |||||
| output.write( | |||||
| example.text_a.replace('\n', ' ') + ' Answer: ' | |||||
| + example.meta['answer']) | |||||
| output.write('\n') | |||||
| return accuracy_func_provider( | |||||
| single_dataset_provider, | |||||
| metric_dict, | |||||
| args, | |||||
| is_test=is_test, | |||||
| eval_func=eval_func, | |||||
| output_func=output_func, | |||||
| only_rank0=False) | |||||
| def main(args): | |||||
| if args.src_seq_length > args.max_position_embeddings: | |||||
| args.max_position_embeddings = args.src_seq_length | |||||
| if args.task.lower() in [ | |||||
| 'cnn_dm', 'cnn_dm_original', 'gigaword', 'blank', | |||||
| 'squad_generation', 'xsum', 'extraction' | |||||
| ]: | |||||
| finetune( | |||||
| args, | |||||
| train_valid_datasets_provider, {}, | |||||
| end_of_epoch_callback_provider=metrics_func_provider, | |||||
| forward_step=seq2seq_forward_step) | |||||
| else: | |||||
| raise NotImplementedError(args.task) | |||||
| @@ -0,0 +1,137 @@ | |||||
| # Use GLM for your NLU tasks | |||||
| To use GLM for your own NLU tasks, you should implement a subclass of `DataProcessor` in [tasks/superglue/dataset.py](dataset.py) and a subclass of `PVP` in [tasks/superglue/pvp.py](pvp.py). You should also specify the We will take the RTE and ReCoRD tasks in SuperGLUE as an example. | |||||
| ## 1. Design your patterns | |||||
| RTE is an NLI task in which the model is required to predict text entailment between a premise and a hypothesis. The label can be `entailment` or `not_entailment` One sample from the training set is | |||||
| ``` | |||||
| premise: No Weapons of Mass Destruction Found in Iraq Yet. | |||||
| hypothesis: Weapons of Mass Destruction Found in Iraq. | |||||
| label: not_entailment | |||||
| ``` | |||||
| We design the pattern as | |||||
| ``` | |||||
| "`hypothesis`"?, [MASK], "`premise`" | |||||
| ``` | |||||
| GLM predicts "Yes" for `entailment` and "No" for `not_entailment`. "Yes" and "No" are called verbalizers for `entailment` and `not_entailment`. | |||||
| ReCoRD is a multi-choice QA task. Each example consists of a news article and a Cloze-style question about the article in which one entity is masked out. The system must predict the masked out entity from a list of possible entities in the provided passage. We directly adopt the cloze-style question as our pattern and use GLM to predict the masked entity. | |||||
| ## 2. Implement subclass of `DataProcessor` | |||||
| A subclass of `DataProcessor` should implement `get_train_examples`, `get_dev_examples` and `get_test_examples`, which return the examples of the train, dev, and test sets. The returned value is a list of `InputExample`. It should also implement `get_labels` to return the list of possible labels. Hete we take the `RTEProcessor` as an example: | |||||
| ```python | |||||
| class RteProcessor(DataProcessor): | |||||
| """Processor for the RTE data set.""" | |||||
| def get_train_examples(self, data_dir): | |||||
| return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") | |||||
| def get_dev_examples(self, data_dir, for_train=False): | |||||
| return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") | |||||
| def get_test_examples(self, data_dir): | |||||
| return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") | |||||
| def get_unlabeled_examples(self, data_dir): | |||||
| return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") | |||||
| def get_labels(self): | |||||
| return ["entailment", "not_entailment"] | |||||
| def _create_examples(self, path: str, set_type: str, hypothesis_name: str = "hypothesis", | |||||
| premise_name: str = "premise") -> List[InputExample]: | |||||
| examples = [] | |||||
| with open(path, encoding='utf8') as f: | |||||
| for line_idx, line in enumerate(f): | |||||
| example_json = json.loads(line) | |||||
| idx = example_json['idx'] | |||||
| if isinstance(idx, str): | |||||
| try: | |||||
| idx = int(idx) | |||||
| except ValueError: | |||||
| idx = line_idx | |||||
| label = example_json.get('label') | |||||
| guid = "%s-%s" % (set_type, idx) | |||||
| text_a = example_json[premise_name] | |||||
| text_b = example_json[hypothesis_name] | |||||
| example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx) | |||||
| examples.append(example) | |||||
| return examples | |||||
| ``` | |||||
| After that, you should add the implemented class to ``PROCESSORS`` at the end of [tasks/superglue/dataset.py](dataset.py): | |||||
| ```python | |||||
| PROCESSORS = { | |||||
| ... | |||||
| "rte": RteProcessor | |||||
| } | |||||
| ``` | |||||
| ## 3. Implement subclass of `PVP` | |||||
| To implement a subclass of `PVP`, you should first decide your verbalizers is single-token or multi-token. The verbalizers in RTE, "Yes" and "No" are single-token. Instead, the verbalizers in ReCoRD are multi-token, as one entity can be tokenized into multiple tokens with WordPiece or BPE tokenizer. | |||||
| For single-token task, you should set `is_multi_token=False` in the class definition. You should implement `get_parts` to return the inputs to GLM given an example and `verbalize` to return the verbalizer given a label. Take `RTEPVP` as an example: | |||||
| ```python | |||||
| class RtePVP(PVP): | |||||
| is_multi_token = False | |||||
| VERBALIZER = { | |||||
| "not_entailment": [" No"], | |||||
| "entailment": [" Yes"] | |||||
| } | |||||
| @property | |||||
| def spell_length(self): | |||||
| return self.pattern_id | |||||
| def get_parts(self, example: InputExample) -> FilledPattern: | |||||
| # switch text_a and text_b to get the correct order | |||||
| text_a = example.text_a | |||||
| text_b = example.text_b.rstrip(string.punctuation) | |||||
| return ['"', self.shortenable(text_b), '" ?'], [[self.mask], ', "', self.shortenable(text_a), '"'] | |||||
| def verbalize(self, label) -> List[str]: | |||||
| return RtePVP.VERBALIZER[label] | |||||
| ``` | |||||
| We use `PvP.shortenable` to mark the segments that can be truncated when exceeding the maximum sequence length. | |||||
| For multi-token task, you should set `is_multi_token=True` in the class definition. You should implement `get_parts` to return the inputs to GLM given an example and `get_answers` to return the candidates. Take `ReCoRDPVP` as an example: | |||||
| ```python | |||||
| class RecordPVP(PVP): | |||||
| is_multi_token = True | |||||
| def get_answers(self, example: InputExample): | |||||
| choices = example.meta['candidates'] | |||||
| choices = [" " + choice for choice in choices] | |||||
| return choices | |||||
| def get_parts(self, example: InputExample) -> FilledPattern: | |||||
| premise = self.shortenable(example.text_a) | |||||
| assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token' | |||||
| question_a, question_b = example.text_b.split('@placeholder') | |||||
| return [premise, " " + question_a.rstrip(), [self.mask], question_b], [] | |||||
| ``` | |||||
| After that, you should implement the class to `PVPS` at the end of [tasks/superglue/pvp.py](pvp.py): | |||||
| ```python | |||||
| PVPS = { | |||||
| ... | |||||
| 'rte': RtePVP, | |||||
| 'record': RecordPVP | |||||
| } | |||||
| ``` | |||||
| ## 4. Run the experiment | |||||
| To run the experiment for your new task, you should create a config file like [config_tasks/task_rte.sh](/config_tasks/task_rte.sh). You should also specify the evaluation metrics for the task in `DEFAULT_METRICS` of [tasks/superglue/finetune.py](finetune.py): | |||||
| ```python | |||||
| DEFAULT_METRICS = { | |||||
| ... | |||||
| "record": [("EM", qa_exact_match), ("F1", qa_f1)], | |||||
| "rte": [("accuracy", accuracy_metric)] | |||||
| } | |||||
| ``` | |||||
| Then you can run the experiment with [finetune_superglue.sh](/scripts/finetune_superglue.sh): | |||||
| ```shell | |||||
| bash scripts/finetune_superglue.sh \ | |||||
| config_tasks/model_blocklm_large.sh \ | |||||
| config_tasks/task_rte.sh | |||||
| ``` | |||||
| @@ -0,0 +1,101 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| """ | |||||
| Official evaluation script for ReCoRD v1.0. | |||||
| (Some functions are adopted from the SQuAD evaluation script.) | |||||
| """ | |||||
| from __future__ import print_function | |||||
| import functools | |||||
| import re | |||||
| import string | |||||
| from collections import Counter, defaultdict | |||||
| from typing import List | |||||
| from tasks.data_utils import InputExample | |||||
| def normalize_answer(s): | |||||
| """Lower text and remove punctuation, articles and extra whitespace.""" | |||||
| def remove_articles(text): | |||||
| return re.sub(r'\b(a|an|the)\b', ' ', text) | |||||
| def white_space_fix(text): | |||||
| return ' '.join(text.split()) | |||||
| def remove_punc(text): | |||||
| exclude = set(string.punctuation) | |||||
| return ''.join(ch for ch in text if ch not in exclude) | |||||
| def lower(text): | |||||
| return text.lower() | |||||
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |||||
| def f1_score(prediction, ground_truth): | |||||
| prediction_tokens = normalize_answer(prediction).split() | |||||
| ground_truth_tokens = normalize_answer(ground_truth).split() | |||||
| common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |||||
| num_same = sum(common.values()) | |||||
| if num_same == 0: | |||||
| return 0 | |||||
| precision = 1.0 * num_same / len(prediction_tokens) | |||||
| recall = 1.0 * num_same / len(ground_truth_tokens) | |||||
| f1 = (2 * precision * recall) / (precision + recall) | |||||
| return f1 | |||||
| def exact_match_score(prediction, ground_truth): | |||||
| return normalize_answer(prediction) == normalize_answer(ground_truth) | |||||
| def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): | |||||
| if not ground_truths: | |||||
| return 0.0 | |||||
| scores_for_ground_truths = [] | |||||
| for ground_truth in ground_truths: | |||||
| score = metric_fn(prediction, ground_truth) | |||||
| scores_for_ground_truths.append(score) | |||||
| return max(scores_for_ground_truths) | |||||
| def qa_evaluate(predictions, labels, examples: List[InputExample], metric): | |||||
| assert len(examples) == len(predictions) | |||||
| score = 0.0 | |||||
| for example, prediction in zip(examples, predictions): | |||||
| ground_truths = example.meta['answers'] | |||||
| prediction = example.meta['candidates'][prediction] | |||||
| if ground_truths: | |||||
| score += metric_max_over_ground_truths(metric, prediction, | |||||
| ground_truths) | |||||
| score = 100.0 * score / len(predictions) | |||||
| return score | |||||
| def multirc_em(predictions, labels, examples: List[InputExample]): | |||||
| """Compute the exact match (EM) for a sequence of predictions and actual labels""" | |||||
| question_ids = [example.meta['question_idx'] for example in examples] | |||||
| unique_questions = set(question_ids) | |||||
| q_actuals = list(zip(question_ids, labels)) | |||||
| q_predictions = list(zip(question_ids, predictions)) | |||||
| actuals_per_question = defaultdict(list) | |||||
| predictions_per_question = defaultdict(list) | |||||
| for qid, val in q_actuals: | |||||
| actuals_per_question[qid].append(val) | |||||
| for qid, val in q_predictions: | |||||
| predictions_per_question[qid].append(val) | |||||
| em = 0 | |||||
| for qid in unique_questions: | |||||
| if actuals_per_question[qid] == predictions_per_question[qid]: | |||||
| em += 1 | |||||
| em /= len(unique_questions) | |||||
| return em | |||||
| qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score) | |||||
| qa_f1 = functools.partial(qa_evaluate, metric=f1_score) | |||||
| @@ -0,0 +1,138 @@ | |||||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Race.""" | |||||
| from collections import OrderedDict | |||||
| from finetune_glm import finetune | |||||
| from tasks.eval_utils import (accuracy_func_provider, accuracy_metric, | |||||
| f1_macro_metric, f1_metric) | |||||
| from tasks.superglue.dataset import (CLASSIFICATION_DATASETS, | |||||
| MULTI_CHOICE_DATASETS, PROCESSORS, | |||||
| SuperGlueDataset, get_output_func) | |||||
| from tasks.superglue.evaluate import multirc_em, qa_exact_match, qa_f1 | |||||
| from tasks.superglue.pvp import PVPS | |||||
| DEFAULT_METRICS = { | |||||
| 'record': [('EM', qa_exact_match), ('F1', qa_f1)], | |||||
| 'copa': [('accuracy', accuracy_metric)], | |||||
| 'rte': [('accuracy', accuracy_metric)], | |||||
| 'boolq': [('accuracy', accuracy_metric)], | |||||
| 'wic': [('accuracy', accuracy_metric)], | |||||
| 'wsc': [('accuracy', accuracy_metric)], | |||||
| 'cb': [('accuracy', accuracy_metric), ('f1-macro', f1_macro_metric)], | |||||
| 'multirc': [('f1a', f1_metric), ('em', multirc_em), | |||||
| ('acc', accuracy_metric)], | |||||
| 'mnli': [('accuracy', accuracy_metric)], | |||||
| 'sst2': [('accuracy', accuracy_metric)], | |||||
| 'qnli': [('accuracy', accuracy_metric)], | |||||
| 'qqp': [('accuracy', accuracy_metric)], | |||||
| 'mrpc': [('accuracy', accuracy_metric)], | |||||
| 'cola': [('accuracy', accuracy_metric)], | |||||
| 'squad': [('accuracy', accuracy_metric)], | |||||
| } | |||||
| def train_valid_datasets_provider(args, tokenizer, pattern_text=False): | |||||
| """Provide train and validation datasets.""" | |||||
| task_name = args.task.lower() | |||||
| data_dir = args.data_dir | |||||
| train_dataset = SuperGlueDataset( | |||||
| args, | |||||
| task_name, | |||||
| data_dir, | |||||
| args.seq_length, | |||||
| 'train', | |||||
| tokenizer, | |||||
| pattern_text=pattern_text) | |||||
| valid_dataset = SuperGlueDataset( | |||||
| args, | |||||
| task_name, | |||||
| data_dir, | |||||
| args.seq_length, | |||||
| 'dev', | |||||
| tokenizer, | |||||
| for_train=True, | |||||
| pattern_text=pattern_text) | |||||
| return train_dataset, valid_dataset | |||||
| def metrics_func_provider(args, tokenizer, is_test): | |||||
| """Privde metrics callback function.""" | |||||
| def single_dataset_provider(split): | |||||
| return SuperGlueDataset(args, args.task.lower(), args.data_dir, | |||||
| args.seq_length, split, tokenizer) | |||||
| output_func = get_output_func(args.task.lower(), args) | |||||
| eval_func = None | |||||
| if args.task.lower() in ['wsc', 'squad' | |||||
| ] and args.cloze_eval and not args.wsc_negative: | |||||
| from tasks.language_model.finetune import classify_evaluate | |||||
| eval_func = classify_evaluate | |||||
| metric_dict = OrderedDict(DEFAULT_METRICS[args.task.lower()]) | |||||
| return accuracy_func_provider( | |||||
| single_dataset_provider, | |||||
| metric_dict, | |||||
| args, | |||||
| is_test=is_test, | |||||
| eval_func=eval_func, | |||||
| output_func=output_func, | |||||
| only_rank0=False, | |||||
| tokenizer=tokenizer) | |||||
| def main(args): | |||||
| model_kwargs = {} | |||||
| processor = PROCESSORS[args.task.lower()](args) | |||||
| pvp = PVPS[args.task.lower()]( | |||||
| args, | |||||
| None, | |||||
| processor.get_labels(), | |||||
| args.seq_length, | |||||
| pattern_id=args.pattern_id, | |||||
| is_multi_token=args.multi_token, | |||||
| num_prompt_tokens=args.num_prompt_tokens) | |||||
| if args.continuous_prompt: | |||||
| model_kwargs['spell_length'] = pvp.spell_length | |||||
| if args.task.lower() in ['wsc', 'squad' | |||||
| ] and args.cloze_eval and not args.wsc_negative: | |||||
| from tasks.language_model.finetune import lm_forward_step | |||||
| finetune( | |||||
| args, | |||||
| train_valid_datasets_provider, | |||||
| model_kwargs, | |||||
| end_of_epoch_callback_provider=metrics_func_provider, | |||||
| forward_step=lm_forward_step) | |||||
| else: | |||||
| if args.cloze_eval: | |||||
| multi_token = pvp.is_multi_token | |||||
| else: | |||||
| multi_token = args.task.lower() in MULTI_CHOICE_DATASETS | |||||
| args.multi_token = multi_token | |||||
| if not multi_token: | |||||
| model_kwargs[ | |||||
| 'model_type'] = 'multiple_choice' if args.cloze_eval else 'classification' | |||||
| model_kwargs['multi_token'] = False | |||||
| model_kwargs['num_labels'] = len(processor.get_labels()) | |||||
| else: | |||||
| model_kwargs['model_type'] = 'multiple_choice' | |||||
| model_kwargs['multi_token'] = True | |||||
| model_kwargs['num_labels'] = 1 | |||||
| finetune( | |||||
| args, | |||||
| train_valid_datasets_provider, | |||||
| model_kwargs, | |||||
| end_of_epoch_callback_provider=metrics_func_provider) | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import random | |||||
| from argparse import Namespace | |||||
| import numpy as np | |||||
| from blocklm_utils import ConstructBlockStrategy | |||||
| # rng = random.Random() | |||||
| # span_lengths = [2, 3, 4, 2, 3, 4] | |||||
| # length = 100 | |||||
| # | |||||
| # counts = np.array([0] * length) | |||||
| # for _ in range(10000): | |||||
| # rng.shuffle(span_lengths) | |||||
| # spans = ConstructBlockStrategy.sample_spans(span_lengths, length, rng) | |||||
| # for start, end in spans: | |||||
| # counts[start: end] += 1 | |||||
| # print(counts) | |||||
| def main(): | |||||
| args = Namespace() | |||||
| args.seq_length = 10 | |||||
| args.eod_token = 0 | |||||
| strategy = ConstructBlockStrategy( | |||||
| args, None, bert_ratio=0.4, max_seq_length=128) | |||||
| counts = np.array([0] * 10) | |||||
| for _ in range(10000): | |||||
| spans = strategy.sample_span_in_document( | |||||
| np.array([1, 2, 3, 0, 4, 5, 6, 7, 9, 0], dtype=np.long), [1, 1], | |||||
| random.Random()) | |||||
| for start, end in spans: | |||||
| counts[start:end] += 1 | |||||
| print(counts) | |||||
| @@ -0,0 +1,27 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import matplotlib.pyplot as plt | |||||
| import numpy as np | |||||
| from learning_rates import AnnealingLR | |||||
| from torch.nn.modules import Linear | |||||
| from torch.optim import Adam | |||||
| def main(): | |||||
| model = Linear(10, 10) | |||||
| optimizer = Adam(model.parameters()) | |||||
| lr_scheduler = AnnealingLR( | |||||
| optimizer, | |||||
| start_lr=0.00015, | |||||
| warmup_iter=3000, | |||||
| num_iters=300000, | |||||
| decay_style='cosine', | |||||
| decay_ratio=0.1) | |||||
| steps = np.arange(0, 400000, 10, dtype=np.long) | |||||
| rates = [] | |||||
| for step in steps: | |||||
| lr_scheduler.num_iters = step | |||||
| rates.append(lr_scheduler.get_lr()) | |||||
| print(rates) | |||||
| plt.plot(steps, rates) | |||||
| plt.savefig('lr.pdf', format='pdf') | |||||
| @@ -0,0 +1,472 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import deepspeed | |||||
| import torch | |||||
| from apex.optimizers import FusedAdam as Adam | |||||
| from torch import distributed as dist | |||||
| from . import mpu | |||||
| from .fp16 import DynamicLossScaler, FP16_Module, FP16_Optimizer | |||||
| from .model import DistributedDataParallel as LocalDDP | |||||
| from .model import (GLMForMultiTokenCloze, GLMForMultiTokenClozeFast, | |||||
| GLMForSequenceClassification, GLMForSingleTokenCloze, | |||||
| GLMModel) | |||||
| from .model import PyTorchDistributedDataParallel as TorchDDP | |||||
| from .model import glm_get_params_for_weight_decay_optimization | |||||
| from .utils import get_checkpoint_iteration, get_checkpoint_name, print_rank_0 | |||||
| def load_pretrained(model, checkpoint_path, args, task_tokens=None): | |||||
| load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path) | |||||
| checkpoint_name = get_checkpoint_name(load_dir, tag, release) | |||||
| if mpu.get_data_parallel_rank() == 0: | |||||
| print('global rank {} is loading pretrained model {}'.format( | |||||
| torch.distributed.get_rank(), checkpoint_name)) | |||||
| # Load the checkpoint. | |||||
| sd = torch.load(checkpoint_name, map_location='cpu') | |||||
| if args.deepspeed: | |||||
| model = model.module | |||||
| if isinstance(model, TorchDDP): | |||||
| model = model.module | |||||
| if isinstance(model, FP16_Module): | |||||
| model = model.module | |||||
| if hasattr(model, 'model'): | |||||
| model = model.model | |||||
| # Model. | |||||
| def extend_embedding_weights(state_weights, model_weights): | |||||
| original_length = state_weights.shape[0] | |||||
| assert original_length <= args.max_position_embeddings + 1 | |||||
| new_weights = model_weights.clone() | |||||
| new_weights[:original_length] = state_weights | |||||
| return new_weights | |||||
| if args.block_lm: | |||||
| if 'transformer.block_position_embeddings.weight' in sd['module']: | |||||
| position_weights = sd['module'][ | |||||
| 'transformer.position_embeddings.weight'] | |||||
| if args.max_position_embeddings + 1 > position_weights.shape[0]: | |||||
| sd['module'][ | |||||
| 'transformer.position_embeddings.weight'] = extend_embedding_weights( | |||||
| position_weights, | |||||
| model.state_dict() | |||||
| ['transformer.position_embeddings.weight'].data) | |||||
| print_rank_0( | |||||
| f'Extend position embedding to {args.max_position_embeddings + 1}' | |||||
| ) | |||||
| if 'transformer.block_position_embeddings.weight' in sd['module']: | |||||
| block_position_weights = sd['module'][ | |||||
| 'transformer.block_position_embeddings.weight'] | |||||
| if args.max_position_embeddings + 1 > block_position_weights.shape[ | |||||
| 0]: | |||||
| sd['module'][ | |||||
| 'transformer.block_position_embeddings.weight'] = extend_embedding_weights( | |||||
| block_position_weights, | |||||
| model.state_dict() | |||||
| ['transformer.block_position_embeddings.weight'].data) | |||||
| print_rank_0( | |||||
| f'Extend block position embedding to {args.max_position_embeddings + 1}' | |||||
| ) | |||||
| for key in list(model.state_dict().keys()): | |||||
| print(key) | |||||
| model.state_dict()[key.replace( | |||||
| 'mixins.block_position_embedding.block_position_embeddings.weight', | |||||
| 'transformer.block_position_embeddings.weight').replace( | |||||
| 'transformer.word_embeddings.weight', | |||||
| 'word_embeddings.weight')] = model.state_dict().pop(key) | |||||
| missing_keys, unexpected_keys = model.load_state_dict( | |||||
| sd['module'], strict=False) | |||||
| if missing_keys or unexpected_keys: | |||||
| print_rank_0( | |||||
| f'Missing keys {missing_keys}, unexpected keys {unexpected_keys}') | |||||
| if args.continuous_prompt and args.prompt_init: | |||||
| model.prompt_spell.init_embedding(model.word_embeddings.weight.data, | |||||
| task_tokens) | |||||
| def get_model(args, | |||||
| model_type=None, | |||||
| multi_token=True, | |||||
| num_labels=None, | |||||
| spell_length=None): | |||||
| """Build the model.""" | |||||
| print_rank_0('building GPT2 model ...') | |||||
| if args.pretrained_bert: | |||||
| if model_type == 'multiple_choice': | |||||
| model = BertForMultipleChoice.from_pretrained( | |||||
| args.tokenizer_model_type, | |||||
| cache_dir=args.cache_dir, | |||||
| fp32_layernorm=args.fp32_layernorm, | |||||
| fp32_embedding=args.fp32_embedding, | |||||
| layernorm_epsilon=args.layernorm_epsilon) | |||||
| elif model_type == 'classification': | |||||
| model = BertForSequenceClassification.from_pretrained( | |||||
| args.tokenizer_model_type, | |||||
| cache_dir=args.cache_dir, | |||||
| fp32_layernorm=args.fp32_layernorm, | |||||
| fp32_embedding=args.fp32_embedding, | |||||
| layernorm_epsilon=args.layernorm_epsilon, | |||||
| num_labels=num_labels) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| else: | |||||
| output_predict, paralle_output = True, True | |||||
| if (model_type == 'multiple_choice' | |||||
| or model_type == 'classification') and not args.cloze_eval: | |||||
| output_predict = False | |||||
| if model_type is not None: | |||||
| paralle_output = False | |||||
| if spell_length is not None: | |||||
| print_rank_0(f'Continuous spell length {spell_length}') | |||||
| model = GLMModel( | |||||
| num_layers=args.num_layers, | |||||
| vocab_size=args.vocab_size, | |||||
| hidden_size=args.hidden_size, | |||||
| num_attention_heads=args.num_attention_heads, | |||||
| embedding_dropout_prob=args.hidden_dropout, | |||||
| attention_dropout_prob=args.attention_dropout, | |||||
| output_dropout_prob=args.hidden_dropout, | |||||
| max_sequence_length=args.max_position_embeddings, | |||||
| max_memory_length=args.mem_length, | |||||
| checkpoint_activations=args.checkpoint_activations, | |||||
| checkpoint_num_layers=args.checkpoint_num_layers, | |||||
| parallel_output=paralle_output, | |||||
| relative_encoding=args.transformer_xl, | |||||
| block_position_encoding=args.block_lm and not args.masked_lm, | |||||
| output_predict=output_predict, | |||||
| spell_length=spell_length, | |||||
| spell_func=args.prompt_func, | |||||
| attention_scale=args.attention_scale) | |||||
| if args.freeze_transformer: | |||||
| model.freeze_transformer( | |||||
| tune_prefix_layers=args.tune_prefix_layers) | |||||
| if model_type is not None: | |||||
| if model_type == 'multiple_choice': | |||||
| if args.cloze_eval: | |||||
| if multi_token: | |||||
| if args.fast_decode: | |||||
| model = GLMForMultiTokenClozeFast( | |||||
| model, length_penalty=args.length_penalty) | |||||
| else: | |||||
| model = GLMForMultiTokenCloze( | |||||
| model, length_penalty=args.length_penalty) | |||||
| else: | |||||
| model = GLMForSingleTokenCloze( | |||||
| model, take_softmax=args.adapet) | |||||
| else: | |||||
| model = GLMForSequenceClassification( | |||||
| model, | |||||
| args.hidden_size, | |||||
| args.output_dropout, | |||||
| args.pool_token, | |||||
| num_class=num_labels) | |||||
| elif model_type == 'classification': | |||||
| model = GLMForSequenceClassification( | |||||
| model, | |||||
| args.hidden_size, | |||||
| args.output_dropout, | |||||
| args.pool_token, | |||||
| num_class=num_labels) | |||||
| elif model_type == 'generation': | |||||
| pass | |||||
| else: | |||||
| raise NotImplementedError(model_type) | |||||
| if mpu.get_data_parallel_rank() == 0: | |||||
| print( | |||||
| ' > number of parameters on model parallel rank {}: {}'.format( | |||||
| mpu.get_model_parallel_rank(), | |||||
| sum([p.nelement() for p in model.parameters()])), | |||||
| flush=True) | |||||
| # To prevent OOM for model sizes that cannot fit in GPU memory in full precision | |||||
| if args.fp16: | |||||
| model.half() | |||||
| # GPU allocation. | |||||
| model.cuda(torch.cuda.current_device()) | |||||
| # Fp16 conversion. | |||||
| if args.fp16: | |||||
| model = FP16_Module(model) | |||||
| # Wrap model for distributed training. | |||||
| if not args.deepspeed and (args.train_iters or args.epochs): | |||||
| if args.DDP_impl == 'torch': | |||||
| i = torch.cuda.current_device() | |||||
| model = TorchDDP( | |||||
| model, | |||||
| device_ids=[i], | |||||
| output_device=i, | |||||
| process_group=mpu.get_data_parallel_group()) | |||||
| elif args.DDP_impl == 'local': | |||||
| model = LocalDDP(model) | |||||
| else: | |||||
| print_rank_0('Skip DDP model') | |||||
| return model | |||||
| def get_optimizer_param_groups(model): | |||||
| # Build parameter groups (weight decay and non-decay). | |||||
| while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)): | |||||
| model = model.module | |||||
| param_groups = glm_get_params_for_weight_decay_optimization(model) | |||||
| # Add model parallel attribute if it is not set. | |||||
| for param_group in param_groups: | |||||
| # print('## param_group', len(param_group['params'])) | |||||
| for param in param_group['params']: | |||||
| if not hasattr(param, 'model_parallel'): | |||||
| param.model_parallel = False | |||||
| return param_groups | |||||
| def get_optimizer(param_groups, args): | |||||
| """Set up the optimizer.""" | |||||
| if args.cpu_optimizer: | |||||
| # Apex FusedAdam uses decoupled weight decay so use the same here | |||||
| if args.cpu_torch_adam: | |||||
| cpu_adam_optimizer = torch.optim.AdamW | |||||
| else: | |||||
| from deepspeed.ops.adam import DeepSpeedCPUAdam | |||||
| cpu_adam_optimizer = DeepSpeedCPUAdam | |||||
| optimizer = cpu_adam_optimizer( | |||||
| param_groups, lr=args.lr, weight_decay=args.weight_decay) | |||||
| else: | |||||
| # Use FusedAdam. | |||||
| if args.optimizer == 'adam': | |||||
| optimizer = Adam( | |||||
| param_groups, | |||||
| lr=args.lr, | |||||
| weight_decay=args.weight_decay, | |||||
| betas=(args.adam_beta1, args.adam_beta2), | |||||
| eps=args.adam_eps) | |||||
| elif args.optimizer == 'adafactor': | |||||
| from transformers import Adafactor | |||||
| optimizer = Adafactor( | |||||
| param_groups, | |||||
| lr=args.lr, | |||||
| relative_step=False, | |||||
| warmup_init=False) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| print(f'Optimizer = {optimizer.__class__.__name__}') | |||||
| if hasattr(args, 'deepspeed') and args.deepspeed: | |||||
| raise NotImplementedError | |||||
| # fp16 wrapper is not required for DeepSpeed. | |||||
| # return optimizer | |||||
| # Wrap into fp16 optimizer. | |||||
| if args.fp16: | |||||
| optimizer = FP16_Optimizer( | |||||
| optimizer, | |||||
| static_loss_scale=args.loss_scale, | |||||
| dynamic_loss_scale=args.dynamic_loss_scale, | |||||
| dynamic_loss_args={ | |||||
| 'scale_window': args.loss_scale_window, | |||||
| 'min_scale': args.min_scale, | |||||
| 'delayed_shift': args.hysteresis | |||||
| }) | |||||
| return optimizer | |||||
| def get_learning_rate_scheduler(optimizer, args): | |||||
| """Build the learning rate scheduler.""" | |||||
| # Add linear learning rate scheduler. | |||||
| if args.lr_decay_iters is not None: | |||||
| num_iters = args.lr_decay_iters | |||||
| else: | |||||
| num_iters = args.train_iters | |||||
| if args.finetune: | |||||
| num_iters = num_iters // args.gradient_accumulation_steps | |||||
| num_iters = max(1, num_iters) | |||||
| init_step = -1 | |||||
| warmup_iter = args.warmup * num_iters | |||||
| lr_scheduler = AnnealingLR( | |||||
| optimizer, | |||||
| start_lr=args.lr, | |||||
| warmup_iter=warmup_iter, | |||||
| num_iters=num_iters - warmup_iter, | |||||
| decay_style=args.lr_decay_style, | |||||
| last_iter=init_step, | |||||
| decay_ratio=args.lr_decay_ratio) | |||||
| return lr_scheduler | |||||
| def setup_model_and_optimizer(args, | |||||
| model_type=None, | |||||
| multi_token=True, | |||||
| num_labels=None, | |||||
| spell_length=None): | |||||
| """Setup model and optimizer.""" | |||||
| model = get_model( | |||||
| args, | |||||
| model_type=model_type, | |||||
| multi_token=multi_token, | |||||
| num_labels=num_labels, | |||||
| spell_length=spell_length) | |||||
| param_groups = get_optimizer_param_groups(model) | |||||
| if args.train_data is not None or args.data_dir is not None and ( | |||||
| args.epochs > 0 or args.train_iters > 0): | |||||
| if args.deepspeed: | |||||
| print_rank_0('DeepSpeed is enabled.') | |||||
| model, optimizer, _, _ = deepspeed.initialize( | |||||
| model=model, | |||||
| model_parameters=param_groups, | |||||
| args=args, | |||||
| mpu=mpu, | |||||
| dist_init_required=False) | |||||
| else: | |||||
| optimizer = get_optimizer(param_groups, args) | |||||
| lr_scheduler = get_learning_rate_scheduler(optimizer, args) | |||||
| else: | |||||
| optimizer, lr_scheduler = None, None | |||||
| return model, optimizer, lr_scheduler | |||||
| def backward_step(optimizer, model, lm_loss, args, timers): | |||||
| """Backward step.""" | |||||
| # Total loss. | |||||
| loss = lm_loss | |||||
| # Backward pass. | |||||
| if args.deepspeed: | |||||
| model.backward(loss) | |||||
| else: | |||||
| # optimizer.zero_grad() | |||||
| if args.fp16: | |||||
| optimizer.backward(loss, update_master_grads=False) | |||||
| else: | |||||
| loss.backward() | |||||
| if args.deepspeed or args.DDP_impl == 'torch': | |||||
| # DeepSpeed backward propagation already addressed all reduce communication. | |||||
| # Reset the timer to avoid breaking timer logs below. | |||||
| timers('allreduce').reset() | |||||
| else: | |||||
| timers('allreduce').start() | |||||
| model.allreduce_params( | |||||
| reduce_after=False, fp32_allreduce=args.fp32_allreduce) | |||||
| timers('allreduce').stop() | |||||
| # Update master gradients. | |||||
| if not args.deepspeed: | |||||
| if args.fp16: | |||||
| optimizer.update_master_grads() | |||||
| # Clipping gradients helps prevent the exploding gradient. | |||||
| if args.clip_grad > 0: | |||||
| if not args.fp16: | |||||
| mpu.clip_grad_norm(model.parameters(), args.clip_grad) | |||||
| else: | |||||
| optimizer.clip_master_grads(args.clip_grad) | |||||
| return lm_loss | |||||
| def see_memory_usage(message, force=False): | |||||
| if not force: | |||||
| return | |||||
| dist.barrier() | |||||
| if dist.get_rank() == 0: | |||||
| print(message) | |||||
| print('Memory Allocated ', | |||||
| torch.cuda.memory_allocated() / (1024 * 1024 * 1024), | |||||
| 'GigaBytes') | |||||
| print('Max Memory Allocated ', | |||||
| torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), | |||||
| 'GigaBytes') | |||||
| print('Cache Allocated ', | |||||
| torch.cuda.memory_cached() / (1024 * 1024 * 1024), 'GigaBytes') | |||||
| print('Max cache Allocated ', | |||||
| torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), | |||||
| 'GigaBytes') | |||||
| print(' ') | |||||
| # input("Press Any Key To Continue ..") | |||||
| def train_step(data_iterator, | |||||
| model, | |||||
| optimizer, | |||||
| lr_scheduler, | |||||
| args, | |||||
| timers, | |||||
| forward_step_func, | |||||
| mems=None, | |||||
| single_step=False): | |||||
| """Single training step.""" | |||||
| lm_loss_total, count = 0.0, 0 | |||||
| mems = [] if mems is None else mems | |||||
| if not args.deepspeed: | |||||
| optimizer.zero_grad() | |||||
| while True: | |||||
| skipped_iter, complete = 0, False | |||||
| # Forward model for one step. | |||||
| timers('forward').start() | |||||
| lm_loss, mems, _ = forward_step_func(data_iterator, model, args, | |||||
| timers, mems) | |||||
| timers('forward').stop() | |||||
| # print_rank_0("Forward step") | |||||
| if not args.deepspeed: | |||||
| lm_loss /= args.gradient_accumulation_steps | |||||
| reduced_loss = lm_loss.detach().clone().view(1) | |||||
| torch.distributed.all_reduce( | |||||
| reduced_loss.data, group=mpu.get_data_parallel_group()) | |||||
| reduced_loss.data = reduced_loss.data / ( | |||||
| args.world_size / args.model_parallel_size) | |||||
| if not DynamicLossScaler._has_inf_or_nan(reduced_loss): | |||||
| lm_loss_total += reduced_loss | |||||
| count += 1 | |||||
| # Calculate gradients, reduce across processes, and clip. | |||||
| timers('backward').start() | |||||
| backward_step(optimizer, model, lm_loss, args, timers) | |||||
| timers('backward').stop() | |||||
| # print_rank_0("Backward step") | |||||
| # Update parameters. | |||||
| timers('optimizer').start() | |||||
| if args.deepspeed: | |||||
| if model.is_gradient_accumulation_boundary(): | |||||
| model.step() | |||||
| complete = True | |||||
| if not (args.fp16 and optimizer.overflow): | |||||
| lr_scheduler.step() | |||||
| else: | |||||
| skipped_iter = 1 | |||||
| else: | |||||
| model.step() | |||||
| else: | |||||
| if count == args.gradient_accumulation_steps: | |||||
| optimizer.step() | |||||
| complete = True | |||||
| # Update learning rate. | |||||
| if not (args.fp16 and optimizer.overflow): | |||||
| lr_scheduler.step() | |||||
| else: | |||||
| skipped_iter = 1 | |||||
| # print_rank_0("Optimizer step") | |||||
| timers('optimizer').stop() | |||||
| if complete: | |||||
| break | |||||
| else: | |||||
| print_rank_0('Found NaN loss, skip backward') | |||||
| del lm_loss, reduced_loss | |||||
| mems = [] | |||||
| if single_step: | |||||
| break | |||||
| if args.deepspeed: | |||||
| lm_loss_total = lm_loss_total / count | |||||
| return lm_loss_total, skipped_iter, mems | |||||
| @@ -0,0 +1,529 @@ | |||||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Utilities for logging and serialization""" | |||||
| import os | |||||
| import random | |||||
| import subprocess | |||||
| import time | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from . import mpu | |||||
| from .fp16 import FP16_Optimizer | |||||
| SUMMARY_WRITER_DIR_NAME = 'runs' | |||||
| def get_log_dir(name, base): | |||||
| return os.path.join(base, SUMMARY_WRITER_DIR_NAME, name) | |||||
| def print_rank_0(message): | |||||
| if torch.distributed.is_initialized(): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| print(message, flush=True) | |||||
| else: | |||||
| print(message, flush=True) | |||||
| def get_hostname(): | |||||
| hostname_cmd = ['hostname -I'] | |||||
| result = subprocess.check_output(hostname_cmd, shell=True) | |||||
| master_addr = result.decode('utf-8').split()[0] | |||||
| return master_addr | |||||
| def get_spare_port(args): | |||||
| if torch.distributed.get_rank() == 0: | |||||
| port = subprocess.check_output(['shuf -n 1 -i 10000-65535'], | |||||
| shell=True) | |||||
| port = int(port.strip()) | |||||
| if port == args.master_port: | |||||
| port = subprocess.check_output(['shuf -n 1 -i 10000-65535'], | |||||
| shell=True) | |||||
| port = int(port.strip()) | |||||
| port = torch.cuda.LongTensor([port]) | |||||
| else: | |||||
| port = torch.cuda.LongTensor([0]) | |||||
| torch.distributed.broadcast(port, 0) | |||||
| port = port.item() | |||||
| return port | |||||
| def print_and_save_args(args, verbose=True, log_dir=None): | |||||
| """Print arguments.""" | |||||
| if verbose: | |||||
| print('arguments:', flush=True) | |||||
| for arg in vars(args): | |||||
| dots = '.' * (29 - len(arg)) | |||||
| print( | |||||
| ' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True) | |||||
| if log_dir is not None: | |||||
| json_file = os.path.join(log_dir, 'config.json') | |||||
| with open(json_file, 'w') as output: | |||||
| json.dump(vars(args), output, sort_keys=True) | |||||
| if args.deepspeed and args.deepspeed_config is not None: | |||||
| with open(args.deepspeed_config) as file: | |||||
| deepspeed_config = json.load(file) | |||||
| deepspeed_json_file = os.path.join(log_dir, | |||||
| 'config_gpt_large.json') | |||||
| with open(deepspeed_json_file, 'w') as output: | |||||
| json.dump(deepspeed_config, output) | |||||
| def print_params_min_max_norm(optimizer, iteration): | |||||
| """Print min, max, and norm of all parameters.""" | |||||
| index = 0 | |||||
| rank = torch.distributed.get_rank() | |||||
| string = 'iteration, rank, index, model-parallel,min, max, norm\n' | |||||
| optimizer_ = optimizer | |||||
| if isinstance(optimizer, FP16_Optimizer): | |||||
| optimizer_ = optimizer.optimizer | |||||
| for param_group in optimizer_.param_groups: | |||||
| for param in param_group['params']: | |||||
| index += 1 | |||||
| min_ = param.data.min() | |||||
| max_ = param.data.max() | |||||
| norm = param.data.norm() | |||||
| string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( | |||||
| iteration, rank, index, int(param.model_parallel)) | |||||
| string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) | |||||
| print(string, flush=True) | |||||
| class Timers: | |||||
| """Group of timers.""" | |||||
| class Timer: | |||||
| """Timer.""" | |||||
| def __init__(self, name): | |||||
| self.name_ = name | |||||
| self.elapsed_ = 0.0 | |||||
| self.started_ = False | |||||
| self.start_time = time.time() | |||||
| def start(self): | |||||
| """Start the timer.""" | |||||
| assert not self.started_, 'timer has already been started' | |||||
| torch.cuda.synchronize() | |||||
| self.start_time = time.time() | |||||
| self.started_ = True | |||||
| def stop(self): | |||||
| """Stop the timer.""" | |||||
| assert self.started_, 'timer is not started' | |||||
| torch.cuda.synchronize() | |||||
| self.elapsed_ += (time.time() - self.start_time) | |||||
| self.started_ = False | |||||
| def reset(self): | |||||
| """Reset timer.""" | |||||
| self.elapsed_ = 0.0 | |||||
| self.started_ = False | |||||
| def elapsed(self, reset=True): | |||||
| """Calculate the elapsed time.""" | |||||
| started_ = self.started_ | |||||
| # If the timing in progress, end it first. | |||||
| if self.started_: | |||||
| self.stop() | |||||
| # Get the elapsed time. | |||||
| elapsed_ = self.elapsed_ | |||||
| # Reset the elapsed time | |||||
| if reset: | |||||
| self.reset() | |||||
| # If timing was in progress, set it back. | |||||
| if started_: | |||||
| self.start() | |||||
| return elapsed_ | |||||
| def __init__(self): | |||||
| self.timers = {} | |||||
| def __call__(self, name): | |||||
| if name not in self.timers: | |||||
| self.timers[name] = self.Timer(name) | |||||
| return self.timers[name] | |||||
| def log(self, names, normalizer=1.0, reset=True): | |||||
| """Log a group of timers.""" | |||||
| assert normalizer > 0.0 | |||||
| string = 'time (ms)' | |||||
| for name in names: | |||||
| elapsed_time = self.timers[name].elapsed( | |||||
| reset=reset) * 1000.0 / normalizer | |||||
| string += ' | {}: {:.2f}'.format(name, elapsed_time) | |||||
| print_rank_0(string) | |||||
| def report_memory(name): | |||||
| """Simple GPU memory report.""" | |||||
| mega_bytes = 1024.0 * 1024.0 | |||||
| string = name + ' memory (MB)' | |||||
| string += ' | allocated: {}'.format(torch.cuda.memory_allocated() | |||||
| / mega_bytes) | |||||
| string += ' | max allocated: {}'.format(torch.cuda.max_memory_allocated() | |||||
| / mega_bytes) | |||||
| string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) | |||||
| string += ' | max cached: {}'.format(torch.cuda.memory_reserved() | |||||
| / mega_bytes) | |||||
| print_rank_0(string) | |||||
| def get_checkpoint_name(checkpoints_path, | |||||
| iteration, | |||||
| release=False, | |||||
| zero=False): | |||||
| if release: | |||||
| d = 'release' | |||||
| else: | |||||
| d = '{}'.format(iteration) | |||||
| if zero: | |||||
| dp_rank = mpu.get_data_parallel_rank() | |||||
| d += '_zero_dp_rank_{}'.format(dp_rank) | |||||
| return os.path.join( | |||||
| checkpoints_path, d, | |||||
| 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) | |||||
| def ensure_directory_exists(filename): | |||||
| dirname = os.path.dirname(filename) | |||||
| if not os.path.exists(dirname): | |||||
| os.makedirs(dirname, exist_ok=True) | |||||
| def get_checkpoint_tracker_filename(checkpoints_path): | |||||
| return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') | |||||
| def save_zero_checkpoint(args, iteration, optimizer): | |||||
| zero_sd = { | |||||
| 'iteration': iteration, | |||||
| 'optimizer_state_dict': optimizer.state_dict() | |||||
| } | |||||
| zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True) | |||||
| ensure_directory_exists(zero_checkpoint_name) | |||||
| torch.save(zero_sd, zero_checkpoint_name) | |||||
| print(' successfully saved {}'.format(zero_checkpoint_name)) | |||||
| def save_checkpoint(iteration, | |||||
| model, | |||||
| optimizer, | |||||
| lr_scheduler, | |||||
| args, | |||||
| tag=None, | |||||
| barrier=True, | |||||
| only_changed_parameters=False, | |||||
| no_deepspeed=False, | |||||
| no_save_optim=False): | |||||
| """Save a model checkpoint.""" | |||||
| if tag is None: | |||||
| tag = str(iteration) | |||||
| if args.deepspeed and not no_deepspeed: | |||||
| save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag) | |||||
| else: | |||||
| # Only rank zer0 of the data parallel writes to the disk. | |||||
| if mpu.get_data_parallel_rank() == 0: | |||||
| checkpoint_name = get_checkpoint_name(args.save, tag) | |||||
| print( | |||||
| 'global rank {} is saving checkpoint at iteration {:7d} to {}'. | |||||
| format(torch.distributed.get_rank(), iteration, | |||||
| checkpoint_name)) | |||||
| sd = {'iteration': iteration} | |||||
| if args.deepspeed: | |||||
| model = model.module | |||||
| state_dict = model.state_dict() | |||||
| if only_changed_parameters: | |||||
| requires_grad_dict = {} | |||||
| for name, parameter in model.named_parameters(): | |||||
| requires_grad_dict[name] = parameter.requires_grad | |||||
| state_dict = { | |||||
| key: value | |||||
| for key, value in state_dict.items() | |||||
| if requires_grad_dict[key] | |||||
| } | |||||
| sd['module'] = state_dict | |||||
| # Optimizer stuff. | |||||
| if not args.no_save_optim and not no_save_optim: | |||||
| if optimizer is not None: | |||||
| sd['optimizer'] = optimizer.state_dict() | |||||
| if lr_scheduler is not None: | |||||
| sd['lr_scheduler'] = lr_scheduler.state_dict() | |||||
| # rng states. | |||||
| if not args.no_save_rng: | |||||
| sd['random_rng_state'] = random.getstate() | |||||
| sd['np_rng_state'] = np.random.get_state() | |||||
| sd['torch_rng_state'] = torch.get_rng_state() | |||||
| sd['cuda_rng_state'] = torch.cuda.get_rng_state() | |||||
| sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker( | |||||
| ).get_states() | |||||
| ensure_directory_exists(checkpoint_name) | |||||
| torch.save(sd, checkpoint_name) | |||||
| print(' successfully saved {}'.format(checkpoint_name)) | |||||
| # Wait so everyone is done (necessary) | |||||
| if barrier: | |||||
| torch.distributed.barrier() | |||||
| # And update the latest iteration | |||||
| if torch.distributed.get_rank() == 0: | |||||
| tracker_filename = get_checkpoint_tracker_filename(args.save) | |||||
| with open(tracker_filename, 'w') as f: | |||||
| f.write(tag) | |||||
| def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag): | |||||
| """Save a model checkpoint.""" | |||||
| sd = {} | |||||
| sd['iteration'] = iteration | |||||
| if lr_scheduler is not None: | |||||
| sd['client_lr_scheduler'] = lr_scheduler.state_dict() | |||||
| # rng states. | |||||
| if not args.no_save_rng: | |||||
| sd['random_rng_state'] = random.getstate() | |||||
| sd['np_rng_state'] = np.random.get_state() | |||||
| sd['torch_rng_state'] = torch.get_rng_state() | |||||
| sd['cuda_rng_state'] = torch.cuda.get_rng_state() | |||||
| sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() | |||||
| model.save_checkpoint(args.save, tag, client_state=sd) | |||||
| def get_checkpoint_iteration(load_path): | |||||
| # Read the tracker file and set the iteration. | |||||
| tracker_filename = get_checkpoint_tracker_filename(load_path) | |||||
| if not os.path.isfile(tracker_filename): | |||||
| print_rank_0('WARNING: could not find the metadata file {} '.format( | |||||
| tracker_filename)) | |||||
| if os.path.isdir(load_path): | |||||
| path = os.path.normpath(load_path) | |||||
| load_dir, tag = os.path.split(path) | |||||
| print_rank_0( | |||||
| 'Try to directly load the checkpoint from the directory') | |||||
| return load_dir, tag, False, True | |||||
| print_rank_0(' will not load any checkpoints and will start from ' | |||||
| 'random') | |||||
| return load_path, 0, False, False | |||||
| with open(tracker_filename, 'r') as f: | |||||
| metastring = f.read().strip() | |||||
| release = metastring == 'release' | |||||
| # try: | |||||
| # iteration = int(metastring) | |||||
| # except ValueError: | |||||
| # release = metastring == 'release' | |||||
| # if not release: | |||||
| # print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( | |||||
| # tracker_filename)) | |||||
| # exit() | |||||
| # assert iteration > 0 or release, 'error parsing metadata file {}'.format( | |||||
| # tracker_filename) | |||||
| return load_path, metastring, release, True | |||||
| def load_checkpoint(model, | |||||
| optimizer, | |||||
| lr_scheduler, | |||||
| args, | |||||
| no_deepspeed=False, | |||||
| no_load_optim=False): | |||||
| """Load a model checkpoint.""" | |||||
| load_dir, tag, release, success = get_checkpoint_iteration(args.load) | |||||
| if not success: | |||||
| return 0 | |||||
| if args.deepspeed and not no_deepspeed: | |||||
| checkpoint_name, sd = model.load_checkpoint( | |||||
| load_dir, | |||||
| tag, | |||||
| load_optimizer_states=not args.no_load_optim and not no_load_optim, | |||||
| load_lr_scheduler_states=not args.no_load_lr_scheduler) | |||||
| if not args.no_load_lr_scheduler and 'client_lr_scheduler' in sd: | |||||
| lr_scheduler.load_state_dict(sd['client_lr_scheduler']) | |||||
| print_rank_0('Load lr scheduler state') | |||||
| if checkpoint_name is None: | |||||
| if mpu.get_data_parallel_rank() == 0: | |||||
| print('Unable to load checkpoint.') | |||||
| return tag | |||||
| else: | |||||
| # Checkpoint. | |||||
| checkpoint_name = get_checkpoint_name(load_dir, tag, release) | |||||
| if mpu.get_data_parallel_rank() == 0: | |||||
| print('global rank {} is loading checkpoint {}'.format( | |||||
| torch.distributed.get_rank(), checkpoint_name)) | |||||
| # Load the checkpoint. | |||||
| sd = torch.load(checkpoint_name, map_location='cpu') | |||||
| # Model. | |||||
| if args.deepspeed: | |||||
| model = model.module | |||||
| missing_keys, unexpected_keys = model.load_state_dict( | |||||
| sd['module'], strict=False) | |||||
| if missing_keys or unexpected_keys: | |||||
| print_rank_0( | |||||
| f'Missing keys {missing_keys}, unexpected keys {unexpected_keys}' | |||||
| ) | |||||
| # Optimizer. | |||||
| if not release and not args.finetune and not args.no_load_optim and not no_load_optim: | |||||
| try: | |||||
| if optimizer is not None: | |||||
| optimizer.load_state_dict(sd['optimizer']) | |||||
| if lr_scheduler is not None: | |||||
| lr_scheduler.load_state_dict(sd['lr_scheduler']) | |||||
| except KeyError: | |||||
| print_rank_0( | |||||
| 'Unable to load optimizer from checkpoint {}, exiting. ' | |||||
| 'Specify --no-load-optim or --finetune to prevent ' | |||||
| 'attempting to load the optimizer ' | |||||
| 'state.'.format(checkpoint_name)) | |||||
| # Iterations. | |||||
| if args.finetune or release: | |||||
| iteration = 0 | |||||
| else: | |||||
| try: | |||||
| iteration = sd['iteration'] | |||||
| except KeyError: | |||||
| try: # Backward compatible with older checkpoints | |||||
| iteration = sd['total_iters'] | |||||
| except KeyError: | |||||
| print_rank_0( | |||||
| 'A metadata file exists but Unable to load iteration ' | |||||
| ' from checkpoint {}, starting from 0 iteration'.format( | |||||
| checkpoint_name)) | |||||
| iteration = 0 | |||||
| # rng states. | |||||
| if not release and not args.finetune and not args.no_load_rng: | |||||
| try: | |||||
| random.setstate(sd['random_rng_state']) | |||||
| np.random.set_state(sd['np_rng_state']) | |||||
| torch.set_rng_state(sd['torch_rng_state']) | |||||
| torch.cuda.set_rng_state(sd['cuda_rng_state']) | |||||
| mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) | |||||
| except KeyError: | |||||
| print_rank_0( | |||||
| 'Unable to load random state from checkpoint {}, exiting. ' | |||||
| 'Specify --no-load-rng or --finetune to prevent ' | |||||
| 'attempting to load the random ' | |||||
| 'state.'.format(checkpoint_name)) | |||||
| if mpu.get_data_parallel_rank() == 0: | |||||
| print(' successfully loaded {}'.format(checkpoint_name)) | |||||
| return iteration | |||||
| def load_weights(src, dst, dst2src=False): | |||||
| """ | |||||
| Loads weights from src to dst via in place copy. | |||||
| src is a huggingface gpt2model, while dst is one of our models. | |||||
| dst2src=True loads parameters from our models into huggingface's. | |||||
| ^dst2src is still untested | |||||
| """ | |||||
| conv_layer = 'Conv1D' in str(type(src)) | |||||
| for n, p in src.named_parameters(): | |||||
| if dst2src: | |||||
| data = dst._parameters[n].data | |||||
| load = p.data | |||||
| else: | |||||
| data = p.data | |||||
| load = dst._parameters[n].data | |||||
| if conv_layer and 'weight' in n: | |||||
| data = data.t().contiguous() | |||||
| load.copy_(data) | |||||
| # dst._parameters[n].data.copy_(data) | |||||
| def load_mlp(our, oai, dst2src=False): | |||||
| load_weights(oai.c_fc, our.dense_h_to_4h, dst2src) | |||||
| load_weights(oai.c_proj, our.dense_4h_to_h, dst2src) | |||||
| def load_attention(our, oai, dst2src=False): | |||||
| load_weights(oai.c_attn, our.query_key_value, dst2src) | |||||
| load_weights(oai.c_proj, our.dense, dst2src) | |||||
| def load_transformer_layer(our, oai, dst2src=False): | |||||
| load_weights(oai.ln_1, our.input_layernorm, dst2src) | |||||
| load_weights(oai.ln_2, our.post_attention_layernorm, dst2src) | |||||
| load_mlp(our.mlp, oai.mlp, dst2src) | |||||
| load_attention(our.attention, oai.attn, dst2src) | |||||
| def move_weights(our, oai, dst2src=False): | |||||
| """ | |||||
| Loads weights from `oai` to `our` via in place copy. | |||||
| `oai` is a huggingface gpt2model, while `our` is one of our models. | |||||
| dst2src=True loads parameters from our models into huggingface's. | |||||
| ^dst2src=True is still untested | |||||
| """ | |||||
| # while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)): | |||||
| # our=our.module | |||||
| transformer_model = oai.transformer | |||||
| load_weights(transformer_model.ln_f, our.transformer.final_layernorm, | |||||
| dst2src) | |||||
| load_weights(transformer_model.wte, our.word_embeddings, dst2src) | |||||
| load_weights(transformer_model.wpe, our.position_embeddings, dst2src) | |||||
| for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): | |||||
| load_transformer_layer(our_layer, oai_layer, dst2src) | |||||
| def debug_finetune_data(local_vars, batch_id, tokenizer): | |||||
| tokens, target_ids = local_vars['tokens'], local_vars['target_ids'] | |||||
| attention_mask, logit_mask, position_ids = local_vars[ | |||||
| 'attention_mask'], local_vars['logit_mask'], local_vars['position_ids'] | |||||
| output_tokens = [] | |||||
| sep = attention_mask[batch_id].item() | |||||
| for i, token in enumerate(tokens[batch_id][:sep].tolist()): | |||||
| token = tokenizer.IdToToken(token) | |||||
| if token == '[MASK]': | |||||
| token = f'[{position_ids[batch_id][0, i].item()}]' | |||||
| output_tokens.append(token) | |||||
| print(' '.join(output_tokens)) | |||||
| target_positions = [] | |||||
| for i in range(sep, tokens.size(-1)): | |||||
| if logit_mask[batch_id][i]: | |||||
| target_positions.append(i) | |||||
| print(target_positions) | |||||
| print(tokenizer.DecodeIds(tokens[batch_id][target_positions].tolist())) | |||||
| if len(target_ids.shape) > 2: | |||||
| print( | |||||
| tokenizer.DecodeIds( | |||||
| target_ids[batch_id][target_positions].tolist())) | |||||
| else: | |||||
| print(tokenizer.DecodeIds(target_ids[batch_id].tolist())) | |||||
| print(position_ids[batch_id][:, target_positions]) | |||||
| @@ -516,6 +516,12 @@ TASK_OUTPUTS = { | |||||
| # } | # } | ||||
| Tasks.text_generation: [OutputKeys.TEXT], | Tasks.text_generation: [OutputKeys.TEXT], | ||||
| # summarization result for single sample | |||||
| # { | |||||
| # "text": "this is the text generated by a model." | |||||
| # } | |||||
| Tasks.text_summarization: [OutputKeys.TEXT], | |||||
| # text generation result for single sample | # text generation result for single sample | ||||
| # { | # { | ||||
| # "text": "北京" | # "text": "北京" | ||||
| @@ -31,6 +31,7 @@ if TYPE_CHECKING: | |||||
| from .translation_pipeline import TranslationPipeline | from .translation_pipeline import TranslationPipeline | ||||
| from .word_segmentation_pipeline import WordSegmentationPipeline | from .word_segmentation_pipeline import WordSegmentationPipeline | ||||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | ||||
| from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | |||||
| from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | ||||
| WordSegmentationThaiPipeline | WordSegmentationThaiPipeline | ||||
| @@ -71,6 +72,7 @@ else: | |||||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | ||||
| 'zero_shot_classification_pipeline': | 'zero_shot_classification_pipeline': | ||||
| ['ZeroShotClassificationPipeline'], | ['ZeroShotClassificationPipeline'], | ||||
| 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | |||||
| 'multilingual_word_segmentation_pipeline': [ | 'multilingual_word_segmentation_pipeline': [ | ||||
| 'MultilingualWordSegmentationPipeline', | 'MultilingualWordSegmentationPipeline', | ||||
| 'WordSegmentationThaiPipeline' | 'WordSegmentationThaiPipeline' | ||||
| @@ -0,0 +1,43 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| from typing import Any, Dict, Optional, Union | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.nlp import MGLMForTextSummarization | |||||
| from modelscope.pipelines.base import Pipeline, Tensor | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import (MGLMSummarizationPreprocessor, | |||||
| Preprocessor) | |||||
| from modelscope.utils.constant import Tasks | |||||
| __all__ = ['MGLMTextSummarizationPipeline'] | |||||
| @PIPELINES.register_module( | |||||
| group_key=Tasks.text_summarization, | |||||
| module_name=Pipelines.mglm_text_summarization) | |||||
| class MGLMTextSummarizationPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[MGLMForTextSummarization, str], | |||||
| preprocessor: [Preprocessor] = None, | |||||
| *args, | |||||
| **kwargs): | |||||
| model = MGLMForTextSummarization(model) if isinstance(model, | |||||
| str) else model | |||||
| self.model = model | |||||
| self.model.eval() | |||||
| if preprocessor is None: | |||||
| preprocessor = MGLMSummarizationPreprocessor() | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| # define the forward pass | |||||
| def forward(self, inputs: Union[Dict, str], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| inputs = {'text': inputs} if isinstance(inputs, str) else inputs | |||||
| return self.model.generate(inputs) | |||||
| # format the outputs from pipeline | |||||
| def postprocess(self, input, **kwargs) -> Dict[str, Any]: | |||||
| return input | |||||
| @@ -18,16 +18,16 @@ if TYPE_CHECKING: | |||||
| from .nlp import ( | from .nlp import ( | ||||
| DocumentSegmentationPreprocessor, FaqQuestionAnsweringPreprocessor, | DocumentSegmentationPreprocessor, FaqQuestionAnsweringPreprocessor, | ||||
| FillMaskPoNetPreprocessor, NLPPreprocessor, | FillMaskPoNetPreprocessor, NLPPreprocessor, | ||||
| NLPTokenizerPreprocessorBase, TextRankingPreprocessor, | |||||
| RelationExtractionPreprocessor, SentenceEmbeddingPreprocessor, | |||||
| SequenceClassificationPreprocessor, TokenClassificationPreprocessor, | |||||
| TextErrorCorrectionPreprocessor, TextGenerationPreprocessor, | |||||
| Text2TextGenerationPreprocessor, Tokenize, | |||||
| NLPTokenizerPreprocessorBase, PassageRankingPreprocessor, | |||||
| TextRankingPreprocessor, RelationExtractionPreprocessor, | |||||
| SentenceEmbeddingPreprocessor, SequenceClassificationPreprocessor, | |||||
| TokenClassificationPreprocessor, TextErrorCorrectionPreprocessor, | |||||
| TextGenerationPreprocessor, Text2TextGenerationPreprocessor, Tokenize, | |||||
| WordSegmentationBlankSetToLabelPreprocessor, | WordSegmentationBlankSetToLabelPreprocessor, | ||||
| ZeroShotClassificationPreprocessor, TextGenerationJiebaPreprocessor, | |||||
| SentencePiecePreprocessor, DialogIntentPredictionPreprocessor, | |||||
| DialogModelingPreprocessor, DialogStateTrackingPreprocessor, | |||||
| ConversationalTextToSqlPreprocessor, | |||||
| MGLMSummarizationPreprocessor, ZeroShotClassificationPreprocessor, | |||||
| TextGenerationJiebaPreprocessor, SentencePiecePreprocessor, | |||||
| DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, | |||||
| DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor, | |||||
| TableQuestionAnsweringPreprocessor, NERPreprocessorViet, | TableQuestionAnsweringPreprocessor, NERPreprocessorViet, | ||||
| NERPreprocessorThai, WordSegmentationPreprocessorThai) | NERPreprocessorThai, WordSegmentationPreprocessorThai) | ||||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | ||||
| @@ -57,6 +57,7 @@ else: | |||||
| 'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor', | 'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor', | ||||
| 'Tokenize', 'Text2TextGenerationPreprocessor', | 'Tokenize', 'Text2TextGenerationPreprocessor', | ||||
| 'WordSegmentationBlankSetToLabelPreprocessor', | 'WordSegmentationBlankSetToLabelPreprocessor', | ||||
| 'MGLMSummarizationPreprocessor', | |||||
| 'ZeroShotClassificationPreprocessor', | 'ZeroShotClassificationPreprocessor', | ||||
| 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', | 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', | ||||
| 'NERPreprocessorViet', 'NERPreprocessorThai', | 'NERPreprocessorViet', 'NERPreprocessorThai', | ||||
| @@ -29,6 +29,7 @@ if TYPE_CHECKING: | |||||
| MultiWOZBPETextField, IntentBPETextField) | MultiWOZBPETextField, IntentBPETextField) | ||||
| from .space_T_en import ConversationalTextToSqlPreprocessor | from .space_T_en import ConversationalTextToSqlPreprocessor | ||||
| from .space_T_cn import TableQuestionAnsweringPreprocessor | from .space_T_cn import TableQuestionAnsweringPreprocessor | ||||
| from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'nlp_base': [ | 'nlp_base': [ | ||||
| @@ -62,6 +63,7 @@ else: | |||||
| 'text_error_correction': [ | 'text_error_correction': [ | ||||
| 'TextErrorCorrectionPreprocessor', | 'TextErrorCorrectionPreprocessor', | ||||
| ], | ], | ||||
| 'mglm_summarization_preprocessor': ['MGLMSummarizationPreprocessor'], | |||||
| 'token_classification_thai_preprocessor': [ | 'token_classification_thai_preprocessor': [ | ||||
| 'NERPreprocessorThai', | 'NERPreprocessorThai', | ||||
| 'WordSegmentationPreprocessorThai', | 'WordSegmentationPreprocessorThai', | ||||
| @@ -0,0 +1,32 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import os.path as osp | |||||
| import re | |||||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||||
| from modelscope.metainfo import Models, Preprocessors | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||||
| from modelscope.utils.config import Config, ConfigFields | |||||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | |||||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | |||||
| from modelscope.utils.logger import get_logger | |||||
| from modelscope.utils.nlp import import_external_nltk_data | |||||
| from modelscope.utils.type_assert import type_assert | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.nlp, module_name=Preprocessors.mglm_summarization) | |||||
| class MGLMSummarizationPreprocessor(Preprocessor): | |||||
| def __init__(self, *args, **kwargs): | |||||
| """preprocess the data | |||||
| Args: | |||||
| model_dir (str): model path | |||||
| """ | |||||
| super().__init__(*args, **kwargs) | |||||
| @type_assert(object, (str, tuple, Dict)) | |||||
| def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: | |||||
| return data | |||||
| @@ -1,18 +1,25 @@ | |||||
| boto3 | |||||
| en_core_web_sm>=2.3.5 | en_core_web_sm>=2.3.5 | ||||
| fasttext | |||||
| filelock | |||||
| ftfy | |||||
| jieba>=0.42.1 | jieba>=0.42.1 | ||||
| megatron_util | |||||
| matplotlib | |||||
| nltk | |||||
| pai-easynlp | pai-easynlp | ||||
| pandas | |||||
| # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | ||||
| protobuf>=3.19.0,<3.21.0 | protobuf>=3.19.0,<3.21.0 | ||||
| pythainlp | pythainlp | ||||
| pyvi | pyvi | ||||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | |||||
| # which introduced compatability issues that are being investigated | |||||
| rouge_score<=0.0.4 | |||||
| regex | |||||
| sacremoses>=0.0.41 | sacremoses>=0.0.41 | ||||
| scikit_learn | |||||
| sentencepiece | |||||
| seqeval | seqeval | ||||
| spacy>=2.3.5 | spacy>=2.3.5 | ||||
| subword_nmt>=0.3.8 | subword_nmt>=0.3.8 | ||||
| termcolor | |||||
| text2sql_lgesql | text2sql_lgesql | ||||
| tokenizers | tokenizers | ||||
| transformers>=4.12.0 | transformers>=4.12.0 | ||||
| @@ -0,0 +1,47 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import unittest | |||||
| from modelscope.models import Model | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.preprocessors import MGLMSummarizationPreprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class mGLMTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.output_dir = 'unittest_output' | |||||
| os.makedirs(self.output_dir, exist_ok=True) | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||||
| def test_run_with_mglm_with_name(self): | |||||
| model = 'ZhipuAI/Multilingual-GLM-Summarization-zh' | |||||
| preprocessor = MGLMSummarizationPreprocessor() | |||||
| pipe = pipeline( | |||||
| task=Tasks.text_summarization, | |||||
| model=model, | |||||
| preprocessor=preprocessor, | |||||
| ) | |||||
| result = pipe( | |||||
| '据中国载人航天工程办公室消息,北京时间2022年10月25日,梦天实验舱与长征五号B遥四运载火箭组合体已转运至发射区。后续将按计划开展发射前各项功能检查和联合测试等工作,计划于近日择机实施发射。目前,文昌航天发射场设施设备状态良好,参试各单位正在加紧开展任务准备,全力以赴确保空间站建造任务决战决胜。' # noqa | |||||
| ) | |||||
| print(result) | |||||
| model = 'ZhipuAI/Multilingual-GLM-Summarization-en' | |||||
| preprocessor = MGLMSummarizationPreprocessor() | |||||
| pipe = pipeline( | |||||
| task=Tasks.text_summarization, | |||||
| model=model, | |||||
| preprocessor=preprocessor, | |||||
| ) | |||||
| result = pipe( | |||||
| '据中国载人航天工程办公室消息,北京时间2022年10月25日,梦天实验舱与长征五号B遥四运载火箭组合体已转运至发射区。后续将按计划开展发射前各项功能检查和联合测试等工作,计划于近日择机实施发射。目前,文昌航天发射场设施设备状态良好,参试各单位正在加紧开展任务准备,全力以赴确保空间站建造任务决战决胜。' # noqa | |||||
| ) | |||||
| print(result) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||