* mglm init * add mglm requirements Co-authored-by: Yufeng <zhuyufeng@gmail.com> Co-authored-by: wenmeng.zwm <wenmeng.zwm@alibaba-inc.com>master
| @@ -82,6 +82,7 @@ class Models(object): | |||
| bert_for_ds = 'bert-for-document-segmentation' | |||
| ponet = 'ponet' | |||
| T5 = 'T5' | |||
| mglm = 'mglm' | |||
| bloom = 'bloom' | |||
| # audio models | |||
| @@ -251,6 +252,7 @@ class Pipelines(object): | |||
| relation_extraction = 'relation-extraction' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| mglm_text_summarization = 'mglm-text-summarization' | |||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
| @@ -376,6 +378,7 @@ class Preprocessors(object): | |||
| re_tokenizer = 're-tokenizer' | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| mglm_summarization = 'mglm-summarization' | |||
| sentence_piece = 'sentence-piece' | |||
| # audio preprocessor | |||
| @@ -35,6 +35,7 @@ if TYPE_CHECKING: | |||
| SbertTokenizerFast, | |||
| ) | |||
| from .T5 import T5ForConditionalGeneration | |||
| from .mglm import MGLMForTextSummarization | |||
| from .task_models import ( | |||
| FeatureExtractionModel, | |||
| InformationExtractionModel, | |||
| @@ -106,6 +107,7 @@ else: | |||
| ], | |||
| 'sentence_embedding': ['SentenceEmbedding'], | |||
| 'T5': ['T5ForConditionalGeneration'], | |||
| 'mglm': ['MGLMForTextSummarization'], | |||
| 'gpt_neo': ['GPTNeoModel'], | |||
| 'bloom': ['BloomModel'], | |||
| } | |||
| @@ -0,0 +1,22 @@ | |||
| # Modified by Zhipu.AI | |||
| # Original Copyright (c) Alibaba, Inc. and its affiliates. | |||
| from typing import TYPE_CHECKING | |||
| from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .mglm_for_text_summarization import mGlmForSummarization | |||
| else: | |||
| _import_structure = { | |||
| 'mglm_for_text_summarization': ['MGLMForTextSummarization'], | |||
| } | |||
| import sys | |||
| sys.modules[__name__] = LazyImportModule( | |||
| __name__, | |||
| globals()['__file__'], | |||
| _import_structure, | |||
| module_spec=__spec__, | |||
| extra_objects={}, | |||
| ) | |||
| @@ -0,0 +1,793 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """argparser configuration""" | |||
| import argparse | |||
| import os | |||
| import deepspeed | |||
| import json | |||
| import torch | |||
| from .utils import get_hostname | |||
| def add_model_config_args(parser): | |||
| """Model arguments""" | |||
| group = parser.add_argument_group('model', 'model configuration') | |||
| group.add_argument( | |||
| '--transformer-xl', | |||
| action='store_true', | |||
| help='use transformer-xl for training') | |||
| group.add_argument( | |||
| '--pretrained-bert', | |||
| action='store_true', | |||
| help='use a pretrained bert-large-uncased model instead' | |||
| 'of initializing from scratch. See ' | |||
| '--tokenizer-model-type to specify which pretrained ' | |||
| 'BERT model to use') | |||
| group.add_argument( | |||
| '--encoder-decoder', | |||
| action='store_true', | |||
| help='use the encoder-decoder architecture for blocklm') | |||
| group.add_argument( | |||
| '--attention-dropout', | |||
| type=float, | |||
| default=0.1, | |||
| help='dropout probability for attention weights') | |||
| group.add_argument( | |||
| '--num-attention-heads', | |||
| type=int, | |||
| default=16, | |||
| help='num of transformer attention heads') | |||
| group.add_argument( | |||
| '--hidden-size', type=int, default=1024, help='tansformer hidden size') | |||
| group.add_argument( | |||
| '--intermediate-size', | |||
| type=int, | |||
| default=None, | |||
| help='transformer embedding dimension for FFN' | |||
| 'set to 4*`--hidden-size` if it is None') | |||
| group.add_argument( | |||
| '--num-layers', type=int, default=24, help='num decoder layers') | |||
| group.add_argument( | |||
| '--layernorm-epsilon', | |||
| type=float, | |||
| default=1e-5, | |||
| help='layer norm epsilon') | |||
| group.add_argument( | |||
| '--hidden-dropout', | |||
| type=float, | |||
| default=0.1, | |||
| help='dropout probability for hidden state transformer') | |||
| group.add_argument( | |||
| '--output-dropout', | |||
| type=float, | |||
| default=0.1, | |||
| help='dropout probability for pooled output') | |||
| group.add_argument( | |||
| '--max-position-embeddings', | |||
| type=int, | |||
| default=512, | |||
| help='maximum number of position embeddings to use') | |||
| group.add_argument( | |||
| '--vocab-size', | |||
| type=int, | |||
| default=250112, | |||
| help='vocab size to use for non-character-level ' | |||
| 'tokenization. This value will only be used when ' | |||
| 'creating a tokenizer') | |||
| group.add_argument( | |||
| '--deep-init', | |||
| action='store_true', | |||
| help='initialize bert model similar to gpt2 model.' | |||
| 'scales initialization of projection layers by a ' | |||
| 'factor of 1/sqrt(2N). Necessary to train bert ' | |||
| 'models larger than BERT-Large.') | |||
| group.add_argument( | |||
| '--make-vocab-size-divisible-by', | |||
| type=int, | |||
| default=128, | |||
| help='Pad the vocab size to be divisible by this value.' | |||
| 'This is added for computational efficieny reasons.') | |||
| group.add_argument( | |||
| '--cpu-optimizer', action='store_true', help='Run optimizer on CPU') | |||
| group.add_argument( | |||
| '--cpu_torch_adam', | |||
| action='store_true', | |||
| help='Use Torch Adam as optimizer on CPU.') | |||
| return parser | |||
| def add_fp16_config_args(parser): | |||
| """Mixed precision arguments.""" | |||
| group = parser.add_argument_group('fp16', 'fp16 configurations') | |||
| group.add_argument( | |||
| '--fp16', action='store_true', help='Run model in fp16 mode') | |||
| group.add_argument( | |||
| '--fp32-embedding', action='store_true', help='embedding in fp32') | |||
| group.add_argument( | |||
| '--fp32-layernorm', action='store_true', help='layer norm in fp32') | |||
| group.add_argument( | |||
| '--fp32-tokentypes', | |||
| action='store_true', | |||
| help='embedding token types in fp32') | |||
| group.add_argument( | |||
| '--fp32-allreduce', action='store_true', help='all-reduce in fp32') | |||
| group.add_argument( | |||
| '--hysteresis', | |||
| type=int, | |||
| default=2, | |||
| help='hysteresis for dynamic loss scaling') | |||
| group.add_argument( | |||
| '--loss-scale', | |||
| type=float, | |||
| default=None, | |||
| help='Static loss scaling, positive power of 2 ' | |||
| 'values can improve fp16 convergence. If None, dynamic' | |||
| 'loss scaling is used.') | |||
| group.add_argument( | |||
| '--loss-scale-window', | |||
| type=float, | |||
| default=1000, | |||
| help='Window over which to raise/lower dynamic scale') | |||
| group.add_argument( | |||
| '--min-scale', | |||
| type=float, | |||
| default=1, | |||
| help='Minimum loss scale for dynamic loss scale') | |||
| group.add_argument('--attention-scale', type=float, default=1.0) | |||
| return parser | |||
| def add_training_args(parser): | |||
| """Training arguments.""" | |||
| group = parser.add_argument_group('train', 'training configurations') | |||
| group.add_argument( | |||
| '--experiment-name', | |||
| type=str, | |||
| default='gpt-345M', | |||
| help='The experiment name for summary and checkpoint') | |||
| group.add_argument( | |||
| '--batch-size', type=int, default=4, help='Data Loader batch size') | |||
| group.add_argument( | |||
| '--gradient-accumulation-steps', | |||
| type=int, | |||
| default=1, | |||
| help='Data Loader batch size') | |||
| group.add_argument( | |||
| '--weight-decay', | |||
| type=float, | |||
| default=0.01, | |||
| help='weight decay coefficient for L2 regularization') | |||
| group.add_argument( | |||
| '--checkpoint-activations', | |||
| action='store_true', | |||
| help='checkpoint activation to allow for training ' | |||
| 'with larger models and sequences') | |||
| group.add_argument( | |||
| '--checkpoint-num-layers', | |||
| type=int, | |||
| default=1, | |||
| help='chunk size (number of layers) for checkpointing') | |||
| group.add_argument( | |||
| '--deepspeed-activation-checkpointing', | |||
| action='store_true', | |||
| help='uses activation checkpointing from deepspeed') | |||
| group.add_argument( | |||
| '--epochs', | |||
| type=int, | |||
| default=None, | |||
| help='Number of finetunning epochs. Zero results in evaluation only.') | |||
| group.add_argument( | |||
| '--clip-grad', type=float, default=1.0, help='gradient clipping') | |||
| group.add_argument( | |||
| '--train-iters', | |||
| type=int, | |||
| default=0, | |||
| help='total number of iterations to train over all training runs') | |||
| group.add_argument('--label-smoothing', type=float, default=0.0) | |||
| group.add_argument( | |||
| '--log-interval', type=int, default=100, help='report interval') | |||
| group.add_argument( | |||
| '--summary-dir', | |||
| type=str, | |||
| default='', | |||
| help='The directory to store the summary') | |||
| group.add_argument('--seed', type=int, default=1234, help='random seed') | |||
| # Batch producer arguments | |||
| group.add_argument( | |||
| '--reset-position-ids', | |||
| action='store_true', | |||
| help='Reset posistion ids after end-of-document token.') | |||
| group.add_argument( | |||
| '--reset-attention-mask', | |||
| action='store_true', | |||
| help='Reset self attention maske after ' | |||
| 'end-of-document token.') | |||
| # Learning rate. | |||
| group.add_argument( | |||
| '--lr-decay-iters', | |||
| type=int, | |||
| default=None, | |||
| help='number of iterations to decay LR over,' | |||
| ' If None defaults to `--train-iters`*`--epochs`') | |||
| group.add_argument( | |||
| '--lr-decay-style', | |||
| type=str, | |||
| default='linear', | |||
| choices=['constant', 'linear', 'cosine', 'exponential'], | |||
| help='learning rate decay function') | |||
| group.add_argument('--lr-decay-ratio', type=float, default=0.1) | |||
| group.add_argument( | |||
| '--lr', type=float, default=1.0e-4, help='initial learning rate') | |||
| group.add_argument( | |||
| '--warmup', | |||
| type=float, | |||
| default=0.01, | |||
| help='percentage of data to warmup on (.01 = 1% of all ' | |||
| 'training iters). Default 0.01') | |||
| group.add_argument( | |||
| '--switch-linear', | |||
| action='store_true', | |||
| help='Switch to linear decay for cosine decay') | |||
| # model checkpointing | |||
| group.add_argument( | |||
| '--save', | |||
| type=str, | |||
| default=None, | |||
| help='Output directory to save checkpoints to.') | |||
| group.add_argument('--new-save-directory', action='store_true') | |||
| group.add_argument( | |||
| '--save-epoch', | |||
| type=int, | |||
| default=1, | |||
| help='number of epochs between saves') | |||
| group.add_argument( | |||
| '--save-interval', | |||
| type=int, | |||
| default=5000, | |||
| help='number of iterations between saves') | |||
| group.add_argument( | |||
| '--no-save-optim', | |||
| action='store_true', | |||
| help='Do not save current optimizer.') | |||
| group.add_argument( | |||
| '--no-save-rng', | |||
| action='store_true', | |||
| help='Do not save current rng state.') | |||
| group.add_argument( | |||
| '--load', | |||
| type=str, | |||
| default=None, | |||
| help='Path to a directory containing a model checkpoint.') | |||
| group.add_argument( | |||
| '--no-load-optim', | |||
| action='store_true', | |||
| help='Do not load optimizer when loading checkpoint.') | |||
| group.add_argument( | |||
| '--no-load-rng', | |||
| action='store_true', | |||
| help='Do not load rng state when loading checkpoint.') | |||
| group.add_argument( | |||
| '--no-load-lr-scheduler', | |||
| action='store_true', | |||
| help='Do not load lr scheduler when loading checkpoint.') | |||
| group.add_argument( | |||
| '--no-deepspeed-load', | |||
| action='store_true', | |||
| help='Not use deepspeed when loading checkpoint') | |||
| group.add_argument( | |||
| '--finetune', | |||
| action='store_true', | |||
| help='Load model for finetuning. Do not load optimizer ' | |||
| 'or rng state from checkpoint and set iteration to 0. ' | |||
| 'Assumed when loading a release checkpoint.') | |||
| group.add_argument( | |||
| '--resume-dataloader', | |||
| action='store_true', | |||
| help='Resume the dataloader when resuming training. ' | |||
| 'Does not apply to tfrecords dataloader, try resuming' | |||
| 'with a different seed in this case.') | |||
| # distributed training args | |||
| group.add_argument( | |||
| '--distributed-backend', | |||
| default='nccl', | |||
| help= | |||
| 'which backend to use for distributed training. One of [gloo, nccl]', | |||
| choices=['nccl', 'gloo']) | |||
| group.add_argument( | |||
| '--DDP-impl', | |||
| default='torch', | |||
| choices=['local', 'torch', 'none'], | |||
| help='which DistributedDataParallel implementation to use.') | |||
| group.add_argument( | |||
| '--local_rank', | |||
| type=int, | |||
| default=None, | |||
| help='local rank passed from distributed launcher') | |||
| # BlockLM training args | |||
| group.add_argument( | |||
| '--block-lm', | |||
| action='store_true', | |||
| help='whether use the BlockLM pre-training') | |||
| group.add_argument( | |||
| '--masked-lm', | |||
| action='store_true', | |||
| help='whether to use the mlm objective') | |||
| group.add_argument('--bert-prob', type=float, default=0.5) | |||
| group.add_argument('--gpt-infill-prob', type=float, default=0.5) | |||
| group.add_argument('--gpt-min-ratio', type=float, default=0.5) | |||
| group.add_argument('--gap-sentence-prob', type=float, default=0.0) | |||
| group.add_argument('--gap-sentence-ratio', type=float, default=0.15) | |||
| group.add_argument('--avg-block-length', type=int, default=3) | |||
| group.add_argument('--short-seq-prob', type=float, default=0.0) | |||
| group.add_argument('--single-span-prob', type=float, default=0.0) | |||
| group.add_argument( | |||
| '--task-mask', | |||
| action='store_true', | |||
| help='Use different mask for generation and blank filling') | |||
| group.add_argument( | |||
| '--no-shuffle-block', | |||
| action='store_true', | |||
| help='not shuffle the blocks when filling the blank') | |||
| group.add_argument( | |||
| '--no-block-position', | |||
| action='store_true', | |||
| help='Use (rough) absolute positions instead of block positions') | |||
| group.add_argument( | |||
| '--sentinel-token', | |||
| action='store_true', | |||
| help='Use sentinel (mask) tokens to replace 2d position encoding') | |||
| group.add_argument('--block-mask-prob', type=float, default=0.0) | |||
| group.add_argument('--context-mask-ratio', type=float, default=0.0) | |||
| group.add_argument( | |||
| '--random-position', | |||
| action='store_true', | |||
| help='Use random start position to cover all the position embeddings') | |||
| return parser | |||
| def add_evaluation_args(parser): | |||
| """Evaluation arguments.""" | |||
| group = parser.add_argument_group('validation', | |||
| 'validation configurations') | |||
| group.add_argument( | |||
| '--eval-batch-size', | |||
| type=int, | |||
| default=None, | |||
| help='Data Loader batch size for evaluation datasets.' | |||
| 'Defaults to `--batch-size`') | |||
| group.add_argument( | |||
| '--eval-iters', | |||
| type=int, | |||
| default=100, | |||
| help='number of iterations to run for evaluation' | |||
| 'validation/test for') | |||
| group.add_argument( | |||
| '--eval-interval', | |||
| type=int, | |||
| default=1000, | |||
| help='interval between running evaluation on validation set') | |||
| group.add_argument( | |||
| '--eval-epoch', | |||
| type=int, | |||
| default=1, | |||
| help='epoch between running evaluation on validation set') | |||
| group.add_argument( | |||
| '--eval-seq-length', | |||
| type=int, | |||
| default=None, | |||
| help='Maximum sequence length to process for ' | |||
| 'evaluation. Defaults to `--seq-length`') | |||
| group.add_argument( | |||
| '--eval-max-preds-per-seq', | |||
| type=int, | |||
| default=None, | |||
| help='Maximum number of predictions to use for ' | |||
| 'evaluation. Defaults to ' | |||
| 'math.ceil(`--eval-seq-length`*.15/10)*10') | |||
| group.add_argument('--overlapping-eval', type=int, default=32) | |||
| return parser | |||
| def add_text_generate_args(parser): | |||
| """Text generate arguments.""" | |||
| group = parser.add_argument_group('Text generation', 'configurations') | |||
| group.add_argument('--temperature', type=float, default=1.0) | |||
| group.add_argument('--top_p', type=float, default=0.0) | |||
| group.add_argument('--top_k', type=int, default=0) | |||
| group.add_argument('--out-seq-length', type=int, default=256) | |||
| group.add_argument('--num-beams', type=int, default=1) | |||
| group.add_argument('--length-penalty', type=float, default=0.0) | |||
| group.add_argument('--no-repeat-ngram-size', type=int, default=0) | |||
| group.add_argument('--min-tgt-length', type=int, default=0) | |||
| group.add_argument('--select-topk', action='store_true') | |||
| group.add_argument('--blank-maskratio', type=float, default=0.1) | |||
| return parser | |||
| def add_data_args(parser): | |||
| """Train/valid/test data arguments.""" | |||
| group = parser.add_argument_group('data', 'data configurations') | |||
| group.add_argument( | |||
| '--model-parallel-size', | |||
| type=int, | |||
| default=1, | |||
| help='size of the model parallel.') | |||
| group.add_argument( | |||
| '--shuffle', | |||
| action='store_true', | |||
| help='Shuffle data. Shuffling is deterministic ' | |||
| 'based on seed and current epoch.') | |||
| group.add_argument('--filter-english', action='store_true') | |||
| group.add_argument( | |||
| '--train-data', | |||
| nargs='+', | |||
| default=None, | |||
| help='Whitespace separated filenames or corpora names ' | |||
| 'for training.') | |||
| group.add_argument( | |||
| '--valid-data', | |||
| nargs='*', | |||
| default=None, | |||
| help="""Filename for validation data.""") | |||
| group.add_argument( | |||
| '--test-data', | |||
| nargs='*', | |||
| default=None, | |||
| help="""Filename for testing""") | |||
| group.add_argument( | |||
| '--data-dir', | |||
| type=str, | |||
| default=None, | |||
| help='The data path to all the data files') | |||
| group.add_argument( | |||
| '--input-data-sizes-file', | |||
| type=str, | |||
| default='sizes.txt', | |||
| help='the filename containing all the shards sizes') | |||
| group.add_argument( | |||
| '--delim', default=',', help='delimiter used to parse csv data files') | |||
| group.add_argument( | |||
| '--text-key', | |||
| default='sentence', | |||
| help='key to use to extract text from json/csv') | |||
| group.add_argument( | |||
| '--eval-text-key', | |||
| default=None, | |||
| help='key to use to extract text from ' | |||
| 'json/csv evaluation datasets') | |||
| group.add_argument( | |||
| '--split', | |||
| default='1000,1,1', | |||
| help='comma-separated list of proportions for training,' | |||
| ' validation, and test split') | |||
| group.add_argument( | |||
| '--no-lazy-loader', | |||
| action='store_true', | |||
| help='whether to lazy read the data set') | |||
| group.add_argument('--half-lazy-loader', action='store_true') | |||
| group.add_argument( | |||
| '--loader-scatter', | |||
| type=int, | |||
| default=None, | |||
| help='Number of scatters to use for dataloaders') | |||
| group.add_argument( | |||
| '--loose-json', | |||
| action='store_true', | |||
| help='Use loose json (one json-formatted string per ' | |||
| 'newline), instead of tight json (data file is one ' | |||
| 'json string)') | |||
| group.add_argument( | |||
| '--presplit-sentences', | |||
| action='store_true', | |||
| help='Dataset content consists of documents where ' | |||
| 'each document consists of newline separated sentences') | |||
| group.add_argument( | |||
| '--num-workers', | |||
| type=int, | |||
| default=2, | |||
| help="""Number of workers to use for dataloading""") | |||
| group.add_argument( | |||
| '--tokenizer-model-type', | |||
| type=str, | |||
| default=None, | |||
| help="Model type to use for sentencepiece tokenization \ | |||
| (one of ['bpe', 'char', 'unigram', 'word']) or \ | |||
| bert vocab to use for BertWordPieceTokenizer (one of \ | |||
| ['bert-large-uncased', 'bert-large-cased', etc.])") | |||
| group.add_argument( | |||
| '--tokenizer-path', | |||
| type=str, | |||
| default='tokenizer.model', | |||
| help='path used to save/load sentencepiece tokenization ' | |||
| 'models') | |||
| group.add_argument( | |||
| '--tokenizer-type', | |||
| type=str, | |||
| default='BertWordPieceTokenizer', | |||
| choices=[ | |||
| 'CharacterLevelTokenizer', 'SentencePieceTokenizer', | |||
| 'BertWordPieceTokenizer', 'GPT2BPETokenizer', 'ChineseSPTokenizer' | |||
| ], | |||
| help='what type of tokenizer to use') | |||
| group.add_argument('--no-pre-tokenize', action='store_true') | |||
| group.add_argument( | |||
| '--cache-dir', | |||
| default=None, | |||
| type=str, | |||
| help='Where to store pre-trained BERT downloads') | |||
| group.add_argument( | |||
| '--use-tfrecords', | |||
| action='store_true', | |||
| help='load `--train-data`, `--valid-data`, ' | |||
| '`--test-data` from BERT tf records instead of ' | |||
| 'normal data pipeline') | |||
| group.add_argument( | |||
| '--seq-length', | |||
| type=int, | |||
| default=512, | |||
| help='Maximum sequence length to process') | |||
| group.add_argument( | |||
| '--mem-length', | |||
| type=int, | |||
| default=0, | |||
| help='The memory length to preserve') | |||
| group.add_argument( | |||
| '--max-preds-per-seq', | |||
| type=int, | |||
| default=None, | |||
| help='Maximum number of predictions to use per sequence.' | |||
| 'Defaults to math.ceil(`--seq-length`*.15/10)*10.' | |||
| 'MUST BE SPECIFIED IF `--use-tfrecords` is True.') | |||
| group.add_argument('--non-sentence-start', type=float, default=0.0) | |||
| group.add_argument( | |||
| '--sample-one-document', | |||
| action='store_true', | |||
| help='only sample one document in one sample') | |||
| group.add_argument( | |||
| '--load-splits', | |||
| type=str, | |||
| default=None, | |||
| help='The path to load split indices from') | |||
| group.add_argument( | |||
| '--save-splits', | |||
| type=str, | |||
| default=None, | |||
| help='The path to save split indices to') | |||
| group.add_argument( | |||
| '--save-test-data', | |||
| type=str, | |||
| default=None, | |||
| help='The path to save the test data') | |||
| group.add_argument( | |||
| '--multi-task-data', | |||
| nargs='*', | |||
| default=None, | |||
| help='Downsteam task names for multi-task pre-training') | |||
| group.add_argument( | |||
| '--multi-task-ratio', | |||
| type=float, | |||
| default=0.0, | |||
| help='Ratio for multi-task pre-training') | |||
| group.add_argument('--multi-seq-length', type=int, default=None) | |||
| group.add_argument('--multi-batch-size', type=int, default=None) | |||
| return parser | |||
| def add_finetune_config_args(parser): | |||
| group = parser.add_argument_group('finetune', 'finetune configurations') | |||
| group.add_argument('--task', type=str, help='Task name.') | |||
| group.add_argument( | |||
| '--load-pretrained', | |||
| type=str, | |||
| help='Load pretrained model', | |||
| default=None) | |||
| group.add_argument( | |||
| '--pool-token', | |||
| type=str, | |||
| choices=['start', 'pad', 'cls'], | |||
| help='The token to pool the sequence representation', | |||
| default='cls') | |||
| group.add_argument( | |||
| '--cloze-eval', | |||
| action='store_true', | |||
| help='Evaluation dataset with cloze task') | |||
| group.add_argument( | |||
| '--multi-token', | |||
| action='store_true', | |||
| help='Use multi token for cloze evaluation') | |||
| group.add_argument( | |||
| '--segment-length', | |||
| type=int, | |||
| default=0, | |||
| help='The maximum segment length for cloze evaluation') | |||
| group.add_argument( | |||
| '--loss-func', | |||
| type=str, | |||
| choices=['cross_entropy', 'hinge', 'generative', 'mix'], | |||
| default='cross_entropy') | |||
| group.add_argument('--block-lm-ratio', type=float, default=0.0) | |||
| group.add_argument( | |||
| '--adapet', | |||
| action='store_true', | |||
| help='Use the decoupled cross entropy loss in AdaPET') | |||
| group.add_argument('--pattern-id', type=int, default=0) | |||
| group.add_argument( | |||
| '--fast-decode', | |||
| action='store_true', | |||
| help= | |||
| 'Fast decode for multi-token cloze. Can only be used without checkpoint activation.' | |||
| ) | |||
| group.add_argument('--few-superglue', action='store_true') | |||
| group.add_argument( | |||
| '--eval-valid', | |||
| action='store_true', | |||
| help='Whether evaluate on the valid set') | |||
| group.add_argument('--validation-metric', type=str, default=None) | |||
| group.add_argument( | |||
| '--unidirectional', | |||
| action='store_true', | |||
| help='Use the left to right language model') | |||
| group.add_argument('--src-seq-length', type=int, default=None) | |||
| group.add_argument('--tgt-seq-length', type=int, default=None) | |||
| group.add_argument('--adam-beta1', type=float, default=0.9) | |||
| group.add_argument('--adam-beta2', type=float, default=0.999) | |||
| group.add_argument('--adam-eps', type=float, default=1e-8) | |||
| group.add_argument( | |||
| '--optimizer', type=str, choices=['adam', 'adafactor'], default='adam') | |||
| group.add_argument('--wsc-negative', action='store_true') | |||
| group.add_argument('--overwrite', action='store_true') | |||
| group.add_argument('--no-validation', action='store_true') | |||
| # Continuous prompt arguments | |||
| group.add_argument( | |||
| '--continuous-prompt', | |||
| action='store_true', | |||
| help='Use continuous prompt for PET') | |||
| group.add_argument('--num-prompt-tokens', type=int, default=0) | |||
| group.add_argument( | |||
| '--prompt-func', default='lstm', choices=['lstm', 'mlp', 'none']) | |||
| group.add_argument( | |||
| '--freeze-transformer', action='store_true', default=False) | |||
| group.add_argument('--tune-prefix-layers', type=int, default=None) | |||
| group.add_argument('--prefix-prompt', type=int, default=0) | |||
| group.add_argument('--prompt-init', action='store_true', default=False) | |||
| return parser | |||
| def get_args(): | |||
| """Parse all the args.""" | |||
| parser = argparse.ArgumentParser(description='PyTorch BERT Model') | |||
| parser = add_model_config_args(parser) | |||
| parser = add_fp16_config_args(parser) | |||
| parser = add_training_args(parser) | |||
| parser = add_evaluation_args(parser) | |||
| parser = add_text_generate_args(parser) | |||
| parser = add_data_args(parser) | |||
| parser = add_finetune_config_args(parser) | |||
| # Include DeepSpeed configuration arguments | |||
| parser = deepspeed.add_config_arguments(parser) | |||
| args = parser.parse_args(args=[]) | |||
| if not args.train_data and not args.data_dir: | |||
| print('WARNING: No training data specified') | |||
| args.cuda = torch.cuda.is_available() | |||
| args.rank = int(os.getenv('RANK', '0')) | |||
| args.world_size = int(os.getenv('WORLD_SIZE', '1')) | |||
| if hasattr(args, 'deepspeed_mpi') and args.deepspeed_mpi: | |||
| mpi_define_env(args) | |||
| elif os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'): | |||
| # We are using (OpenMPI) mpirun for launching distributed data parallel processes | |||
| local_rank = int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')) | |||
| local_size = int(os.getenv('OMPI_COMM_WORLD_LOCAL_SIZE')) | |||
| # Possibly running with Slurm | |||
| num_nodes = int(os.getenv('SLURM_JOB_NUM_NODES', '1')) | |||
| nodeid = int(os.getenv('SLURM_NODEID', '0')) | |||
| args.local_rank = local_rank | |||
| args.rank = nodeid * local_size + local_rank | |||
| args.world_size = num_nodes * local_size | |||
| args.model_parallel_size = min(args.model_parallel_size, args.world_size) | |||
| if args.rank == 0: | |||
| print('using world size: {} and model-parallel size: {} '.format( | |||
| args.world_size, args.model_parallel_size)) | |||
| args.dynamic_loss_scale = False | |||
| if args.loss_scale is None: | |||
| args.dynamic_loss_scale = True | |||
| if args.rank == 0: | |||
| print(' > using dynamic loss scaling') | |||
| # The args fp32_* or fp16_* meant to be active when the | |||
| # args fp16 is set. So the default behaviour should all | |||
| # be false. | |||
| if not args.fp16: | |||
| args.fp32_embedding = False | |||
| args.fp32_tokentypes = False | |||
| args.fp32_layernorm = False | |||
| if hasattr(args, 'deepspeed' | |||
| ) and args.deepspeed and args.deepspeed_config is not None: | |||
| with open(args.deepspeed_config) as file: | |||
| deepspeed_config = json.load(file) | |||
| if 'train_micro_batch_size_per_gpu' in deepspeed_config: | |||
| args.batch_size = deepspeed_config[ | |||
| 'train_micro_batch_size_per_gpu'] | |||
| if 'gradient_accumulation_steps' in deepspeed_config: | |||
| args.gradient_accumulation_steps = deepspeed_config[ | |||
| 'gradient_accumulation_steps'] | |||
| else: | |||
| args.gradient_accumulation_steps = 1 | |||
| if 'optimizer' in deepspeed_config: | |||
| optimizer_params_config = deepspeed_config['optimizer'].get( | |||
| 'params', {}) | |||
| args.lr = optimizer_params_config.get('lr', args.lr) | |||
| args.weight_decay = optimizer_params_config.get( | |||
| 'weight_decay', args.weight_decay) | |||
| return args | |||
| def mpi_define_env(args): | |||
| from mpi4py import MPI | |||
| comm = MPI.COMM_WORLD | |||
| rank = comm.Get_rank() | |||
| world_size = comm.Get_size() | |||
| master_addr = None | |||
| if rank == 0: | |||
| master_addr = get_hostname() | |||
| master_addr = comm.bcast(master_addr, root=0) | |||
| # Determine local rank by assuming hostnames are unique | |||
| proc_name = MPI.Get_processor_name() | |||
| all_procs = comm.allgather(proc_name) | |||
| local_rank = sum([i == proc_name for i in all_procs[:rank]]) | |||
| os.environ['RANK'] = str(rank) | |||
| os.environ['WORLD_SIZE'] = str(world_size) | |||
| args.local_rank = local_rank | |||
| args.world_size = world_size | |||
| args.rank = rank | |||
| os.environ['MASTER_ADDR'] = master_addr | |||
| os.environ[ | |||
| 'MASTER_PORT'] = '29500' # TORCH_DISTRIBUTED_DEFAULT_PORT = 29500 | |||
| print( | |||
| 'Discovered MPI settings of world_rank={}, local_rank={}, world_size={}, master_addr={}, master_port={}' | |||
| .format(os.environ['RANK'], args.local_rank, os.environ['WORLD_SIZE'], | |||
| os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])) | |||
| @@ -0,0 +1,625 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import copy | |||
| import math | |||
| import random | |||
| import numpy as np | |||
| import torch | |||
| import torch.utils.data | |||
| from scipy.stats import poisson | |||
| from . import mpu | |||
| from .utils import print_rank_0 | |||
| def rindex(lst, val, start=None): | |||
| if start is None: | |||
| start = len(lst) - 1 | |||
| for i in range(start, -1, -1): | |||
| if lst[i] == val: | |||
| return i | |||
| return -1 | |||
| def index_in_list(lst, val, start=None): | |||
| if start is None: | |||
| start = 0 | |||
| for i in range(start, len(lst)): | |||
| if lst[i] == val: | |||
| return i | |||
| return -1 | |||
| class ConstructBlockStrategy: | |||
| def __init__(self, | |||
| args, | |||
| tokenizer, | |||
| max_seq_length, | |||
| bert_prob=1.0, | |||
| gap_sentence_prob=0.0, | |||
| gpt_infill_prob=0.5, | |||
| gpt_min_ratio=0.5, | |||
| bert_ratio=0.15, | |||
| gap_sentence_ratio=0.15, | |||
| average_block_length=3, | |||
| max_block_length=40, | |||
| block_mask_prob=0.0, | |||
| context_mask_ratio=0.0, | |||
| context_mask_range=3, | |||
| short_seq_prob=0.0, | |||
| single_span_prob=0.0, | |||
| block_position_encoding=True, | |||
| encoder_decoder=False, | |||
| shuffle_blocks=True, | |||
| sentinel_token=False, | |||
| task_mask=False, | |||
| random_position=False, | |||
| masked_lm=False): | |||
| self.eod_token = args.eod_token | |||
| self.tokenizer = tokenizer | |||
| self.count = 0 | |||
| self.max_seq_length = max_seq_length | |||
| self.rank = mpu.get_data_parallel_rank() | |||
| self.world_size = mpu.get_data_parallel_world_size() | |||
| # self.rank = 0 | |||
| # self.world_size = 1 | |||
| assert 0.0 <= bert_prob <= 1.0 | |||
| self.bert_prob = bert_prob | |||
| self.gap_sentence_prob = gap_sentence_prob | |||
| self.gpt_prob = 1 - bert_prob - gap_sentence_prob | |||
| assert self.gpt_prob >= -1e-10 | |||
| self.infill_prob = gpt_infill_prob | |||
| self.gpt_min_ratio = gpt_min_ratio | |||
| self.bert_ratio = bert_ratio | |||
| self.gap_sentence_ratio = gap_sentence_ratio | |||
| self.block_length_distribution = [ | |||
| poisson.pmf(i, average_block_length) | |||
| for i in range(1, max_block_length) | |||
| ] | |||
| self.block_mask_prob = block_mask_prob | |||
| self.context_mask_ratio = context_mask_ratio | |||
| self.context_mask_range = context_mask_range | |||
| self.short_seq_prob = short_seq_prob | |||
| self.single_span_prob = single_span_prob | |||
| self.block_position_encoding = block_position_encoding | |||
| self.encoder_decoder = encoder_decoder | |||
| self.shuffle_blocks = shuffle_blocks | |||
| self.sentinel_token = sentinel_token | |||
| self.generation_mask = 'gMASK' if task_mask else 'MASK' | |||
| self.generation_mask = self.tokenizer.get_command( | |||
| self.generation_mask).Id | |||
| self.gap_sentence_mask = 'sMASK' if task_mask else 'MASK' | |||
| self.gap_sentence_mask = self.tokenizer.get_command( | |||
| self.gap_sentence_mask).Id | |||
| self.random_position = random_position | |||
| self.masked_lm = masked_lm | |||
| print_rank_0( | |||
| f'BERT prob {self.bert_prob}, gap sent prob {self.gap_sentence_prob}, GPT prob {self.gpt_prob}, infill prob {self.infill_prob}' # noqa | |||
| ) | |||
| print_rank_0( | |||
| f'generation min ratio {self.gpt_min_ratio}, block ratio {self.bert_ratio}, gap sent ratio {self.gap_sentence_ratio}' # noqa | |||
| ) | |||
| print_rank_0( | |||
| f'block length distribution {self.block_length_distribution}') | |||
| print_rank_0( | |||
| f'block mask prob {self.block_mask_prob}, context mask ratio {self.context_mask_ratio}' | |||
| ) | |||
| def contains_sentence_end(self, tok): | |||
| tok = self.tokenizer.IdToToken(tok) | |||
| if '.' in tok: | |||
| return True | |||
| if '?' in tok: | |||
| return True | |||
| if '!' in tok: | |||
| return True | |||
| if ';' in tok: | |||
| return True | |||
| if ':' in tok: | |||
| return True | |||
| if '。' in tok: | |||
| return True | |||
| if '?' in tok: | |||
| return True | |||
| if '!' in tok: | |||
| return True | |||
| if ';' in tok: | |||
| return True | |||
| if '…' in tok: | |||
| return True | |||
| if '\n' in tok: | |||
| return True | |||
| return False | |||
| @staticmethod | |||
| def sample_spans(span_lengths, total_length, rng, offset=0): | |||
| blank_length = total_length - sum(span_lengths) | |||
| m = blank_length - len(span_lengths) + 1 | |||
| places = [rng.randrange(m + 1) for _ in range(len(span_lengths))] | |||
| places.sort() | |||
| spans = [] | |||
| for place, span_length in zip(places, span_lengths): | |||
| start = offset + place | |||
| end = offset + place + span_length | |||
| spans.append((start, end)) | |||
| offset += span_length + 1 | |||
| return spans | |||
| def sample_span_in_document(self, tokens, masked_lengths, rng): | |||
| rng.shuffle(masked_lengths) | |||
| mask_spans = [] | |||
| mask_index = 0 | |||
| indices = [-1] + np.where(tokens == self.eod_token)[0].tolist() | |||
| last_index = len(tokens) | |||
| documents = [] | |||
| for index in reversed(indices): | |||
| start_index = index | |||
| if start_index + 1 < len(tokens) and tokens[ | |||
| start_index + 1] == self.tokenizer.get_command('ENC').Id: | |||
| start_index += 1 | |||
| length = last_index - start_index - 1 | |||
| if last_index == len(tokens) and length > 0: | |||
| length -= 1 | |||
| documents.append((start_index + 1, length)) | |||
| last_index = index | |||
| documents.sort(key=lambda x: x[1]) | |||
| for i, (offset, length) in enumerate(documents): | |||
| if i == len(documents) - 1: | |||
| current_masked_length, current_count = 0, 0 | |||
| while mask_index + current_count < len( | |||
| masked_lengths | |||
| ) and masked_lengths[ | |||
| mask_index + # noqa | |||
| current_count] + current_masked_length + current_count <= length: | |||
| current_masked_length += masked_lengths[mask_index | |||
| + current_count] | |||
| current_count += 1 | |||
| if current_count > 0: | |||
| spans = self.sample_spans( | |||
| masked_lengths[mask_index:mask_index + current_count], | |||
| length, | |||
| rng, | |||
| offset=offset) | |||
| mask_spans += spans | |||
| if mask_index + current_count < len(masked_lengths) - 1: | |||
| print(length, masked_lengths[mask_index:], | |||
| masked_lengths[:mask_index], indices) | |||
| else: | |||
| current_masked_total = int(length * self.bert_ratio) | |||
| current_masked_length, current_count = 0, 0 | |||
| while mask_index + current_count < len( | |||
| masked_lengths | |||
| ) and masked_lengths[ | |||
| mask_index + # noqa | |||
| current_count] + current_masked_length <= current_masked_total: | |||
| current_masked_length += masked_lengths[mask_index | |||
| + current_count] | |||
| current_count += 1 | |||
| if current_count > 0: | |||
| spans = self.sample_spans( | |||
| masked_lengths[mask_index:mask_index + current_count], | |||
| length, | |||
| rng, | |||
| offset=offset) | |||
| mask_spans += spans | |||
| mask_index += current_count | |||
| return mask_spans | |||
| def make_masked_data(self, | |||
| tokens, | |||
| loss_masks, | |||
| attention_mask, | |||
| block_spans, | |||
| rng, | |||
| task='bert'): | |||
| position_ids = np.arange(len(tokens), dtype=np.long) | |||
| targets = copy.deepcopy(tokens) | |||
| mask_id = self.tokenizer.get_command('MASK').Id | |||
| mlm_masks = np.zeros(len(tokens), dtype=np.long) | |||
| for start, end in block_spans: | |||
| for idx in range(start, end): | |||
| tokens[idx] = mask_id | |||
| mlm_masks[start:end] = 1 | |||
| loss_masks = loss_masks * mlm_masks | |||
| return tokens, targets, loss_masks, position_ids | |||
| def make_block_data(self, | |||
| tokens, | |||
| loss_masks, | |||
| attention_mask, | |||
| block_spans, | |||
| rng, | |||
| task='bert'): | |||
| text_length = len(tokens) | |||
| position_ids = np.ones(len(tokens), dtype=np.long) | |||
| for start, end in block_spans: | |||
| position_ids[start + 1:end] = 0 | |||
| position_ids = np.cumsum(position_ids) - 1 | |||
| if self.random_position and position_ids[-1] < self.max_seq_length - 1: | |||
| position_bias = self.max_seq_length - position_ids[-1] | |||
| position_bias = rng.randrange(0, position_bias) | |||
| position_ids = position_ids + position_bias | |||
| if self.encoder_decoder or not self.shuffle_blocks: | |||
| block_spans.sort(key=lambda x: x[0]) | |||
| else: | |||
| rng.shuffle(block_spans) | |||
| if self.sentinel_token: | |||
| block_spans = [(start, end, idx) | |||
| for idx, (start, end) in enumerate(block_spans)] | |||
| else: | |||
| block_spans = [(start, end, 0) for start, end in block_spans] | |||
| target_tokens, target_position_ids, target_block_position_ids, targets = [], [], [], [] | |||
| for start, end, idx in block_spans: | |||
| sop_token = 'sop' if idx == 0 else f'sop{idx}' | |||
| target_tokens.append([self.tokenizer.get_command(sop_token).Id]) | |||
| span_tokens = copy.deepcopy(tokens[start:end]) | |||
| if self.block_mask_prob > 0.0 and task == 'bert': | |||
| for sub_idx in range(len(span_tokens)): | |||
| if random.random() < self.block_mask_prob: | |||
| span_tokens[sub_idx] = self.tokenizer.get_command( | |||
| 'dBLOCK').Id | |||
| target_tokens.append(span_tokens) | |||
| targets.append(tokens[start:end]) | |||
| targets.append([self.tokenizer.get_command('eop').Id]) | |||
| if not self.sentinel_token: | |||
| target_position_id = position_ids[start:end] | |||
| target_position_ids.append(target_position_id) | |||
| target_position_ids.append([target_position_id[0]]) | |||
| else: | |||
| target_position_ids.append([self.max_seq_length] * # noqa | |||
| (end - start + 1)) | |||
| if self.block_position_encoding: | |||
| target_block_position_ids.append( | |||
| np.arange(1, end - start + 2, dtype=np.long)) | |||
| else: | |||
| target_block_position_ids.append([1] * (end - start + 1)) | |||
| block_spans.sort(key=lambda x: x[0]) | |||
| source_tokens, source_position_ids, local_spans = [], [], [] | |||
| last, current_length = 0, 0 | |||
| for start, end, idx in block_spans: | |||
| if task == 'generation': | |||
| mask_id = self.generation_mask | |||
| elif task == 'gap_sentence': | |||
| mask_id = self.gap_sentence_mask | |||
| else: | |||
| mask_token = 'MASK' if idx == 0 else f'MASK{idx}' | |||
| mask_id = self.tokenizer.get_command(mask_token).Id | |||
| local_spans.append((current_length, current_length + start - last)) | |||
| source_tokens.append(tokens[last:start]) | |||
| source_tokens.append([mask_id]) | |||
| source_position_ids.append(position_ids[last:start]) | |||
| source_position_ids.append([position_ids[start]]) | |||
| current_length += start - last + 1 | |||
| last = end | |||
| if last < len(tokens): | |||
| local_spans.append( | |||
| (current_length, current_length + len(tokens) - last)) | |||
| source_tokens.append(tokens[last:]) | |||
| source_position_ids.append(position_ids[last:]) | |||
| source_length = sum(map(len, source_tokens)) | |||
| if attention_mask is not None: | |||
| assert source_length == attention_mask | |||
| if target_tokens and self.eod_token in np.concatenate( | |||
| target_tokens).tolist(): | |||
| print('Found EOS in target', self.tokenizer.DecodeIds(tokens)) | |||
| raise RuntimeError | |||
| if self.encoder_decoder: | |||
| target_tokens = target_tokens + [ | |||
| self.tokenizer.get_command('eop').Id | |||
| ] | |||
| loss_masks = np.ones(len(target_tokens), dtype=np.long) | |||
| return source_tokens, target_tokens, loss_masks | |||
| else: | |||
| tokens = np.concatenate(source_tokens + target_tokens) | |||
| if task == 'bert' and self.context_mask_ratio > 0: | |||
| mask_candidates = set() | |||
| for start, end in local_spans: | |||
| if start != 0: | |||
| local_end = min(end, start + self.context_mask_range) | |||
| mask_candidates.update(range(start, local_end)) | |||
| if end != 0: | |||
| local_start = max(start, end - self.context_mask_range) | |||
| mask_candidates.update(range(local_start, end)) | |||
| mask_pos = rng.sample( | |||
| mask_candidates, | |||
| int(self.context_mask_ratio * text_length)) | |||
| for pos in mask_pos: | |||
| tokens[pos] = self.tokenizer.get_command('dBLOCK').Id | |||
| targets = np.concatenate(source_tokens + targets) | |||
| loss_masks = np.ones(len(tokens), dtype=np.long) | |||
| loss_masks[:source_length] = 0 | |||
| position_ids = np.concatenate(source_position_ids | |||
| + target_position_ids) | |||
| block_position_ids = np.concatenate( | |||
| [np.zeros(source_length, dtype=np.long)] | |||
| + target_block_position_ids) | |||
| position_ids = np.stack([position_ids, block_position_ids], axis=0) | |||
| if attention_mask is not None: | |||
| return tokens, targets, loss_masks, position_ids | |||
| else: | |||
| return tokens, targets, loss_masks, position_ids, source_length | |||
| def generate_blank_data(self, | |||
| sample, | |||
| masked_lengths, | |||
| attention_mask, | |||
| rng, | |||
| task='bert'): | |||
| rng.shuffle(masked_lengths) | |||
| tokens, loss_masks = sample['text'], sample['loss_mask'] | |||
| assert tokens[0] == self.tokenizer.get_command('ENC').Id | |||
| block_spans = self.sample_span_in_document(tokens, masked_lengths, rng) | |||
| if len(block_spans) < len(masked_lengths): | |||
| return None | |||
| if self.masked_lm: | |||
| data = self.make_masked_data(tokens, loss_masks, attention_mask, | |||
| block_spans, rng) | |||
| else: | |||
| data = self.make_block_data( | |||
| tokens, | |||
| loss_masks, | |||
| attention_mask, | |||
| block_spans, | |||
| rng, | |||
| task=task) | |||
| return data | |||
| def split_samples(self, samples, rng): | |||
| target_length = rng.randrange(32, self.max_seq_length - 1) | |||
| num_splits = (self.max_seq_length - 1) // target_length | |||
| new_samples = [] | |||
| cls_id = self.tokenizer.get_command('ENC').Id | |||
| eos_id = self.tokenizer.get_command('eos').Id | |||
| for sample in samples: | |||
| tokens, loss_masks = sample['text'][1:], sample['loss_mask'][1:] | |||
| for _ in range(num_splits): | |||
| if target_length >= len(tokens): | |||
| new_tokens, new_loss_masks = tokens, loss_masks | |||
| else: | |||
| random_start = rng.randrange(0, | |||
| len(tokens) - target_length) | |||
| while random_start > 0 and ( | |||
| tokens[random_start] == eos_id or # noqa | |||
| not (self.contains_sentence_end( # noqa | |||
| tokens[random_start - 1]) or # noqa | |||
| tokens[random_start - 1] == eos_id)): # noqa | |||
| random_start -= 1 | |||
| random_end = random_start + target_length | |||
| while random_end > random_start and not ( | |||
| self.contains_sentence_end(tokens[random_end - 1]) | |||
| or tokens[random_end - 1] == eos_id): | |||
| random_end -= 1 | |||
| if random_end - random_start < target_length // 2: | |||
| random_end = random_start + target_length | |||
| new_tokens, new_loss_masks = tokens[ | |||
| random_start:random_end], loss_masks[ | |||
| random_start:random_end] | |||
| new_tokens = np.concatenate(([cls_id], new_tokens)) | |||
| new_loss_masks = np.concatenate(([0], new_loss_masks)) | |||
| new_samples.append({ | |||
| 'text': new_tokens, | |||
| 'loss_mask': new_loss_masks | |||
| }) | |||
| return new_samples | |||
| def construct_blocks(self, samples): | |||
| worker_info = torch.utils.data.get_worker_info() | |||
| if worker_info is not None: | |||
| worker_id, num_workers = worker_info.id, worker_info.num_workers | |||
| else: | |||
| worker_id, num_workers = 0, 1 | |||
| rng = random.Random((self.count * num_workers + worker_id) | |||
| * self.world_size + self.rank) | |||
| self.count += 1 | |||
| token_batch, target_batch, loss_mask_batch, position_id_batch = [], [], [], [] | |||
| source_batch, target_batch = [], [] | |||
| if rng.random() < self.short_seq_prob: | |||
| samples = self.split_samples(samples, rng) | |||
| rand = rng.random() | |||
| single_span = rand < self.single_span_prob | |||
| rand = 0.0 if single_span else rng.random() | |||
| attention_mask = [] | |||
| if rand < self.bert_prob: | |||
| mode = 'bert' | |||
| for sample in samples: | |||
| if single_span: | |||
| masked_lengths = [ | |||
| rng.choices( | |||
| range(1, | |||
| len(self.block_length_distribution) + 1), | |||
| weights=self.block_length_distribution)[0] | |||
| ] | |||
| masked_count = masked_lengths[0] | |||
| else: | |||
| masked_lengths, masked_count = [], 0 | |||
| while masked_count < int( | |||
| self.bert_ratio * len(sample['text'])): | |||
| block_length = rng.choices( | |||
| range(1, | |||
| len(self.block_length_distribution) + 1), | |||
| weights=self.block_length_distribution)[0] | |||
| masked_lengths.append(block_length) | |||
| masked_count += block_length | |||
| if self.masked_lm: | |||
| sep = len(sample['text']) | |||
| else: | |||
| sep = len( | |||
| sample['text']) - masked_count + len(masked_lengths) | |||
| data = self.generate_blank_data( | |||
| sample, masked_lengths, sep, rng, task='bert') | |||
| if data is not None: | |||
| if self.encoder_decoder: | |||
| source_tokens, target_tokens, loss_masks = data | |||
| source_batch.append(source_tokens) | |||
| target_batch.append(target_tokens) | |||
| loss_mask_batch.append(loss_masks) | |||
| else: | |||
| tokens, targets, loss_masks, position_ids = data | |||
| token_batch.append(tokens) | |||
| target_batch.append(targets) | |||
| loss_mask_batch.append(loss_masks) | |||
| position_id_batch.append(position_ids) | |||
| attention_mask.append(sep) | |||
| elif rand < self.bert_prob + self.gap_sentence_prob: | |||
| mode = 'sentence' | |||
| for sample in samples: | |||
| tokens, loss_masks = sample['text'], sample['loss_mask'] | |||
| sentence_spans = [] | |||
| last_index = 1 if tokens[0] == self.tokenizer.get_command( | |||
| 'ENC').Id else 0 | |||
| for i in range(len(tokens)): | |||
| if self.contains_sentence_end(tokens[i]): | |||
| if last_index < i + 1: | |||
| sentence_spans.append((last_index, i + 1)) | |||
| last_index = i + 1 | |||
| elif tokens[i] == self.tokenizer.get_command('eos').Id: | |||
| last_index = i + 1 | |||
| if last_index < len(tokens): | |||
| sentence_spans.append((last_index, len(tokens))) | |||
| if not sentence_spans and torch.distributed.get_rank() == 0: | |||
| try: | |||
| print(self.tokenizer.DecodeIds(tokens[1:])) | |||
| except IndexError: | |||
| print(tokens[1:]) | |||
| rng.shuffle(sentence_spans) | |||
| block_spans, block_length = [], 0 | |||
| for start, end in sentence_spans: | |||
| block_spans.append((start, end)) | |||
| block_length += end - start | |||
| if block_length >= int( | |||
| self.gap_sentence_ratio * len(tokens)): | |||
| break | |||
| data = self.make_block_data( | |||
| tokens, | |||
| loss_masks, | |||
| None, | |||
| block_spans, | |||
| rng, | |||
| task='gap_sentence') | |||
| tokens, targets, loss_masks, position_ids, sep = data | |||
| token_batch.append(tokens) | |||
| target_batch.append(targets) | |||
| loss_mask_batch.append(loss_masks) | |||
| position_id_batch.append(position_ids) | |||
| attention_mask.append(sep) | |||
| else: | |||
| # start_indices = [index_in_list(sample['loss_mask'], 1) for sample in samples] | |||
| # end_indices = [rindex(sample['loss_mask'], 1) for sample in samples] | |||
| # start_index, end_index = max(start_indices), min(end_indices) - self.min_generation_length | |||
| # if end_index < start_index + 1: | |||
| # end_index = start_index + 1 | |||
| # division = rng.randrange(start_index, end_index) | |||
| mode = 'gpt' | |||
| max_generation_length = rng.randint( | |||
| int(self.gpt_min_ratio | |||
| * min(map(lambda x: len(x['text']), samples))), | |||
| max(map(lambda x: len(x['text']), samples)) - 2) | |||
| for sample in samples: | |||
| generation_length = min(max_generation_length, | |||
| len(sample['text']) - 2) | |||
| attention_mask.append( | |||
| len(sample['text']) - generation_length + 1) | |||
| multiple_doc = index_in_list( | |||
| sample['text'], | |||
| self.tokenizer.get_command('eos').Id) not in [ | |||
| -1, len(sample['text']) - 1 | |||
| ] # noqa | |||
| if multiple_doc or rng.random() < self.infill_prob: | |||
| division = len(sample['text']) - generation_length | |||
| tokens, loss_masks = sample['text'], sample['loss_mask'] | |||
| source_tokens, target_tokens = tokens[:division], tokens[ | |||
| division:] | |||
| target_masks = loss_masks[division:] | |||
| tokens = np.concatenate((source_tokens, [ | |||
| self.generation_mask, | |||
| self.tokenizer.get_command('sop').Id | |||
| ], target_tokens[:-1])) | |||
| targets = np.concatenate( | |||
| (source_tokens, [self.generation_mask], target_tokens)) | |||
| loss_masks = np.concatenate( | |||
| (np.zeros(len(source_tokens) + 1, | |||
| dtype=np.long), target_masks)) | |||
| token_batch.append(tokens) | |||
| target_batch.append(targets) | |||
| loss_mask_batch.append(loss_masks) | |||
| position_ids = np.arange( | |||
| len(source_tokens) + len(target_tokens) + 1, | |||
| dtype=np.long) | |||
| position_ids[len(source_tokens) + 1:] = len(source_tokens) | |||
| if self.block_position_encoding: | |||
| block_position_ids = np.concatenate( | |||
| (np.zeros(len(source_tokens), dtype=np.long), | |||
| np.arange(len(target_tokens) + 1, dtype=np.long))) | |||
| else: | |||
| block_position_ids = np.concatenate( | |||
| (np.zeros(len(source_tokens) + 1, dtype=np.long), | |||
| np.ones(len(target_tokens) + 1, dtype=np.long))) | |||
| position_id_batch.append( | |||
| np.stack([position_ids, block_position_ids], axis=0)) | |||
| else: | |||
| tokens, targets, loss_masks, position_ids = self.generate_blank_data( | |||
| sample, [generation_length], | |||
| attention_mask[-1], | |||
| rng, | |||
| task='generation') | |||
| token_batch.append(tokens) | |||
| target_batch.append(targets) | |||
| loss_mask_batch.append(loss_masks) | |||
| position_id_batch.append(position_ids) | |||
| if tokens is None: | |||
| print(sample, generation_length, multiple_doc) | |||
| if self.encoder_decoder: | |||
| return { | |||
| 'text': torch.tensor(source_batch, dtype=torch.long), | |||
| 'target': torch.tensor(target_batch, dtype=torch.long), | |||
| 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long) | |||
| } | |||
| else: | |||
| token_batch, target_batch, loss_mask_batch, position_id_batch = self.pad_batch( | |||
| token_batch, target_batch, loss_mask_batch, position_id_batch) | |||
| return { | |||
| 'text': torch.tensor(token_batch, dtype=torch.long), | |||
| 'target': torch.tensor(target_batch, dtype=torch.long), | |||
| 'loss_mask': torch.tensor(loss_mask_batch, dtype=torch.long), | |||
| 'position_id': | |||
| torch.tensor(position_id_batch, dtype=torch.long), | |||
| 'attention_mask': | |||
| torch.tensor(attention_mask, dtype=torch.long), | |||
| 'mode': mode | |||
| } | |||
| @staticmethod | |||
| def pad_batch(token_batch, target_batch, loss_mask_batch, | |||
| position_id_batch): | |||
| seq_lengths = list(map(len, token_batch)) | |||
| if seq_lengths.count(seq_lengths[0]) != len(seq_lengths): | |||
| max_length = max(seq_lengths) | |||
| token_batch = [ | |||
| np.concatenate( | |||
| (tokens, np.zeros(max_length - len(tokens), | |||
| dtype=np.long))) | |||
| for tokens in token_batch | |||
| ] | |||
| target_batch = [ | |||
| np.concatenate( | |||
| (targets, | |||
| np.zeros(max_length - len(targets), dtype=np.long))) | |||
| for targets in target_batch | |||
| ] | |||
| loss_mask_batch = [ | |||
| np.concatenate( | |||
| (loss_masks, | |||
| np.zeros(max_length - len(loss_masks), dtype=np.long))) | |||
| for loss_masks in loss_mask_batch | |||
| ] | |||
| position_id_batch = [ | |||
| np.concatenate((position_ids, | |||
| np.zeros( | |||
| (2, max_length - position_ids.shape[1]), | |||
| dtype=np.long)), | |||
| axis=1) for position_ids in position_id_batch | |||
| ] | |||
| return token_batch, target_batch, loss_mask_batch, position_id_batch | |||
| @@ -0,0 +1,513 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """parses arguments and preps data loader""" | |||
| import copy | |||
| import os | |||
| import random | |||
| from bisect import bisect_right | |||
| from itertools import accumulate | |||
| import numpy as np | |||
| import torch | |||
| import torch.utils.data | |||
| from . import data_utils, mpu | |||
| from .blocklm_utils import ConstructBlockStrategy | |||
| from .data_utils.tokenization import make_tokenizer | |||
| from .utils import print_rank_0 | |||
| class MultiTaskDataset(torch.utils.data.Dataset): | |||
| def __init__(self, | |||
| tasks, | |||
| datasets, | |||
| reweight=True, | |||
| temperature=0.8, | |||
| max_limit=200000): | |||
| super(MultiTaskDataset, self).__init__() | |||
| self.tasks = tasks | |||
| self.datasets = datasets | |||
| self.reweight = reweight | |||
| self.temperature = temperature | |||
| self.lens = [len(dataset) for dataset in datasets] | |||
| self.weights = np.array( | |||
| [min(length, max_limit)**temperature for length in self.lens]) | |||
| self.total_len = sum(self.lens) | |||
| self.cumulative_lens = list(accumulate(self.lens)) | |||
| if self.reweight: | |||
| print_rank_0(list(zip(self.tasks, self.lens, self.weights))) | |||
| else: | |||
| print_rank_0(list(zip(self.tasks, self.lens))) | |||
| self.weights /= self.weights.sum() | |||
| def __len__(self): | |||
| return self.total_len * 1000 | |||
| @staticmethod | |||
| def pet_wrapper(data): | |||
| text = data['text'] | |||
| loss_mask = data['logit_mask'] | |||
| target = data['target'] | |||
| attention_mask = data['mask'] | |||
| position_id = data['position'] | |||
| label = data['label'] | |||
| if len(text.shape) == 2: | |||
| text = text[label] | |||
| loss_mask = loss_mask[label] | |||
| target = target[label] | |||
| attention_mask = attention_mask[label] | |||
| position_id = position_id[label] | |||
| else: | |||
| target = target[label] | |||
| if not target.shape: | |||
| target = target.repeat(len(text)) | |||
| return { | |||
| 'text': text, | |||
| 'target': target, | |||
| 'loss_mask': loss_mask, | |||
| 'position_id': position_id, | |||
| 'attention_mask': attention_mask | |||
| } | |||
| def __getitem__(self, idx): | |||
| if self.reweight: | |||
| rng = random.Random(idx) | |||
| rng = np.random.RandomState( | |||
| seed=[rng.randint(0, 2**32 - 1) for _ in range(16)]) | |||
| dataset_idx = rng.choice( | |||
| np.arange(len(self.datasets)), p=self.weights) | |||
| dataset = self.datasets[dataset_idx] | |||
| sample_idx = rng.choice(np.arange(len(dataset))) | |||
| item = self.datasets[dataset_idx][sample_idx] | |||
| else: | |||
| dataset_idx = bisect_right(self.cumulative_lens, idx) | |||
| if dataset_idx == 0: | |||
| sample_idx = idx | |||
| else: | |||
| sample_idx = idx - self.cumulative_lens[dataset_idx - 1] | |||
| item = self.datasets[dataset_idx][sample_idx] | |||
| item = self.pet_wrapper(item) | |||
| return item | |||
| class DataConfig: | |||
| def __init__(self, defaults=None): | |||
| super(DataConfig, self).__init__() | |||
| if defaults is None: | |||
| defaults = {} | |||
| self.defaults = defaults | |||
| def apply(self, args, tokenizer): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('configuring data') | |||
| self.apply_defaults(args) | |||
| return make_loaders(args, tokenizer) | |||
| def set_defaults(self, **kwargs): | |||
| for k, v in kwargs.items(): | |||
| self.defaults[k] = v | |||
| def apply_defaults(self, args): | |||
| for k, v in self.defaults.items(): | |||
| k = k.replace('-', '_') | |||
| if not hasattr(args, k): | |||
| setattr(args, k, v) | |||
| def prepare_tokenizer(args): | |||
| add_sentinel_token = 0 | |||
| if args.sentinel_token: | |||
| add_sentinel_token = args.max_position_embeddings | |||
| tokenizer = make_tokenizer( | |||
| args.tokenizer_type, | |||
| None, | |||
| args.tokenizer_path, | |||
| args.vocab_size, | |||
| args.tokenizer_model_type, | |||
| add_block_symbols=args.block_lm, | |||
| cache_dir=args.cache_dir, | |||
| add_sentinel_token=add_sentinel_token, | |||
| add_task_mask=args.task_mask, | |||
| add_decoder_mask=args.block_mask_prob > 0.0 | |||
| or args.context_mask_ratio > 0.0) | |||
| if mpu.get_model_parallel_rank() == 0: | |||
| num_tokens = tokenizer.num_tokens | |||
| eod_token = tokenizer.get_command('eos').Id | |||
| assert eod_token == tokenizer.get_command('pad').Id | |||
| before = num_tokens | |||
| after = before | |||
| multiple = args.make_vocab_size_divisible_by | |||
| while (after % multiple) != 0: | |||
| after += 1 | |||
| print_rank_0('> padded vocab (size: {}) with {} dummy ' | |||
| 'tokens (new size: {})'.format(before, after - before, | |||
| after)) | |||
| print_rank_0('> found end-of-document token: {}'.format(eod_token)) | |||
| token_counts = torch.cuda.LongTensor([after, eod_token]) | |||
| else: | |||
| token_counts = torch.cuda.LongTensor([0, 0]) | |||
| # Broadcast num tokens. | |||
| torch.distributed.broadcast( | |||
| token_counts, | |||
| mpu.get_model_parallel_src_rank(), | |||
| group=mpu.get_model_parallel_group()) | |||
| num_tokens = token_counts[0].item() | |||
| eod_token = token_counts[1].item() | |||
| args.vocab_size, args.eod_token = num_tokens, eod_token | |||
| return tokenizer | |||
| def make_data_loader(dataset, | |||
| tokenizer, | |||
| batch_size, | |||
| num_iters, | |||
| args, | |||
| shuffle=False, | |||
| block_collate=False): | |||
| world_size = torch.distributed.get_world_size( | |||
| group=mpu.get_data_parallel_group()) | |||
| rank = torch.distributed.get_rank(group=mpu.get_data_parallel_group()) | |||
| if args.loader_scatter is not None: | |||
| rank = rank // args.loader_scatter | |||
| world_size = world_size // args.loader_scatter | |||
| batch_size = batch_size // args.loader_scatter | |||
| distributed = world_size > 1 | |||
| if args.transformer_xl: | |||
| batch_sampler = data_utils.samplers.DistributedSequentialSampler( | |||
| len(dataset), num_iters, batch_size, rank, world_size) | |||
| else: | |||
| if shuffle: | |||
| sampler = data_utils.samplers.RandomSampler( | |||
| dataset, | |||
| replacement=True, | |||
| num_samples=batch_size * args.train_iters | |||
| * args.gradient_accumulation_steps) | |||
| else: | |||
| sampler = torch.utils.data.SequentialSampler(dataset) | |||
| drop_last = distributed | |||
| # the GPUs in the same model parallel group receive the same data | |||
| if distributed: | |||
| batch_sampler = data_utils.samplers.DistributedBatchSampler( | |||
| sampler, | |||
| batch_size, | |||
| drop_last, | |||
| rank, | |||
| world_size, | |||
| gradient_accumulation_steps=args.gradient_accumulation_steps) | |||
| else: | |||
| batch_sampler = torch.utils.data.BatchSampler( | |||
| sampler, batch_size, drop_last) | |||
| collate_fn = None | |||
| if block_collate: | |||
| collate_fn = ConstructBlockStrategy( | |||
| args, | |||
| tokenizer, | |||
| args.seq_length, | |||
| bert_prob=args.bert_prob, | |||
| gap_sentence_prob=args.gap_sentence_prob, | |||
| gap_sentence_ratio=args.gap_sentence_ratio, | |||
| gpt_infill_prob=args.gpt_infill_prob, | |||
| average_block_length=args.avg_block_length, | |||
| gpt_min_ratio=args.gpt_min_ratio, | |||
| block_mask_prob=args.block_mask_prob, | |||
| context_mask_ratio=args.context_mask_ratio, | |||
| short_seq_prob=args.short_seq_prob, | |||
| single_span_prob=args.single_span_prob, | |||
| shuffle_blocks=not args.no_shuffle_block, | |||
| block_position_encoding=not args.no_block_position, | |||
| sentinel_token=args.sentinel_token, | |||
| encoder_decoder=args.encoder_decoder, | |||
| task_mask=args.task_mask, | |||
| random_position=args.random_position, | |||
| masked_lm=args.masked_lm).construct_blocks | |||
| data_loader = torch.utils.data.DataLoader( | |||
| dataset, | |||
| batch_sampler=batch_sampler, | |||
| num_workers=args.num_workers, | |||
| pin_memory=True, | |||
| collate_fn=collate_fn) | |||
| return data_loader | |||
| def make_tfrecord_loaders(args): | |||
| """Load train/val/test dataset from shuffled TFRecords""" | |||
| import data_utils.tf_dl | |||
| data_set_args = { | |||
| 'batch_size': args.batch_size, | |||
| 'max_seq_len': args.seq_length, | |||
| 'max_preds_per_seq': args.max_preds_per_seq, | |||
| 'train': True, | |||
| 'num_workers': max(args.num_workers, 1), | |||
| 'seed': args.seed + args.rank + 1, | |||
| 'threaded_dl': args.num_workers > 0 | |||
| } | |||
| train = data_utils.tf_dl.TFRecordDataLoader(args.train_data, | |||
| **data_set_args) | |||
| data_set_args['train'] = False | |||
| if args.eval_seq_length is not None: | |||
| data_set_args['max_seq_len'] = args.eval_seq_length | |||
| if args.eval_max_preds_per_seq is not None: | |||
| data_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq | |||
| valid = None | |||
| if args.valid_data is not None: | |||
| valid = data_utils.tf_dl.TFRecordDataLoader(args.valid_data, | |||
| **data_set_args) | |||
| test = None | |||
| if args.test_data is not None: | |||
| test = data_utils.tf_dl.TFRecordDataLoader(args.test_data, | |||
| **data_set_args) | |||
| tokenizer = data_utils.make_tokenizer( | |||
| args.tokenizer_type, | |||
| train, | |||
| args.tokenizer_path, | |||
| args.vocab_size, | |||
| args.tokenizer_model_type, | |||
| cache_dir=args.cache_dir) | |||
| return (train, valid, test), tokenizer | |||
| def make_loaders(args, tokenizer): | |||
| """makes training/val/test""" | |||
| if args.use_tfrecords: | |||
| return make_tfrecord_loaders(args) | |||
| world_size = torch.distributed.get_world_size( | |||
| group=mpu.get_data_parallel_group()) | |||
| if args.loader_scatter is not None: | |||
| assert world_size % args.loader_scatter == 0 | |||
| batch_size = args.batch_size * world_size | |||
| eval_batch_size = batch_size | |||
| if args.eval_batch_size is not None: | |||
| eval_batch_size = args.eval_batch_size * world_size | |||
| seq_length = args.seq_length | |||
| if seq_length < 0: | |||
| seq_length = seq_length * world_size | |||
| eval_seq_length = args.eval_seq_length | |||
| if eval_seq_length is not None and eval_seq_length < 0: | |||
| eval_seq_length = eval_seq_length * world_size | |||
| split = get_split(args) | |||
| data_set_args = { | |||
| 'path': args.train_data, | |||
| 'seq_length': seq_length, | |||
| 'mem_length': args.mem_length, | |||
| 'delim': args.delim, | |||
| 'text_key': args.text_key, | |||
| 'label_key': 'label', | |||
| 'ds_type': args.data_set_type, | |||
| 'split': split, | |||
| 'loose': args.loose_json, | |||
| 'max_preds_per_seq': args.max_preds_per_seq, | |||
| 'presplit_sentences': args.presplit_sentences, | |||
| 'sample_one_document': args.sample_one_document, | |||
| 'filter_english': args.filter_english, | |||
| 'pre_tokenize': not args.no_pre_tokenize, | |||
| 'tokenizer': tokenizer, | |||
| 'save_splits': args.save_splits, | |||
| 'load_splits': args.load_splits, | |||
| 'save_test_data': args.save_test_data, | |||
| 'no_lazy_loader': args.no_lazy_loader, | |||
| 'loader_scatter': args.loader_scatter, | |||
| 'data_parallel_rank': mpu.get_data_parallel_rank(), | |||
| 'non_sentence_start': args.non_sentence_start, | |||
| 'half_lazy_loader': args.half_lazy_loader | |||
| } | |||
| eval_set_args = copy.copy(data_set_args) | |||
| eval_set_args['split'] = [1.] | |||
| # if optional eval args were set then replace their | |||
| # equivalent values in the arg dict | |||
| if eval_seq_length: | |||
| eval_set_args['seq_length'] = eval_seq_length | |||
| if args.eval_max_preds_per_seq: | |||
| eval_set_args['max_preds_per_seq'] = args.eval_max_preds_per_seq | |||
| if args.eval_text_key is not None: | |||
| eval_set_args['text_key'] = args.eval_text_key | |||
| # make datasets splits and tokenizer | |||
| train, valid, test = None, None, None | |||
| if args.train_data is not None: | |||
| train = data_utils.make_dataset(**data_set_args) | |||
| if data_utils.should_split(split): | |||
| train, valid, test = train | |||
| eval_set_args['tokenizer'] = tokenizer | |||
| # make training and val dataset if necessary | |||
| if valid is None and args.valid_data is not None: | |||
| eval_set_args['path'] = args.valid_data | |||
| valid = data_utils.make_dataset(**eval_set_args) | |||
| eval_set_args['tokenizer'] = tokenizer | |||
| if test is None and args.test_data is not None: | |||
| eval_set_args['path'] = args.test_data | |||
| test = data_utils.make_dataset(**eval_set_args) | |||
| # wrap datasets with data loader | |||
| use_block = args.block_lm or args.encoder_decoder | |||
| if train is not None and args.batch_size > 0: | |||
| train = make_data_loader( | |||
| train, | |||
| tokenizer, | |||
| batch_size, | |||
| args.train_iters, | |||
| args, | |||
| shuffle=args.shuffle, | |||
| block_collate=use_block) | |||
| args.do_train = True | |||
| else: | |||
| args.do_train = False | |||
| eval_batch_size = eval_batch_size if eval_batch_size != 0 else batch_size | |||
| if valid is not None: | |||
| valid = make_data_loader( | |||
| valid, | |||
| tokenizer, | |||
| eval_batch_size, | |||
| args.train_iters, | |||
| args, | |||
| shuffle=args.shuffle, | |||
| block_collate=use_block) | |||
| args.do_valid = True | |||
| else: | |||
| args.do_valid = False | |||
| if test is not None: | |||
| test = make_data_loader( | |||
| test, | |||
| tokenizer, | |||
| eval_batch_size, | |||
| len(test) // eval_batch_size + 1, | |||
| args, | |||
| shuffle=args.shuffle, | |||
| block_collate=use_block) | |||
| args.do_test = True | |||
| else: | |||
| args.do_test = False | |||
| return train, valid, test | |||
| def build_multi_task_dataset(args, tokenizer): | |||
| task_dirs = { | |||
| 'mnli': 'MNLI', | |||
| 'cola': 'CoLA', | |||
| 'mrpc': 'MRPC', | |||
| 'qnli': 'QNLI', | |||
| 'qqp': 'QQP', | |||
| 'sst2': 'SST-2', | |||
| 'agnews': 'Agnews', | |||
| 'yelp-polarity': 'yelp_review_polarity_csv', | |||
| 'yelp-full': 'yelp_review_full_csv', | |||
| 'yahoo': 'Yahoo', | |||
| 'squad': 'SQuAD', | |||
| 'race': 'RACE' | |||
| } | |||
| train, valid = None, None | |||
| if mpu.get_model_parallel_rank() == 0: | |||
| multi_seq_length = args.seq_length | |||
| if args.multi_seq_length is not None: | |||
| multi_seq_length = args.multi_seq_length | |||
| train_datasets, valid_datasets = [], [] | |||
| for task in args.multi_task_data: | |||
| task = task.lower() | |||
| data_dir = os.path.join(args.data_dir, task_dirs[task]) | |||
| train_datasets.append( | |||
| SuperGlueDataset( | |||
| args, | |||
| task, | |||
| data_dir, | |||
| multi_seq_length, | |||
| 'train', | |||
| tokenizer, | |||
| pattern_ensemble=True)) | |||
| valid_datasets.append( | |||
| SuperGlueDataset( | |||
| args, | |||
| task, | |||
| data_dir, | |||
| multi_seq_length, | |||
| 'dev', | |||
| tokenizer, | |||
| pattern_ensemble=True)) | |||
| train = MultiTaskDataset(args.multi_task_data, train_datasets) | |||
| valid = MultiTaskDataset(args.multi_task_data, valid_datasets) | |||
| world_size = torch.distributed.get_world_size( | |||
| group=mpu.get_data_parallel_group()) | |||
| multi_batch_size = args.batch_size * world_size | |||
| if args.multi_batch_size is not None: | |||
| multi_batch_size = args.multi_batch_size * world_size | |||
| train = make_data_loader( | |||
| train, | |||
| tokenizer, | |||
| multi_batch_size, | |||
| args.train_iters, | |||
| args, | |||
| shuffle=True) | |||
| valid = make_data_loader( | |||
| valid, | |||
| tokenizer, | |||
| multi_batch_size, | |||
| args.train_iters, | |||
| args, | |||
| shuffle=True) | |||
| return train, valid | |||
| def get_split(args): | |||
| """ | |||
| Get dataset splits from comma separated string list | |||
| """ | |||
| splits = [] | |||
| if args.split.find(',') != -1: | |||
| splits = [float(s) for s in args.split.split(',')] | |||
| elif args.split.find('/') != -1: | |||
| splits = [float(s) for s in args.split.split('/')] | |||
| else: | |||
| splits = [float(args.split)] | |||
| split_total = sum(splits) | |||
| if split_total < 1.: | |||
| splits.append(1 - split_total) | |||
| while len(splits) < 3: | |||
| splits.append(0.) | |||
| splits = splits[:3] | |||
| if args.valid_data is not None: | |||
| splits[1] = 0. | |||
| if args.test_data is not None: | |||
| splits[2] = 0. | |||
| final_sum = sum(splits) | |||
| return [s / final_sum for s in splits] | |||
| def configure_data(): | |||
| """add cmdline flags for configuring datasets""" | |||
| # These are options that are used by data_utils, but are either | |||
| # deprecated or not meant to be exposed to the command line user. | |||
| # These options are intneded to be set in code by specific scripts. | |||
| defaults = { | |||
| 'world_size': 1, | |||
| 'rank': -1, | |||
| 'persist_state': 0, | |||
| 'lazy': False, | |||
| 'transpose': False, | |||
| 'data_set_type': 'supervised', | |||
| 'seq_length': 256, | |||
| 'eval_seq_length': 256, | |||
| 'samples_per_shard': 100 | |||
| } | |||
| return DataConfig(defaults=defaults) | |||
| @@ -0,0 +1,341 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """utils for creating datasets""" | |||
| import math | |||
| import os | |||
| import random | |||
| import time | |||
| import torch | |||
| from . import corpora | |||
| from .datasets import (BertSentencepairDataset, BlockDataset, ConcatDataset, | |||
| GPT2Dataset, ShuffleDataset, SplitDataset, XLDataset, | |||
| split_ds) | |||
| from .lazy_loader import (LazyLoader, LazyWriter, exists_lazy, exists_scatter, | |||
| get_scatter_path) | |||
| from .samplers import DistributedBatchSampler | |||
| from .tokenization import (BertWordPieceTokenizer, CharacterLevelTokenizer, | |||
| CommandToken, GPT2BPETokenizer, Tokenization, | |||
| Tokenizer, make_tokenizer) | |||
| TRAIN_DATA = 0 | |||
| VAL_DATA = 1 | |||
| TEST_DATA = 2 | |||
| def should_split(split): | |||
| """ | |||
| given split proportions checks if should split | |||
| Examples: | |||
| >>> should_split([10,0,0]) | |||
| False | |||
| >>> should_split([1,.1,.2]) | |||
| True | |||
| """ | |||
| return max(split) / sum(split) != 1. | |||
| def get_ext(path): | |||
| """gets path extension""" | |||
| return os.path.splitext(path)[1] | |||
| def get_dataset(name, | |||
| tokenizer, | |||
| pre_tokenize, | |||
| data_parallel_rank, | |||
| loader_scatter=None, | |||
| no_lazy_loader=False, | |||
| half_lazy_loader=False): | |||
| """gets dataset object based on keyword args and file at `path`""" | |||
| global_rank = torch.distributed.get_rank() | |||
| if not supported_corpus(name): | |||
| raise NotImplementedError('dataset %s is not supported' % name) | |||
| dataset = corpora.NAMED_CORPORA[name] | |||
| path = dataset.PATH | |||
| if issubclass(dataset, corpora.PromptReader): | |||
| if not (exists_lazy(path, data_type='prompt') | |||
| and exists_lazy(path, data_type='text')) and not ( | |||
| loader_scatter is not None and exists_scatter( | |||
| path, data_type='prompt', scatter_num=loader_scatter) | |||
| and exists_scatter( | |||
| path, data_type='text', scatter_num=loader_scatter)): | |||
| # create cached version of dataset for lazy loading if it doesn't exist | |||
| if global_rank == 0: | |||
| print(f'Creating lazy loader for dataset {name}') | |||
| prompt_writer = LazyWriter( | |||
| path, data_type='prompt', is_array=pre_tokenize) | |||
| text_writer = LazyWriter( | |||
| path, data_type='text', is_array=pre_tokenize) | |||
| writers = {'prompt': prompt_writer, 'text': text_writer} | |||
| reader = dataset( | |||
| writers=writers, | |||
| tokenizer=tokenizer, | |||
| tokenize=pre_tokenize) | |||
| reader.process() | |||
| prompt_writer.close() | |||
| text_writer.close() | |||
| else: | |||
| while not os.path.exists( | |||
| LazyWriter.get_len_path(path, data_type='prompt')): | |||
| time.sleep(1) | |||
| map_fn = (lambda x: x.tolist()) if pre_tokenize else None | |||
| if loader_scatter is not None: | |||
| if not (exists_scatter( | |||
| path, data_type='prompt', scatter_num=loader_scatter) | |||
| and exists_scatter( | |||
| path, data_type='text', scatter_num=loader_scatter)): | |||
| if global_rank == 0: | |||
| print(f'Creating scatter loader for dataset {name}') | |||
| prompts = LazyLoader( | |||
| path, | |||
| data_type='prompt', | |||
| map_fn=map_fn, | |||
| mem_map=True, | |||
| is_array=pre_tokenize) | |||
| texts = LazyLoader( | |||
| path, | |||
| data_type='text', | |||
| map_fn=map_fn, | |||
| mem_map=True, | |||
| is_array=pre_tokenize) | |||
| indices = list(range(len(texts))) | |||
| random.shuffle(indices) | |||
| segment_length = (len(indices) - 1) // loader_scatter + 1 | |||
| for i in range(loader_scatter): | |||
| scatter_path = get_scatter_path(path, scatter_rank=i) | |||
| prompt_writer = LazyWriter( | |||
| scatter_path, | |||
| data_type='prompt', | |||
| is_array=pre_tokenize) | |||
| text_writer = LazyWriter( | |||
| scatter_path, | |||
| data_type='text', | |||
| is_array=pre_tokenize) | |||
| for idx in indices[i * segment_length:(i + 1) | |||
| * segment_length]: | |||
| prompt_writer.write(prompts[idx]) | |||
| text_writer.write(texts[idx]) | |||
| prompt_writer.close() | |||
| text_writer.close() | |||
| else: | |||
| while not (exists_scatter( | |||
| path, data_type='prompt', | |||
| scatter_num=loader_scatter) and exists_scatter( | |||
| path, | |||
| data_type='text', | |||
| scatter_num=loader_scatter)): | |||
| time.sleep(1) | |||
| scatter_path = get_scatter_path( | |||
| path, scatter_rank=data_parallel_rank % loader_scatter) | |||
| print(f'Rank {global_rank} is using scatter from {scatter_path}') | |||
| prompts = LazyLoader( | |||
| scatter_path, | |||
| data_type='prompt', | |||
| map_fn=map_fn, | |||
| mem_map=True, | |||
| is_array=pre_tokenize, | |||
| load_memory=no_lazy_loader, | |||
| half_load=half_lazy_loader) | |||
| texts = LazyLoader( | |||
| scatter_path, | |||
| data_type='text', | |||
| map_fn=map_fn, | |||
| mem_map=True, | |||
| is_array=pre_tokenize, | |||
| load_memory=no_lazy_loader, | |||
| half_load=half_lazy_loader) | |||
| else: | |||
| prompts = LazyLoader( | |||
| path, | |||
| data_type='prompt', | |||
| map_fn=map_fn, | |||
| mem_map=True, | |||
| is_array=pre_tokenize, | |||
| load_memory=no_lazy_loader, | |||
| half_load=half_lazy_loader) | |||
| texts = LazyLoader( | |||
| path, | |||
| data_type='text', | |||
| map_fn=map_fn, | |||
| mem_map=True, | |||
| is_array=pre_tokenize, | |||
| load_memory=no_lazy_loader, | |||
| half_load=half_lazy_loader) | |||
| text = corpora.PromptDataset( | |||
| prompt_loader=prompts, | |||
| text_loader=texts, | |||
| tokenizer=tokenizer, | |||
| to_tokenize=not pre_tokenize) | |||
| if loader_scatter is None: | |||
| if global_rank == 0: | |||
| print(f'Create dataset {name} with {len(text)} documents') | |||
| for i in range(10): | |||
| rand_id = i if i < 5 else random.randrange(len(text)) | |||
| sample_tokens = text[rand_id]['tokens'][:1024] | |||
| print(sample_tokens) | |||
| print(tokenizer.DecodeIds(sample_tokens).encode('utf-8')) | |||
| else: | |||
| for scatter_id in range(loader_scatter): | |||
| if data_parallel_rank % loader_scatter == scatter_id and data_parallel_rank // loader_scatter == 0: | |||
| print( | |||
| f'Create dataset {name} at scatter {scatter_id} with {len(text)} documents' | |||
| ) | |||
| for i in range(10): | |||
| sample_tokens = text[i]['tokens'][:1024] | |||
| print(sample_tokens) | |||
| print(tokenizer.DecodeIds(sample_tokens)) | |||
| torch.distributed.barrier() | |||
| return text | |||
| elif issubclass(dataset, corpora.KeyReader): | |||
| if not (exists_lazy(path, data_type='text') | |||
| and exists_lazy(path, data_type='mask')): | |||
| # create cached version of dataset for lazy loading if it doesn't exist | |||
| if global_rank == 0: | |||
| text_writer = LazyWriter( | |||
| path, data_type='text', is_array=pre_tokenize) | |||
| mask_writer = LazyWriter(path, data_type='mask', is_array=True) | |||
| writers = {'mask': mask_writer, 'text': text_writer} | |||
| dataset( | |||
| writers=writers, | |||
| tokenizer=tokenizer, | |||
| tokenize=pre_tokenize) | |||
| mask_writer.close() | |||
| text_writer.close() | |||
| else: | |||
| while not os.path.exists( | |||
| LazyWriter.get_len_path(path, data_type='mask')): | |||
| time.sleep(1) | |||
| map_fn = (lambda x: x.tolist()) if pre_tokenize else None | |||
| masks = LazyLoader( | |||
| path, data_type='mask', map_fn=map_fn, mem_map=True, is_array=True) | |||
| texts = LazyLoader( | |||
| path, | |||
| data_type='text', | |||
| map_fn=map_fn, | |||
| mem_map=True, | |||
| is_array=pre_tokenize) | |||
| text = corpora.KeyDataset( | |||
| mask_loader=masks, | |||
| text_loader=texts, | |||
| tokenizer=tokenizer, | |||
| to_tokenize=not pre_tokenize) | |||
| return text | |||
| def supported_corpus(corpus_name): | |||
| """checks if corpus name is defined in `corpora.py`""" | |||
| return corpus_name in corpora.NAMED_CORPORA | |||
| def make_dataset(path, | |||
| seq_length, | |||
| mem_length, | |||
| shuffle=True, | |||
| split=None, | |||
| tokenizer=None, | |||
| sample_one_document=False, | |||
| pre_tokenize=False, | |||
| ds_type='', | |||
| save_splits=None, | |||
| load_splits=None, | |||
| save_test_data=None, | |||
| no_lazy_loader=False, | |||
| loader_scatter=None, | |||
| data_parallel_rank=None, | |||
| filter_english=False, | |||
| non_sentence_start=0.0, | |||
| half_lazy_loader=False, | |||
| **kwargs): | |||
| """function to create datasets+tokenizers for common options""" | |||
| if split is None: | |||
| split = [1.] | |||
| # get one or multiple datasets and concatenate | |||
| if isinstance(path, str): | |||
| ds = get_dataset( | |||
| path, | |||
| tokenizer=tokenizer, | |||
| pre_tokenize=pre_tokenize, | |||
| no_lazy_loader=no_lazy_loader, | |||
| loader_scatter=loader_scatter, | |||
| data_parallel_rank=data_parallel_rank, | |||
| half_lazy_loader=half_lazy_loader) | |||
| else: | |||
| ds = [ | |||
| get_dataset( | |||
| p, | |||
| tokenizer=tokenizer, | |||
| pre_tokenize=pre_tokenize, | |||
| no_lazy_loader=no_lazy_loader, | |||
| loader_scatter=loader_scatter, | |||
| data_parallel_rank=data_parallel_rank, | |||
| half_lazy_loader=half_lazy_loader) for p in path | |||
| ] | |||
| ds = ConcatDataset(ds) | |||
| # Split dataset into train/val/test (and wrap bert dataset) | |||
| def wrap_dataset(dataset): | |||
| if ds_type.lower() == 'bert': | |||
| presplit_sentences = kwargs[ | |||
| 'presplit_sentences'] if 'presplit_sentences' in kwargs else False | |||
| dataset = BertSentencepairDataset( | |||
| dataset, | |||
| max_seq_len=seq_length, | |||
| presplit_sentences=presplit_sentences) | |||
| elif ds_type.lower() == 'gpt-xl': | |||
| assert pre_tokenize | |||
| dataset = XLDataset( | |||
| dataset, | |||
| tokenizer, | |||
| max_seq_len=seq_length, | |||
| mem_len=mem_length, | |||
| sample_across_doc=not sample_one_document) | |||
| elif ds_type.lower() == 'gpt2': | |||
| dataset = GPT2Dataset( | |||
| dataset, | |||
| tokenizer, | |||
| max_seq_len=seq_length, | |||
| sample_across_doc=not sample_one_document) | |||
| elif ds_type.lower() == 'block': | |||
| dataset = BlockDataset( | |||
| dataset, | |||
| tokenizer, | |||
| max_seq_len=seq_length, | |||
| sample_across_doc=not sample_one_document, | |||
| filter_english=filter_english, | |||
| non_sentence_start=non_sentence_start) | |||
| return dataset | |||
| if should_split(split): | |||
| ds = split_ds( | |||
| ds, | |||
| split, | |||
| shuffle=shuffle, | |||
| save_splits=save_splits, | |||
| load_splits=load_splits) | |||
| if save_test_data is not None and torch.distributed.get_rank() == 0: | |||
| test_ds = ds[-1] | |||
| with open(save_test_data, 'w', encoding='utf-8') as output: | |||
| for data in test_ds: | |||
| text = data['tokens'] | |||
| text = tokenizer.DecodeIds(text) | |||
| output.write(text) | |||
| output.write('\n') | |||
| print(f'Write test data to {save_test_data}') | |||
| ds = [wrap_dataset(d) if d is not None else None for d in ds] | |||
| else: | |||
| ds = wrap_dataset(ds) | |||
| return ds | |||
| @@ -0,0 +1,583 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """several datasets with preset arguments""" | |||
| import os | |||
| import random | |||
| from collections import defaultdict | |||
| from multiprocessing import Process, Queue | |||
| from queue import Empty | |||
| import json | |||
| import tqdm | |||
| from torch.utils import data | |||
| from modelscope.models.nlp.mglm.utils import print_rank_0 | |||
| from .datasets import csv_dataset, json_dataset | |||
| from .lazy_loader import LazyLoader | |||
| NUM_PROCESSES = 100 | |||
| def punctuation_standardization(string: str): | |||
| punctuation_dict = { | |||
| '\u201c': "\"", | |||
| '\u201d': "\"", | |||
| '\u2019': "'", | |||
| '\u2018': "'", | |||
| '\u2013': '-' | |||
| } | |||
| for key, value in punctuation_dict.items(): | |||
| string = string.replace(key, value) | |||
| return string | |||
| class KeyDataset(data.Dataset): | |||
| def __init__(self, text_loader, mask_loader, **kwargs): | |||
| self.texts = text_loader | |||
| self.masks = mask_loader | |||
| self.is_lazy = False | |||
| if isinstance(self.texts, LazyLoader) and isinstance( | |||
| self.masks, LazyLoader): | |||
| self.text_lens = self.texts.lens | |||
| self.is_lazy = True | |||
| def get_text_len(self, idx): | |||
| return self.text_lens[idx] | |||
| def __getitem__(self, index): | |||
| text = self.texts[index] | |||
| mask_length = self.masks[index] | |||
| mask = [] | |||
| for i, length in enumerate(mask_length): | |||
| if i % 2 == 0: | |||
| mask += [0] * length | |||
| else: | |||
| mask += [1] * length | |||
| assert len(text) == len(mask) | |||
| return {'tokens': text, 'loss_masks': mask} | |||
| def __len__(self): | |||
| return len(self.texts) | |||
| class PromptDataset(data.Dataset): | |||
| def __init__(self, | |||
| prompt_loader, | |||
| text_loader, | |||
| tokenizer=None, | |||
| to_tokenize=False, | |||
| **kwargs): | |||
| self.prompts = prompt_loader | |||
| self.texts = text_loader | |||
| self.tokenizer = tokenizer | |||
| self.to_tokenize = to_tokenize | |||
| if isinstance(self.prompts, LazyLoader) and isinstance( | |||
| self.texts, LazyLoader): | |||
| self.prompt_lens = self.prompts.lens | |||
| self.text_lens = self.texts.lens | |||
| self.is_lazy = True | |||
| def get_text_len(self, idx): | |||
| return self.prompt_lens[idx] + self.text_lens[idx] | |||
| def __getitem__(self, index): | |||
| prompt = self.prompts[index] | |||
| text = self.texts[index] | |||
| if self.to_tokenize: | |||
| prompt = self.tokenizer.EncodeAsIds(prompt).tokenization | |||
| text = self.tokenizer.EncodeAsIds(text).tokenization | |||
| return { | |||
| 'tokens': prompt + text, | |||
| 'loss_masks': [0] * len(prompt) + [1] * len(text) | |||
| } | |||
| def __len__(self): | |||
| return len(self.prompts) | |||
| class DataReader: | |||
| PATH = None | |||
| assert_str = None | |||
| reserve_punct = False | |||
| split_row = True | |||
| TASK_QUEUE_LIMIT = 10000000 | |||
| DONE_QUEUE_LIMIT = 10000000 | |||
| def tokenize_worker(self, input, output, info, tokenizer, tokenize): | |||
| raise NotImplementedError | |||
| def print_info(self, info): | |||
| pass | |||
| def __init__(self, writers, tokenizer=None, tokenize=False, **kwargs): | |||
| print(self.PATH) | |||
| print(self.assert_str) | |||
| assert os.path.exists(self.PATH), self.assert_str | |||
| print_rank_0(f'Creating dataset from {self.PATH}') | |||
| self.tokenizer = tokenizer | |||
| self.tokenize = tokenize | |||
| self.writers = writers | |||
| def process(self): | |||
| if os.path.isdir(self.PATH): | |||
| paths = [ | |||
| os.path.join(top, name) for top, _, names in os.walk(self.PATH) | |||
| for name in names | |||
| ] | |||
| # paths = [entry.path for entry in os.scandir(self.PATH) if | |||
| # not entry.is_dir() and not entry.name.endswith("bz2")] | |||
| else: | |||
| paths = [self.PATH] | |||
| task_queue, done_queue, info_queue = Queue( | |||
| maxsize=self.TASK_QUEUE_LIMIT), Queue( | |||
| maxsize=self.DONE_QUEUE_LIMIT), Queue() | |||
| processes = [] | |||
| for i in range(NUM_PROCESSES): | |||
| process = Process( | |||
| target=self.tokenize_worker, | |||
| args=(task_queue, done_queue, info_queue, self.tokenizer, | |||
| self.tokenize)) | |||
| process.start() | |||
| processes.append(process) | |||
| def read_input_to_queue(): | |||
| for path in paths: | |||
| print_rank_0(f'Start reading {path}') | |||
| with open(path) as file: | |||
| items = json.load(file) | |||
| for item in items: | |||
| task_queue.put(item) | |||
| # if self.split_row: | |||
| # for row in file: | |||
| # task_queue.put(row) | |||
| # else: | |||
| # items = json.load(file) | |||
| # for item in items["RECORDS"]: | |||
| # task_queue.put(item) | |||
| print_rank_0('Read input complete') | |||
| for i in range(len(processes)): | |||
| task_queue.put('STOP') | |||
| process = Process(target=read_input_to_queue) | |||
| process.start() | |||
| count = len(processes) | |||
| progress_bar = tqdm.tqdm() | |||
| while True: | |||
| data = done_queue.get() | |||
| if data == 'COMPLETE': | |||
| count -= 1 | |||
| if count == 0: | |||
| break | |||
| else: | |||
| self.write_result(data, self.writers) | |||
| progress_bar.update() | |||
| progress_bar.close() | |||
| self.print_info(info_queue) | |||
| @staticmethod | |||
| def write_result(data, writers): | |||
| raise NotImplementedError | |||
| @staticmethod | |||
| def get_token_count(contents): | |||
| return sum(map(len, contents)) | |||
| @classmethod | |||
| def process_sample(cls, text, tokenizer, tokenize): | |||
| if isinstance(text, str) and tokenize: | |||
| if not cls.reserve_punct: | |||
| text = punctuation_standardization(text) | |||
| text = tokenizer.EncodeAsIds(text).tokenization if text else [] | |||
| return text | |||
| @staticmethod | |||
| def trim_field(content, max_length): | |||
| if len(content) > max_length: | |||
| content = content[:max_length] | |||
| content += '......' | |||
| return content | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| raise NotImplementedError | |||
| class PromptReader(DataReader): | |||
| is_json = True | |||
| def tokenize_worker(self, input, output, info, tokenizer, tokenize): | |||
| for row in iter(input.get, 'STOP'): | |||
| if row: | |||
| if self.is_json: | |||
| row = row.rstrip() | |||
| row = json.loads(row) | |||
| prompts, texts = self.process_line(row, tokenizer, tokenize) | |||
| for prompt, text in zip(prompts, texts): | |||
| output.put((prompt, text)) | |||
| output.put('COMPLETE') | |||
| @staticmethod | |||
| def write_result(data, writers): | |||
| prompt, text = data | |||
| writers['prompt'].write(prompt) | |||
| writers['text'].write(text) | |||
| class KeyReader(DataReader): | |||
| PATH = '/root/data/wikipedia/wiki-key.txt' | |||
| assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| keys, contents = data['key'], data['content'] | |||
| assert len(keys) == len(contents) | |||
| for i in range(1, len(keys)): | |||
| keys[i] = ' ' + keys[i] | |||
| contents = [' ' + content for content in contents] | |||
| keys = [tokenizer.EncodeAsIds(key).tokenization for key in keys] | |||
| contents = [ | |||
| tokenizer.EncodeAsIds(content).tokenization for content in contents | |||
| ] | |||
| summary = sum(keys, []) | |||
| summary_prefix = self.process_sample('Summary: ', tokenizer, tokenize) | |||
| summary_mask = [len(summary_prefix), len(summary)] | |||
| summary = summary_prefix + summary | |||
| text, text_mask = [], [] | |||
| for key, content in zip(keys, contents): | |||
| content = content + [tokenizer.get_command('eop').Id] | |||
| text += key | |||
| text += content | |||
| text_mask.append(len(key)) | |||
| text_mask.append(len(content)) | |||
| return (summary, summary_mask), (text, text_mask) | |||
| def tokenize_worker(self, input, output, info, tokenizer, tokenize): | |||
| for row in iter(input.get, 'STOP'): | |||
| data = json.loads(row) | |||
| summary, content = self.process_line(data, tokenizer, tokenize) | |||
| output.put((summary, content)) | |||
| output.put('COMPLETE') | |||
| @staticmethod | |||
| def write_result(data, writers): | |||
| summary, content = data | |||
| writers['text'].write(summary[0]) | |||
| writers['mask'].write(summary[1]) | |||
| writers['text'].write(content[0]) | |||
| writers['mask'].write(content[1]) | |||
| class zhihu(PromptReader): | |||
| PATH = '/dataset/fd5061f6/data/tokenize_data/zhihu.lazy' | |||
| reserve_punct = True | |||
| assert_str = 'make sure to set PATH for zhihu data_utils/corpora.py' | |||
| qtitle_prefix = '问题:' | |||
| qcontent_prefix = '问题描述:' | |||
| user_prefix = '回答用户:' | |||
| answer_prefix = ' 回答:' | |||
| # qtitle_prefix = [] | |||
| # qcontent_prefix = [] | |||
| # user_prefix = [] | |||
| # answer_prefix = [] | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| prompts, texts = [], [] | |||
| ans_length = len(data.get('ans-content', '')) | |||
| ans_up = data.get('ans-up-num', '') | |||
| ans_up = int(ans_up) if ans_up else 0 | |||
| if ans_length > 100 or ans_up > 1000: | |||
| qtitle = data['q_title'] | |||
| qcontent = data['q-content'] | |||
| if qcontent is None: | |||
| qcontent = '' | |||
| qcontent = self.trim_field(qcontent, max_length=100) | |||
| user = data.get('user-signature', '') | |||
| prompt = self.qtitle_prefix + qtitle + self.qcontent_prefix + qcontent + self.user_prefix + user + self.answer_prefix # noqa | |||
| text = data['ans-content'] | |||
| prompt, text = self.process_sample(prompt, tokenizer, | |||
| tokenize), self.process_sample( | |||
| text, tokenizer, tokenize) | |||
| prompts.append(prompt) | |||
| texts.append(text) | |||
| # prompt = data["q_title"] + data["q-content"] + data["user-signature"] | |||
| # text = data["ans-content"] | |||
| # prompts.append(prompt) | |||
| # texts.append(text) | |||
| return prompts, texts | |||
| class zhidao(PromptReader): | |||
| PATH = '/root/data/zhidao/zhidao' | |||
| reserve_punct = True | |||
| assert_str = 'make sure to set PATH for zhidao data_utils/corpora.py' | |||
| qtitle_prefix = '问题:' | |||
| qcontent_prefix = '问题描述:' | |||
| answer_prefix = '回答:' | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| if 'title' not in data: | |||
| return [], [] | |||
| prompts, texts = [], [] | |||
| qtitle = data['title'] | |||
| qcontent = data.get('content', '') | |||
| qcontent = self.trim_field(qcontent, max_length=100) | |||
| prompt = self.qtitle_prefix + qtitle + self.qcontent_prefix + qcontent + self.answer_prefix | |||
| prompt = self.process_sample(prompt, tokenizer, tokenize) | |||
| if 'best_answer' in data: | |||
| text = data['best_answer']['content'] | |||
| if len(text) > 10: | |||
| text = self.process_sample(text, tokenizer, tokenize) | |||
| prompts.append(prompt) | |||
| texts.append(text) | |||
| for answer in data.get('other_answers', []): | |||
| text = answer['content'] | |||
| if len(text) > 100: | |||
| text = self.process_sample(text, tokenizer, tokenize) | |||
| prompts.append(prompt) | |||
| texts.append(text) | |||
| return prompts, texts | |||
| class baike(PromptReader): | |||
| PATH = '/dataset/fd5061f6/data/tokenize_data/baike.lazy' | |||
| reserve_punct = True | |||
| assert_str = 'make sure to set PATH for baike data_utils/corpora.py' | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| prompts, texts = [], [] | |||
| text = data.get('title', '') + data.get('abstract', '') + data.get( | |||
| 'content', '') | |||
| if text: | |||
| p, t = self.process_sample('', tokenizer, | |||
| tokenize), self.process_sample( | |||
| text, tokenizer, tokenize) | |||
| prompts.append(p) | |||
| texts.append(t) | |||
| return prompts, texts | |||
| class wikipedia(PromptReader): | |||
| """ | |||
| dataset for wikipedia with arguments configured for convenience | |||
| command line usage: `--train-data wikipedia` | |||
| """ | |||
| # PATH = '/dataset/data/wiki.txt' | |||
| PATH = '/root/data/bert_data/wiki.txt' | |||
| assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| text = data['text'] | |||
| prompt, text = self.process_sample('', tokenizer, | |||
| tokenize), self.process_sample( | |||
| text, tokenizer, tokenize) | |||
| return [prompt], [text] | |||
| class TestDataset(PromptReader): | |||
| PATH = '/root/data/test.json' | |||
| assert_str = 'make sure to set PATH for wikipedia data_utils/corpora.py' | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| prompt, text = data['prompt'], data['text'] | |||
| prompt, text = self.process_sample(prompt, tokenizer, | |||
| tokenize), self.process_sample( | |||
| text, tokenizer, tokenize) | |||
| return [prompt], [text] | |||
| class OpenWebText(PromptReader): | |||
| PATH = '/dataset/fd5061f6/english_data/openwebtext2' | |||
| assert_str = 'make sure to set PATH for openwebtext data_utils/corpora.py' | |||
| def __init__(self, *args, **kwargs): | |||
| import fasttext | |||
| super().__init__(*args, **kwargs) | |||
| self.model = fasttext.load_model( | |||
| '/dataset/fd5061f6/english_data/lid.176.bin') | |||
| print_rank_0('Load language detection model') | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| text = data['text'] | |||
| if len(text) > 100: | |||
| lang = self.model.predict(text.replace('\n', ''))[0][0] | |||
| if lang == '__label__en': | |||
| prompt, text = self.process_sample( | |||
| '', tokenizer, | |||
| tokenize), self.process_sample(text, tokenizer, tokenize) | |||
| return [prompt], [text] | |||
| return [], [] | |||
| class CCNews(PromptReader): | |||
| PATH = '/mnt/cc_news.json' | |||
| assert_str = 'make sure to set PATH for cc-news data_utils/corpora.py' | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| text = '' | |||
| title = data.get('title', None) | |||
| description = data.get('description', None) | |||
| maintext = data.get('maintext', None) | |||
| if title: | |||
| text += title.strip() + ' ' | |||
| if description and (not maintext | |||
| or not maintext.startswith(description)): | |||
| text += description.strip() + ' ' | |||
| if maintext: | |||
| text += maintext | |||
| if len(text) > 100: | |||
| prompt, text = self.process_sample('', tokenizer, | |||
| tokenize), self.process_sample( | |||
| text, tokenizer, tokenize) | |||
| return [prompt], [text] | |||
| else: | |||
| return [], [] | |||
| class BertData(PromptReader): | |||
| is_json = False | |||
| PATH = '/dataset/fd5061f6/english_data/wikibook' | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| if data: | |||
| prompt, text = '', data | |||
| prompt, text = self.process_sample(prompt, tokenizer, | |||
| tokenize), self.process_sample( | |||
| text, tokenizer, tokenize) | |||
| return [prompt], [text] | |||
| else: | |||
| return [], [] | |||
| class Pile(PromptReader): | |||
| is_json = True | |||
| PATH = '/mnt/train' | |||
| filtered_sources = [ | |||
| 'Github', 'StackExchange', 'DM Mathematics', 'Ubuntu IRC', 'EuroParl', | |||
| 'YoutubeSubtitles', 'Enron Emails' | |||
| ] | |||
| downsample_sources = {'PubMed Central': 0.3, 'ArXiv': 0.3, 'FreeLaw': 0.3} | |||
| def print_info(self, info): | |||
| total_dict = defaultdict(int) | |||
| while True: | |||
| try: | |||
| source_dict = info.get(block=False) | |||
| for source, length in source_dict.items(): | |||
| total_dict[source] += length | |||
| except Empty: | |||
| break | |||
| print_rank_0(total_dict) | |||
| def tokenize_worker(self, input, output, info, tokenizer, tokenize): | |||
| source_dict = defaultdict(int) | |||
| for row in iter(input.get, 'STOP'): | |||
| row = row.rstrip() | |||
| if row: | |||
| if self.is_json: | |||
| row = json.loads(row) | |||
| prompts, texts, source = self.process_line( | |||
| row, tokenizer, tokenize) | |||
| length = 0 | |||
| for prompt, text in zip(prompts, texts): | |||
| length += len(text) | |||
| output.put((prompt, text)) | |||
| if source: | |||
| source_dict[source] += length | |||
| output.put('COMPLETE') | |||
| info.put(source_dict) | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| source = data['meta'].get('pile_set_name', None) | |||
| text = data.get('text', None) | |||
| if source and text: | |||
| if source in self.filtered_sources: | |||
| return [], [], None | |||
| elif source in self.downsample_sources and random.random( | |||
| ) > self.downsample_sources[source]: | |||
| return [], [], None | |||
| else: | |||
| prompt, text = self.process_sample( | |||
| '', tokenizer, | |||
| tokenize), self.process_sample(text, tokenizer, tokenize) | |||
| return [prompt], [text], source | |||
| else: | |||
| return [], [], None | |||
| class Stories(PromptReader): | |||
| is_json = True | |||
| PATH = '/dataset/fd5061f6/english_data/stories_31G.jsonl' | |||
| def process_line(self, data, tokenizer, tokenize): | |||
| text = data.get('text', None) | |||
| if text: | |||
| prompt, text = self.process_sample('', tokenizer, | |||
| tokenize), self.process_sample( | |||
| text, tokenizer, tokenize) | |||
| return [prompt], [text] | |||
| else: | |||
| return [], [] | |||
| class BertBaseData(BertData): | |||
| PATH = '/root/data/formatted_one_article_per_line' | |||
| class BertLargeData(BertData): | |||
| PATH = '/dataset/c07bd62b/cognitive/zhengxiao/formatted_one_article_per_line_large' | |||
| class WuDaoCorpus(PromptReader): | |||
| # PATH = "/dataset/fd5061f6/chinese_data/WuDao" | |||
| PATH = '/wudao' | |||
| is_json = False | |||
| reserve_punct = True | |||
| split_row = False | |||
| def process_line(self, item, tokenizer, tokenize): | |||
| prompts, texts = [], [] | |||
| text = '' | |||
| title = item.get('title', None) | |||
| content = item.get('content', None) | |||
| if title: | |||
| text += title.strip() + ' ' | |||
| if content: | |||
| text += content | |||
| if len(text) > 100: | |||
| prompt, text = self.process_sample('', tokenizer, | |||
| tokenize), self.process_sample( | |||
| text, tokenizer, tokenize) | |||
| prompts.append(prompt) | |||
| texts.append(text) | |||
| return prompts, texts | |||
| NAMED_CORPORA = { | |||
| 'wikipedia': wikipedia, | |||
| 'wikipedia-key': KeyReader, | |||
| 'openwebtext': OpenWebText, | |||
| 'zhihu': zhihu, | |||
| 'zhidao': zhidao, | |||
| 'baike': baike, | |||
| 'test': TestDataset, | |||
| 'wikibook': BertData, | |||
| 'bert-base': BertBaseData, | |||
| 'bert-large': BertLargeData, | |||
| 'cc-news': CCNews, | |||
| 'pile': Pile, | |||
| 'stories': Stories, | |||
| 'wudao': WuDaoCorpus | |||
| } | |||
| @@ -0,0 +1,71 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import glob | |||
| import os | |||
| import json | |||
| import nltk | |||
| nltk.download('punkt') | |||
| class NLTKSegmenter: | |||
| def __init(self): | |||
| pass | |||
| @staticmethod | |||
| def segment_string(article): | |||
| return nltk.tokenize.sent_tokenize(article) | |||
| wiki_path = 'data/extracted' | |||
| output_path = 'formatted/wiki-key.txt' | |||
| segmenter = NLTKSegmenter() | |||
| with open(output_path, 'w') as output: | |||
| for dirname in glob.glob(os.path.join(wiki_path, '*'), recursive=False): | |||
| for filename in glob.glob( | |||
| os.path.join(dirname, 'wiki_*'), recursive=True): | |||
| print(filename) | |||
| article_lines = [] | |||
| article_open = False | |||
| with open(filename, mode='r', newline='\n') as file: | |||
| for line in file: | |||
| line = line.rstrip() | |||
| if '<doc id=' in line: | |||
| article_open = True | |||
| elif '</doc>' in line: | |||
| key_sentences, contents = [], [] | |||
| key, content = None, [] | |||
| for sentences in article_lines[1:]: | |||
| if len(sentences) > 1: | |||
| if key: | |||
| if len(content) > 0 or len(contents) == 0: | |||
| key_sentences.append(key) | |||
| contents.append(content) | |||
| else: | |||
| contents[-1].append(key) | |||
| key, content = None, [] | |||
| key_sentences.append(sentences[0]) | |||
| contents.append(sentences[1:]) | |||
| elif len(sentences) > 0: | |||
| if key: | |||
| content.append(sentences[0]) | |||
| else: | |||
| key = sentences[0] | |||
| if key: | |||
| if len(content) > 0 or len(contents) == 0: | |||
| key_sentences.append(key) | |||
| contents.append(content) | |||
| else: | |||
| contents[-1].append(key) | |||
| contents = [' '.join(content) for content in contents] | |||
| article = {'key': key_sentences, 'content': contents} | |||
| output.write(json.dumps(article)) | |||
| output.write('\n') | |||
| article_open = False | |||
| article_lines = [] | |||
| else: | |||
| if article_open and line: | |||
| sentences = segmenter.segment_string(line) | |||
| article_lines.append(sentences) | |||
| @@ -0,0 +1,256 @@ | |||
| # Modified by Zhipu.AI | |||
| # This file is provided as is from: | |||
| # https://github.com/huggingface/pytorch-pretrained-BERT | |||
| # Please refer to their repository for copyright. | |||
| """ | |||
| Utilities for working with the local dataset cache. | |||
| This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp | |||
| Copyright by the AllenNLP authors. | |||
| """ | |||
| from __future__ import (absolute_import, division, print_function, | |||
| unicode_literals) | |||
| import logging | |||
| import os | |||
| import shutil | |||
| import sys | |||
| import tempfile | |||
| from functools import wraps | |||
| from hashlib import sha256 | |||
| from io import open | |||
| from urllib.parse import urlparse | |||
| import boto3 | |||
| import json | |||
| import requests | |||
| from botocore.exceptions import ClientError | |||
| from tqdm import tqdm | |||
| try: | |||
| from pathlib import Path | |||
| PYTORCH_PRETRAINED_BERT_CACHE = Path( | |||
| os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', | |||
| Path.home() / '.pytorch_pretrained_bert')) | |||
| except (AttributeError, ImportError): | |||
| PYTORCH_PRETRAINED_BERT_CACHE = os.getenv( | |||
| 'PYTORCH_PRETRAINED_BERT_CACHE', | |||
| os.path.join(os.path.expanduser('~'), '.pytorch_pretrained_bert')) | |||
| logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |||
| def url_to_filename(url, etag=None): | |||
| """ | |||
| Convert `url` into a hashed filename in a repeatable way. | |||
| If `etag` is specified, append its hash to the url's, delimited | |||
| by a period. | |||
| """ | |||
| url_bytes = url.encode('utf-8') | |||
| url_hash = sha256(url_bytes) | |||
| filename = url_hash.hexdigest() | |||
| if etag: | |||
| etag_bytes = etag.encode('utf-8') | |||
| etag_hash = sha256(etag_bytes) | |||
| filename += '.' + etag_hash.hexdigest() | |||
| return filename | |||
| def filename_to_url(filename, cache_dir=None): | |||
| """ | |||
| Return the url and etag (which may be ``None``) stored for `filename`. | |||
| Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. | |||
| """ | |||
| if cache_dir is None: | |||
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | |||
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |||
| cache_dir = str(cache_dir) | |||
| cache_path = os.path.join(cache_dir, filename) | |||
| if not os.path.exists(cache_path): | |||
| raise EnvironmentError('file {} not found'.format(cache_path)) | |||
| meta_path = cache_path + '.json' | |||
| if not os.path.exists(meta_path): | |||
| raise EnvironmentError('file {} not found'.format(meta_path)) | |||
| with open(meta_path, encoding='utf-8') as meta_file: | |||
| metadata = json.load(meta_file) | |||
| url = metadata['url'] | |||
| etag = metadata['etag'] | |||
| return url, etag | |||
| def cached_path(url_or_filename, cache_dir=None): | |||
| """ | |||
| Given something that might be a URL (or might be a local path), | |||
| determine which. If it's a URL, download the file and cache it, and | |||
| return the path to the cached file. If it's already a local path, | |||
| make sure the file exists and then return the path. | |||
| """ | |||
| if cache_dir is None: | |||
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | |||
| if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): | |||
| url_or_filename = str(url_or_filename) | |||
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |||
| cache_dir = str(cache_dir) | |||
| parsed = urlparse(url_or_filename) | |||
| if parsed.scheme in ('http', 'https', 's3'): | |||
| # URL, so get it from the cache (downloading if necessary) | |||
| return get_from_cache(url_or_filename, cache_dir) | |||
| elif os.path.exists(url_or_filename): | |||
| # File, and it exists. | |||
| return url_or_filename | |||
| elif parsed.scheme == '': | |||
| # File, but it doesn't exist. | |||
| raise EnvironmentError('file {} not found'.format(url_or_filename)) | |||
| else: | |||
| # Something unknown | |||
| raise ValueError( | |||
| 'unable to parse {} as a URL or as a local path'.format( | |||
| url_or_filename)) | |||
| def split_s3_path(url): | |||
| """Split a full s3 path into the bucket name and path.""" | |||
| parsed = urlparse(url) | |||
| if not parsed.netloc or not parsed.path: | |||
| raise ValueError('bad s3 path {}'.format(url)) | |||
| bucket_name = parsed.netloc | |||
| s3_path = parsed.path | |||
| # Remove '/' at beginning of path. | |||
| if s3_path.startswith('/'): | |||
| s3_path = s3_path[1:] | |||
| return bucket_name, s3_path | |||
| def s3_request(func): | |||
| """ | |||
| Wrapper function for s3 requests in order to create more helpful error | |||
| messages. | |||
| """ | |||
| @wraps(func) | |||
| def wrapper(url, *args, **kwargs): | |||
| try: | |||
| return func(url, *args, **kwargs) | |||
| except ClientError as exc: | |||
| if int(exc.response['Error']['Code']) == 404: | |||
| raise EnvironmentError('file {} not found'.format(url)) | |||
| else: | |||
| raise | |||
| return wrapper | |||
| @s3_request | |||
| def s3_etag(url): | |||
| """Check ETag on S3 object.""" | |||
| s3_resource = boto3.resource('s3') | |||
| bucket_name, s3_path = split_s3_path(url) | |||
| s3_object = s3_resource.Object(bucket_name, s3_path) | |||
| return s3_object.e_tag | |||
| @s3_request | |||
| def s3_get(url, temp_file): | |||
| """Pull a file directly from S3.""" | |||
| s3_resource = boto3.resource('s3') | |||
| bucket_name, s3_path = split_s3_path(url) | |||
| s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) | |||
| def http_get(url, temp_file): | |||
| req = requests.get(url, stream=True) | |||
| content_length = req.headers.get('Content-Length') | |||
| total = int(content_length) if content_length is not None else None | |||
| progress = tqdm(unit='B', total=total) | |||
| for chunk in req.iter_content(chunk_size=1024): | |||
| if chunk: # filter out keep-alive new chunks | |||
| progress.update(len(chunk)) | |||
| temp_file.write(chunk) | |||
| progress.close() | |||
| def get_from_cache(url, cache_dir=None): | |||
| """ | |||
| Given a URL, look for the corresponding dataset in the local cache. | |||
| If it's not there, download it. Then return the path to the cached file. | |||
| """ | |||
| if cache_dir is None: | |||
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE | |||
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): | |||
| cache_dir = str(cache_dir) | |||
| if not os.path.exists(cache_dir): | |||
| os.makedirs(cache_dir) | |||
| # Get eTag to add to filename, if it exists. | |||
| if url.startswith('s3://'): | |||
| etag = s3_etag(url) | |||
| else: | |||
| response = requests.head(url, allow_redirects=True) | |||
| if response.status_code != 200: | |||
| raise IOError( | |||
| 'HEAD request failed for url {} with status code {}'.format( | |||
| url, response.status_code)) | |||
| etag = response.headers.get('ETag') | |||
| filename = url_to_filename(url, etag) | |||
| # get cache path to put the file | |||
| cache_path = os.path.join(cache_dir, filename) | |||
| if not os.path.exists(cache_path): | |||
| # Download to temporary file, then copy to cache dir once finished. | |||
| # Otherwise you get corrupt cache entries if the download gets interrupted. | |||
| with tempfile.NamedTemporaryFile() as temp_file: | |||
| logger.info('%s not found in cache, downloading to %s', url, | |||
| temp_file.name) | |||
| # GET file object | |||
| if url.startswith('s3://'): | |||
| s3_get(url, temp_file) | |||
| else: | |||
| http_get(url, temp_file) | |||
| # we are copying the file before closing it, so flush to avoid truncation | |||
| temp_file.flush() | |||
| # shutil.copyfileobj() starts at the current position, so go to the start | |||
| temp_file.seek(0) | |||
| logger.info('copying %s to cache at %s', temp_file.name, | |||
| cache_path) | |||
| with open(cache_path, 'wb') as cache_file: | |||
| shutil.copyfileobj(temp_file, cache_file) | |||
| logger.info('creating metadata file for %s', cache_path) | |||
| meta = {'url': url, 'etag': etag} | |||
| meta_path = cache_path + '.json' | |||
| with open(meta_path, 'w', encoding='utf-8') as meta_file: | |||
| json.dump(meta, meta_file) | |||
| logger.info('removing temp file %s', temp_file.name) | |||
| return cache_path | |||
| def read_set_from_file(filename): | |||
| ''' | |||
| Extract a de-duped collection (set) of text from a file. | |||
| Expected file format is one item per line. | |||
| ''' | |||
| collection = set() | |||
| with open(filename, 'r', encoding='utf-8') as file_: | |||
| for line in file_: | |||
| collection.add(line.rstrip()) | |||
| return collection | |||
| def get_file_extension(path, dot=True, lower=True): | |||
| ext = os.path.splitext(path)[1] | |||
| ext = ext if dot else ext[1:] | |||
| return ext.lower() if lower else ext | |||
| @@ -0,0 +1,286 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """utils for loading text from disk""" | |||
| import mmap | |||
| import os | |||
| import pickle as pkl | |||
| import time | |||
| from itertools import accumulate | |||
| import numpy as np | |||
| import torch | |||
| from torch.multiprocessing import Lock | |||
| def get_lazy_path(path): | |||
| """ | |||
| Gets directory path where lazy files are stored. | |||
| """ | |||
| return os.path.splitext(path)[0] + '.lazy' | |||
| def exists_lazy(path, data_type='data'): | |||
| """ | |||
| Check if we've already made a lazy version of this file for the `data_type` field. | |||
| """ | |||
| if not os.path.exists(get_lazy_path(path)): | |||
| return False | |||
| contents = os.listdir(get_lazy_path(path)) | |||
| if data_type not in contents: | |||
| return False | |||
| if data_type + '.len.pkl' not in contents: | |||
| return False | |||
| return True | |||
| def get_scatter_path(path, scatter_rank): | |||
| path = os.path.splitext(path)[0] + '.scatter' | |||
| scatter_path = os.path.join(path, str(scatter_rank)) | |||
| return scatter_path | |||
| def exists_scatter(path, scatter_num=64, data_type='data'): | |||
| for i in range(scatter_num): | |||
| scatter_path = get_scatter_path(path, scatter_rank=i) | |||
| if not exists_lazy(scatter_path, data_type=data_type): | |||
| return False | |||
| return True | |||
| class LazyWriter: | |||
| def __init__(self, | |||
| path, | |||
| data_type, | |||
| is_array=False, | |||
| array_data_type=np.int32): | |||
| lazypath = get_lazy_path(path) | |||
| if not os.path.exists(lazypath): | |||
| os.makedirs(lazypath) | |||
| self.datapath = os.path.join(lazypath, data_type) | |||
| self.lenpath = os.path.join(lazypath, data_type + '.len.pkl') | |||
| self.array_data_type = array_data_type | |||
| self.output = open(self.datapath, 'wb') | |||
| self.lengths = [] | |||
| self.is_array = is_array | |||
| @staticmethod | |||
| def get_len_path(path, data_type): | |||
| lazypath = get_lazy_path(path) | |||
| return os.path.join(lazypath, data_type + '.len.pkl') | |||
| def write(self, s): | |||
| if isinstance(s, dict): | |||
| s = s['text'] | |||
| if self.is_array: | |||
| encoded = np.array( | |||
| s, dtype=self.array_data_type).tobytes(order='C') | |||
| self.output.write(encoded) | |||
| self.lengths.append(len(s)) | |||
| else: | |||
| encoded = s.encode('utf-8') | |||
| self.output.write(encoded) | |||
| self.lengths.append(len(encoded)) | |||
| def close(self): | |||
| self.output.close() | |||
| with open(self.lenpath, 'wb') as f: | |||
| pkl.dump(self.lengths, f) | |||
| def split_strings(strings, start, chr_lens): | |||
| """ | |||
| Split strings based on string lengths and given start. | |||
| """ | |||
| return [ | |||
| strings[i - start:j - start] | |||
| for i, j in zip([start] + chr_lens[:-1], chr_lens) | |||
| ] | |||
| class ProcessorTokenizer: | |||
| """ | |||
| callable class that runs a preprocessing, as well as tokenization step, | |||
| on input text. | |||
| """ | |||
| def __init__(self, tokenizer, process_fn=None): | |||
| self.tokenizer = tokenizer | |||
| self.process_fn = process_fn | |||
| def __call__(self, string): | |||
| if self.tokenizer is not None: | |||
| string = self.tokenizer(string, process_fn=self.process_fn) | |||
| elif self.process_fn is not None: | |||
| string = self.process_fn(string) | |||
| return string | |||
| class LazyLoader(object): | |||
| """ | |||
| Arguments: | |||
| path: path to directory where array entries are concatenated into one big string file | |||
| and the .len file are located | |||
| data_type (str): Some datsets have multiple fields that are stored in different paths. | |||
| `data_type` specifies which of these fields to load in this class | |||
| mem_map (boolean): Specifies whether to memory map file `path` | |||
| map_fn (callable): Fetched strings are passed through map_fn before being returned. | |||
| Example of lazy loader directory structure: | |||
| file.json | |||
| file.lazy/ | |||
| data_type1 | |||
| data_type1.len.pkl | |||
| data_type2 | |||
| data_type2.len.pkl | |||
| """ | |||
| def __init__(self, | |||
| path, | |||
| data_type='data', | |||
| mem_map=False, | |||
| map_fn=None, | |||
| is_array=False, | |||
| array_data_type=np.int32, | |||
| load_memory=False, | |||
| half_load=False): | |||
| lazypath = get_lazy_path(path) | |||
| datapath = os.path.join(lazypath, data_type) | |||
| # get file where array entries are concatenated into one big string | |||
| self._file = open(datapath, 'rb') | |||
| self.file = self._file | |||
| self.is_array = is_array | |||
| self.array_data_type = array_data_type | |||
| # memory map file if necessary | |||
| lenpath = os.path.join(lazypath, data_type + '.len.pkl') | |||
| self.lens = pkl.load(open(lenpath, 'rb')) | |||
| if half_load: | |||
| self.lens = self.lens[:2 * len(self.lens) // 3] | |||
| self.ends = list(accumulate(self.lens)) | |||
| self.dumb_ends = list(self.ends) | |||
| self.mem_map = mem_map | |||
| self.load_memory = load_memory | |||
| if self.load_memory: | |||
| data_type_size = np.dtype(self.array_data_type).itemsize | |||
| if half_load: | |||
| self.file = self.file.read(sum(self.lens) * data_type_size) | |||
| else: | |||
| self.file = self.file.read() | |||
| self.file = np.ndarray( | |||
| shape=(len(self.file) // data_type_size, ), | |||
| dtype=array_data_type, | |||
| buffer=self.file, | |||
| order='C') | |||
| elif self.mem_map: | |||
| if is_array: | |||
| if self.ends[-1] == 0: | |||
| self.file = np.array([], dtype=array_data_type) | |||
| else: | |||
| self.file = np.memmap( | |||
| self.file, dtype=array_data_type, mode='r', order='C') | |||
| else: | |||
| if self.ends[-1] == 0: | |||
| self.file = bytearray() | |||
| else: | |||
| self.file = mmap.mmap( | |||
| self.file.fileno(), 0, prot=mmap.PROT_READ) | |||
| self.read_lock = Lock() | |||
| self.process_fn = map_fn | |||
| self.map_fn = map_fn | |||
| self._tokenizer = None | |||
| self.is_lazy = True | |||
| def SetTokenizer(self, tokenizer): | |||
| """ | |||
| logic to set and remove (set to None) tokenizer. | |||
| combines preprocessing/tokenization into one callable. | |||
| """ | |||
| if tokenizer is None: | |||
| if not hasattr(self, '_tokenizer'): | |||
| self._tokenizer = tokenizer | |||
| else: | |||
| self._tokenizer = tokenizer | |||
| self.map_fn = ProcessorTokenizer(tokenizer, self.process_fn) | |||
| def GetTokenizer(self): | |||
| return self._tokenizer | |||
| def __getitem__(self, index): | |||
| """ | |||
| read file and splice strings based on string ending array `self.ends` | |||
| """ | |||
| if not isinstance(index, slice): | |||
| if index == 0: | |||
| start = 0 | |||
| else: | |||
| start = self.ends[index - 1] | |||
| end = self.ends[index] | |||
| rtn = self.file_read(start, end) | |||
| if self.map_fn is not None: | |||
| rtn = self.map_fn(rtn) | |||
| else: | |||
| # if slice, fetch strings with 1 diskread and then splice in memory | |||
| chr_lens = self.ends[index] | |||
| if index.start == 0 or index.start is None: | |||
| start = 0 | |||
| else: | |||
| start = self.ends[index.start - 1] | |||
| stop = chr_lens[-1] | |||
| strings = self.file_read(start, stop) | |||
| rtn = split_strings(strings, start, chr_lens) | |||
| if self.map_fn is not None: | |||
| rtn = [self.map_fn(s) for s in rtn] | |||
| return rtn | |||
| def __len__(self): | |||
| return len(self.ends) | |||
| def file_read(self, start=0, end=None): | |||
| """read specified portion of file""" | |||
| data_type_size = np.dtype(self.array_data_type).itemsize | |||
| # atomic reads to avoid race conditions with multiprocess dataloader | |||
| self.read_lock.acquire() | |||
| if not self.mem_map and not self.load_memory: | |||
| # seek to start of file read | |||
| if self.is_array: | |||
| start = start * data_type_size | |||
| end = end * data_type_size if end is not None else None | |||
| self.file.seek(start) | |||
| # read to end of file if no end point provided | |||
| if end is None: | |||
| rtn = self.file.read() | |||
| # else read amount needed to reach end point | |||
| else: | |||
| rtn = self.file.read(end - start) | |||
| if self.is_array: | |||
| rtn = np.ndarray( | |||
| shape=(len(rtn) // data_type_size, ), | |||
| dtype=self.array_data_type, | |||
| buffer=rtn, | |||
| order='C') | |||
| else: | |||
| rtn = rtn.decode('utf-8', 'ignore') | |||
| else: | |||
| rtn = self.file[start:end] | |||
| if self.is_array: | |||
| rtn = rtn.copy() | |||
| else: | |||
| rtn = rtn.decode('utf-8', 'strict') | |||
| self.read_lock.release() | |||
| # TODO: @raulp figure out mem map byte string bug | |||
| # if mem map'd need to decode byte string to string | |||
| # # rtn = str(rtn) | |||
| # if self.mem_map: | |||
| # rtn = rtn.decode('unicode_escape') | |||
| return rtn | |||
| @@ -0,0 +1,190 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """batch samplers that work with either random or sequential data samplers""" | |||
| import math | |||
| import os | |||
| import sys | |||
| import numpy as np | |||
| import torch | |||
| from torch.utils import data | |||
| class RandomSampler(data.sampler.Sampler): | |||
| r""" | |||
| Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, | |||
| but this class lets the user set an epoch like DistributedSampler | |||
| Samples elements randomly. If without replacement, then sample from a shuffled dataset. | |||
| If with replacement, then user can specify ``num_samples`` to draw. | |||
| Arguments: | |||
| data_source (Dataset): dataset to sample from | |||
| num_samples (int): number of samples to draw, default=len(dataset) | |||
| replacement (bool): samples are drawn with replacement if ``True``, default=False | |||
| """ | |||
| def __init__(self, data_source, replacement=False, num_samples=None): | |||
| super(RandomSampler, self).__init__(data_source) | |||
| self.data_source = data_source | |||
| self.replacement = replacement | |||
| self._num_samples = num_samples | |||
| self.epoch = -1 | |||
| if self._num_samples is not None and replacement is False: | |||
| raise ValueError( | |||
| 'With replacement=False, num_samples should not be specified, ' | |||
| 'since a random permute will be performed.') | |||
| if not isinstance(self.num_samples, int) or self.num_samples <= 0: | |||
| raise ValueError('num_samples should be a positive integer ' | |||
| 'value, but got num_samples={}'.format( | |||
| self.num_samples)) | |||
| if not isinstance(self.replacement, bool): | |||
| raise ValueError('replacement should be a boolean value, but got ' | |||
| 'replacement={}'.format(self.replacement)) | |||
| @property | |||
| def num_samples(self): | |||
| # dataset size might change at runtime | |||
| if self._num_samples is None: | |||
| return len(self.data_source) | |||
| return self._num_samples | |||
| def __iter__(self): | |||
| n = len(self.data_source) | |||
| g = torch.Generator() | |||
| if self.epoch >= 0: | |||
| g.manual_seed(self.epoch) | |||
| if self.replacement: | |||
| for _ in range(self.num_samples // 32): | |||
| yield from torch.randint( | |||
| high=n, size=(32, ), dtype=torch.int64, | |||
| generator=g).tolist() | |||
| yield from torch.randint( | |||
| high=n, | |||
| size=(self.num_samples % 32, ), | |||
| dtype=torch.int64, | |||
| generator=g).tolist() | |||
| else: | |||
| yield from torch.randperm(n, generator=self.generator).tolist() | |||
| def __len__(self): | |||
| return self.num_samples | |||
| def set_epoch(self, epoch): | |||
| self.epoch = epoch | |||
| class DistributedSequentialSampler(data.sampler.Sampler): | |||
| def __init__(self, | |||
| num_samples, | |||
| train_iters, | |||
| batch_size, | |||
| rank=-1, | |||
| world_size=2): | |||
| super().__init__(num_samples) | |||
| if rank == -1: | |||
| rank = 0 | |||
| world_size = 1 | |||
| self.num_samples = num_samples | |||
| self.rank = rank | |||
| self.world_size = world_size | |||
| self.start_iter = 0 | |||
| self.train_iters = train_iters | |||
| self.batch_size = batch_size | |||
| self.batch_bias = [ | |||
| i * (num_samples // batch_size) for i in range(batch_size) | |||
| ] | |||
| def __iter__(self): | |||
| for idx in range(self.start_iter, self.train_iters * 10): | |||
| batch = [(idx + bias) % self.num_samples | |||
| for bias in self.batch_bias] | |||
| tbatch = self._batch(batch) | |||
| yield tbatch | |||
| def __len__(self): | |||
| return self.train_iters | |||
| def _batch(self, batch): | |||
| """extracts samples only pertaining to this worker's batch""" | |||
| start = self.rank * self.batch_size // self.world_size | |||
| end = (self.rank + 1) * self.batch_size // self.world_size | |||
| return batch[start:end] | |||
| class DistributedBatchSampler(data.sampler.BatchSampler): | |||
| """ | |||
| similar to normal implementation of distributed sampler, except implementation is at the | |||
| batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary | |||
| data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. | |||
| """ | |||
| def __init__(self, | |||
| sampler, | |||
| batch_size, | |||
| drop_last, | |||
| rank=-1, | |||
| world_size=2, | |||
| wrap_last=False, | |||
| gradient_accumulation_steps=None): | |||
| super(DistributedBatchSampler, self).__init__(sampler, batch_size, | |||
| drop_last) | |||
| if rank == -1: | |||
| assert False, 'should not be here' | |||
| self.rank = rank | |||
| self.world_size = world_size | |||
| self.sampler.wrap_around = 0 | |||
| self.wrap_around = 0 | |||
| self.wrap_last = wrap_last | |||
| self.start_iter = 0 | |||
| self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps # noqa | |||
| def __iter__(self): | |||
| batch = [] | |||
| i = 0 | |||
| for idx in self.data_iterator(self.sampler, wrap_around=False): | |||
| batch.append(idx) | |||
| if len(batch) == self.batch_size: | |||
| tbatch = self._batch(batch) | |||
| if i >= self.start_iter * self.effective_batch_size: | |||
| yield tbatch | |||
| self.start_iter = 0 | |||
| i += len(batch) | |||
| batch = [] | |||
| batch_len = len(batch) | |||
| if batch_len > 0 and not self.drop_last: | |||
| if self.wrap_last: | |||
| self.sampler.wrap_around -= (self.batch_size) | |||
| self.wrap_around += (len(batch)) | |||
| self.wrap_around %= self.batch_size | |||
| yield self._batch(batch) | |||
| if self.wrap_last: | |||
| self.sampler.wrap_around += self.batch_size | |||
| def data_iterator(self, _iter, wrap_around=False): | |||
| """iterates through data and handles wrap around""" | |||
| for i, idx in enumerate(_iter): | |||
| if i < self.wrap_around % self.batch_size: | |||
| continue | |||
| if wrap_around: | |||
| self.wrap_around += 1 | |||
| self.wrap_around %= self.batch_size | |||
| yield idx | |||
| def _batch(self, batch): | |||
| """extracts samples only pertaining to this worker's batch""" | |||
| start = self.rank * self.batch_size // self.world_size | |||
| end = (self.rank + 1) * self.batch_size // self.world_size | |||
| return batch[start:end] | |||
| @@ -0,0 +1,158 @@ | |||
| # Modified by Zhipu.AI | |||
| """ | |||
| from https://github.com/openai/gpt-2/, changed for chinese | |||
| """ | |||
| import os # yapf: disable | |||
| """ | |||
| SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation | |||
| systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements | |||
| subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the | |||
| extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end | |||
| system that does not depend on language-specific pre/postprocessing. | |||
| https://github.com/google/sentencepiece | |||
| pip install sentencepiece | |||
| or git clone https://github.com/google/sentencepiece.git | |||
| python setup.py install | |||
| """ | |||
| def get_pairs(word): | |||
| pairs = set() | |||
| prev_char = word[0] | |||
| for char in word[1:]: | |||
| pairs.add((prev_char, char)) | |||
| prev_char = char | |||
| return pairs | |||
| class Encoder: | |||
| def __init__(self, encoder, bpe_merges): | |||
| self.encoder = encoder | |||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||
| self.cache = {} | |||
| self.max_len = 0 | |||
| def bpe(self, token): | |||
| if token in self.cache: | |||
| return self.cache[token] | |||
| word = tuple(token) | |||
| pairs = get_pairs(word) | |||
| if not pairs: | |||
| return token | |||
| while True: | |||
| bigram = min( | |||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||
| if bigram not in self.bpe_ranks: | |||
| break | |||
| first, second = bigram | |||
| new_word = [] | |||
| i = 0 | |||
| while i < len(word): | |||
| try: | |||
| j = word.index(first, i) | |||
| new_word.extend(word[i:j]) | |||
| i = j | |||
| except: # noqa | |||
| new_word.extend(word[i:]) | |||
| break | |||
| if word[i] == first and i < len(word) - 1 and word[ | |||
| i + 1] == second: | |||
| new_word.append(first + second) | |||
| i += 2 | |||
| else: | |||
| new_word.append(word[i]) | |||
| i += 1 | |||
| new_word = tuple(new_word) | |||
| word = new_word | |||
| if len(word) == 1: | |||
| break | |||
| else: | |||
| pairs = get_pairs(word) | |||
| word = ' '.join(word) | |||
| self.cache[token] = word | |||
| return word | |||
| def encode(self, text): | |||
| return [self.encoder.get(token, 1) for token in self.tokenize(text)] | |||
| def decode(self, tokens): | |||
| text = ''.join([self.decoder[token] for token in tokens]) | |||
| return text | |||
| def tokenize(self, text): | |||
| bpe_tokens = [] | |||
| bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' ')) | |||
| return bpe_tokens | |||
| def convert_tokens_to_ids(self, tokens): | |||
| return [self.encoder.get(token, 1) for token in tokens] | |||
| class Encoder_SP: | |||
| def __init__(self, model_path): | |||
| import sentencepiece as spm | |||
| self.sp = spm.SentencePieceProcessor() | |||
| self.sp.Load(model_path) | |||
| def encode(self, text): | |||
| """ | |||
| text="...." | |||
| """ | |||
| return self.sp.EncodeAsIds(text) | |||
| def decode(self, tokens): | |||
| """ | |||
| tokens=[x1,x2,...] | |||
| """ | |||
| text = [int(token) for token in tokens] | |||
| # print(text) | |||
| return self.sp.DecodeIds(text) | |||
| def tokenize(self, text): | |||
| return self.sp.EncodeAsPieces(text) | |||
| def convert_tokens_to_ids(self, tokens): | |||
| return [self.sp.PieceToId(token) for token in tokens] | |||
| def convert_token_to_id(self, token): | |||
| return self.sp.PieceToId(token) | |||
| def convert_id_to_token(self, idx): | |||
| return self.sp.IdToPiece(idx) | |||
| def get_encoder(encoder_file, bpe_file): | |||
| import json | |||
| filepath, filename = os.path.split(encoder_file) | |||
| shotname, extension = os.path.splitext(filename) | |||
| if ('.model' == extension) and (bpe_file == ''): | |||
| return Encoder_SP(encoder_file) | |||
| else: | |||
| with open(encoder_file, 'r', encoding='utf-8') as f: | |||
| encoder = json.load(f) | |||
| with open(bpe_file, 'r', encoding='utf-8') as f: | |||
| bpe_data = f.read() | |||
| bpe_merges = [ | |||
| tuple(merge_str.split()) | |||
| for merge_str in bpe_data.split('\n')[1:-1] | |||
| ] | |||
| return Encoder( | |||
| encoder=encoder, | |||
| bpe_merges=bpe_merges, | |||
| ) | |||
| def from_pretrained(model_path): | |||
| return get_encoder(model_path + '/tokenizer/mglm250k/mglm250k-uni.model', | |||
| '') | |||
| @@ -0,0 +1,359 @@ | |||
| # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes for OpenAI GPT.""" | |||
| from __future__ import (absolute_import, division, print_function, | |||
| unicode_literals) | |||
| import logging | |||
| import os | |||
| import sys | |||
| from io import open | |||
| import json | |||
| import regex as re | |||
| from .file_utils import cached_path | |||
| try: | |||
| from functools import lru_cache | |||
| except ImportError: | |||
| # Just a dummy decorator to get the checks to run on python2 | |||
| # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. | |||
| def lru_cache(): | |||
| return lambda func: func | |||
| logger = logging.getLogger(__name__) | |||
| PRETRAINED_VOCAB_ARCHIVE_MAP = { | |||
| 'gpt2': '.pytorch_pretrained_bert/gpt2-vocab.json', | |||
| 'roberta': '.pytorch_pretrained_bert/roberta-vocab.json' | |||
| } | |||
| PRETRAINED_MERGES_ARCHIVE_MAP = { | |||
| 'gpt2': '.pytorch_pretrained_bert/gpt2-merges.txt', | |||
| 'roberta': '.pytorch_pretrained_bert/roberta-merges.txt' | |||
| } | |||
| PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { | |||
| 'gpt2': 1024, | |||
| } | |||
| VOCAB_NAME = 'vocab.json' | |||
| MERGES_NAME = 'merges.txt' | |||
| SPECIAL_TOKENS_NAME = 'special_tokens.txt' | |||
| @lru_cache() | |||
| def bytes_to_unicode(): | |||
| """ | |||
| Returns list of utf-8 byte and a corresponding list of unicode strings. | |||
| The reversible bpe codes work on unicode strings. | |||
| This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. | |||
| When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. | |||
| This is a signficant percentage of your normal, say, 32K bpe vocab. | |||
| To avoid that, we want lookup tables between utf-8 bytes and unicode strings. | |||
| And avoids mapping to whitespace/control characters the bpe code barfs on. | |||
| """ | |||
| _chr = unichr if sys.version_info[0] == 2 else chr | |||
| bs = list(range(ord('!'), | |||
| ord('~') + 1)) + list(range( | |||
| ord('¡'), | |||
| ord('¬') + 1)) + list(range(ord('®'), | |||
| ord('ÿ') + 1)) | |||
| cs = bs[:] | |||
| n = 0 | |||
| for b in range(2**8): | |||
| if b not in bs: | |||
| bs.append(b) | |||
| cs.append(2**8 + n) | |||
| n += 1 | |||
| cs = [_chr(n) for n in cs] | |||
| return dict(zip(bs, cs)) | |||
| def get_pairs(word): | |||
| """Return set of symbol pairs in a word. | |||
| Word is represented as tuple of symbols (symbols being variable-length strings). | |||
| """ | |||
| pairs = set() | |||
| prev_char = word[0] | |||
| for char in word[1:]: | |||
| pairs.add((prev_char, char)) | |||
| prev_char = char | |||
| return pairs | |||
| class GPT2Tokenizer(object): | |||
| """ | |||
| GPT-2 BPE tokenizer. Peculiarities: | |||
| - Byte-level BPE | |||
| """ | |||
| @classmethod | |||
| def from_pretrained(cls, | |||
| pretrained_model_name_or_path, | |||
| cache_dir=None, | |||
| *inputs, | |||
| **kwargs): | |||
| """ | |||
| Instantiate a PreTrainedBertModel from a pre-trained model file. | |||
| Download and cache the pre-trained model file if needed. | |||
| """ | |||
| if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: | |||
| vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ | |||
| pretrained_model_name_or_path] | |||
| merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[ | |||
| pretrained_model_name_or_path] | |||
| special_tokens_file = None | |||
| else: | |||
| vocab_file = os.path.join(pretrained_model_name_or_path, | |||
| VOCAB_NAME) | |||
| merges_file = os.path.join(pretrained_model_name_or_path, | |||
| MERGES_NAME) | |||
| special_tokens_file = os.path.join(pretrained_model_name_or_path, | |||
| SPECIAL_TOKENS_NAME) | |||
| if not os.path.exists(special_tokens_file): | |||
| special_tokens_file = None | |||
| else: | |||
| logger.info('loading special tokens file {}'.format( | |||
| special_tokens_file)) | |||
| # redirect to the cache, if necessary | |||
| # try: | |||
| # resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) | |||
| # resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) | |||
| # except EnvironmentError: | |||
| # logger.error( | |||
| # "Model name '{}' was not found in model name list ({}). " | |||
| # "We assumed '{}' was a path or url but couldn't find files {} and {} " | |||
| # "at this path or url.".format( | |||
| # pretrained_model_name_or_path, | |||
| # ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), | |||
| # pretrained_model_name_or_path, | |||
| # vocab_file, merges_file)) | |||
| # return None | |||
| # if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: | |||
| # logger.info("loading vocabulary file {}".format(vocab_file)) | |||
| # logger.info("loading merges file {}".format(merges_file)) | |||
| # else: | |||
| # logger.info("loading vocabulary file {} from cache at {}".format( | |||
| # vocab_file, resolved_vocab_file)) | |||
| # logger.info("loading merges file {} from cache at {}".format( | |||
| # merges_file, resolved_merges_file)) | |||
| resolved_vocab_file = vocab_file | |||
| resolved_merges_file = merges_file | |||
| logger.info('loading vocabulary file {}'.format(vocab_file)) | |||
| logger.info('loading merges file {}'.format(merges_file)) | |||
| if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: | |||
| # if we're using a pretrained model, ensure the tokenizer wont index sequences longer | |||
| # than the number of positional embeddings | |||
| max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ | |||
| pretrained_model_name_or_path] | |||
| kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) | |||
| # Instantiate tokenizer. | |||
| if special_tokens_file and 'special_tokens' not in kwargs: | |||
| special_tokens = open( | |||
| special_tokens_file, encoding='utf-8').read().split('\n')[:-1] | |||
| else: | |||
| special_tokens = kwargs.pop('special_tokens', []) | |||
| tokenizer = cls( | |||
| resolved_vocab_file, | |||
| resolved_merges_file, | |||
| special_tokens=special_tokens, | |||
| *inputs, | |||
| **kwargs) | |||
| return tokenizer | |||
| def __init__(self, | |||
| vocab_file, | |||
| merges_file, | |||
| errors='replace', | |||
| special_tokens=None, | |||
| max_len=None): | |||
| self.max_len = max_len if max_len is not None else int(1e12) | |||
| self.encoder = json.load(open(vocab_file)) | |||
| self.decoder = {v: k for k, v in self.encoder.items()} | |||
| self.errors = errors # how to handle errors in decoding | |||
| self.byte_encoder = bytes_to_unicode() | |||
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} | |||
| bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] | |||
| bpe_merges = [tuple(merge.split()) for merge in bpe_data] | |||
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) | |||
| self.cache = {} | |||
| # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions | |||
| self.pat = re.compile( | |||
| r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" | |||
| ) | |||
| self.special_tokens = {} | |||
| self.special_tokens_decoder = {} | |||
| self.set_special_tokens(special_tokens) | |||
| def __len__(self): | |||
| return len(self.encoder) + len(self.special_tokens) | |||
| def set_special_tokens(self, special_tokens): | |||
| """ Add a list of additional tokens to the encoder. | |||
| The additional tokens are indexed starting from the last index of the | |||
| current vocabulary in the order of the `special_tokens` list. | |||
| """ | |||
| if not special_tokens: | |||
| self.special_tokens = {} | |||
| self.special_tokens_decoder = {} | |||
| return | |||
| self.special_tokens = dict((tok, len(self.encoder) + i) | |||
| for i, tok in enumerate(special_tokens)) | |||
| self.special_tokens_decoder = { | |||
| v: k | |||
| for k, v in self.special_tokens.items() | |||
| } | |||
| logger.info('Special tokens {}'.format(self.special_tokens)) | |||
| def bpe(self, token): | |||
| if token in self.cache: | |||
| return self.cache[token] | |||
| word = tuple(token) | |||
| pairs = get_pairs(word) | |||
| if not pairs: | |||
| return token | |||
| while True: | |||
| bigram = min( | |||
| pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) | |||
| if bigram not in self.bpe_ranks: | |||
| break | |||
| first, second = bigram | |||
| new_word = [] | |||
| i = 0 | |||
| while i < len(word): | |||
| try: | |||
| j = word.index(first, i) | |||
| new_word.extend(word[i:j]) | |||
| i = j | |||
| except: # noqa | |||
| new_word.extend(word[i:]) | |||
| break | |||
| if word[i] == first and i < len(word) - 1 and word[ | |||
| i + 1] == second: | |||
| new_word.append(first + second) | |||
| i += 2 | |||
| else: | |||
| new_word.append(word[i]) | |||
| i += 1 | |||
| new_word = tuple(new_word) | |||
| word = new_word | |||
| if len(word) == 1: | |||
| break | |||
| else: | |||
| pairs = get_pairs(word) | |||
| word = ' '.join(word) | |||
| self.cache[token] = word | |||
| return word | |||
| def tokenize(self, text): | |||
| """ Tokenize a string. """ | |||
| bpe_tokens = [] | |||
| for token in re.findall(self.pat, text): | |||
| if sys.version_info[0] == 2: | |||
| token = ''.join(self.byte_encoder[ord(b)] for b in token) | |||
| else: | |||
| token = ''.join(self.byte_encoder[b] | |||
| for b in token.encode('utf-8')) | |||
| bpe_tokens.extend( | |||
| bpe_token for bpe_token in self.bpe(token).split(' ')) | |||
| return bpe_tokens | |||
| def convert_tokens_to_ids(self, tokens): | |||
| """ Converts a sequence of tokens into ids using the vocab. """ | |||
| ids = [] | |||
| if isinstance(tokens, str) or (sys.version_info[0] == 2 | |||
| and isinstance(tokens, unicode)): | |||
| if tokens in self.special_tokens: | |||
| return self.special_tokens[tokens] | |||
| else: | |||
| return self.encoder.get(tokens, 0) | |||
| for token in tokens: | |||
| if token in self.special_tokens: | |||
| ids.append(self.special_tokens[token]) | |||
| else: | |||
| ids.append(self.encoder.get(token, 0)) | |||
| if len(ids) > self.max_len: | |||
| logger.warning( | |||
| 'Token indices sequence length is longer than the specified maximum ' | |||
| ' sequence length for this OpenAI GPT model ({} > {}). Running this' | |||
| ' sequence through the model will result in indexing errors'. | |||
| format(len(ids), self.max_len)) | |||
| return ids | |||
| def convert_ids_to_tokens(self, ids, skip_special_tokens=False): | |||
| """Converts a sequence of ids in BPE tokens using the vocab.""" | |||
| tokens = [] | |||
| for i in ids: | |||
| if i in self.special_tokens_decoder: | |||
| if not skip_special_tokens: | |||
| tokens.append(self.special_tokens_decoder[i]) | |||
| else: | |||
| tokens.append(self.decoder[i]) | |||
| return tokens | |||
| def encode(self, text): | |||
| return self.convert_tokens_to_ids(self.tokenize(text)) | |||
| def decode(self, tokens): | |||
| text = ''.join([self.decoder[token] for token in tokens]) | |||
| text = bytearray([self.byte_decoder[c] for c in text]).decode( | |||
| 'utf-8', errors=self.errors) | |||
| return text | |||
| def save_vocabulary(self, vocab_path): | |||
| """Save the tokenizer vocabulary and merge files to a directory.""" | |||
| if not os.path.isdir(vocab_path): | |||
| logger.error('Vocabulary path ({}) should be a directory'.format( | |||
| vocab_path)) | |||
| return | |||
| vocab_file = os.path.join(vocab_path, VOCAB_NAME) | |||
| merge_file = os.path.join(vocab_path, MERGES_NAME) | |||
| special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME) | |||
| with open(vocab_file, 'w', encoding='utf-8') as f: | |||
| f.write(json.dumps(self.encoder, ensure_ascii=False)) | |||
| index = 0 | |||
| with open(merge_file, 'w', encoding='utf-8') as writer: | |||
| writer.write(u'#version: 0.2\n') | |||
| for bpe_tokens, token_index in sorted( | |||
| self.bpe_ranks.items(), key=lambda kv: kv[1]): | |||
| if index != token_index: | |||
| logger.warning( | |||
| 'Saving vocabulary to {}: BPE merge indices are not consecutive.' | |||
| ' Please check that the tokenizer is not corrupted!'. | |||
| format(merge_file)) | |||
| index = token_index | |||
| writer.write(' '.join(bpe_tokens) + u'\n') | |||
| index += 1 | |||
| index = len(self.encoder) | |||
| with open(special_tokens_file, 'w', encoding='utf-8') as writer: | |||
| for token, token_index in sorted( | |||
| self.special_tokens.items(), key=lambda kv: kv[1]): | |||
| if index != token_index: | |||
| logger.warning( | |||
| 'Saving special tokens vocabulary to {}: BPE indices are not consecutive.' | |||
| ' Please check that the tokenizer is not corrupted!'. | |||
| format(special_tokens_file)) | |||
| index = token_index | |||
| writer.write(token + u'\n') | |||
| index += 1 | |||
| return vocab_file, merge_file, special_tokens_file | |||
| @@ -0,0 +1,408 @@ | |||
| # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Tokenization classes. Provided as is from https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py""" # noqa | |||
| from __future__ import (absolute_import, division, print_function, | |||
| unicode_literals) | |||
| import collections | |||
| import logging | |||
| import os | |||
| import unicodedata | |||
| from io import open | |||
| from .file_utils import cached_path | |||
| logger = logging.getLogger(__name__) | |||
| PRETRAINED_VOCAB_ARCHIVE_MAP = { | |||
| 'bert-base-uncased': | |||
| '.pytorch_pretrained_bert/bert-base-uncased-vocab.txt', | |||
| 'bert-large-uncased': | |||
| '.pytorch_pretrained_bert/bert-large-uncased-vocab.txt', | |||
| 'bert-base-cased': | |||
| '.pytorch_pretrained_bert/bert-base-cased-vocab.txt', | |||
| 'bert-large-cased': | |||
| '.pytorch_pretrained_bert/bert-large-cased-vocab.txt', | |||
| 'bert-base-multilingual-uncased': | |||
| 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt', | |||
| 'bert-base-multilingual-cased': | |||
| 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt', | |||
| 'bert-base-chinese': | |||
| 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt', | |||
| } | |||
| PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { | |||
| 'bert-base-uncased': 512, | |||
| 'bert-large-uncased': 512, | |||
| 'bert-base-cased': 512, | |||
| 'bert-large-cased': 512, | |||
| 'bert-base-multilingual-uncased': 512, | |||
| 'bert-base-multilingual-cased': 512, | |||
| 'bert-base-chinese': 512, | |||
| } | |||
| VOCAB_NAME = 'vocab.txt' | |||
| def load_vocab(vocab_file): | |||
| """Loads a vocabulary file into a dictionary.""" | |||
| vocab = collections.OrderedDict() | |||
| index = 0 | |||
| with open(vocab_file, 'r', encoding='utf-8') as reader: | |||
| while True: | |||
| token = reader.readline() | |||
| if not token: | |||
| break | |||
| token = token.strip() | |||
| vocab[token] = index | |||
| index += 1 | |||
| return vocab | |||
| def whitespace_tokenize(text): | |||
| """Runs basic whitespace cleaning and splitting on a piece of text.""" | |||
| text = text.strip() | |||
| if not text: | |||
| return [] | |||
| tokens = text.split() | |||
| return tokens | |||
| class BertTokenizer(object): | |||
| """Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||
| def __init__(self, | |||
| vocab_file, | |||
| do_lower_case=True, | |||
| max_len=None, | |||
| do_basic_tokenize=True, | |||
| never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): | |||
| """Constructs a BertTokenizer. | |||
| Args: | |||
| vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||
| do_lower_case: Whether to lower case the input | |||
| Only has an effect when do_wordpiece_only=False | |||
| do_basic_tokenize: Whether to do basic tokenization before wordpiece. | |||
| max_len: An artificial maximum length to truncate tokenized sequences to; | |||
| Effective maximum length is always the minimum of this | |||
| value (if specified) and the underlying BERT model's | |||
| sequence length. | |||
| never_split: List of tokens which will never be split during tokenization. | |||
| Only has an effect when do_wordpiece_only=False | |||
| """ | |||
| if not os.path.isfile(vocab_file): | |||
| raise ValueError( | |||
| "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |||
| 'model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`' | |||
| .format(vocab_file)) | |||
| self.vocab = load_vocab(vocab_file) | |||
| self.ids_to_tokens = collections.OrderedDict([ | |||
| (ids, tok) for tok, ids in self.vocab.items() | |||
| ]) | |||
| self.do_basic_tokenize = do_basic_tokenize | |||
| if do_basic_tokenize: | |||
| self.basic_tokenizer = BasicTokenizer( | |||
| do_lower_case=do_lower_case, never_split=never_split) | |||
| self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||
| self.max_len = max_len if max_len is not None else int(1e12) | |||
| def tokenize(self, text): | |||
| if self.do_basic_tokenize: | |||
| split_tokens = [] | |||
| for token in self.basic_tokenizer.tokenize(text): | |||
| for sub_token in self.wordpiece_tokenizer.tokenize(token): | |||
| split_tokens.append(sub_token) | |||
| else: | |||
| split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||
| return split_tokens | |||
| def convert_tokens_to_ids(self, tokens): | |||
| """Converts a sequence of tokens into ids using the vocab.""" | |||
| ids = [] | |||
| for token in tokens: | |||
| ids.append(self.vocab[token]) | |||
| if len(ids) > self.max_len: | |||
| logger.warning( | |||
| 'Token indices sequence length is longer than the specified maximum ' | |||
| ' sequence length for this BERT model ({} > {}). Running this' | |||
| ' sequence through BERT will result in indexing errors'.format( | |||
| len(ids), self.max_len)) | |||
| return ids | |||
| def convert_ids_to_tokens(self, ids): | |||
| """Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||
| tokens = [] | |||
| for i in ids: | |||
| tokens.append(self.ids_to_tokens[i]) | |||
| return tokens | |||
| @classmethod | |||
| def from_pretrained(cls, | |||
| pretrained_model_name_or_path, | |||
| cache_dir=None, | |||
| *inputs, | |||
| **kwargs): | |||
| """ | |||
| Instantiate a PreTrainedBertModel from a pre-trained model file. | |||
| Download and cache the pre-trained model file if needed. | |||
| """ | |||
| if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: | |||
| vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[ | |||
| pretrained_model_name_or_path] | |||
| else: | |||
| vocab_file = pretrained_model_name_or_path | |||
| if os.path.isdir(vocab_file): | |||
| vocab_file = os.path.join(vocab_file, VOCAB_NAME) | |||
| # redirect to the cache, if necessary | |||
| try: | |||
| resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) | |||
| except EnvironmentError: | |||
| logger.error( | |||
| "Model name '{}' was not found in model name list ({}). " | |||
| "We assumed '{}' was a path or url but couldn't find any file " | |||
| 'associated to this path or url.'.format( | |||
| pretrained_model_name_or_path, | |||
| ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), | |||
| vocab_file)) | |||
| return None | |||
| if resolved_vocab_file == vocab_file: | |||
| logger.info('loading vocabulary file {}'.format(vocab_file)) | |||
| else: | |||
| logger.info('loading vocabulary file {} from cache at {}'.format( | |||
| vocab_file, resolved_vocab_file)) | |||
| if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: | |||
| # if we're using a pretrained model, ensure the tokenizer wont index sequences longer | |||
| # than the number of positional embeddings | |||
| max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[ | |||
| pretrained_model_name_or_path] | |||
| kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) | |||
| # Instantiate tokenizer. | |||
| tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) | |||
| return tokenizer | |||
| class BasicTokenizer(object): | |||
| """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||
| def __init__(self, | |||
| do_lower_case=True, | |||
| never_split=('[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]')): | |||
| """Constructs a BasicTokenizer. | |||
| Args: | |||
| do_lower_case: Whether to lower case the input. | |||
| """ | |||
| self.do_lower_case = do_lower_case | |||
| self.never_split = never_split | |||
| def tokenize(self, text): | |||
| """Tokenizes a piece of text.""" | |||
| text = self._clean_text(text) | |||
| # This was added on November 1st, 2018 for the multilingual and Chinese | |||
| # models. This is also applied to the English models now, but it doesn't | |||
| # matter since the English models were not trained on any Chinese data | |||
| # and generally don't have any Chinese data in them (there are Chinese | |||
| # characters in the vocabulary because Wikipedia does have some Chinese | |||
| # words in the English Wikipedia.). | |||
| text = self._tokenize_chinese_chars(text) | |||
| orig_tokens = whitespace_tokenize(text) | |||
| split_tokens = [] | |||
| for token in orig_tokens: | |||
| if self.do_lower_case and token not in self.never_split: | |||
| token = token.lower() | |||
| token = self._run_strip_accents(token) | |||
| split_tokens.extend(self._run_split_on_punc(token)) | |||
| output_tokens = whitespace_tokenize(' '.join(split_tokens)) | |||
| return output_tokens | |||
| def _run_strip_accents(self, text): | |||
| """Strips accents from a piece of text.""" | |||
| text = unicodedata.normalize('NFD', text) | |||
| output = [] | |||
| for char in text: | |||
| cat = unicodedata.category(char) | |||
| if cat == 'Mn': | |||
| continue | |||
| output.append(char) | |||
| return ''.join(output) | |||
| def _run_split_on_punc(self, text): | |||
| """Splits punctuation on a piece of text.""" | |||
| if text in self.never_split: | |||
| return [text] | |||
| chars = list(text) | |||
| i = 0 | |||
| start_new_word = True | |||
| output = [] | |||
| while i < len(chars): | |||
| char = chars[i] | |||
| if _is_punctuation(char): | |||
| output.append([char]) | |||
| start_new_word = True | |||
| else: | |||
| if start_new_word: | |||
| output.append([]) | |||
| start_new_word = False | |||
| output[-1].append(char) | |||
| i += 1 | |||
| return [''.join(x) for x in output] | |||
| def _tokenize_chinese_chars(self, text): | |||
| """Adds whitespace around any CJK character.""" | |||
| output = [] | |||
| for char in text: | |||
| cp = ord(char) | |||
| if self._is_chinese_char(cp): | |||
| output.append(' ') | |||
| output.append(char) | |||
| output.append(' ') | |||
| else: | |||
| output.append(char) | |||
| return ''.join(output) | |||
| def _is_chinese_char(self, cp): | |||
| """Checks whether CP is the codepoint of a CJK character.""" | |||
| # This defines a "chinese character" as anything in the CJK Unicode block: | |||
| # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||
| # | |||
| # Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |||
| # despite its name. The modern Korean Hangul alphabet is a different block, | |||
| # as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||
| # space-separated words, so they are not treated specially and handled | |||
| # like the all of the other languages. | |||
| if ((cp >= 0x4E00 and cp <= 0x9FFF) or # noqa | |||
| (cp >= 0x3400 and cp <= 0x4DBF) or # noqa | |||
| (cp >= 0x20000 and cp <= 0x2A6DF) or # noqa | |||
| (cp >= 0x2A700 and cp <= 0x2B73F) or # noqa | |||
| (cp >= 0x2B740 and cp <= 0x2B81F) or # noqa | |||
| (cp >= 0x2B820 and cp <= 0x2CEAF) or # noqa | |||
| (cp >= 0xF900 and cp <= 0xFAFF) or # noqa | |||
| (cp >= 0x2F800 and cp <= 0x2FA1F)): # noqa | |||
| return True | |||
| return False | |||
| def _clean_text(self, text): | |||
| """Performs invalid character removal and whitespace cleanup on text.""" | |||
| output = [] | |||
| for char in text: | |||
| cp = ord(char) | |||
| if cp == 0 or cp == 0xfffd or _is_control(char): | |||
| continue | |||
| if _is_whitespace(char): | |||
| output.append(' ') | |||
| else: | |||
| output.append(char) | |||
| return ''.join(output) | |||
| class WordpieceTokenizer(object): | |||
| """Runs WordPiece tokenization.""" | |||
| def __init__(self, vocab, unk_token='[UNK]', max_input_chars_per_word=100): | |||
| self.vocab = vocab | |||
| self.unk_token = unk_token | |||
| self.max_input_chars_per_word = max_input_chars_per_word | |||
| def tokenize(self, text): | |||
| """Tokenizes a piece of text into its word pieces. | |||
| This uses a greedy longest-match-first algorithm to perform tokenization | |||
| using the given vocabulary. | |||
| For example: | |||
| input = "unaffable" | |||
| output = ["un", "##aff", "##able"] | |||
| Args: | |||
| text: A single token or whitespace separated tokens. This should have | |||
| already been passed through `BasicTokenizer`. | |||
| Returns: | |||
| A list of wordpiece tokens. | |||
| """ | |||
| output_tokens = [] | |||
| for token in whitespace_tokenize(text): | |||
| chars = list(token) | |||
| if len(chars) > self.max_input_chars_per_word: | |||
| output_tokens.append(self.unk_token) | |||
| continue | |||
| is_bad = False | |||
| start = 0 | |||
| sub_tokens = [] | |||
| while start < len(chars): | |||
| end = len(chars) | |||
| cur_substr = None | |||
| while start < end: | |||
| substr = ''.join(chars[start:end]) | |||
| if start > 0: | |||
| substr = '##' + substr | |||
| if substr in self.vocab: | |||
| cur_substr = substr | |||
| break | |||
| end -= 1 | |||
| if cur_substr is None: | |||
| is_bad = True | |||
| break | |||
| sub_tokens.append(cur_substr) | |||
| start = end | |||
| if is_bad: | |||
| output_tokens.append(self.unk_token) | |||
| else: | |||
| output_tokens.extend(sub_tokens) | |||
| return output_tokens | |||
| def _is_whitespace(char): | |||
| """Checks whether `chars` is a whitespace character.""" | |||
| # \t, \n, and \r are technically contorl characters but we treat them | |||
| # as whitespace since they are generally considered as such. | |||
| if char == ' ' or char == '\t' or char == '\n' or char == '\r': | |||
| return True | |||
| cat = unicodedata.category(char) | |||
| if cat == 'Zs': | |||
| return True | |||
| return False | |||
| def _is_control(char): | |||
| """Checks whether `chars` is a control character.""" | |||
| # These are technically control characters but we count them as whitespace | |||
| # characters. | |||
| if char == '\t' or char == '\n' or char == '\r': | |||
| return False | |||
| cat = unicodedata.category(char) | |||
| if cat.startswith('C'): | |||
| return True | |||
| return False | |||
| def _is_punctuation(char): | |||
| """Checks whether `chars` is a punctuation character.""" | |||
| cp = ord(char) | |||
| # We treat all non-letter/number ASCII as punctuation. | |||
| # Characters such as "^", "$", and "`" are not in the Unicode | |||
| # Punctuation class but we treat them as punctuation anyways, for | |||
| # consistency. | |||
| if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) | |||
| or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): | |||
| return True | |||
| cat = unicodedata.category(char) | |||
| if cat.startswith('P'): | |||
| return True | |||
| return False | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from .fp16 import * # noqa | |||
| from .fp16util import (BN_convert_float, FP16Model, clip_grad_norm, | |||
| convert_module, convert_network, | |||
| master_params_to_model_params, | |||
| model_grads_to_master_grads, network_to_half, | |||
| prep_param_lists, to_python_float, tofp16) | |||
| from .loss_scaler import * # noqa | |||
| @@ -0,0 +1,660 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Stable version of apex FP16 Optimizer""" | |||
| import torch | |||
| from torch import nn | |||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||
| from torch.autograd import Variable | |||
| from torch.nn.parameter import Parameter | |||
| from .fp16util import (clip_grad_norm, master_params_to_model_params, | |||
| model_grads_to_master_grads) | |||
| from .loss_scaler import DynamicLossScaler, LossScaler | |||
| FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) | |||
| HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) | |||
| def conversion_helper(val, conversion): | |||
| """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" | |||
| if not isinstance(val, (tuple, list)): | |||
| return conversion(val) | |||
| rtn = [conversion_helper(v, conversion) for v in val] | |||
| if isinstance(val, tuple): | |||
| rtn = tuple(rtn) | |||
| return rtn | |||
| def fp32_to_fp16(val): | |||
| """Convert fp32 `val` to fp16""" | |||
| def half_conversion(val): | |||
| val_typecheck = val | |||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||
| val_typecheck = val.data | |||
| if isinstance(val_typecheck, FLOAT_TYPES): | |||
| val = val.half() | |||
| return val | |||
| return conversion_helper(val, half_conversion) | |||
| def fp16_to_fp32(val): | |||
| """Convert fp16 `val` to fp32""" | |||
| def float_conversion(val): | |||
| val_typecheck = val | |||
| if isinstance(val_typecheck, (Parameter, Variable)): | |||
| val_typecheck = val.data | |||
| if isinstance(val_typecheck, HALF_TYPES): | |||
| val = val.float() | |||
| return val | |||
| return conversion_helper(val, float_conversion) | |||
| class FP16_Module(nn.Module): | |||
| def __init__(self, module): | |||
| super(FP16_Module, self).__init__() | |||
| self.add_module('module', module.half()) | |||
| def forward(self, *inputs, **kwargs): | |||
| return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) | |||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||
| return self.module.named_parameters(prefix=prefix, recurse=recurse) | |||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
| return self.module.state_dict(destination, prefix, keep_vars) | |||
| def load_state_dict(self, state_dict, strict=True): | |||
| return self.module.load_state_dict(state_dict, strict=strict) | |||
| # TODO: Update overflow check + downscale to use Carl's fused kernel. | |||
| class FP16_Optimizer(object): | |||
| """ | |||
| :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, | |||
| and manage static or dynamic loss scaling and master weights in a manner transparent to the user. | |||
| For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, | |||
| and changing the call to ``backward``. | |||
| Example:: | |||
| model = torch.nn.Linear(D_in, D_out).cuda().half() | |||
| optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||
| # Name the FP16_Optimizer instance to replace the existing optimizer | |||
| # (recommended but not required): | |||
| optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||
| ... | |||
| # loss.backward() becomes: | |||
| optimizer.backward(loss) | |||
| ... | |||
| Example with dynamic loss scaling:: | |||
| ... | |||
| optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) | |||
| # optional arg to control dynamic loss scaling behavior | |||
| # dynamic_loss_args={'scale_window' : 500}) | |||
| # Usually, dynamic_loss_args is not necessary. | |||
| Args: | |||
| init_optimizer (torch.optim.optimizer): Existing optimizer created with the parameters to optimize. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`. | |||
| static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale gradients computed by the model. Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. | |||
| dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. | |||
| dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. | |||
| verbose (bool, optional, default=True): By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check. If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``. ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling. | |||
| ``init_optimizer`` is expected to have been constructed in the ordinary way. | |||
| It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be | |||
| named to replace ``init_optimizer``, for two reasons: | |||
| First, it means that references to the same name | |||
| later in the file will not have to change. | |||
| Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to | |||
| modify ``init_optimizer``. If you do choose a unique name for the new | |||
| :class:`FP16_Optimizer` instance, you should only work with this new instance, | |||
| because the preexisting optimizer might no longer behave as expected. | |||
| ``init_optimizer`` may be any Pytorch optimizer. | |||
| It may contain a mixture of fp16 and fp32 parameters organized into any number of | |||
| ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will | |||
| ingest these ``param_groups`` and remember them. | |||
| Calls to :: | |||
| loss.backward() | |||
| must be replaced with :: | |||
| optimizer.backward(loss) | |||
| because :class:`FP16_Optimizer` requires ownership of the backward pass to implement | |||
| loss scaling and copies to master gradients. | |||
| .. note:: | |||
| Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients | |||
| are downscaled before being applied. This means that adjusting the loss scale, or using | |||
| dynamic loss scaling, should not require retuning the learning rate or any other | |||
| hyperparameters. | |||
| **Advanced options** | |||
| **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. | |||
| See docstring for :attr:`step`. | |||
| **Gradient clipping**: Use :attr:`clip_master_grads`. | |||
| **Multiple losses**: If your model accumulates gradients from multiple losses, | |||
| this can be made more efficient by supplying ``update_master_grads=False`` | |||
| to :attr:`backward`. See docstring for :attr:`backward`. | |||
| **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: | |||
| print(optimizer.loss_scale) | |||
| optimizer.loss_scale = new_loss_scale | |||
| For static loss scaling, manually adjusting the loss scale over time is a reasonable | |||
| thing to do. During later epochs, gradients may become smaller, and a | |||
| higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss | |||
| scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting | |||
| the loss scale is not recommended. | |||
| **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in | |||
| Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer` | |||
| should still work as intended. | |||
| """ # noqa | |||
| def __init__(self, | |||
| init_optimizer, | |||
| static_loss_scale=1.0, | |||
| dynamic_loss_scale=False, | |||
| dynamic_loss_args=None, | |||
| verbose=False): | |||
| if not torch.cuda.is_available: | |||
| raise SystemError('Cannot use fp16 without CUDA.') | |||
| self.verbose = verbose | |||
| self.optimizer = init_optimizer | |||
| # init_state_dict sets up an alternative way to cast per-param state tensors. | |||
| # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary. | |||
| # init_state_dict = init_optimizer.state_dict() | |||
| self.fp16_groups = [] | |||
| self.fp32_from_fp16_groups = [] | |||
| self.fp32_from_fp32_groups = [] | |||
| for i, param_group in enumerate(self.optimizer.param_groups): | |||
| self.maybe_print( | |||
| 'FP16_Optimizer processing param group {}:'.format(i)) | |||
| fp16_params_this_group = [] | |||
| fp32_params_this_group = [] | |||
| fp32_from_fp16_params_this_group = [] | |||
| for i, param in enumerate(param_group['params']): | |||
| if param.requires_grad: | |||
| if param.type() == 'torch.cuda.HalfTensor': | |||
| self.maybe_print( | |||
| 'FP16_Optimizer received torch.cuda.HalfTensor with {}' | |||
| .format(param.size())) | |||
| fp16_params_this_group.append(param) | |||
| master_param = param.detach().clone().float() | |||
| master_param.requires_grad = True | |||
| # Copythe model parallel flag. | |||
| master_param.model_parallel = param.model_parallel | |||
| param_group['params'][i] = master_param | |||
| fp32_from_fp16_params_this_group.append(master_param) | |||
| # Reset existing state dict key to the new master param. | |||
| # We still need to recast per-param state tensors, if any, to FP32. | |||
| if param in self.optimizer.state: | |||
| self.optimizer.state[ | |||
| master_param] = self.optimizer.state.pop(param) | |||
| elif param.type() == 'torch.cuda.FloatTensor': | |||
| self.maybe_print( | |||
| 'FP16_Optimizer received torch.cuda.FloatTensor with {}' | |||
| .format(param.size())) | |||
| fp32_params_this_group.append(param) | |||
| param_group['params'][i] = param | |||
| else: | |||
| raise TypeError( | |||
| 'Wrapped parameters must be either ' | |||
| 'torch.cuda.FloatTensor or torch.cuda.HalfTensor. ' | |||
| 'Received {}'.format(param.type())) | |||
| self.fp16_groups.append(fp16_params_this_group) | |||
| self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) | |||
| self.fp32_from_fp32_groups.append(fp32_params_this_group) | |||
| # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors | |||
| self.optimizer.load_state_dict(self.optimizer.state_dict()) | |||
| # alternative way to cast per-param state tensors: | |||
| # self.optimizer.load_state_dict(init_state_dict) | |||
| if dynamic_loss_scale: | |||
| self.dynamic_loss_scale = True | |||
| if dynamic_loss_args is not None: | |||
| self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) | |||
| else: | |||
| self.loss_scaler = DynamicLossScaler() | |||
| else: | |||
| self.dynamic_loss_scale = False | |||
| self.loss_scaler = LossScaler(static_loss_scale) | |||
| self.overflow = False | |||
| self.first_closure_call_this_step = True | |||
| self.clip_grad_norm = clip_grad_norm | |||
| def maybe_print(self, msg): | |||
| if self.verbose: | |||
| print(msg) | |||
| def __getstate__(self): | |||
| raise RuntimeError( | |||
| 'FP16_Optimizer should be serialized using state_dict().') | |||
| def __setstate__(self, state): | |||
| raise RuntimeError( | |||
| 'FP16_Optimizer should be deserialized using load_state_dict().') | |||
| def zero_grad(self, set_grads_to_None=False): | |||
| """ | |||
| Zero fp32 and fp16 parameter grads. | |||
| """ | |||
| # In principle, only the .grad attributes of the model params need to be zeroed, | |||
| # because gradients are copied into the FP32 master params. However, we zero | |||
| # all gradients owned by the optimizer, just to be safe: | |||
| for group in self.optimizer.param_groups: | |||
| for p in group['params']: | |||
| if set_grads_to_None: | |||
| p.grad = None | |||
| else: | |||
| if p.grad is not None: | |||
| p.grad.detach_() | |||
| p.grad.zero_() | |||
| # Zero fp16 gradients owned by the model: | |||
| for fp16_group in self.fp16_groups: | |||
| for param in fp16_group: | |||
| if set_grads_to_None: | |||
| param.grad = None | |||
| else: | |||
| if param.grad is not None: | |||
| param.grad.detach_( | |||
| ) # as in torch.optim.optimizer.zero_grad() | |||
| param.grad.zero_() | |||
| def _check_overflow(self): | |||
| params = [] | |||
| for group in self.fp16_groups: | |||
| for param in group: | |||
| params.append(param) | |||
| for group in self.fp32_from_fp32_groups: | |||
| for param in group: | |||
| params.append(param) | |||
| self.overflow = self.loss_scaler.has_overflow(params) | |||
| def _update_scale(self, has_overflow=False): | |||
| self.loss_scaler.update_scale(has_overflow) | |||
| def _master_params_to_model_params(self): | |||
| for fp16_group, fp32_from_fp16_group in zip( | |||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||
| master_params_to_model_params(fp16_group, fp32_from_fp16_group) | |||
| def _model_params_to_master_params(self): | |||
| for fp16_group, fp32_from_fp16_group in zip( | |||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||
| master_params_to_model_params(fp32_from_fp16_group, fp16_group) | |||
| # To consider: Integrate distributed with this wrapper by registering a hook on each variable | |||
| # that does the overflow check, gradient copy + downscale, and fp32 allreduce in a different stream. | |||
| def _model_grads_to_master_grads(self): | |||
| for fp16_group, fp32_from_fp16_group in zip( | |||
| self.fp16_groups, self.fp32_from_fp16_groups): | |||
| model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) | |||
| def _downscale_master(self): | |||
| if self.loss_scale != 1.0: | |||
| for group in self.optimizer.param_groups: | |||
| for param in group['params']: | |||
| if param.grad is not None: | |||
| param.grad.data.mul_(1. / self.loss_scale) | |||
| def clip_master_grads(self, max_norm, norm_type=2): | |||
| """ | |||
| Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. | |||
| Args: | |||
| max_norm (float or int): max norm of the gradients | |||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |||
| infinity norm. | |||
| Returns: | |||
| Total norm of the current fp32 gradients (viewed as a single vector). | |||
| .. warning:: | |||
| Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``). | |||
| """ # noqa | |||
| if not self.overflow: | |||
| fp32_params = [] | |||
| for param_group in self.optimizer.param_groups: | |||
| for param in param_group['params']: | |||
| fp32_params.append(param) | |||
| return self.clip_grad_norm(fp32_params, max_norm, norm_type) | |||
| else: | |||
| return -1 | |||
| def state_dict(self): | |||
| """ | |||
| Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. | |||
| This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict | |||
| of the contained Pytorch optimizer. | |||
| Example:: | |||
| checkpoint = {} | |||
| checkpoint['model'] = model.state_dict() | |||
| checkpoint['optimizer'] = optimizer.state_dict() | |||
| torch.save(checkpoint, "saved.pth") | |||
| """ | |||
| state_dict = {} | |||
| state_dict['loss_scaler'] = self.loss_scaler | |||
| state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale | |||
| state_dict['overflow'] = self.overflow | |||
| state_dict[ | |||
| 'first_closure_call_this_step'] = self.first_closure_call_this_step | |||
| state_dict['optimizer_state_dict'] = self.optimizer.state_dict() | |||
| state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups | |||
| return state_dict | |||
| def load_state_dict(self, state_dict): | |||
| """ | |||
| Loads a state_dict created by an earlier call to state_dict(). | |||
| If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, | |||
| whose parameters in turn came from ``model``, it is expected that the user | |||
| will call ``model.load_state_dict()`` before | |||
| ``fp16_optimizer_instance.load_state_dict()`` is called. | |||
| Example:: | |||
| model = torch.nn.Linear(D_in, D_out).cuda().half() | |||
| optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) | |||
| optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) | |||
| ... | |||
| checkpoint = torch.load("saved.pth") | |||
| model.load_state_dict(checkpoint['model']) | |||
| optimizer.load_state_dict(checkpoint['optimizer']) | |||
| """ | |||
| # I think it should actually be ok to reload the optimizer before the model. | |||
| self.loss_scaler = state_dict['loss_scaler'] | |||
| self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] | |||
| self.overflow = state_dict['overflow'] | |||
| self.first_closure_call_this_step = state_dict[ | |||
| 'first_closure_call_this_step'] | |||
| self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) | |||
| # At this point, the optimizer's references to the model's fp32 parameters are up to date. | |||
| # The optimizer's hyperparameters and internal buffers are also up to date. | |||
| # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still | |||
| # out of date. There are two options. | |||
| # 1: Refresh the master params from the model's fp16 params. | |||
| # This requires less storage but incurs precision loss. | |||
| # 2: Save and restore the fp32 master copies separately. | |||
| # We choose option 2. | |||
| # | |||
| # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device | |||
| # of their associated parameters, because it's possible those buffers might not exist yet in | |||
| # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been | |||
| # constructed in the same way as the one whose state_dict we are loading, the same master params | |||
| # are guaranteed to exist, so we can just copy_() from the saved master params. | |||
| for current_group, saved_group in zip(self.fp32_from_fp16_groups, | |||
| state_dict['fp32_from_fp16']): | |||
| for current, saved in zip(current_group, saved_group): | |||
| current.data.copy_(saved.data) | |||
| def step(self, closure=None): # could add clip option. | |||
| """ | |||
| If no closure is supplied, :attr:`step` should be called after | |||
| ``fp16_optimizer_obj.backward(loss)``. | |||
| :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to | |||
| :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params | |||
| originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run | |||
| another forward pass using their model. | |||
| If a closure is supplied, :attr:`step` may be called without a prior call to | |||
| :attr:`backward(loss)`. | |||
| This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. | |||
| However, the user should take care that any ``loss.backward()`` call within the closure | |||
| has been replaced by ``fp16_optimizer_obj.backward(loss)``. | |||
| Args: | |||
| closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. | |||
| Example with closure:: | |||
| # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an | |||
| # existing pytorch optimizer. | |||
| for input, target in dataset: | |||
| def closure(): | |||
| optimizer.zero_grad() | |||
| output = model(input) | |||
| loss = loss_fn(output, target) | |||
| # loss.backward() becomes: | |||
| optimizer.backward(loss) | |||
| return loss | |||
| optimizer.step(closure) | |||
| .. warning:: | |||
| Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling. | |||
| .. _`ordinary Pytorch optimizer use`: | |||
| http://pytorch.org/docs/master/optim.html#optimizer-step-closure | |||
| """ # noqa | |||
| scale = self.loss_scaler.loss_scale | |||
| self._update_scale(self.overflow) | |||
| if self.overflow: | |||
| self.maybe_print( | |||
| 'OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}' | |||
| .format(scale, self.loss_scale)) | |||
| return | |||
| if closure is not None: | |||
| retval = self._step_with_closure(closure) | |||
| else: | |||
| retval = self.optimizer.step() | |||
| self._master_params_to_model_params() | |||
| return retval | |||
| def _step_with_closure(self, closure): | |||
| def wrapped_closure(): | |||
| # helpful for debugging | |||
| # print("Calling wrapped_closure, first_closure_call_this_step = {}" | |||
| # .format(self.first_closure_call_this_step)) | |||
| if self.first_closure_call_this_step: | |||
| # We expect that the fp16 params are initially fresh on entering self.step(), | |||
| # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() | |||
| # is called within self.optimizer.step(). | |||
| self.first_closure_call_this_step = False | |||
| else: | |||
| # If self.optimizer.step() internally calls wrapped_closure more than once, | |||
| # it may update the fp32 params after each call. However, self.optimizer | |||
| # doesn't know about the fp16 params at all. If the fp32 params get updated, | |||
| # we can't rely on self.optimizer to refresh the fp16 params. We need | |||
| # to handle that manually: | |||
| self._master_params_to_model_params() | |||
| # Our API expects the user to give us ownership of the backward() call by | |||
| # replacing all calls to loss.backward() with optimizer.backward(loss). | |||
| # This requirement holds whether or not the call to backward() is made within a closure. | |||
| # If the user is properly calling optimizer.backward(loss) within "closure," | |||
| # calling closure() here will give the fp32 master params fresh gradients | |||
| # for the optimizer to play with, so all wrapped_closure needs to do is call | |||
| # closure() and return the loss. | |||
| temp_loss = closure() | |||
| while (self.overflow): | |||
| scale = self.loss_scaler.loss_scale | |||
| self._update_scale(self.overflow) | |||
| self.maybe_print( | |||
| 'OVERFLOW within closure! Skipping step. Attempted loss scale: {}, ' | |||
| 'reducing to {}'.format(scale, self.loss_scale)) | |||
| temp_loss = closure() | |||
| return temp_loss | |||
| retval = self.optimizer.step(wrapped_closure) | |||
| self.first_closure_call_this_step = True | |||
| return retval | |||
| def backward(self, loss, update_master_grads=True, retain_graph=False): | |||
| """ | |||
| :attr:`backward` performs the following conceptual steps: | |||
| 1. fp32_loss = loss.float() (see first Note below) | |||
| 2. scaled_loss = fp32_loss*loss_scale | |||
| 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). | |||
| 4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32. | |||
| 5. Finally, master grads are divided by loss_scale. | |||
| In this way, after :attr:`backward`, the master params have fresh gradients, | |||
| and :attr:`step` may be called. | |||
| .. note:: | |||
| :attr:`backward` internally converts the loss to fp32 before applying the loss scale. | |||
| This provides some additional safety against overflow if the user has supplied an | |||
| fp16 loss value. | |||
| However, for maximum overflow safety, the user should | |||
| compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to | |||
| :attr:`backward`. | |||
| .. warning:: | |||
| The gradients found in a model's leaves after the call to | |||
| :attr:`backward` should not be regarded as valid in general, | |||
| because it's possible | |||
| they have been scaled (and in the case of dynamic loss scaling, | |||
| the scale factor may change over time). | |||
| If the user wants to inspect gradients after a call to :attr:`backward`, | |||
| only the master gradients should be regarded as valid. These can be retrieved via | |||
| :attr:`inspect_master_grad_data()`. | |||
| Args: | |||
| loss: The loss output by the user's model. loss may be either float or half (but see first Note above). | |||
| update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. | |||
| retain_graph (bool, optional, default=False): Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``. If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below). | |||
| Example:: | |||
| # Ordinary operation: | |||
| optimizer.backward(loss) | |||
| # Naive operation with multiple losses (technically valid, but less efficient): | |||
| # fp32 grads will be correct after the second call, but | |||
| # the first call incurs an unnecessary fp16->fp32 grad copy. | |||
| optimizer.backward(loss1) | |||
| optimizer.backward(loss2) | |||
| # More efficient way to handle multiple losses: | |||
| # The fp16->fp32 grad copy is delayed until fp16 grads from all | |||
| # losses have been accumulated. | |||
| optimizer.backward(loss1, update_master_grads=False) | |||
| optimizer.backward(loss2, update_master_grads=False) | |||
| optimizer.update_master_grads() | |||
| """ # noqa | |||
| # To consider: try multiple backward passes using retain_grad=True to find | |||
| # a loss scale that works. After you find a loss scale that works, do a final dummy | |||
| # backward pass with retain_graph=False to tear down the graph. Doing this would avoid | |||
| # discarding the iteration, but probably wouldn't improve overall efficiency. | |||
| self.loss_scaler.backward(loss.float(), retain_graph=retain_graph) | |||
| if update_master_grads: | |||
| self.update_master_grads() | |||
| def update_master_grads(self): | |||
| """ | |||
| Copy the ``.grad`` attribute from stored references to fp16 parameters to | |||
| the ``.grad`` attribute of the fp32 master parameters that are directly | |||
| updated by the optimizer. :attr:`update_master_grads` only needs to be called if | |||
| ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. | |||
| """ # noqa | |||
| if self.dynamic_loss_scale: | |||
| self._check_overflow() | |||
| if self.overflow: return # noqa | |||
| self._model_grads_to_master_grads() | |||
| self._downscale_master() | |||
| def inspect_master_grad_data(self): | |||
| """ | |||
| When running with :class:`FP16_Optimizer`, | |||
| ``.grad`` attributes of a model's fp16 leaves should not be | |||
| regarded as truthful, because they might be scaled. | |||
| After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, | |||
| the fp32 master params' ``.grad`` | |||
| attributes will contain valid gradients properly divided by the loss scale. However, | |||
| because :class:`FP16_Optimizer` flattens some parameters, accessing them may be | |||
| nonintuitive. :attr:`inspect_master_grad_data` | |||
| allows those gradients to be viewed with shapes corresponding to their associated model leaves. | |||
| Returns: | |||
| List of lists (one list for each parameter group). The list for each parameter group | |||
| is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. | |||
| """ | |||
| if self.overflow: | |||
| print( | |||
| 'Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. ' | |||
| 'Gradients are currently invalid (may be inf, nan, or stale). Returning None.' | |||
| ) | |||
| return None | |||
| else: | |||
| # The optimizer owns only references to master params. | |||
| master_grads_data = [] | |||
| for param_group in self.optimizer.param_groups: | |||
| master_grads_this_group = [] | |||
| for param in param_group['params']: | |||
| if param.grad is not None: | |||
| master_grads_this_group.append(param.grad.data) | |||
| else: | |||
| master_grads_this_group.append(None) | |||
| master_grads_data.append(master_grads_this_group) | |||
| return master_grads_data | |||
| # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale" | |||
| def _get_loss_scale(self): | |||
| return self.loss_scaler.loss_scale | |||
| def _set_loss_scale(self, value): | |||
| self.loss_scaler.cur_scale = value | |||
| loss_scale = property(_get_loss_scale, _set_loss_scale) | |||
| # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" | |||
| def _get_state(self): | |||
| return self.optimizer.state | |||
| def _set_state(self, value): | |||
| self.optimizer.state = value | |||
| state = property(_get_state, _set_state) | |||
| # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" | |||
| # (for example, to adjust the learning rate) | |||
| def _get_param_groups(self): | |||
| return self.optimizer.param_groups | |||
| def _set_param_groups(self, value): | |||
| self.optimizer.param_groups = value | |||
| param_groups = property(_get_param_groups, _set_param_groups) | |||
| @@ -0,0 +1,220 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| import torch.nn as nn | |||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||
| from torch.autograd import Variable | |||
| from modelscope.models.nlp.mglm import mpu | |||
| class tofp16(nn.Module): | |||
| """ | |||
| Utility module that implements:: | |||
| def forward(self, input): | |||
| return input.half() | |||
| """ | |||
| def __init__(self): | |||
| super(tofp16, self).__init__() | |||
| def forward(self, input): | |||
| return input.half() | |||
| def BN_convert_float(module): | |||
| """ | |||
| Utility function for network_to_half(). | |||
| Retained for legacy purposes. | |||
| """ | |||
| if isinstance( | |||
| module, | |||
| torch.nn.modules.batchnorm._BatchNorm) and module.affine is True: | |||
| module.float() | |||
| for child in module.children(): | |||
| BN_convert_float(child) | |||
| return module | |||
| def network_to_half(network): | |||
| """ | |||
| Convert model to half precision in a batchnorm-safe way. | |||
| Retained for legacy purposes. It is recommended to use FP16Model. | |||
| """ | |||
| return nn.Sequential(tofp16(), BN_convert_float(network.half())) | |||
| def convert_module(module, dtype): | |||
| """ | |||
| Converts a module's immediate parameters and buffers to dtype. | |||
| """ | |||
| for param in module.parameters(recurse=False): | |||
| if param is not None: | |||
| if param.data.dtype.is_floating_point: | |||
| param.data = param.data.to(dtype=dtype) | |||
| if param._grad is not None and param._grad.data.dtype.is_floating_point: | |||
| param._grad.data = param._grad.data.to(dtype=dtype) | |||
| for buf in module.buffers(recurse=False): | |||
| if buf is not None and buf.data.dtype.is_floating_point: | |||
| buf.data = buf.data.to(dtype=dtype) | |||
| def convert_network(network, dtype): | |||
| """ | |||
| Converts a network's parameters and buffers to dtype. | |||
| """ | |||
| for module in network.modules(): | |||
| if isinstance(module, torch.nn.modules.batchnorm._BatchNorm | |||
| ) and module.affine is True: | |||
| continue | |||
| convert_module(module, dtype) | |||
| return network | |||
| class FP16Model(nn.Module): | |||
| """ | |||
| Convert model to half precision in a batchnorm-safe way. | |||
| """ | |||
| def __init__(self, network): | |||
| super(FP16Model, self).__init__() | |||
| self.network = convert_network(network, dtype=torch.half) | |||
| def forward(self, *inputs): | |||
| inputs = tuple(t.half() for t in inputs) | |||
| return self.network(*inputs) | |||
| def backwards_debug_hook(grad): | |||
| raise RuntimeError( | |||
| 'master_params recieved a gradient in the backward pass!') | |||
| def prep_param_lists(model, flat_master=False): | |||
| """ | |||
| Creates a list of FP32 master parameters for a given model, as in | |||
| `Training Neural Networks with Mixed Precision: Real Examples`_. | |||
| Args: | |||
| model (torch.nn.Module): Existing Pytorch model | |||
| flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. | |||
| Returns: | |||
| A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. | |||
| Example:: | |||
| model_params, master_params = prep_param_lists(model) | |||
| .. warning:: | |||
| Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. | |||
| .. _`Training Neural Networks with Mixed Precision: Real Examples`: | |||
| http://on-demand.gputechconf.com/gtc/2018/video/S81012/ | |||
| """ # noqa | |||
| model_params = [ | |||
| param for param in model.parameters() if param.requires_grad | |||
| ] | |||
| if flat_master: | |||
| # Give the user some more useful error messages | |||
| try: | |||
| # flatten_dense_tensors returns a contiguous flat array. | |||
| # http://pytorch.org/docs/master/_modules/torch/_utils.html | |||
| master_params = _flatten_dense_tensors( | |||
| [param.data for param in model_params]).float() | |||
| except: # noqa | |||
| print( | |||
| 'Error in prep_param_lists: model may contain a mixture of parameters ' | |||
| 'of different types. Use flat_master=False, or use F16_Optimizer.' | |||
| ) | |||
| raise | |||
| master_params = torch.nn.Parameter(master_params) | |||
| master_params.requires_grad = True | |||
| # master_params.register_hook(backwards_debug_hook) | |||
| if master_params.grad is None: | |||
| master_params.grad = master_params.new(*master_params.size()) | |||
| return model_params, [master_params] | |||
| else: | |||
| master_params = [ | |||
| param.clone().float().detach() for param in model_params | |||
| ] | |||
| for param in master_params: | |||
| param.requires_grad = True | |||
| return model_params, master_params | |||
| def model_grads_to_master_grads(model_params, | |||
| master_params, | |||
| flat_master=False): | |||
| """ | |||
| Copy model gradients to master gradients. | |||
| Args: | |||
| model_params: List of model parameters created by :func:`prep_param_lists`. | |||
| master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. | |||
| """ # noqa | |||
| if flat_master: | |||
| # The flattening may incur one more deep copy than is necessary. | |||
| master_params[0].grad.data.copy_( | |||
| _flatten_dense_tensors([p.grad.data for p in model_params])) | |||
| else: | |||
| for model, master in zip(model_params, master_params): | |||
| if model.grad is not None: | |||
| if master.grad is None: | |||
| master.grad = Variable( | |||
| master.data.new(*master.data.size())) | |||
| master.grad.data.copy_(model.grad.data) | |||
| else: | |||
| master.grad = None | |||
| def master_params_to_model_params(model_params, | |||
| master_params, | |||
| flat_master=False): | |||
| """ | |||
| Copy master parameters to model parameters. | |||
| Args: | |||
| model_params: List of model parameters created by :func:`prep_param_lists`. | |||
| master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. | |||
| """ # noqa | |||
| if flat_master: | |||
| for model, master in zip( | |||
| model_params, | |||
| _unflatten_dense_tensors(master_params[0].data, model_params)): | |||
| model.data.copy_(master) | |||
| else: | |||
| for model, master in zip(model_params, master_params): | |||
| model.data.copy_(master.data) | |||
| # Backward compatibility fixes | |||
| def to_python_float(t): | |||
| if hasattr(t, 'item'): | |||
| return t.item() | |||
| else: | |||
| return t[0] | |||
| TORCH_MAJOR = int(torch.__version__.split('.')[0]) | |||
| TORCH_MINOR = int(torch.__version__.split('.')[1]) | |||
| clip_grad_norm = mpu.clip_grad_norm | |||
| @@ -0,0 +1,245 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| from modelscope.models.nlp.mglm import mpu | |||
| # item() is a recent addition, so this helps with backward compatibility. | |||
| def to_python_float(t): | |||
| if hasattr(t, 'item'): | |||
| return t.item() | |||
| else: | |||
| return t[0] | |||
| class LossScaler: | |||
| """ | |||
| Class that manages a static loss scale. This class is intended to interact with | |||
| :class:`FP16_Optimizer`, and should not be directly manipulated by the user. | |||
| Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to | |||
| :class:`FP16_Optimizer`'s constructor. | |||
| Args: | |||
| scale (float, optional, default=1.0): The loss scale. | |||
| """ | |||
| def __init__(self, scale=1): | |||
| self.cur_scale = scale | |||
| # `params` is a list / generator of torch.Variable | |||
| def has_overflow(self, params): | |||
| return False | |||
| # `x` is a torch.Tensor | |||
| def _has_inf_or_nan(x): | |||
| return False | |||
| def update_scale(self, overflow): | |||
| pass | |||
| @property | |||
| def loss_scale(self): | |||
| return self.cur_scale | |||
| def scale_gradient(self, module, grad_in, grad_out): | |||
| return tuple(self.loss_scale * g for g in grad_in) | |||
| def backward(self, loss, retain_graph=False): | |||
| scaled_loss = loss * self.loss_scale | |||
| scaled_loss.backward(retain_graph=retain_graph) | |||
| class DynamicLossScaler: | |||
| """ | |||
| Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` | |||
| indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of | |||
| :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` | |||
| operates, because the default options can be changed using the | |||
| the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. | |||
| Loss scaling is designed to combat the problem of underflowing gradients encountered at long | |||
| times when training fp16 networks. Dynamic loss scaling begins by attempting a very high loss | |||
| scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are | |||
| encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has | |||
| occurred. | |||
| :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, | |||
| and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. | |||
| If a certain number of iterations occur without overflowing gradients detected, | |||
| :class:`DynamicLossScaler` increases the loss scale once more. | |||
| In this way :class:`DynamicLossScaler` attempts to "ride the edge" of | |||
| always using the highest loss scale possible without incurring overflow. | |||
| Args: | |||
| init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` | |||
| scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. | |||
| scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. | |||
| """ # noqa | |||
| def __init__(self, | |||
| init_scale=2**32, | |||
| scale_factor=2., | |||
| scale_window=1000, | |||
| min_scale=1, | |||
| delayed_shift=1, | |||
| consecutive_hysteresis=False): | |||
| self.cur_scale = init_scale | |||
| self.cur_iter = 0 | |||
| self.last_overflow_iter = -1 | |||
| self.scale_factor = scale_factor | |||
| self.scale_window = scale_window | |||
| self.min_scale = min_scale | |||
| self.delayed_shift = delayed_shift | |||
| self.cur_hysteresis = delayed_shift | |||
| self.consecutive_hysteresis = consecutive_hysteresis | |||
| # `params` is a list / generator of torch.Variable | |||
| def has_overflow_serial(self, params): | |||
| for p in params: | |||
| if p.grad is not None and DynamicLossScaler._has_inf_or_nan( | |||
| p.grad.data): | |||
| return True | |||
| return False | |||
| def has_overflow(self, params): | |||
| overflow = self.has_overflow_serial(params) | |||
| # Since each model parallel GPU carries only part of the model, | |||
| # make sure overflow flag is synced across all the model parallel GPUs | |||
| overflow_gpu = torch.cuda.ByteTensor([overflow]) | |||
| torch.distributed.all_reduce( | |||
| overflow_gpu, | |||
| op=torch.distributed.ReduceOp.MAX, | |||
| group=mpu.get_model_parallel_group()) | |||
| overflow = overflow_gpu[0].item() | |||
| return bool(overflow) | |||
| # `x` is a torch.Tensor | |||
| def _has_inf_or_nan(x): | |||
| try: | |||
| # if x is half, the .float() incurs an additional deep copy, but it's necessary if | |||
| # Pytorch's .sum() creates a one-element tensor of the same type as x | |||
| # (which is true for some recent version of pytorch). | |||
| cpu_sum = float(x.float().sum()) | |||
| # More efficient version that can be used if .sum() returns a Python scalar | |||
| # cpu_sum = float(x.sum()) | |||
| except RuntimeError as instance: | |||
| # We want to check if inst is actually an overflow exception. | |||
| # RuntimeError could come from a different error. | |||
| # If so, we still want the exception to propagate. | |||
| if 'value cannot be converted' not in instance.args[0]: | |||
| raise | |||
| return True | |||
| else: | |||
| if cpu_sum == float( | |||
| 'inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: | |||
| return True | |||
| return False | |||
| # `overflow` is boolean indicating whether the gradient overflowed | |||
| def update_scale(self, overflow): | |||
| if not hasattr(self, 'min_scale'): | |||
| self.min_scale = 1 | |||
| if not hasattr(self, 'delayed_shift'): | |||
| self.delayed_shift = 1 | |||
| if not hasattr(self, 'cur_hysteresis'): | |||
| self.cur_hysteresis = 1 | |||
| if not hasattr(self, 'consecutive_hysteresis'): | |||
| self.consecutive_hysteresis = True | |||
| if overflow: | |||
| # self.cur_scale /= self.scale_factor | |||
| if self.delayed_shift == 1 or self.cur_hysteresis == 1: | |||
| self.cur_scale = max(self.cur_scale / self.scale_factor, | |||
| self.min_scale) | |||
| else: | |||
| self.cur_hysteresis -= 1 | |||
| self.last_overflow_iter = self.cur_iter | |||
| else: | |||
| if self.consecutive_hysteresis: | |||
| self.cur_hysteresis = self.delayed_shift | |||
| if (self.cur_iter | |||
| - self.last_overflow_iter) % self.scale_window == 0: | |||
| if not self.consecutive_hysteresis: | |||
| self.cur_hysteresis = self.delayed_shift | |||
| self.cur_scale *= self.scale_factor | |||
| self.cur_iter += 1 | |||
| @property | |||
| def loss_scale(self): | |||
| return self.cur_scale | |||
| def scale_gradient(self, module, grad_in, grad_out): | |||
| return tuple(self.loss_scale * g for g in grad_in) | |||
| def backward(self, loss, retain_graph=False): | |||
| scaled_loss = loss * self.loss_scale | |||
| scaled_loss.backward(retain_graph=retain_graph) | |||
| ############################################################## | |||
| # Example usage below here -- assuming it's in a separate file | |||
| ############################################################## | |||
| """ | |||
| TO-DO separate out into an example. | |||
| if __name__ == "__main__": | |||
| import torch | |||
| from torch.autograd import Variable | |||
| from dynamic_loss_scaler import DynamicLossScaler | |||
| # N is batch size; D_in is input dimension; | |||
| # H is hidden dimension; D_out is output dimension. | |||
| N, D_in, H, D_out = 64, 1000, 100, 10 | |||
| # Create random Tensors to hold inputs and outputs, and wrap them in Variables. | |||
| x = Variable(torch.randn(N, D_in), requires_grad=False) | |||
| y = Variable(torch.randn(N, D_out), requires_grad=False) | |||
| w1 = Variable(torch.randn(D_in, H), requires_grad=True) | |||
| w2 = Variable(torch.randn(H, D_out), requires_grad=True) | |||
| parameters = [w1, w2] | |||
| learning_rate = 1e-6 | |||
| optimizer = torch.optim.SGD(parameters, lr=learning_rate) | |||
| loss_scaler = DynamicLossScaler() | |||
| for t in range(500): | |||
| y_pred = x.mm(w1).clamp(min=0).mm(w2) | |||
| loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale | |||
| print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) | |||
| print('Iter {} scaled loss: {}'.format(t, loss.data[0])) | |||
| print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) | |||
| # Run backprop | |||
| optimizer.zero_grad() | |||
| loss.backward() | |||
| # Check for overflow | |||
| has_overflow = DynamicLossScaler.has_overflow(parameters) | |||
| # If no overflow, unscale grad and update as usual | |||
| if not has_overflow: | |||
| for param in parameters: | |||
| param.grad.data.mul_(1. / loss_scaler.loss_scale) | |||
| optimizer.step() | |||
| # Otherwise, don't do anything -- ie, skip iteration | |||
| else: | |||
| print('OVERFLOW!') | |||
| # Update loss scale for next iteration | |||
| loss_scaler.update_scale(has_overflow) | |||
| """ | |||
| @@ -0,0 +1,483 @@ | |||
| # Copyright 2020 The HuggingFace Inc. team | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from abc import ABC, abstractmethod | |||
| from collections import UserDict | |||
| from typing import Iterable, List, Optional, Tuple | |||
| import torch | |||
| PROCESS_INPUTS_DOCSTRING = r""" | |||
| Args: | |||
| input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): | |||
| Indices of input sequence tokens in the vocabulary. | |||
| Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See | |||
| :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
| details. | |||
| `What are input IDs? <../glossary.html#input-ids>`__ | |||
| next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |||
| Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses. | |||
| next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |||
| :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses. | |||
| next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`): | |||
| Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond. | |||
| pad_token_id (:obj:`int`, `optional`): | |||
| The id of the `padding` token. | |||
| eos_token_id (:obj:`int`, `optional`): | |||
| The id of the `end-of-sequence` token. | |||
| Return: | |||
| :obj:`UserDict`: A dictionary composed of the fields as defined above: | |||
| - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated | |||
| scores of all non-finished beams. | |||
| - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens | |||
| to be added to the non-finished beam_hypotheses. | |||
| - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices | |||
| indicating to which beam the next tokens shall be added. | |||
| """ | |||
| FINALIZE_INPUTS_DOCSTRING = r""" | |||
| Args: | |||
| input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`): | |||
| Indices of input sequence tokens in the vocabulary. | |||
| Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See | |||
| :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for | |||
| details. | |||
| `What are input IDs? <../glossary.html#input-ids>`__ | |||
| final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |||
| The final scores of all non-finished beams. | |||
| final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |||
| The last tokens to be added to the non-finished beam_hypotheses. | |||
| final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`): | |||
| The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added. | |||
| pad_token_id (:obj:`int`, `optional`): | |||
| The id of the `padding` token. | |||
| eos_token_id (:obj:`int`, `optional`): | |||
| The id of the `end-of-sequence` token. | |||
| Return: | |||
| :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated | |||
| sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all | |||
| batches finished early due to the :obj:`eos_token_id`. | |||
| """ | |||
| class BeamScorer(ABC): | |||
| """ | |||
| Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and | |||
| :meth:`~transformers.PretrainedModel.beam_sample`. | |||
| """ | |||
| @abstractmethod | |||
| def process(self, input_ids: torch.LongTensor, | |||
| next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, | |||
| next_indices: torch.LongTensor, | |||
| **kwargs) -> Tuple[torch.Tensor]: | |||
| raise NotImplementedError('This is an abstract method.') | |||
| @abstractmethod | |||
| def finalize(self, input_ids: torch.LongTensor, | |||
| next_scores: torch.FloatTensor, next_tokens: torch.LongTensor, | |||
| next_indices: torch.LongTensor, **kwargs) -> torch.LongTensor: | |||
| raise NotImplementedError('This is an abstract method.') | |||
| class BeamSearchScorer(BeamScorer): | |||
| r""" | |||
| :class:`transformers.BeamScorer` implementing standard beam search decoding. | |||
| Adapted in part from `Facebook's XLM beam search code | |||
| <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. | |||
| Args: | |||
| batch_size (:obj:`int`): | |||
| Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. | |||
| max_length (:obj:`int`): | |||
| The maximum length of the sequence to be generated. | |||
| num_beams (:obj:`int`): | |||
| Number of beams for beam search. | |||
| device (:obj:`torch.device`): | |||
| Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of | |||
| :obj:`BeamSearchScorer` will be allocated. | |||
| length_penalty (:obj:`float`, `optional`, defaults to 1.0): | |||
| Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the | |||
| model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer | |||
| sequences. | |||
| do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||
| Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. | |||
| num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): | |||
| The number of beam hypotheses that shall be returned upon calling | |||
| :meth:`~transformer.BeamSearchScorer.finalize`. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| batch_size: int, | |||
| max_length: int, | |||
| num_beams: int, | |||
| device: torch.device, | |||
| length_penalty: Optional[float] = 1.0, | |||
| do_early_stopping: Optional[bool] = False, | |||
| num_beam_hyps_to_keep: Optional[int] = 1, | |||
| ): | |||
| self.max_length = max_length | |||
| self.num_beams = num_beams | |||
| self.device = device | |||
| self.length_penalty = length_penalty | |||
| self.do_early_stopping = do_early_stopping | |||
| self.num_beam_hyps_to_keep = num_beam_hyps_to_keep | |||
| self._is_init = False | |||
| self._beam_hyps = [ | |||
| BeamHypotheses( | |||
| num_beams=self.num_beams, | |||
| max_length=self.max_length, | |||
| length_penalty=self.length_penalty, | |||
| early_stopping=self.do_early_stopping, | |||
| ) for _ in range(batch_size) | |||
| ] | |||
| self._done = torch.tensor([False for _ in range(batch_size)], | |||
| dtype=torch.bool, | |||
| device=self.device) | |||
| # if not isinstance(num_beams, int) or num_beams <= 1: | |||
| # raise ValueError( | |||
| # ) | |||
| @property | |||
| def is_done(self) -> bool: | |||
| return self._done.all() | |||
| def process(self, | |||
| input_ids: torch.LongTensor, | |||
| next_scores: torch.FloatTensor, | |||
| next_tokens: torch.LongTensor, | |||
| next_indices: torch.LongTensor, | |||
| pad_token_id: Optional[int] = None, | |||
| eos_token_id: Optional[int] = None, | |||
| mems=None) -> Tuple[torch.Tensor]: | |||
| cur_len = input_ids.shape[-1] | |||
| batch_size = len(self._beam_hyps) | |||
| assert batch_size == (input_ids.shape[0] // self.num_beams) | |||
| if isinstance(eos_token_id, int): | |||
| eos_token_id = [eos_token_id] | |||
| device = next_scores.device | |||
| next_beam_scores = torch.zeros((batch_size, self.num_beams), | |||
| dtype=next_scores.dtype, | |||
| device=device) | |||
| next_beam_tokens = torch.zeros((batch_size, self.num_beams), | |||
| dtype=next_tokens.dtype, | |||
| device=device) | |||
| next_beam_indices = torch.zeros((batch_size, self.num_beams), | |||
| dtype=next_indices.dtype, | |||
| device=device) | |||
| for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |||
| if self._done[batch_idx]: | |||
| assert ( | |||
| len(beam_hyp) >= self.num_beams | |||
| ), 'Batch can only be done if at least {} beams have been generated'.format( | |||
| self.num_beams) | |||
| assert ( | |||
| eos_token_id is not None and pad_token_id is not None | |||
| ), 'generated beams >= num_beams -> eos_token_id and pad_token have to be defined' | |||
| # pad the batch | |||
| next_beam_scores[batch_idx, :] = 0 | |||
| next_beam_tokens[batch_idx, :] = pad_token_id | |||
| next_beam_indices[batch_idx, :] = 0 | |||
| continue | |||
| # next tokens for this sentence | |||
| beam_idx = 0 | |||
| for beam_token_rank, (next_token, next_score, | |||
| next_index) in enumerate( | |||
| zip(next_tokens[batch_idx], | |||
| next_scores[batch_idx], | |||
| next_indices[batch_idx])): | |||
| batch_beam_idx = batch_idx * self.num_beams + next_index | |||
| # add to generated hypotheses if end of sentence | |||
| if (eos_token_id is not None) and (next_token.item() | |||
| in eos_token_id): | |||
| # if beam_token does not belong to top num_beams tokens, it should not be added | |||
| is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams | |||
| if is_beam_token_worse_than_top_num_beams: | |||
| continue | |||
| beam_hyp.add( | |||
| input_ids[batch_beam_idx].clone(), | |||
| next_score.item(), | |||
| mems=[mem[[next_index.item()]] | |||
| for mem in mems] if mems else None) | |||
| else: | |||
| # add next predicted token since it is not eos_token | |||
| next_beam_scores[batch_idx, beam_idx] = next_score | |||
| next_beam_tokens[batch_idx, beam_idx] = next_token | |||
| next_beam_indices[batch_idx, beam_idx] = batch_beam_idx | |||
| beam_idx += 1 | |||
| # once the beam for next step is full, don't add more tokens to it. | |||
| if beam_idx == self.num_beams: | |||
| break | |||
| if beam_idx < self.num_beams: | |||
| raise ValueError( | |||
| f'At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected.' # noqa | |||
| ) # noqa | |||
| # Check if we are done so that we can save a pad step if all(done) | |||
| self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( | |||
| next_scores[batch_idx].max().item(), cur_len) | |||
| return UserDict({ | |||
| 'next_beam_scores': next_beam_scores.view(-1), | |||
| 'next_beam_tokens': next_beam_tokens.view(-1), | |||
| 'next_beam_indices': next_beam_indices.view(-1), | |||
| }) | |||
| def finalize(self, | |||
| input_ids: torch.LongTensor, | |||
| final_beam_scores: torch.FloatTensor, | |||
| final_beam_tokens: torch.LongTensor, | |||
| final_beam_indices: torch.LongTensor, | |||
| pad_token_id: Optional[int] = None, | |||
| eos_token_id: Optional[int] = None, | |||
| mems=None) -> Tuple[torch.LongTensor, List[torch.Tensor]]: | |||
| batch_size = len(self._beam_hyps) | |||
| # finalize all open beam hypotheses and add to generated hypotheses | |||
| for batch_idx, beam_hyp in enumerate(self._beam_hyps): | |||
| if self._done[batch_idx]: | |||
| continue | |||
| # need to add best num_beams hypotheses to generated hyps | |||
| for beam_id in range(self.num_beams): | |||
| batch_beam_idx = batch_idx * self.num_beams + beam_id | |||
| final_score = final_beam_scores[batch_beam_idx].item() | |||
| final_tokens = input_ids[batch_beam_idx] | |||
| beam_hyp.add( | |||
| final_tokens, | |||
| final_score, | |||
| mems=[mem[[batch_beam_idx]] | |||
| for mem in mems] if mems else None) | |||
| # select the best hypotheses | |||
| sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) | |||
| best = [] | |||
| # retrieve best hypotheses | |||
| for i, beam_hyp in enumerate(self._beam_hyps): | |||
| sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0]) | |||
| for j in range(self.num_beam_hyps_to_keep): | |||
| best_hyp, mems = sorted_hyps.pop()[1:] | |||
| sent_lengths[self.num_beam_hyps_to_keep * i | |||
| + j] = len(best_hyp) | |||
| best.append((best_hyp, mems)) | |||
| # prepare for adding eos | |||
| sent_max_len = min(sent_lengths.max().item(), self.max_length) | |||
| decoded: torch.LongTensor = input_ids.new( | |||
| batch_size * self.num_beam_hyps_to_keep, sent_max_len) | |||
| # shorter batches are padded if needed | |||
| if sent_lengths.min().item() != sent_lengths.max().item(): | |||
| assert pad_token_id is not None, '`pad_token_id` has to be defined' | |||
| decoded.fill_(pad_token_id) | |||
| # fill with hypotheses and eos_token_id if the latter fits in | |||
| mems = [] | |||
| for i, (hypo, mem) in enumerate(best): | |||
| decoded[i, :sent_lengths[i]] = hypo | |||
| if sent_lengths[i] < sent_max_len: | |||
| decoded[i, sent_lengths[i]] = eos_token_id | |||
| mems.append(mem) | |||
| mems = [ | |||
| torch.cat([mem[i] for mem in mems], dim=0) | |||
| for i in range(len(mems[0])) | |||
| ] if mems and mems[0] else None | |||
| return decoded, mems | |||
| class BeamHypotheses: | |||
| def __init__(self, num_beams: int, max_length: int, length_penalty: float, | |||
| early_stopping: bool): | |||
| """ | |||
| Initialize n-best list of hypotheses. | |||
| """ | |||
| self.max_length = max_length - 1 # ignoring bos_token | |||
| self.length_penalty = length_penalty | |||
| self.early_stopping = early_stopping | |||
| self.num_beams = num_beams | |||
| self.beams = [] | |||
| self.worst_score = 1e9 | |||
| def __len__(self): | |||
| """ | |||
| Number of hypotheses in the list. | |||
| """ | |||
| return len(self.beams) | |||
| def add(self, hyp: torch.LongTensor, sum_logprobs: float, mems=None): | |||
| """ | |||
| Add a new hypothesis to the list. | |||
| """ | |||
| score = sum_logprobs / (max(hyp.shape[-1], 1)**self.length_penalty) | |||
| if len(self) < self.num_beams or score > self.worst_score: | |||
| self.beams.append((score, hyp, mems)) | |||
| if len(self) > self.num_beams: | |||
| sorted_next_scores = sorted([ | |||
| (s, idx) for idx, (s, _, _) in enumerate(self.beams) | |||
| ]) | |||
| del self.beams[sorted_next_scores[0][1]] | |||
| self.worst_score = sorted_next_scores[1][0] | |||
| else: | |||
| self.worst_score = min(score, self.worst_score) | |||
| def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: | |||
| """ | |||
| If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst | |||
| one in the heap, then we are done with this sentence. | |||
| """ | |||
| if len(self) < self.num_beams: | |||
| return False | |||
| elif self.early_stopping: | |||
| return True | |||
| else: | |||
| cur_score = best_sum_logprobs / cur_len**self.length_penalty | |||
| ret = self.worst_score >= cur_score | |||
| return ret | |||
| class LogitsProcessor(ABC): | |||
| """Abstract base class for all logit processors that can be applied during generation.""" | |||
| def __call__(self, input_ids: torch.LongTensor, | |||
| scores: torch.FloatTensor) -> torch.FloatTensor: | |||
| """Torch method for processing logits.""" | |||
| raise NotImplementedError( | |||
| f'{self.__class__} is an abstract class. Only classes inheriting this class can be called.' | |||
| ) | |||
| class LogitsProcessorList(list): | |||
| """ | |||
| This class can be used to create a list of :class:`~transformers.LogitsProcessor` or | |||
| :class:`~transformers.LogitsWarper` to subsequently process a :obj:`scores` input tensor. This class inherits from | |||
| list and adds a specific `__call__` method to apply each :class:`~transformers.LogitsProcessor` or | |||
| :class:`~transformers.LogitsProcessor` to the inputs. | |||
| """ | |||
| def __call__(self, input_ids: torch.LongTensor, | |||
| scores: torch.FloatTensor) -> torch.FloatTensor: | |||
| for processor in self: | |||
| scores = processor(input_ids, scores) | |||
| return scores | |||
| class MinLengthLogitsProcessor(LogitsProcessor): | |||
| r""" | |||
| :class:`transformers.LogitsProcessor` enforcing a min-length by setting EOS probability to 0. | |||
| Args: | |||
| min_length (:obj:`int`): | |||
| The minimum length below which the score of :obj:`eos_token_id` is set to :obj:`-float("Inf")`. | |||
| eos_token_id (:obj:`int`): | |||
| The id of the `end-of-sequence` token. | |||
| """ | |||
| def __init__(self, min_length: int, eos_token_id: int): | |||
| if not isinstance(min_length, int) or min_length < 0: | |||
| raise ValueError( | |||
| f'`min_length` has to be a positive integer, but is {min_length}' | |||
| ) | |||
| if not isinstance(eos_token_id, int) or eos_token_id < 0: | |||
| raise ValueError( | |||
| f'`eos_token_id` has to be a positive integer, but is {eos_token_id}' | |||
| ) | |||
| self.min_length = min_length | |||
| self.eos_token_id = eos_token_id | |||
| def __call__(self, input_ids: torch.LongTensor, | |||
| scores: torch.FloatTensor) -> torch.FloatTensor: | |||
| cur_len = input_ids.shape[-1] | |||
| if cur_len < self.min_length: | |||
| scores[:, self.eos_token_id] = -float('inf') | |||
| return scores | |||
| class NoRepeatNGramLogitsProcessor(LogitsProcessor): | |||
| r""" | |||
| :class:`transformers.LogitsProcessor` that enforces no repetition of n-grams. See `Fairseq | |||
| <https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345>`__. | |||
| Args: | |||
| ngram_size (:obj:`int`): | |||
| All ngrams of size :obj:`ngram_size` can only occur once. | |||
| """ | |||
| def __init__(self, ngram_size: int): | |||
| if not isinstance(ngram_size, int) or ngram_size <= 0: | |||
| raise ValueError( | |||
| f'`ngram_size` has to be a strictly positive integer, but is {ngram_size}' | |||
| ) | |||
| self.ngram_size = ngram_size | |||
| def __call__(self, input_ids: torch.LongTensor, | |||
| scores: torch.FloatTensor) -> torch.FloatTensor: | |||
| num_batch_hypotheses = scores.shape[0] | |||
| cur_len = input_ids.shape[-1] | |||
| banned_batch_tokens = self._calc_banned_ngram_tokens( | |||
| input_ids, num_batch_hypotheses, cur_len) | |||
| for i, banned_tokens in enumerate(banned_batch_tokens): | |||
| scores[i, banned_tokens] = -float('inf') | |||
| return scores | |||
| def _calc_banned_ngram_tokens(self, prev_input_ids: torch.Tensor, | |||
| num_hypos: int, | |||
| cur_len: int) -> List[Iterable[int]]: | |||
| """Copied from fairseq for no_repeat_ngram in beam_search""" | |||
| if cur_len + 1 < self.ngram_size: | |||
| # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet | |||
| return [[] for _ in range(num_hypos)] | |||
| generated_ngrams = [{} for _ in range(num_hypos)] | |||
| for idx in range(num_hypos): | |||
| gen_tokens = prev_input_ids[idx].tolist() | |||
| generated_ngram = generated_ngrams[idx] | |||
| for ngram in zip(*[gen_tokens[i:] | |||
| for i in range(self.ngram_size)]): | |||
| prev_ngram_tuple = tuple(ngram[:-1]) | |||
| generated_ngram[prev_ngram_tuple] = generated_ngram.get( | |||
| prev_ngram_tuple, []) + [ngram[-1]] | |||
| def _get_generated_ngrams(hypo_idx): | |||
| # Before decoding the next token, prevent decoding of ngrams that have already appeared | |||
| start_idx = cur_len + 1 - self.ngram_size | |||
| ngram_idx = tuple(prev_input_ids[hypo_idx, | |||
| start_idx:cur_len].tolist()) | |||
| return generated_ngrams[hypo_idx].get(ngram_idx, []) | |||
| banned_tokens = [ | |||
| _get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos) | |||
| ] | |||
| return banned_tokens | |||
| @@ -0,0 +1,469 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import os | |||
| import random | |||
| from os import path as osp | |||
| from typing import Dict | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from modelscope.hub.snapshot_download import snapshot_download | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import Tensor, TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.config import Config | |||
| from modelscope.utils.constant import ModelFile, Tasks | |||
| from . import mpu | |||
| from .arguments import get_args | |||
| from .generation_utils import BeamSearchScorer | |||
| from .train_utils import get_model | |||
| from .utils import load_checkpoint | |||
| __all__ = ['MGLMForTextSummarization'] | |||
| def setup_args(args): | |||
| args.block_lm = True | |||
| args.task_mask = True | |||
| args.cloze_eval = True | |||
| args.num_layers = 24 | |||
| args.hidden_size = 1536 | |||
| args.num_attention_heads = 16 | |||
| args.max_position_embeddings = 1024 | |||
| args.tokenizer_type = 'ChineseSPTokenizer' | |||
| args.load_pretrained = '' | |||
| args.DDP_impl = 'none' | |||
| args.model_parallel_size = 1 | |||
| args.fp16 = True | |||
| args.cache_dir = 'cache' | |||
| args.out_seq_length = 200 | |||
| args.seq_length = 512 | |||
| args.temperature = 0.9 | |||
| args.top_k = 2 | |||
| args.top_p = 0.8 | |||
| args.frequency_penalty = 0.1 | |||
| args.presence_penalty = 0.1 | |||
| args.mem_length = args.seq_length + args.mem_length - 1 | |||
| return args | |||
| def setup_model(args): | |||
| """Setup model and optimizer.""" | |||
| model = get_model(args, model_type='generation') | |||
| if args.load_pretrained is not None: | |||
| args.no_load_optim = True | |||
| args.load = args.load_pretrained | |||
| _ = load_checkpoint(model, None, None, args) | |||
| return model | |||
| def set_random_seed(seed): | |||
| """Set random seed for reproducability.""" | |||
| if seed is not None and seed > 0: | |||
| random.seed(seed) | |||
| np.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| mpu.model_parallel_cuda_manual_seed(seed) | |||
| def get_masks_and_position_ids(data, | |||
| eod_token, | |||
| reset_position_ids, | |||
| reset_attention_mask, | |||
| loss_mask=None, | |||
| attention_mask=None, | |||
| set_loss_mask=False, | |||
| mem_length=None): | |||
| # Extract batch size and sequence length. | |||
| batch_size, seq_length = data.size() | |||
| # Attention mask (lower triangular). | |||
| if mem_length: | |||
| if attention_mask is None: | |||
| attention_mask = torch.ones( | |||
| (1, seq_length, seq_length + mem_length), device=data.device) | |||
| attention_mask = torch.tril( | |||
| torch.triu(attention_mask, 1 - seq_length + mem_length), | |||
| mem_length) | |||
| else: | |||
| if reset_attention_mask: | |||
| att_mask_batch = batch_size | |||
| else: | |||
| att_mask_batch = 1 | |||
| if attention_mask is None: | |||
| attention_mask = torch.ones( | |||
| (att_mask_batch, seq_length, seq_length), device=data.device) | |||
| attention_mask = torch.tril(attention_mask) | |||
| attention_mask = attention_mask.unsqueeze(1) | |||
| # Loss mask. | |||
| if loss_mask is None: | |||
| loss_mask = torch.ones( | |||
| data.size(), dtype=torch.float, device=data.device) | |||
| # Position ids. | |||
| position_ids = torch.arange( | |||
| seq_length, dtype=torch.long, device=data.device) | |||
| position_ids = position_ids.unsqueeze(0).expand_as(data) | |||
| if set_loss_mask: | |||
| loss_mask[data == eod_token] = 0.0 | |||
| # We need to clone as the ids will be modifed based on batch index. | |||
| if reset_position_ids: | |||
| position_ids = position_ids.clone() | |||
| if reset_position_ids or reset_attention_mask: | |||
| # Loop through the batches: | |||
| for b in range(batch_size): | |||
| # Find indecies where EOD token is. | |||
| eod_index = position_ids[b, data[b] == eod_token] | |||
| # Detach indecies from positions if going to modify positions. | |||
| if reset_position_ids: | |||
| eod_index = eod_index.clone() | |||
| # Loop through EOD indecies: | |||
| prev_index = 0 | |||
| for j in range(eod_index.size()[0]): | |||
| i = eod_index[j] | |||
| # Mask attention loss. | |||
| if reset_attention_mask: | |||
| attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 | |||
| # Reset positions. | |||
| if reset_position_ids: | |||
| position_ids[b, (i + 1):] -= (i + 1 - prev_index) | |||
| prev_index = i + 1 | |||
| return attention_mask, loss_mask, position_ids | |||
| def initialize_distributed(args): | |||
| """Initialize torch.distributed.""" | |||
| # Manually set the device ids. | |||
| device = args.rank % torch.cuda.device_count() | |||
| if args.local_rank is not None: | |||
| device = args.local_rank | |||
| torch.cuda.set_device(device) | |||
| # Call the init process | |||
| init_method = 'tcp://' | |||
| args.master_ip = os.getenv('MASTER_ADDR', 'localhost') | |||
| args.master_port = os.getenv('MASTER_PORT', '6000') | |||
| init_method += args.master_ip + ':' + args.master_port | |||
| torch.distributed.init_process_group( | |||
| backend=args.distributed_backend, | |||
| world_size=args.world_size, | |||
| rank=args.rank, | |||
| init_method=init_method) | |||
| # Set the model-parallel / data-parallel communicators. | |||
| mpu.initialize_model_parallel(args.model_parallel_size) | |||
| # Optional DeepSpeed Activation Checkpointing Features | |||
| # | |||
| if hasattr( | |||
| args, 'deepspeed' | |||
| ) and args.deepspeed and args.deepspeed_activation_checkpointing: | |||
| set_deepspeed_activation_checkpointing(args) | |||
| def get_batch(context_tokens, device, args): | |||
| tokens = context_tokens | |||
| tokens = tokens.view(args.batch_size, -1).contiguous() | |||
| tokens = tokens.to(device) | |||
| # Get the masks and postition ids. | |||
| if args.block_lm: | |||
| attention_mask = torch.tensor([tokens.size(1)], | |||
| device=device, | |||
| dtype=torch.long) | |||
| position_ids = torch.arange( | |||
| tokens.size(1), device=device, dtype=torch.long) | |||
| if not args.no_block_position: | |||
| block_position_ids = torch.zeros( | |||
| tokens.size(1), device=device, dtype=torch.long) | |||
| position_ids = torch.stack((position_ids, block_position_ids), | |||
| dim=0) | |||
| position_ids = position_ids.unsqueeze(0) | |||
| else: | |||
| attention_mask, loss_mask, position_ids = get_masks_and_position_ids( | |||
| tokens, | |||
| args.eod_token, | |||
| reset_position_ids=False, | |||
| reset_attention_mask=False, | |||
| set_loss_mask=False, | |||
| mem_length=args.mem_length) | |||
| return tokens, attention_mask, position_ids | |||
| def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |||
| # This function has been mostly taken from huggingface conversational ai code at | |||
| # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313 | |||
| if top_k > 0: | |||
| # Remove all tokens with a probability less than the last token of the top-k | |||
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, | |||
| None] | |||
| logits[indices_to_remove] = filter_value | |||
| if top_p > 0.0: | |||
| # convert to 1D | |||
| logits = logits.view(logits.size()[1]).contiguous() | |||
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |||
| cumulative_probs = torch.cumsum( | |||
| F.softmax(sorted_logits, dim=-1), dim=-1) | |||
| # Remove tokens with cumulative probability above the threshold | |||
| sorted_indices_to_remove = cumulative_probs > top_p | |||
| # Shift the indices to the right to keep also the first token above the threshold | |||
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ | |||
| ..., :-1].clone() | |||
| sorted_indices_to_remove[..., 0] = 0 | |||
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |||
| logits[indices_to_remove] = filter_value | |||
| # going back to 2D | |||
| logits = logits.view(1, -1).contiguous() | |||
| return logits | |||
| def sample_sequence(model, | |||
| tokenizer, | |||
| context_tokens, | |||
| context_length, | |||
| args, | |||
| device, | |||
| mems=None, | |||
| end_tokens=None): | |||
| if not args.block_lm: | |||
| context_tokens, attention_mask, position_ids = get_batch( | |||
| context_tokens, device, args) | |||
| tokens = torch.empty((args.num_beams, 0), | |||
| device=context_tokens.device, | |||
| dtype=torch.long) | |||
| else: | |||
| tokens = context_tokens.new_full((1, 1), | |||
| tokenizer.get_command('sop').Id) | |||
| counter = 0 | |||
| if mems is None: | |||
| mems = [] | |||
| if end_tokens is None: | |||
| end_tokens = [args.eod_token] | |||
| last_beam_num = 1 | |||
| output_tokens_list = [] | |||
| generated_tokens_list = [] | |||
| while counter < args.out_seq_length: | |||
| if counter == 0 and not args.block_lm: | |||
| next_token_logits, *mems = model(context_tokens, position_ids, | |||
| attention_mask, *mems) | |||
| else: | |||
| if args.block_lm: | |||
| if args.no_block_position: | |||
| position_ids = context_tokens.new_full( | |||
| (last_beam_num, 1), context_length + counter) | |||
| else: | |||
| position_ids = context_tokens.new_ones(last_beam_num, 2, 1) | |||
| position_ids[:, 0] = context_length | |||
| position_ids[:, 1] = counter + 1 | |||
| attention_mask = context_tokens.new_zeros( | |||
| [1], device=context_tokens.device, dtype=torch.long) | |||
| else: | |||
| position_ids = context_tokens.new_ones((last_beam_num, 1)) * ( | |||
| context_length + counter - 1) | |||
| attention_mask = context_tokens.new_ones( | |||
| last_beam_num, | |||
| 1, | |||
| 1, | |||
| args.mem_length + 1, | |||
| device=context_tokens.device, | |||
| dtype=torch.float) | |||
| last_token = tokens[:, -1:] | |||
| next_token_logits, *mems = model(last_token, position_ids, | |||
| attention_mask, *mems) | |||
| next_token_logits = next_token_logits[:, -1] | |||
| next_token_logits /= args.temperature | |||
| frequency_count = torch.zeros(next_token_logits.shape) | |||
| for tk in output_tokens_list: | |||
| frequency_count[0][tk] += 1 | |||
| next_token_logits -= (args.frequency_penalty | |||
| * frequency_count).to(device) | |||
| next_token_logits -= ( | |||
| args.presence_penalty * # noqa | |||
| (frequency_count > 0)).to(device) | |||
| next_token_logits = top_k_logits( | |||
| next_token_logits, top_k=args.top_k, top_p=args.top_p) | |||
| log_probs = F.softmax(next_token_logits, dim=-1) | |||
| prev = torch.multinomial(log_probs, num_samples=1)[0] | |||
| is_end = prev.item() in end_tokens | |||
| if is_end: | |||
| break | |||
| decode_tokens = tokenizer.DecodeIds([prev.item()]) # noqa | |||
| generated_tokens_list.append(prev.item()) | |||
| prev = prev.view(1, 1) | |||
| tokens = prev if tokens is None else torch.cat((tokens, prev), dim=1) | |||
| counter += 1 | |||
| output_tokens_list = tokens.view(-1).contiguous() | |||
| return torch.cat((context_tokens, tokens), dim=1), mems | |||
| def read_context(tokenizer, args, context): | |||
| terminate_runs, skip_run = 0, 0 # noqa | |||
| if mpu.get_model_parallel_rank() == 0: | |||
| while True: | |||
| # raw_text = input("\nContext prompt (stop to exit) >>> ") | |||
| raw_text = context | |||
| if not raw_text: | |||
| print('Prompt should not be empty!') | |||
| break | |||
| # if raw_text == "stop": | |||
| # terminate_runs = 1 | |||
| # break | |||
| generation_mask = '[gMASK]' if args.task_mask else '[MASK]' | |||
| if args.block_lm and 'MASK]' not in raw_text: | |||
| raw_text += ' ' + generation_mask | |||
| # output.write(raw_text) | |||
| context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization | |||
| if args.block_lm: | |||
| context_tokens = [tokenizer.get_command('ENC').Id | |||
| ] + context_tokens | |||
| if not raw_text.endswith('[gMASK]'): | |||
| context_tokens = context_tokens + [ | |||
| tokenizer.get_command('eos').Id | |||
| ] | |||
| context_length = len(context_tokens) | |||
| if context_length >= args.seq_length: | |||
| print('\nContext length', context_length, | |||
| '\nPlease give smaller context than the window length!') | |||
| break | |||
| break | |||
| else: | |||
| context_length = 0 | |||
| terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs]) | |||
| torch.distributed.broadcast( | |||
| terminate_runs_tensor, | |||
| mpu.get_model_parallel_src_rank(), | |||
| group=mpu.get_model_parallel_group()) | |||
| terminate_runs = terminate_runs_tensor[0].item() | |||
| if terminate_runs == 1: | |||
| return terminate_runs, None, None, None | |||
| context_length_tensor = torch.cuda.LongTensor([context_length]) | |||
| torch.distributed.broadcast( | |||
| context_length_tensor, | |||
| mpu.get_model_parallel_src_rank(), | |||
| group=mpu.get_model_parallel_group()) | |||
| context_length = context_length_tensor[0].item() | |||
| if mpu.get_model_parallel_rank() == 0: | |||
| context_tokens_tensor = torch.cuda.LongTensor(context_tokens) | |||
| else: | |||
| context_tokens_tensor = torch.cuda.LongTensor([0] * context_length) | |||
| torch.distributed.broadcast( | |||
| context_tokens_tensor, | |||
| mpu.get_model_parallel_src_rank(), | |||
| group=mpu.get_model_parallel_group()) | |||
| if mpu.get_model_parallel_rank() != 0: | |||
| raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist()) | |||
| return terminate_runs, raw_text, context_tokens_tensor, context_length | |||
| @MODELS.register_module(Tasks.text_summarization, module_name=Models.mglm) | |||
| class MGLMForTextSummarization(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the text summarization model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| from .configure_data import prepare_tokenizer | |||
| # Disable CuDNN. | |||
| torch.backends.cudnn.enabled = False | |||
| # Arguments. | |||
| self.args = setup_args(get_args()) | |||
| self.args.load_pretrained = model_dir | |||
| # Pytorch distributed. | |||
| try: | |||
| initialize_distributed(self.args) | |||
| except (RuntimeError): | |||
| print('group process initialized twice') | |||
| # Random seeds for reproducability. | |||
| set_random_seed(self.args.seed) | |||
| # setting default batch size to 1 | |||
| self.args.batch_size = 1 | |||
| self.args.tokenizer_path = model_dir | |||
| self.tokenizer = prepare_tokenizer(self.args) | |||
| self.model = setup_model(self.args) | |||
| self.cfg = Config.from_file( | |||
| osp.join(model_dir, ModelFile.CONFIGURATION)) | |||
| def forward(self, input: Dict[str, str]) -> Dict[str, str]: | |||
| pass | |||
| def generate(self, input: Dict[str, str]) -> Dict[str, str]: | |||
| model = self.model | |||
| tokenizer = self.tokenizer | |||
| args = self.args | |||
| device = torch.cuda.current_device() | |||
| model.eval() | |||
| context = input['text'] + self.cfg.model.prompt | |||
| with torch.no_grad(): | |||
| terminate_runs, raw_text, context_tokens_tensor, context_length = read_context( | |||
| tokenizer, args, context) | |||
| mems = [] | |||
| tokens, attention_mask, position_ids = get_batch( | |||
| context_tokens_tensor, device, args) | |||
| mask_tokens = ['MASK', 'sMASK', 'gMASK' | |||
| ] if args.task_mask else ['MASK'] | |||
| mask_tokens = [ | |||
| tokenizer.get_command(token).Id for token in mask_tokens | |||
| ] | |||
| end_tokens = [tokenizer.get_command('eop').Id, args.eod_token] | |||
| mask_positions = [] | |||
| for token in mask_tokens: | |||
| mask_positions += (context_tokens_tensor == token).nonzero( | |||
| as_tuple=True)[0].tolist() | |||
| mask_positions.sort() | |||
| if args.no_block_position: | |||
| for mask_position in mask_positions: | |||
| position_ids[0, mask_position + 1:] += args.out_seq_length | |||
| _, *mems = model(tokens, position_ids, attention_mask, *mems) | |||
| for mask_position in mask_positions: | |||
| if args.no_block_position: | |||
| position = position_ids[0, mask_position].item() | |||
| else: | |||
| position = mask_position | |||
| tokens, mems, = sample_sequence( | |||
| model, | |||
| tokenizer, | |||
| tokens, | |||
| position, | |||
| args, | |||
| device, | |||
| mems=mems, | |||
| end_tokens=end_tokens) | |||
| output_tokens_list = tokens.view(-1).contiguous() | |||
| trim_decode_tokens = tokenizer.DecodeIds( | |||
| output_tokens_list.tolist()) | |||
| res = trim_decode_tokens.split('<|startofpiece|>')[-1] | |||
| print(res) | |||
| return {OutputKeys.TEXT: res} | |||
| @@ -0,0 +1,20 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from .distributed import (DistributedDataParallel, | |||
| PyTorchDistributedDataParallel) | |||
| from .downstream import (GLMForMultiTokenCloze, GLMForMultiTokenClozeFast, | |||
| GLMForSequenceClassification, GLMForSingleTokenCloze) | |||
| from .modeling_glm import (GLMModel, | |||
| glm_get_params_for_weight_decay_optimization) | |||
| @@ -0,0 +1,127 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| import torch.distributed as dist | |||
| from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors | |||
| from torch.autograd import Variable | |||
| from torch.nn.modules import Module | |||
| from torch.nn.parallel.distributed import DistributedDataParallel as DDP | |||
| from modelscope.models.nlp.mglm import mpu | |||
| class PyTorchDistributedDataParallel(DDP): | |||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||
| return self.module.named_parameters(prefix=prefix, recurse=recurse) | |||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
| sd = self.module.state_dict(destination, prefix, keep_vars) | |||
| return sd | |||
| def load_state_dict(self, state_dict, strict=True): | |||
| return self.module.load_state_dict(state_dict, strict=strict) | |||
| class DistributedDataParallel(Module): | |||
| def __init__(self, module): | |||
| super(DistributedDataParallel, self).__init__() | |||
| self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False | |||
| self.module = module | |||
| self.data_parallel_group = mpu.get_data_parallel_group() | |||
| src_rank = mpu.get_model_parallel_rank() | |||
| for p in self.module.parameters(): | |||
| if torch.is_tensor(p): | |||
| dist.broadcast(p, src_rank, group=self.data_parallel_group) | |||
| def allreduce_params(reduce_after=True, | |||
| no_scale=False, | |||
| fp32_allreduce=False): | |||
| if (self.needs_reduction): | |||
| self.needs_reduction = False | |||
| buckets = {} | |||
| for name, param in self.module.named_parameters(): | |||
| if param.requires_grad and param.grad is not None: | |||
| tp = (param.data.type()) | |||
| if tp not in buckets: | |||
| buckets[tp] = [] | |||
| buckets[tp].append(param) | |||
| if self.warn_on_half: | |||
| if torch.cuda.HalfTensor in buckets: | |||
| print( | |||
| 'WARNING: gloo dist backend for half parameters may be extremely slow. It is recommended to use the NCCL backend in this case.' # noqa | |||
| ) | |||
| self.warn_on_half = False | |||
| for tp in buckets: | |||
| bucket = buckets[tp] | |||
| grads = [param.grad.data for param in bucket] | |||
| coalesced = _flatten_dense_tensors(grads) | |||
| if fp32_allreduce: | |||
| coalesced = coalesced.float() | |||
| if not no_scale and not reduce_after: | |||
| coalesced /= dist.get_world_size( | |||
| group=self.data_parallel_group) | |||
| dist.all_reduce(coalesced, group=self.data_parallel_group) | |||
| torch.cuda.synchronize() | |||
| if not no_scale and reduce_after: | |||
| coalesced /= dist.get_world_size( | |||
| group=self.data_parallel_group) | |||
| for buf, synced in zip( | |||
| grads, _unflatten_dense_tensors(coalesced, grads)): | |||
| buf.copy_(synced) | |||
| self.hook_handles = [] | |||
| self.hooks = [] | |||
| for param in list(self.module.parameters()): | |||
| def allreduce_hook(*unused): | |||
| Variable._execution_engine.queue_callback(allreduce_params) | |||
| self.allreduce_params = allreduce_params | |||
| def forward(self, *inputs, **kwargs): | |||
| self.needs_reduction = True | |||
| return self.module(*inputs, **kwargs) | |||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
| sd = self.module.state_dict(destination, prefix, keep_vars) | |||
| return sd | |||
| def load_state_dict(self, state_dict, strict=True): | |||
| return self.module.load_state_dict(state_dict, strict=strict) | |||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||
| return self.module.named_parameters(prefix=prefix, recurse=recurse) | |||
| ''' | |||
| def _sync_buffers(self): | |||
| buffers = list(self.module._all_buffers()) | |||
| if len(buffers) > 0: | |||
| # cross-node buffer sync | |||
| flat_buffers = _flatten_dense_tensors(buffers) | |||
| dist.broadcast(flat_buffers, 0) | |||
| for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)): | |||
| buf.copy_(synced) | |||
| def train(self, mode=True): | |||
| # Clear NCCL communicator and CUDA event cache of the default group ID, | |||
| # These cache will be recreated at the later call. This is currently a | |||
| # work-around for a potential NCCL deadlock. | |||
| if dist._backend == dist.dist_backend.NCCL: | |||
| dist._clear_group_cache() | |||
| super(DistributedDataParallel, self).train(mode) | |||
| self.module.train(mode) | |||
| ''' | |||
| @@ -0,0 +1,242 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| """Multiple choice model.""" | |||
| import torch | |||
| import torch.nn | |||
| from .modeling_glm import GLMModel | |||
| class GLMForMultiTokenCloze(torch.nn.Module): | |||
| def __init__(self, | |||
| language_model: GLMModel, | |||
| take_softmax=True, | |||
| length_penalty=0.0): | |||
| super(GLMForMultiTokenCloze, self).__init__() | |||
| self.model = language_model | |||
| self.take_softmax = take_softmax | |||
| self.length_penalty = length_penalty | |||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
| # [h.remove() for h in self.hook_handles] | |||
| sd = self.model.state_dict(destination, prefix, keep_vars) | |||
| return sd | |||
| def load_state_dict(self, state_dict, strict=True): | |||
| return self.model.load_state_dict(state_dict, strict=strict) | |||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||
| return self.model.named_parameters(prefix=prefix, recurse=recurse) | |||
| def forward(self, | |||
| input_ids, | |||
| position_ids, | |||
| attention_mask, | |||
| target_ids=None, | |||
| logit_mask=None, | |||
| prompt_pos=None): | |||
| if target_ids is None: | |||
| return self.model(input_ids, position_ids, attention_mask) | |||
| num_choices = None | |||
| if len(input_ids.shape) == 3: | |||
| batch_size, num_choices = input_ids.shape[:2] | |||
| input_ids = input_ids.reshape(-1, input_ids.size(-1)) | |||
| attention_mask = attention_mask.reshape(-1, | |||
| *attention_mask.size()[2:]) | |||
| position_ids = position_ids.reshape(-1, *position_ids.size()[2:]) | |||
| target_ids = target_ids.reshape(-1, target_ids.size(-1)) | |||
| logit_mask = logit_mask.reshape(-1, logit_mask.size(-1)) | |||
| if prompt_pos is not None: | |||
| prompt_pos = prompt_pos.reshape(-1, prompt_pos.size(-1)) | |||
| outputs, *mems = self.model( | |||
| input_ids, position_ids, attention_mask, prompt_pos=prompt_pos) | |||
| if self.take_softmax: | |||
| outputs = torch.nn.functional.log_softmax(outputs, dim=-1) | |||
| # select the target logits | |||
| batch_ids = torch.arange( | |||
| target_ids.size(0), dtype=torch.long, device=target_ids.device) | |||
| batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids) | |||
| seq_ids = torch.arange( | |||
| target_ids.size(-1), dtype=torch.long, device=target_ids.device) | |||
| seq_ids = seq_ids.unsqueeze(0).expand_as(target_ids) | |||
| logits = outputs[batch_ids, seq_ids, target_ids] | |||
| logits = (logits * logit_mask).sum(dim=1) | |||
| if self.length_penalty > 0.0: | |||
| logits = logits / logit_mask.sum(dim=1)**self.length_penalty | |||
| if num_choices is not None: | |||
| logits = logits.view(-1, num_choices) | |||
| return (logits, *mems) | |||
| class GLMForMultiTokenClozeFast(torch.nn.Module): | |||
| def __init__(self, language_model, take_softmax=True, length_penalty=0.0): | |||
| super(GLMForMultiTokenClozeFast, self).__init__() | |||
| self.model = language_model | |||
| self.take_softmax = take_softmax | |||
| self.length_penalty = length_penalty | |||
| def forward(self, input_ids, position_ids, attention_mask, dec_input_ids, | |||
| dec_position_ids, dec_attention_mask, dec_target_ids, | |||
| dec_logit_mask): | |||
| # encoder | |||
| outputs, *mems = self.model( | |||
| input_ids, | |||
| position_ids, | |||
| attention_mask, | |||
| return_memory=True, | |||
| detach_memory=False) | |||
| batch_size, num_choices, max_dec_len = dec_input_ids.size() | |||
| max_enc_len = input_ids.size(-1) | |||
| enc_mems = [] | |||
| for hidden in mems: | |||
| hidden = hidden.unsqueeze(1).expand(-1, num_choices, -1, | |||
| -1).reshape( | |||
| batch_size * num_choices, | |||
| *hidden.size()[1:]) | |||
| enc_mems.append(hidden) | |||
| def build_dec_mask_matrix(seq_length, sep, memory_length=0): | |||
| m = enc_mems[0].new_ones((1, seq_length, seq_length)) | |||
| m = torch.tril(m) | |||
| # sep = dec_attention_mask | |||
| ids = torch.arange( | |||
| memory_length, device=sep.device, dtype=sep.dtype).view(1, -1) | |||
| mask = ids < sep.view(-1, 1) # batch * mem | |||
| mask = mask.unsqueeze(1).float().expand(-1, seq_length, -1) | |||
| m = m.expand(batch_size * num_choices, -1, -1) | |||
| m = torch.cat((mask, m), dim=2) | |||
| m = m.unsqueeze(1) | |||
| return m | |||
| dec_input_ids = dec_input_ids.reshape(-1, max_dec_len) | |||
| dec_position_ids = dec_position_ids.reshape( | |||
| -1, | |||
| *dec_position_ids.size()[2:]) | |||
| # dec_attention_mask = dec_attention_mask.reshape(-1, *dec_attention_mask.size()[2:]).unsqueeze(1) | |||
| dec_attention_mask = build_dec_mask_matrix( | |||
| max_dec_len, dec_attention_mask.reshape(-1), max_enc_len) | |||
| dec_target_ids = dec_target_ids.reshape(-1, dec_target_ids.size(-1)) | |||
| dec_logit_mask = dec_logit_mask.reshape(-1, dec_logit_mask.size(-1)) | |||
| outputs, *mems = self.model(dec_input_ids, dec_position_ids, | |||
| dec_attention_mask, *enc_mems) | |||
| if self.take_softmax: | |||
| outputs = torch.nn.functional.log_softmax(outputs, dim=-1) | |||
| batch_ids = torch.arange( | |||
| dec_target_ids.size(0), | |||
| dtype=torch.long, | |||
| device=dec_target_ids.device) | |||
| batch_ids = batch_ids.unsqueeze(1).expand_as(dec_target_ids) | |||
| seq_ids = torch.arange( | |||
| dec_target_ids.size(-1), | |||
| dtype=torch.long, | |||
| device=dec_target_ids.device) | |||
| seq_ids = seq_ids.unsqueeze(0).expand_as(dec_target_ids) | |||
| logits = outputs[batch_ids, seq_ids, dec_target_ids] | |||
| logits = (logits * dec_logit_mask).sum(dim=1) | |||
| if self.length_penalty > 0.0: | |||
| logits = logits / dec_logit_mask.sum(dim=1)**self.length_penalty | |||
| if num_choices is not None: | |||
| logits = logits.view(-1, num_choices) | |||
| return (logits, *mems) | |||
| class GLMForSingleTokenCloze(torch.nn.Module): | |||
| def __init__(self, language_model, take_softmax=False): | |||
| super().__init__() | |||
| self.model = language_model | |||
| self.take_softmax = take_softmax | |||
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |||
| # [h.remove() for h in self.hook_handles] | |||
| sd = self.model.state_dict(destination, prefix, keep_vars) | |||
| return sd | |||
| def load_state_dict(self, state_dict, strict=True): | |||
| return self.model.load_state_dict(state_dict, strict=strict) | |||
| def named_parameters(self, prefix: str = '', recurse: bool = True): | |||
| return self.model.named_parameters(prefix=prefix, recurse=recurse) | |||
| def forward(self, | |||
| input_ids, | |||
| position_ids, | |||
| attention_mask, | |||
| target_ids=None, | |||
| logit_mask=None, | |||
| prompt_pos=None): | |||
| if target_ids is None: | |||
| return self.model(input_ids, position_ids, attention_mask) | |||
| assert len(input_ids.shape) == 2 | |||
| outputs, *mems = self.model( | |||
| input_ids, position_ids, attention_mask, prompt_pos=prompt_pos) | |||
| batch_ids = torch.arange( | |||
| outputs.size(0), | |||
| dtype=attention_mask.dtype, | |||
| device=attention_mask.device) | |||
| target_logits = outputs[batch_ids, attention_mask] | |||
| if self.take_softmax: | |||
| target_prob = torch.nn.functional.log_softmax( | |||
| target_logits, dim=-1) | |||
| else: | |||
| target_prob = target_logits | |||
| batch_ids = batch_ids.unsqueeze(1).expand_as(target_ids) | |||
| output = target_prob[batch_ids, target_ids] | |||
| return (output, target_logits, *mems) | |||
| class GLMForSequenceClassification(torch.nn.Module): | |||
| def __init__(self, | |||
| language_model, | |||
| hidden_size, | |||
| hidden_dropout, | |||
| pool_token, | |||
| num_class=1): | |||
| super().__init__() | |||
| self.pool_token = pool_token | |||
| self.model = language_model | |||
| self.num_class = num_class | |||
| # Multi-choice head. | |||
| self.pool_layer = torch.nn.Linear(hidden_size, hidden_size) | |||
| self.multichoice_dropout = torch.nn.Dropout(hidden_dropout) | |||
| self.multichoice_head = torch.nn.Linear(hidden_size, num_class) | |||
| def forward(self, input_ids, position_ids, attention_mask): | |||
| num_choices = None | |||
| if len(input_ids.shape) == 3: | |||
| assert self.num_class == 1 | |||
| batch_size, num_choices = input_ids.shape[:2] | |||
| input_ids = input_ids.reshape(-1, input_ids.size(-1)) | |||
| attention_mask = attention_mask.reshape(-1, | |||
| *attention_mask.size()[2:]) | |||
| position_ids = position_ids.reshape(-1, *position_ids.size()[2:]) | |||
| outputs, *mems = self.model(input_ids, position_ids, attention_mask) | |||
| if self.pool_token == 'start': | |||
| output = outputs[torch.arange( | |||
| outputs.size(0), | |||
| dtype=attention_mask.dtype, | |||
| device=attention_mask.device), attention_mask] | |||
| elif self.pool_token == 'pad': | |||
| output = outputs[torch.arange( | |||
| outputs.size(0), | |||
| dtype=attention_mask.dtype, | |||
| device=attention_mask.device), attention_mask - 1] | |||
| elif self.pool_token == 'cls': | |||
| output = outputs[:, 0] | |||
| else: | |||
| raise NotImplementedError | |||
| output = torch.tanh(self.pool_layer(output)) | |||
| multichoice_output = self.multichoice_dropout(output) | |||
| logits = self.multichoice_head(multichoice_output) | |||
| if num_choices is not None: | |||
| logits = logits.view(-1, num_choices) | |||
| return (logits, *mems) | |||
| @@ -0,0 +1,245 @@ | |||
| # Modified by Zhipu.AI | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """GPT-2 model.""" | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| from modelscope.models.nlp.mglm import mpu | |||
| from modelscope.models.nlp.mglm.model.prompt import PromptSpell | |||
| from modelscope.models.nlp.mglm.utils import print_rank_0 | |||
| def init_method_normal(std=0.02): | |||
| """Init method based on normal distribution. | |||
| This is only used for embeddings. The transformer has its | |||
| own initializer. | |||
| """ | |||
| def init_(tensor): | |||
| return torch.nn.init.normal_(tensor, mean=0.0, std=std) | |||
| return init_ | |||
| class GLMModel(torch.nn.Module): | |||
| """GLM Language model. | |||
| The output of the forward method are the logits (parallel or | |||
| serial depending on the `parallel_output` flag. | |||
| """ | |||
| def __init__( | |||
| self, | |||
| num_layers, | |||
| vocab_size, | |||
| hidden_size, | |||
| num_attention_heads, | |||
| embedding_dropout_prob, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| max_sequence_length, | |||
| max_memory_length, | |||
| checkpoint_activations, | |||
| checkpoint_num_layers=1, | |||
| parallel_output=True, | |||
| relative_encoding=False, | |||
| block_position_encoding=False, | |||
| output_predict=True, | |||
| spell_length=None, | |||
| spell_func='lstm', | |||
| attention_scale=1.0, | |||
| ): | |||
| super(GLMModel, self).__init__() | |||
| self.parallel_output = parallel_output | |||
| self.output_predict = output_predict | |||
| self.hidden_size = hidden_size | |||
| init_method = init_method_normal(std=0.02) | |||
| # Word embeddings (parallel). | |||
| self.word_embeddings = mpu.VocabParallelEmbedding( | |||
| vocab_size, hidden_size, init_method=init_method) | |||
| # Transformer | |||
| self.transformer = mpu.GPT2ParallelTransformer( | |||
| num_layers, | |||
| hidden_size, | |||
| num_attention_heads, | |||
| max_sequence_length, | |||
| max_memory_length, | |||
| embedding_dropout_prob, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| checkpoint_activations, | |||
| checkpoint_num_layers, | |||
| attention_scale=attention_scale, | |||
| relative_encoding=relative_encoding, | |||
| block_position_encoding=block_position_encoding) | |||
| if spell_length is not None: | |||
| self.prompt_spell = PromptSpell(spell_length, self.hidden_size, | |||
| spell_func) | |||
| def freeze_transformer(self, tune_prefix_layers=None): | |||
| log_str = 'Freeze transformer' | |||
| self.word_embeddings.requires_grad_(False) | |||
| self.transformer.requires_grad_(False) | |||
| if tune_prefix_layers is not None: | |||
| log_str += f' tune {tune_prefix_layers} prefix layers' | |||
| for i in range(tune_prefix_layers): | |||
| self.transformer.layers[i].requires_grad_(True) | |||
| print_rank_0(log_str) | |||
| def forward(self, | |||
| input_ids, | |||
| position_ids, | |||
| attention_mask, | |||
| *mems, | |||
| return_memory=False, | |||
| detach_memory=True, | |||
| prompt_pos=None): | |||
| # Embeddings. | |||
| batch_size = input_ids.size(0) | |||
| words_embeddings = self.word_embeddings(input_ids) | |||
| embeddings = words_embeddings | |||
| if prompt_pos is not None: | |||
| embeddings = embeddings.clone() | |||
| prompt_embeds = self.prompt_spell() | |||
| batch_index = torch.arange( | |||
| batch_size, device=input_ids.device).unsqueeze(1) | |||
| embeddings[batch_index, prompt_pos] = prompt_embeds | |||
| # Transformer. | |||
| transformer_output = self.transformer( | |||
| embeddings, | |||
| position_ids, | |||
| attention_mask, | |||
| mems, | |||
| return_memory=return_memory, | |||
| detach_memory=detach_memory) | |||
| logits, hidden_layers = transformer_output | |||
| outputs = hidden_layers | |||
| if self.output_predict: | |||
| # Parallel logits. | |||
| logits_parallel = mpu.copy_to_model_parallel_region(logits) | |||
| logits_parallel = F.linear(logits_parallel, | |||
| self.word_embeddings.weight) | |||
| if self.parallel_output: | |||
| return (logits_parallel, *outputs) | |||
| return (mpu.gather_from_model_parallel_region(logits_parallel), | |||
| *outputs) | |||
| else: | |||
| return (logits, *outputs) | |||
| class EncoderDecoder(torch.nn.Module): | |||
| """Seq2Seq Transformer Model | |||
| The output of the forward method are the logits (parallel or serial depending on the `parallel_output` flag). | |||
| """ | |||
| def __init__(self, | |||
| num_layers, | |||
| vocab_size, | |||
| hidden_size, | |||
| num_attention_heads, | |||
| embedding_dropout_prob, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| max_sequence_length, | |||
| max_memory_length, | |||
| checkpoint_activations, | |||
| checkpoint_num_layers=1, | |||
| parallel_output=True, | |||
| output_predict=True): | |||
| super(EncoderDecoder, self).__init__() | |||
| self.parallel_output = parallel_output | |||
| self.output_predict = output_predict | |||
| init_method = init_method_normal(std=0.02) | |||
| # Word embeddings (parallel). | |||
| self.word_embeddings = mpu.VocabParallelEmbedding( | |||
| vocab_size, hidden_size, init_method=init_method) | |||
| # Transformer | |||
| self.encoder = mpu.GPT2ParallelTransformer( | |||
| num_layers, hidden_size, num_attention_heads, max_sequence_length, | |||
| max_memory_length, embedding_dropout_prob, attention_dropout_prob, | |||
| output_dropout_prob, checkpoint_activations, checkpoint_num_layers) | |||
| self.decoder = mpu.GPT2ParallelTransformer( | |||
| num_layers, | |||
| hidden_size, | |||
| num_attention_heads, | |||
| max_sequence_length, | |||
| max_memory_length, | |||
| embedding_dropout_prob, | |||
| attention_dropout_prob, | |||
| output_dropout_prob, | |||
| checkpoint_activations, | |||
| checkpoint_num_layers, | |||
| use_decoder_layer=True) | |||
| def forward(self, source_ids, target_ids, source_position_ids, | |||
| target_position_ids, source_mask, target_mask): | |||
| # Embeddings. | |||
| source_embeddings = self.word_embeddings(source_ids) | |||
| target_embeddings = self.word_embeddings(target_ids) | |||
| # Transformer. | |||
| encoder_output, _ = self.encoder(source_embeddings, | |||
| source_position_ids, source_mask) | |||
| decoder_output, _ = self.decoder(target_embeddings, | |||
| target_position_ids, target_mask) | |||
| if self.output_predict: | |||
| # Parallel logits. | |||
| output_parallel = mpu.copy_to_model_parallel_region(decoder_output) | |||
| logits_parallel = F.linear(output_parallel, | |||
| self.word_embeddings.weight) | |||
| if self.parallel_output: | |||
| return (logits_parallel, ) | |||
| return (mpu.gather_from_model_parallel_region(logits_parallel), ) | |||
| else: | |||
| return (decoder_output, ) | |||
| def glm_get_params_for_weight_decay_optimization(module): | |||
| weight_decay_params = {'params': []} | |||
| no_weight_decay_params = {'params': [], 'weight_decay': 0.0} | |||
| for module_ in module.modules(): | |||
| if isinstance(module_, (mpu.LayerNorm, torch.nn.LayerNorm)): | |||
| no_weight_decay_params['params'].extend([ | |||
| p for p in list(module_._parameters.values()) | |||
| if p is not None and p.requires_grad | |||
| ]) | |||
| else: | |||
| weight_decay_params['params'].extend([ | |||
| p for n, p in list(module_._parameters.items()) | |||
| if p is not None and p.requires_grad and n != 'bias' | |||
| ]) | |||
| no_weight_decay_params['params'].extend([ | |||
| p for n, p in list(module_._parameters.items()) | |||
| if p is not None and p.requires_grad and n == 'bias' | |||
| ]) | |||
| return weight_decay_params, no_weight_decay_params | |||
| @@ -0,0 +1,59 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import random | |||
| import torch | |||
| class PromptSpell(torch.nn.Module): | |||
| def __init__(self, spell_length, hidden_size, spell_func): | |||
| super(PromptSpell, self).__init__() | |||
| self.spell_length = spell_length | |||
| self.hidden_size = hidden_size | |||
| self.spell_embeddings = torch.nn.Embedding(self.spell_length, | |||
| self.hidden_size) | |||
| self.spell_func = spell_func | |||
| if self.spell_func == 'lstm': | |||
| self.lstm_head = torch.nn.LSTM( | |||
| input_size=self.hidden_size, | |||
| hidden_size=self.hidden_size, | |||
| num_layers=2, | |||
| # dropout=self.lstm_dropout, | |||
| bidirectional=True, | |||
| batch_first=True) # .to(torch.device("cuda")) | |||
| self.mlp_head = torch.nn.Sequential( | |||
| torch.nn.Linear(2 * self.hidden_size, self.hidden_size), | |||
| torch.nn.ReLU(), | |||
| torch.nn.Linear(self.hidden_size, self.hidden_size)) | |||
| elif self.spell_func == 'mlp': | |||
| self.mlp_head = torch.nn.Sequential( | |||
| torch.nn.Linear(self.hidden_size, self.hidden_size), | |||
| torch.nn.ReLU(), | |||
| torch.nn.Linear(self.hidden_size, self.hidden_size)) | |||
| elif self.spell_func != 'none': | |||
| raise NotImplementedError('Prompt function ' + self.spell_func) | |||
| def init_embedding(self, word_embeddings=None, task_tokens=None): | |||
| num_words = 5000 | |||
| with torch.no_grad(): | |||
| for i in range(self.spell_length): | |||
| rand_token = random.randrange(num_words) | |||
| if task_tokens is None: | |||
| target_embedding = word_embeddings[rand_token] | |||
| else: | |||
| word_embedding = word_embeddings[rand_token] | |||
| task_token = random.choice(task_tokens) | |||
| task_embedding = word_embeddings[task_token] | |||
| ratio = random.random() | |||
| target_embedding = word_embedding * ratio + task_embedding * ( | |||
| 1 - ratio) | |||
| self.spell_embeddings.weight.data[i] = target_embedding | |||
| def forward(self): | |||
| prompt_embeds = self.spell_embeddings.weight.unsqueeze(0) | |||
| if self.spell_func == 'lstm': | |||
| prompt_embeds = self.lstm_head(prompt_embeds)[0] | |||
| if self.spell_func == 'lstm' or self.spell_func == 'mlp': | |||
| prompt_embeds = self.mlp_head(prompt_embeds) | |||
| return prompt_embeds | |||
| @@ -0,0 +1,37 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Model parallel utility interface.""" | |||
| from .cross_entropy import vocab_parallel_cross_entropy | |||
| from .data import broadcast_data | |||
| from .grads import clip_grad_norm | |||
| from .initialize import (destroy_model_parallel, get_data_parallel_group, | |||
| get_data_parallel_rank, get_data_parallel_world_size, | |||
| get_model_parallel_group, get_model_parallel_rank, | |||
| get_model_parallel_src_rank, | |||
| get_model_parallel_world_size, | |||
| initialize_model_parallel, | |||
| model_parallel_is_initialized) | |||
| from .layers import (ColumnParallelLinear, ParallelEmbedding, | |||
| RowParallelLinear, VocabParallelEmbedding) | |||
| from .mappings import (copy_to_model_parallel_region, | |||
| gather_from_model_parallel_region, | |||
| reduce_from_model_parallel_region, | |||
| scatter_to_model_parallel_region) | |||
| from .random import (checkpoint, get_cuda_rng_tracker, | |||
| model_parallel_cuda_manual_seed, | |||
| partition_activations_in_checkpoint) | |||
| from .transformer import (BertParallelSelfAttention, | |||
| BertParallelTransformerLayer, | |||
| GPT2ParallelTransformer, LayerNorm) | |||
| @@ -0,0 +1,110 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| from .initialize import (get_model_parallel_group, get_model_parallel_rank, | |||
| get_model_parallel_world_size) | |||
| from .utils import VocabUtility | |||
| class _VocabParallelCrossEntropy(torch.autograd.Function): | |||
| @staticmethod | |||
| def forward(ctx, vocab_parallel_logits, target): | |||
| # Copy so the input remains unchanged. | |||
| logits = vocab_parallel_logits.clone() | |||
| # Maximum value along vocab dimension across all GPUs. | |||
| logits_max = torch.max(logits, dim=-1)[0] | |||
| torch.distributed.all_reduce( | |||
| logits_max, | |||
| op=torch.distributed.ReduceOp.MAX, | |||
| group=get_model_parallel_group()) | |||
| # Subtract the maximum value. | |||
| logits.sub_(logits_max.unsqueeze(dim=-1)) | |||
| # Sum of exponential of logits along vocab dimension across all GPUs. | |||
| exp_logits = logits.exp() | |||
| sum_exp_logits = exp_logits.sum(dim=-1) | |||
| torch.distributed.all_reduce( | |||
| sum_exp_logits, | |||
| op=torch.distributed.ReduceOp.SUM, | |||
| group=get_model_parallel_group()) | |||
| # Get the partition's vocab indecies | |||
| get_vocab_range = VocabUtility.vocab_range_from_per_partition_vocab_size | |||
| partition_vocab_size = vocab_parallel_logits.size()[-1] | |||
| rank = get_model_parallel_rank() | |||
| world_size = get_model_parallel_world_size() | |||
| vocab_start_index, vocab_end_index = get_vocab_range( | |||
| partition_vocab_size, rank, world_size) | |||
| # Create a mask of valid vocab ids (1 means it needs to be masked). | |||
| target_mask = (target < vocab_start_index) | ( | |||
| target >= vocab_end_index) | |||
| masked_target = target.clone() - vocab_start_index | |||
| masked_target[target_mask] = 0 | |||
| # Get predicted-logits = logits[target]. | |||
| # For Simplicity, we convert logits to a 2-D tensor with size | |||
| # [*, partition-vocab-size] and target to a 1-D tensor of size [*]. | |||
| logits_2d = logits.view(-1, partition_vocab_size) | |||
| masked_target_1d = masked_target.view(-1) | |||
| arange_1d = torch.arange( | |||
| start=0, end=logits_2d.size()[0], device=logits_2d.device) | |||
| predicted_logits_1d = logits_2d[arange_1d, masked_target_1d] | |||
| predicted_logits = predicted_logits_1d.view_as(target) | |||
| predicted_logits[target_mask] = 0.0 | |||
| # All reduce is needed to get the chunks from other GPUs. | |||
| torch.distributed.all_reduce( | |||
| predicted_logits, | |||
| op=torch.distributed.ReduceOp.SUM, | |||
| group=get_model_parallel_group()) | |||
| # Loss = log(sum(exp(logits))) - predicted-logit. | |||
| loss = torch.log(sum_exp_logits) - predicted_logits | |||
| # Store softmax, target-mask and masked-target for backward pass. | |||
| exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) | |||
| ctx.save_for_backward(exp_logits, target_mask, masked_target_1d) | |||
| return loss | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| # Retreive tensors from the forward path. | |||
| softmax, target_mask, masked_target_1d = ctx.saved_tensors | |||
| # All the inputs have softmax as thier gradient. | |||
| grad_input = softmax | |||
| # For simplicity, work with the 2D gradient. | |||
| partition_vocab_size = softmax.size()[-1] | |||
| grad_2d = grad_input.view(-1, partition_vocab_size) | |||
| # Add the gradient from matching classes. | |||
| arange_1d = torch.arange( | |||
| start=0, end=grad_2d.size()[0], device=grad_2d.device) | |||
| grad_2d[arange_1d, | |||
| masked_target_1d] -= (1.0 - target_mask.view(-1).float()) | |||
| # Finally elementwise multiplication with the output gradients. | |||
| grad_input.mul_(grad_output.unsqueeze(dim=-1)) | |||
| return grad_input, None | |||
| def vocab_parallel_cross_entropy(vocab_parallel_logits, target): | |||
| """Helper function for the cross entropy.""" | |||
| return _VocabParallelCrossEntropy.apply(vocab_parallel_logits, target) | |||
| @@ -0,0 +1,117 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| from .initialize import (get_model_parallel_group, get_model_parallel_rank, | |||
| get_model_parallel_src_rank) | |||
| _MAX_DATA_DIM = 5 | |||
| def _check_data_types(keys, data, target_dtype): | |||
| """Check that all the keys have the same target data type.""" | |||
| for key in keys: | |||
| assert data[key].dtype == target_dtype, '{} has data type {} which '\ | |||
| 'is different than {}'.format(key, data[key].dtype, target_dtype) | |||
| def _build_key_size_numel_dictionaries(keys, data): | |||
| """Build the size on rank 0 and broadcast.""" | |||
| max_dim = _MAX_DATA_DIM | |||
| sizes = [0 for _ in range(max_dim) for _ in keys] | |||
| # Pack the sizes on rank zero. | |||
| if get_model_parallel_rank() == 0: | |||
| offset = 0 | |||
| for key in keys: | |||
| assert data[key].dim( | |||
| ) < max_dim, 'you should increase MAX_DATA_DIM' | |||
| size = data[key].size() | |||
| for i, s in enumerate(size): | |||
| sizes[i + offset] = s | |||
| offset += max_dim | |||
| # Move to GPU and broadcast. | |||
| sizes_cuda = torch.cuda.LongTensor(sizes) | |||
| torch.distributed.broadcast( | |||
| sizes_cuda, | |||
| get_model_parallel_src_rank(), | |||
| group=get_model_parallel_group()) | |||
| # Move back to cpu and unpack. | |||
| sizes_cpu = sizes_cuda.cpu() | |||
| key_size = {} | |||
| key_numel = {} | |||
| total_numel = 0 | |||
| offset = 0 | |||
| for key in keys: | |||
| i = 0 | |||
| size = [] | |||
| numel = 1 | |||
| while sizes_cpu[offset + i] > 0: | |||
| this_size = sizes_cpu[offset + i] | |||
| size.append(this_size) | |||
| numel *= this_size | |||
| i += 1 | |||
| key_size[key] = size | |||
| key_numel[key] = numel | |||
| total_numel += numel | |||
| offset += max_dim | |||
| return key_size, key_numel, total_numel | |||
| def broadcast_data(keys, data, datatype): | |||
| """Broadcast data from rank zero of each model parallel group to the | |||
| members of the same model parallel group. | |||
| Arguments: | |||
| keys: list of keys in the data disctionary to be broadcasted | |||
| data: data dictionary of string keys and cpu tensor values. | |||
| datatype: torch data type of all tensors in data associated | |||
| with keys. | |||
| """ | |||
| # Build (key, size) and (key, number of elements) dictionaries along | |||
| # with the total number of elements on all ranks. | |||
| key_size, key_numel, total_numel = _build_key_size_numel_dictionaries( | |||
| keys, data) | |||
| # Pack on rank zero. | |||
| if get_model_parallel_rank() == 0: | |||
| # Check that all keys have the same data type. | |||
| _check_data_types(keys, data, datatype) | |||
| # Flatten the data associated with the keys | |||
| flatten_data = torch.cat( | |||
| [data[key].contiguous().view(-1) for key in keys], dim=0).cuda() | |||
| else: | |||
| flatten_data = torch.empty( | |||
| total_numel, device=torch.cuda.current_device(), dtype=datatype) | |||
| # Boradcast | |||
| torch.distributed.broadcast( | |||
| flatten_data, | |||
| get_model_parallel_src_rank(), | |||
| group=get_model_parallel_group()) | |||
| # Unpack | |||
| output = {} | |||
| offset = 0 | |||
| for key in keys: | |||
| size = key_size[key] | |||
| numel = key_numel[key] | |||
| output[key] = flatten_data.narrow(0, offset, numel).view(size) | |||
| offset += numel | |||
| return output | |||
| @@ -0,0 +1,72 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # Parts of the code here are adapted from PyTorch | |||
| # repo: https://github.com/pytorch/pytorch | |||
| import torch | |||
| from torch._six import inf | |||
| from .initialize import get_model_parallel_group, get_model_parallel_rank | |||
| def clip_grad_norm(parameters, max_norm, norm_type=2): | |||
| """Clips gradient norm of an iterable of parameters. | |||
| This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and | |||
| added functionality to handle model parallel parameters. Note that | |||
| the gradients are modified in place. | |||
| Arguments: | |||
| parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a | |||
| single Tensor that will have gradients normalized | |||
| max_norm (float or int): max norm of the gradients | |||
| norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for | |||
| infinity norm. | |||
| Returns: | |||
| Total norm of the parameters (viewed as a single vector). | |||
| """ | |||
| if isinstance(parameters, torch.Tensor): | |||
| parameters = [parameters] | |||
| parameters = list(filter(lambda p: p.grad is not None, parameters)) | |||
| max_norm = float(max_norm) | |||
| norm_type = float(norm_type) | |||
| if norm_type == inf: | |||
| total_norm = max(p.grad.data.abs().max() for p in parameters) | |||
| total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) | |||
| # Take max across all GPUs. | |||
| torch.distributed.all_reduce( | |||
| total_norm_cuda, | |||
| op=torch.distributed.ReduceOp.MAX, | |||
| group=get_model_parallel_group()) | |||
| total_norm = total_norm_cuda[0].item() | |||
| else: | |||
| total_norm = 0 | |||
| for p in parameters: | |||
| if p.model_parallel or (get_model_parallel_rank() == 0): | |||
| param_norm = p.grad.data.norm(norm_type) | |||
| total_norm += param_norm.item()**norm_type | |||
| # Sum across all model parallel GPUs. | |||
| total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) | |||
| torch.distributed.all_reduce( | |||
| total_norm_cuda, | |||
| op=torch.distributed.ReduceOp.SUM, | |||
| group=get_model_parallel_group()) | |||
| total_norm = total_norm_cuda[0].item()**(1. / norm_type) | |||
| clip_coef = max_norm / (total_norm + 1e-6) | |||
| if clip_coef < 1: | |||
| for p in parameters: | |||
| p.grad.data.mul_(clip_coef) | |||
| return total_norm | |||
| @@ -0,0 +1,130 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Model and data parallel groups.""" | |||
| import torch | |||
| from .utils import ensure_divisibility | |||
| # Model parallel group that the current rank belongs to. | |||
| _MODEL_PARALLEL_GROUP = None | |||
| # Data parallel group that the current rank belongs to. | |||
| _DATA_PARALLEL_GROUP = None | |||
| def initialize_model_parallel(model_parallel_size_): | |||
| """ | |||
| Initialize model data parallel groups. | |||
| Arguments: | |||
| model_parallel_size: number of GPUs used to parallelize model. | |||
| Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we | |||
| use 2 GPUs to parallelize the model. The present function will | |||
| create 4 model parallel groups and 2 data parallel grous as: | |||
| 4 model parallel groups: | |||
| [g0, g1], [g2, g3], [g4, g5], [g6, g7] | |||
| 2 data parallel groups: | |||
| [g0, g2, g4, g6], [g1, g3, g5, g7] | |||
| Note that for efficiency, the caller should make sure adjacent ranks | |||
| are on the same DGX box. For example if we are using 2 DGX-1 boxes | |||
| with a total of 16 GPUs, rank 0 to 7 belong to the first box and | |||
| ranks 8 to 15 belong to the second box. | |||
| """ | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> initializing model parallel with size {}'.format( | |||
| model_parallel_size_)) | |||
| # Get world size and rank. Ensure some consistencies. | |||
| assert torch.distributed.is_initialized() | |||
| world_size = torch.distributed.get_world_size() | |||
| model_parallel_size = min(model_parallel_size_, world_size) | |||
| ensure_divisibility(world_size, model_parallel_size) | |||
| rank = torch.distributed.get_rank() | |||
| # Build the data parallel groups. | |||
| global _DATA_PARALLEL_GROUP | |||
| assert _DATA_PARALLEL_GROUP is None, \ | |||
| 'data parallel group is already initialized' | |||
| for i in range(model_parallel_size): | |||
| ranks = range(i, world_size, model_parallel_size) | |||
| group = torch.distributed.new_group(ranks) | |||
| if i == (rank % model_parallel_size): | |||
| _DATA_PARALLEL_GROUP = group | |||
| # Build the model parallel groups. | |||
| global _MODEL_PARALLEL_GROUP | |||
| assert _MODEL_PARALLEL_GROUP is None, \ | |||
| 'model parallel group is already initialized' | |||
| for i in range(world_size // model_parallel_size): | |||
| ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) | |||
| group = torch.distributed.new_group(ranks) | |||
| if i == (rank // model_parallel_size): | |||
| _MODEL_PARALLEL_GROUP = group | |||
| def model_parallel_is_initialized(): | |||
| """Check if model and data parallel groups are initialized.""" | |||
| if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None: | |||
| return False | |||
| return True | |||
| def get_model_parallel_group(): | |||
| """Get the model parallel group the caller rank belongs to.""" | |||
| assert _MODEL_PARALLEL_GROUP is not None, \ | |||
| 'model parallel group is not initialized' | |||
| return _MODEL_PARALLEL_GROUP | |||
| def get_data_parallel_group(): | |||
| """Get the data parallel group the caller rank belongs to.""" | |||
| assert _DATA_PARALLEL_GROUP is not None, \ | |||
| 'data parallel group is not initialized' | |||
| return _DATA_PARALLEL_GROUP | |||
| def get_model_parallel_world_size(): | |||
| """Return world size for the model parallel group.""" | |||
| return torch.distributed.get_world_size(group=get_model_parallel_group()) | |||
| def get_model_parallel_rank(): | |||
| """Return my rank for the model parallel group.""" | |||
| return torch.distributed.get_rank(group=get_model_parallel_group()) | |||
| def get_model_parallel_src_rank(): | |||
| """Calculate the global rank corresponding to a local rank zeor | |||
| in the model parallel group.""" | |||
| global_rank = torch.distributed.get_rank() | |||
| local_world_size = get_model_parallel_world_size() | |||
| return (global_rank // local_world_size) * local_world_size | |||
| def get_data_parallel_world_size(): | |||
| """Return world size for the data parallel group.""" | |||
| return torch.distributed.get_world_size(group=get_data_parallel_group()) | |||
| def get_data_parallel_rank(): | |||
| """Return my rank for the data parallel group.""" | |||
| return torch.distributed.get_rank(group=get_data_parallel_group()) | |||
| def destroy_model_parallel(): | |||
| """Set the groups to none.""" | |||
| global _MODEL_PARALLEL_GROUP | |||
| _MODEL_PARALLEL_GROUP = None | |||
| global _DATA_PARALLEL_GROUP | |||
| _DATA_PARALLEL_GROUP = None | |||
| @@ -0,0 +1,357 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # Parts of the code here are adapted from PyTorch | |||
| # repo: https://github.com/pytorch/pytorch | |||
| import math | |||
| import torch | |||
| import torch.nn.functional as F | |||
| import torch.nn.init as init | |||
| from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm | |||
| from torch.nn.parameter import Parameter | |||
| from .initialize import get_model_parallel_rank, get_model_parallel_world_size | |||
| from .mappings import (copy_to_model_parallel_region, | |||
| gather_from_model_parallel_region, | |||
| reduce_from_model_parallel_region, | |||
| scatter_to_model_parallel_region) | |||
| from .random import get_cuda_rng_tracker | |||
| from .utils import VocabUtility, divide, split_tensor_along_last_dim | |||
| def _initialize_affine_weight(weight, | |||
| output_size, | |||
| input_size, | |||
| per_partition_size, | |||
| partition_dim, | |||
| init_method, | |||
| stride=1, | |||
| return_master_weight=False): | |||
| """Initialize affine weight for model parallel. | |||
| Build the master weight on all processes and scatter | |||
| the relevant chunk.""" | |||
| # If we only use 1 process for model parallelism, bypass scatter. | |||
| world_size = get_model_parallel_world_size() | |||
| if world_size == 1: | |||
| init_method(weight) | |||
| if return_master_weight: | |||
| return weight | |||
| return None | |||
| # Initialize master weight | |||
| master_weight = torch.empty( | |||
| output_size, input_size, dtype=weight.dtype, requires_grad=False) | |||
| init_method(master_weight) | |||
| # Split and copy | |||
| per_partition_per_stride_size = divide(per_partition_size, stride) | |||
| weight_list = torch.split( | |||
| master_weight, per_partition_per_stride_size, dim=partition_dim) | |||
| rank = get_model_parallel_rank() | |||
| my_weight_list = weight_list[rank::world_size] | |||
| with torch.no_grad(): | |||
| torch.cat(my_weight_list, dim=partition_dim, out=weight) | |||
| if return_master_weight: | |||
| return master_weight | |||
| return None | |||
| class VocabParallelEmbedding(torch.nn.Module): | |||
| """Embedding parallelized in the vocabulary dimension. | |||
| This is mainly adapted from torch.nn.Embedding and all the default | |||
| values are kept. | |||
| Arguments: | |||
| num_embeddings: vocabulary size. | |||
| embedding_dim: size of hidden state. | |||
| init_method: method to initialize weights. | |||
| """ | |||
| def __init__(self, | |||
| num_embeddings, | |||
| embedding_dim, | |||
| init_method=init.xavier_normal_): | |||
| super(VocabParallelEmbedding, self).__init__() | |||
| # Keep the input dimensions. | |||
| self.num_embeddings = num_embeddings | |||
| self.embedding_dim = embedding_dim | |||
| # Set the detauls for compatibility. | |||
| self.padding_idx = None | |||
| self.max_norm = None | |||
| self.norm_type = 2. | |||
| self.scale_grad_by_freq = False | |||
| self.sparse = False | |||
| self._weight = None | |||
| # Divide the weight matrix along the vocaburaly dimension. | |||
| self.vocab_start_index, self.vocab_end_index = \ | |||
| VocabUtility.vocab_range_from_global_vocab_size( | |||
| self.num_embeddings, get_model_parallel_rank(), | |||
| get_model_parallel_world_size()) | |||
| self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index # noqa | |||
| # Allocate weights. | |||
| self.weight = Parameter( | |||
| torch.Tensor(self.num_embeddings_per_partition, | |||
| self.embedding_dim)) | |||
| self.weight.model_parallel = True | |||
| # And initialize. | |||
| _initialize_affine_weight(self.weight, self.num_embeddings, | |||
| self.embedding_dim, | |||
| self.num_embeddings_per_partition, 0, | |||
| init_method) | |||
| def forward(self, input_): | |||
| # Build the mask. | |||
| input_mask = (input_ < self.vocab_start_index) | \ | |||
| (input_ >= self.vocab_end_index) | |||
| # Mask the input. | |||
| masked_input = input_.clone() - self.vocab_start_index | |||
| masked_input[input_mask] = 0 | |||
| # Get the embeddings. | |||
| output_parallel = F.embedding(masked_input, self.weight, | |||
| self.padding_idx, self.max_norm, | |||
| self.norm_type, self.scale_grad_by_freq, | |||
| self.sparse) | |||
| # Mask the output embedding. | |||
| output_parallel[input_mask, :] = 0.0 | |||
| # Reduce across all the model parallel GPUs. | |||
| output = reduce_from_model_parallel_region(output_parallel) | |||
| return output | |||
| class ParallelEmbedding(torch.nn.Module): | |||
| """Embedding parallelized in the embedding dimension. | |||
| This is mainly adapted from torch.nn.Embedding and all the default | |||
| values are kept. | |||
| Arguments: | |||
| num_embeddings: vocabulary size. | |||
| embedding_dim: size of hidden state. | |||
| init_method: method to initialize weights. | |||
| """ | |||
| def __init__(self, | |||
| num_embeddings, | |||
| embedding_dim, | |||
| init_method=init.xavier_normal_, | |||
| keep_master_weight_for_test=False): | |||
| super(ParallelEmbedding, self).__init__() | |||
| # Keep the input dimensions. | |||
| self.num_embeddings = num_embeddings | |||
| self.embedding_dim = embedding_dim | |||
| # Set some detauls for compatibility. | |||
| self.padding_idx = None | |||
| self.max_norm = None | |||
| self.norm_type = 2. | |||
| self.scale_grad_by_freq = False | |||
| self.sparse = False | |||
| self._weight = None | |||
| # Divide the weight matrix along the embedding dimension. | |||
| world_size = get_model_parallel_world_size() | |||
| self.embedding_dim_per_partition = divide(self.embedding_dim, | |||
| world_size) | |||
| # Allocate weights. | |||
| self.weight = Parameter( | |||
| torch.Tensor(self.num_embeddings, | |||
| self.embedding_dim_per_partition)) | |||
| self.weight.model_parallel = True | |||
| # And initialize. | |||
| _initialize_affine_weight( | |||
| self.weight, | |||
| self.num_embeddings, | |||
| self.embedding_dim, | |||
| self.embedding_dim_per_partition, | |||
| 1, | |||
| init_method, | |||
| stride=1, | |||
| return_master_weight=False) | |||
| def forward(self, input_): | |||
| input_parallel = copy_to_model_parallel_region(input_) | |||
| output_parallel = F.embedding(input_parallel, self.weight, | |||
| self.padding_idx, self.max_norm, | |||
| self.norm_type, self.scale_grad_by_freq, | |||
| self.sparse) | |||
| output = gather_from_model_parallel_region(output_parallel) | |||
| return output | |||
| class ColumnParallelLinear(torch.nn.Module): | |||
| """Linear layer with column parallelism. | |||
| The linear layer is defined as Y = XA + b. A is parallelized along | |||
| its second dimension as A = [A_1, ..., A_p]. | |||
| Arguments: | |||
| input_size: first dimension of matrix A. | |||
| output_size: second dimension of matrix A. | |||
| bias: If true, add bias | |||
| gather_output: If true, call all-gether on output and make Y avaiable | |||
| to all GPUs, otherwise, every GPU will have its output | |||
| which is Y_i = XA_i | |||
| init_method: method to initialize weights. Note that bias is always set | |||
| to zero. | |||
| stride: For the strided linear layers. | |||
| keep_master_weight_for_test: This was added for testing and should be | |||
| set to False. It returns the master weights | |||
| used for initialization. | |||
| """ | |||
| def __init__(self, | |||
| input_size, | |||
| output_size, | |||
| bias=True, | |||
| gather_output=True, | |||
| init_method=init.xavier_normal_, | |||
| stride=1, | |||
| keep_master_weight_for_test=False): | |||
| super(ColumnParallelLinear, self).__init__() | |||
| # Keep input parameters | |||
| self.input_size = input_size | |||
| self.output_size = output_size | |||
| self.gather_output = gather_output | |||
| # Divide the weight matrix along the last dimension. | |||
| world_size = get_model_parallel_world_size() | |||
| self.output_size_per_partition = divide(output_size, world_size) | |||
| # Parameters. | |||
| # Note: torch.nn.functional.linear performs XA^T + b and as a result | |||
| # we allocate the transpose. | |||
| self.weight = Parameter( | |||
| torch.Tensor(self.output_size_per_partition, self.input_size)) | |||
| self.weight.model_parallel = True | |||
| if bias: | |||
| self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) | |||
| self.bias.model_parallel = True | |||
| # Always initialize bias to zero. | |||
| with torch.no_grad(): | |||
| self.bias.zero_() | |||
| else: | |||
| self.register_parameter('bias', None) | |||
| # Initialize weight. | |||
| self.master_weight = _initialize_affine_weight( | |||
| self.weight, | |||
| self.output_size, | |||
| self.input_size, | |||
| self.output_size_per_partition, | |||
| 0, | |||
| init_method, | |||
| stride=stride, | |||
| return_master_weight=keep_master_weight_for_test) | |||
| def forward(self, input_): | |||
| # Set up backprop all-reduce. | |||
| input_parallel = copy_to_model_parallel_region(input_) | |||
| # Matrix multiply. | |||
| output_parallel = F.linear(input_parallel, self.weight, self.bias) | |||
| if self.gather_output: | |||
| # All-gather across the partitions. | |||
| output = gather_from_model_parallel_region(output_parallel) | |||
| else: | |||
| output = output_parallel | |||
| return output | |||
| class RowParallelLinear(torch.nn.Module): | |||
| """Linear layer with row parallelism. | |||
| The linear layer is defined as Y = XA + b. A is parallelized along | |||
| its first dimension and X along its second dimension as: | |||
| - - | |||
| | A_1 | | |||
| | . | | |||
| A = | . | X = [X_1, ..., X_p] | |||
| | . | | |||
| | A_p | | |||
| - - | |||
| Arguments: | |||
| input_size: first dimension of matrix A. | |||
| output_size: second dimension of matrix A. | |||
| bias: If true, add bias. Note that bias is not parallelized. | |||
| input_is_parallel: If true, we assume that the input is already | |||
| split across the GPUs and we do not split | |||
| again. | |||
| init_method: method to initialize weights. Note that bias is always set | |||
| to zero. | |||
| stride: For the strided linear layers. | |||
| keep_master_weight_for_test: This was added for testing and should be | |||
| set to False. It returns the master weights | |||
| used for initialization. | |||
| """ | |||
| def __init__(self, | |||
| input_size, | |||
| output_size, | |||
| bias=True, | |||
| input_is_parallel=False, | |||
| init_method=init.xavier_normal_, | |||
| stride=1, | |||
| keep_master_weight_for_test=False): | |||
| super(RowParallelLinear, self).__init__() | |||
| # Keep input parameters | |||
| self.input_size = input_size | |||
| self.output_size = output_size | |||
| self.input_is_parallel = input_is_parallel | |||
| # Divide the weight matrix along the last dimension. | |||
| world_size = get_model_parallel_world_size() | |||
| self.input_size_per_partition = divide(input_size, world_size) | |||
| # Parameters. | |||
| # Note: torch.nn.functional.linear performs XA^T + b and as a result | |||
| # we allocate the transpose. | |||
| self.weight = Parameter( | |||
| torch.Tensor(self.output_size, self.input_size_per_partition)) | |||
| self.weight.model_parallel = True | |||
| if bias: | |||
| self.bias = Parameter(torch.Tensor(self.output_size)) | |||
| # Always initialize bias to zero. | |||
| with torch.no_grad(): | |||
| self.bias.zero_() | |||
| else: | |||
| self.register_parameter('bias', None) | |||
| # Initialize weight. | |||
| self.master_weight = _initialize_affine_weight( | |||
| self.weight, | |||
| self.output_size, | |||
| self.input_size, | |||
| self.input_size_per_partition, | |||
| 1, | |||
| init_method, | |||
| stride=stride, | |||
| return_master_weight=keep_master_weight_for_test) | |||
| def forward(self, input_): | |||
| # Set up backprop all-reduce. | |||
| if self.input_is_parallel: | |||
| input_parallel = input_ | |||
| else: | |||
| input_parallel = scatter_to_model_parallel_region(input_) | |||
| # Matrix multiply. | |||
| output_parallel = F.linear(input_parallel, self.weight) | |||
| # All-reduce across all the partitions. | |||
| output_ = reduce_from_model_parallel_region(output_parallel) | |||
| if self.bias is not None: | |||
| output = output_ + self.bias | |||
| else: | |||
| output = output_ | |||
| return output | |||
| @@ -0,0 +1,144 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| from .initialize import get_model_parallel_group | |||
| from .utils import split_tensor_along_last_dim | |||
| def _reduce(input_): | |||
| """All-reduce the the input tensor across model parallel group.""" | |||
| group = get_model_parallel_group() | |||
| # Bypass the function if we are using only 1 GPU. | |||
| if torch.distributed.get_world_size(group=group) == 1: | |||
| return input_ | |||
| # All-reduce. | |||
| torch.distributed.all_reduce(input_, group=group) | |||
| return input_ | |||
| def _split(input_): | |||
| """Split the tensor along its last dimension and keep the | |||
| corresponding slice.""" | |||
| group = get_model_parallel_group() | |||
| # Bypass the function if we are using only 1 GPU. | |||
| if torch.distributed.get_world_size(group=group) == 1: | |||
| return input_ | |||
| # Split along last dimension. | |||
| world_size = torch.distributed.get_world_size(group=group) | |||
| input_list = split_tensor_along_last_dim(input_, world_size) | |||
| # Note: torch.split does not create contiguous tensors by default. | |||
| rank = torch.distributed.get_rank(group=group) | |||
| output = input_list[rank].contiguous() | |||
| return output | |||
| def _gather(input_): | |||
| """Gather tensors and concatinate along the last dimension.""" | |||
| group = get_model_parallel_group() | |||
| # Bypass the function if we are using only 1 GPU. | |||
| if torch.distributed.get_world_size(group=group) == 1: | |||
| return input_ | |||
| # Size and dimension. | |||
| last_dim = input_.dim() - 1 | |||
| rank = torch.distributed.get_rank(group=group) | |||
| world_size = torch.distributed.get_world_size(group=group) | |||
| tensor_list = [torch.empty_like(input_) for _ in range(world_size)] | |||
| tensor_list[rank] = input_ | |||
| torch.distributed.all_gather(tensor_list, input_, group=group) | |||
| # Note: torch.cat already creates a contiguous tensor. | |||
| output = torch.cat(tensor_list, dim=last_dim).contiguous() | |||
| return output | |||
| class _CopyToModelParallelRegion(torch.autograd.Function): | |||
| """Pass the input to the model parallel region.""" | |||
| @staticmethod | |||
| def forward(ctx, input_): | |||
| return input_ | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| return _reduce(grad_output) | |||
| class _ReduceFromModelParallelRegion(torch.autograd.Function): | |||
| """All-redcue the input from the model parallel region.""" | |||
| @staticmethod | |||
| def forward(ctx, input_): | |||
| return _reduce(input_) | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| return grad_output | |||
| class _ScatterToModelParallelRegion(torch.autograd.Function): | |||
| """Split the input and keep only the corresponding chuck to the rank.""" | |||
| @staticmethod | |||
| def forward(ctx, input_): | |||
| return _split(input_) | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| return _gather(grad_output) | |||
| class _GatherFromModelParallelRegion(torch.autograd.Function): | |||
| """Gather the input from model parallel region and concatinate.""" | |||
| @staticmethod | |||
| def forward(ctx, input_): | |||
| return _gather(input_) | |||
| @staticmethod | |||
| def backward(ctx, grad_output): | |||
| return _split(grad_output) | |||
| # ----------------- | |||
| # Helper functions. | |||
| # ----------------- | |||
| def copy_to_model_parallel_region(input_): | |||
| return _CopyToModelParallelRegion.apply(input_) | |||
| def reduce_from_model_parallel_region(input_): | |||
| return _ReduceFromModelParallelRegion.apply(input_) | |||
| def scatter_to_model_parallel_region(input_): | |||
| return _ScatterToModelParallelRegion.apply(input_) | |||
| def gather_from_model_parallel_region(input_): | |||
| return _GatherFromModelParallelRegion.apply(input_) | |||
| @@ -0,0 +1,408 @@ | |||
| # Modified by Samyam Rajbhandari | |||
| # Used to partition the activations stored for backward propagation | |||
| # Therefore reduces the memory consumption | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # Parts of the code here are adapted from PyTorch | |||
| # repo: https://github.com/pytorch/pytorch | |||
| import contextlib | |||
| import torch | |||
| import torch.distributed as dist | |||
| from torch import _C | |||
| from torch.cuda import _lazy_call | |||
| from torch.cuda import device as device_ctx_manager | |||
| from .initialize import (get_data_parallel_rank, get_model_parallel_group, | |||
| get_model_parallel_rank, | |||
| get_model_parallel_world_size) | |||
| # from torch.utils.checkpoint import detach_variable | |||
| PARTITION_ACTIVATIONS = False | |||
| PA_CORRECTNESS_TEST = False | |||
| def see_memory_usage(message, force=False): | |||
| if not force: | |||
| return | |||
| dist.barrier() | |||
| if dist.get_rank() == 0: | |||
| print(message) | |||
| print('Memory Allocated ', | |||
| torch.cuda.memory_allocated() / (1024 * 1024 * 1024), | |||
| 'GigaBytes') | |||
| print('Max Memory Allocated ', | |||
| torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), | |||
| 'GigaBytes') | |||
| print('Cache Allocated ', | |||
| torch.cuda.memory_cached() / (1024 * 1024 * 1024), 'GigaBytes') | |||
| print('Max cache Allocated ', | |||
| torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), | |||
| 'GigaBytes') | |||
| print(' ') | |||
| # input("Press Any Key To Continue ..") | |||
| mp_rank = None # get_model_parallel_rank() | |||
| mp_size = None # get_model_parallel_world_size() | |||
| mp_group = None # get_model_parallel_group() | |||
| # Default name for the model parallel rng tracker. | |||
| _MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng' | |||
| transport_stream = None | |||
| cuda_device = None | |||
| def detach_variable(inputs, device=None): | |||
| if isinstance(inputs, tuple): | |||
| out = [] | |||
| for inp in inputs: | |||
| if not isinstance(inp, torch.Tensor): | |||
| out.append(inp) | |||
| continue | |||
| requires_grad = inp.requires_grad | |||
| if device is not None: | |||
| x = inp.to(device=device) | |||
| else: | |||
| x = inp | |||
| x = x.detach() | |||
| x.requires_grad = requires_grad | |||
| out.append(x) | |||
| return tuple(out) | |||
| else: | |||
| raise RuntimeError( | |||
| 'Only tuple of tensors is supported. Got Unsupported input type: ', | |||
| type(inputs).__name__) | |||
| def _set_cuda_rng_state(new_state, device=-1): | |||
| """Sets the random number generator state of the current GPU. | |||
| Argumentss: | |||
| new_state (torch.ByteTensor): The desired state | |||
| This function is adapted from PyTorch repo (torch.cuda.set_rng_state) | |||
| with a single change: the input state is not cloned. Cloning caused | |||
| major performance issues for +4 GPU cases. | |||
| """ | |||
| if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState): | |||
| # older PyTorch | |||
| def cb(): | |||
| with device_ctx_manager(device): | |||
| _C._cuda_setRNGState(new_state) | |||
| else: | |||
| # newer PyTorch | |||
| if device == -1: | |||
| device = torch.device('cuda') | |||
| elif isinstance(device, str): | |||
| device = torch.device(device) | |||
| elif isinstance(device, int): | |||
| device = torch.device('cuda', device) | |||
| def cb(): | |||
| idx = device.index | |||
| if idx is None: | |||
| idx = torch.cuda.current_device() | |||
| default_generator = torch.cuda.default_generators[idx] | |||
| default_generator.set_state(new_state) | |||
| _lazy_call(cb) | |||
| class CudaRNGStatesTracker: | |||
| """Tracker for the cuda RNG states. | |||
| Using the `add` method, a cuda rng state is initialized based on | |||
| the input `seed` and is assigned to `name`. Later, by forking the | |||
| rng state, we can perform operations and return to our starting | |||
| cuda state. | |||
| """ | |||
| def __init__(self): | |||
| # Map from a string name to the cuda rng state. | |||
| self.states_ = {} | |||
| # Seeds are just for book keeping and ensure no seed is set twice. | |||
| self.seeds_ = set() | |||
| def reset(self): | |||
| """Set to the initial state (no tracker).""" | |||
| self.states_ = {} | |||
| self.seeds_ = set() | |||
| def get_states(self): | |||
| """Get rng states. Copy the dictionary so we have direct | |||
| pointers to the states, not just a pointer to the dictionary.""" | |||
| states = {} | |||
| for name in self.states_: | |||
| states[name] = self.states_[name] | |||
| return states | |||
| def set_states(self, states): | |||
| """Set the rng states. For efficiency purposes, we do not check | |||
| the size of seed for compatibility.""" | |||
| self.states_ = states | |||
| def add(self, name, seed): | |||
| """Track the rng state.""" | |||
| # Check seed is not already used. | |||
| if seed in self.seeds_: | |||
| raise Exception('seed {} already exists'.format(seed)) | |||
| self.seeds_.add(seed) | |||
| # Check that state is not already defined. | |||
| if name in self.states_: | |||
| raise Exception('cuda rng state {} already exists'.format(name)) | |||
| # Get the current rng state. | |||
| orig_rng_state = torch.cuda.get_rng_state() | |||
| # Set the new state and store it. | |||
| torch.cuda.manual_seed(seed) | |||
| self.states_[name] = torch.cuda.get_rng_state() | |||
| # Reset rng state to what it was. | |||
| _set_cuda_rng_state(orig_rng_state) | |||
| @contextlib.contextmanager | |||
| def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): | |||
| """Fork the cuda rng state, perform operations, and exit with | |||
| the original state.""" | |||
| # Check if we have added the state | |||
| if name not in self.states_: | |||
| raise Exception('cuda rng state {} is not added'.format(name)) | |||
| # Store current rng state. | |||
| orig_cuda_rng_state = torch.cuda.get_rng_state() | |||
| # Set rng state to the desired one | |||
| _set_cuda_rng_state(self.states_[name]) | |||
| # Do the stuff we wanted to do. | |||
| try: | |||
| yield | |||
| finally: | |||
| # Update the current rng state for later use. | |||
| self.states_[name] = torch.cuda.get_rng_state() | |||
| # And set the state to the original state we started with. | |||
| _set_cuda_rng_state(orig_cuda_rng_state) | |||
| # RNG tracker object. | |||
| _CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() | |||
| def get_cuda_rng_tracker(): | |||
| """Get cuda rng tracker.""" | |||
| return _CUDA_RNG_STATE_TRACKER | |||
| def model_parallel_cuda_manual_seed(seed): | |||
| """Initialize model parallel cuda seed. | |||
| This function should be called after the model parallel is | |||
| initialized. Also, no torch.cuda.manual_seed should be called | |||
| after this function. Basically, this is replacement for that | |||
| function. | |||
| Two set of RNG states are tracked: | |||
| default state: This is for data parallelism and is the same among a | |||
| set of model parallel GPUs but different across | |||
| different model paralle groups. This is used for | |||
| example for dropout in the non-model-parallel regions. | |||
| model-parallel state: This state is different among a set of model | |||
| parallel GPUs, but the same across data parallel | |||
| groups. This is used for example for dropout in | |||
| model parallel regions. | |||
| """ | |||
| # 2718 is just for fun and any POSITIVE value will work. | |||
| offset = seed + 2718 | |||
| model_parallel_seed = offset + get_model_parallel_rank() | |||
| # Data parallel gets the original sedd. | |||
| data_parallel_seed = seed | |||
| if torch.distributed.get_rank() == 0: | |||
| print( | |||
| '> initializing model parallel cuda seeds on global rank {}, ' | |||
| 'model parallel rank {}, and data parallel rank {} with ' | |||
| 'model parallel seed: {} and data parallel seed: {}'.format( | |||
| torch.distributed.get_rank(), get_model_parallel_rank(), | |||
| get_data_parallel_rank(), model_parallel_seed, | |||
| data_parallel_seed), | |||
| flush=True) | |||
| _CUDA_RNG_STATE_TRACKER.reset() | |||
| # Set the default state. | |||
| torch.cuda.manual_seed(data_parallel_seed) | |||
| # and model parallel state. | |||
| _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, | |||
| model_parallel_seed) | |||
| def get_partition_start(item): | |||
| global mp_rank, mp_size, mp_group | |||
| partition_size = get_partition_size(item) | |||
| start = partition_size * mp_rank | |||
| return int(start) | |||
| def get_partition_size(item): | |||
| global mp_rank, mp_size, mp_group | |||
| size = item.numel() | |||
| partition_size = size / mp_size | |||
| return int(partition_size) | |||
| def get_full_inputs(tensors): | |||
| inputs = [] | |||
| for i in range(int(len(tensors) / 2) - 1): | |||
| item = tensors[2 * i] | |||
| size = tensors[2 * i + 1] | |||
| partition_size = item.numel() | |||
| tensor_size = partition_size * mp_size | |||
| flat_tensor = torch.zeros([tensor_size], | |||
| dtype=item.dtype, | |||
| device=item.device) | |||
| partitions = [] | |||
| for i in range(mp_size): | |||
| part_i = flat_tensor.narrow(0, partition_size * i, partition_size) | |||
| if i == mp_rank: | |||
| part_i.copy_(item) | |||
| partitions.append(part_i) | |||
| dist.all_gather(partitions, partitions[mp_rank], group=mp_group) | |||
| input_tensor = flat_tensor.view(list(size.numpy())) | |||
| item.data = input_tensor.data | |||
| inputs.append(item) | |||
| inputs.append(tensors[-2]) | |||
| return tuple(inputs) | |||
| class CheckpointFunction(torch.autograd.Function): | |||
| """This function is adapted from torch.utils.checkpoint with | |||
| two main changes: | |||
| 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` | |||
| 2) the states in the model parallel tracker are also properly | |||
| tracked/set/reset. | |||
| """ | |||
| @staticmethod | |||
| def forward(ctx, run_function, *args): | |||
| ctx.run_function = run_function | |||
| global mp_rank, mp_size, mp_group | |||
| if mp_rank is None: | |||
| mp_rank = get_model_parallel_rank() | |||
| mp_size = get_model_parallel_world_size() | |||
| mp_group = get_model_parallel_group() | |||
| global cuda_device, transport_stream, PARTITION_ACTIVATIONS | |||
| if cuda_device is None: | |||
| if dist.get_rank() == 0: | |||
| print( | |||
| f'Partition Activations {PARTITION_ACTIVATIONS} and Correctness Check {PA_CORRECTNESS_TEST}' | |||
| ) | |||
| cuda_device = torch.cuda.current_device() | |||
| # The transport stream is used to overlap the allgather communication for the activations | |||
| # with the computation in the backward pass | |||
| transport_stream = torch.cuda.Stream(device=cuda_device) | |||
| if PARTITION_ACTIVATIONS: | |||
| inputs = [ | |||
| item.detach().contiguous().view(-1).narrow( | |||
| 0, get_partition_start(item), | |||
| get_partition_size(item)).clone() for item in args[:-1] | |||
| ] | |||
| inputs.append(args[-1]) | |||
| # just in case something funky is happening such as reuse of inputs | |||
| inputs_cuda = [item.to(cuda_device) for item in args] | |||
| # Copy the rng states. | |||
| ctx.fwd_cpu_rng_state = torch.get_rng_state() | |||
| ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() | |||
| ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() | |||
| # ctx.save_for_backward(*args) | |||
| with torch.no_grad(): | |||
| outputs = run_function(*inputs_cuda) | |||
| del inputs_cuda | |||
| if PARTITION_ACTIVATIONS: | |||
| new_args = [] | |||
| for arg, inp in zip(args, inputs): | |||
| size = torch.tensor(arg.size()) | |||
| arg.data = inp.data | |||
| new_args.append(arg) | |||
| new_args.append(size) | |||
| ctx.save_for_backward(*new_args) | |||
| else: | |||
| ctx.save_for_backward(*args) | |||
| return outputs | |||
| @staticmethod | |||
| def backward(ctx, *args): | |||
| if not torch.autograd._is_checkpoint_valid(): | |||
| raise RuntimeError('Checkpointing is not compatible with .grad(), ' | |||
| 'please use .backward() if possible') | |||
| global cuda_device, transport_stream, PARTITION_ACTIVATIONS | |||
| if PARTITION_ACTIVATIONS: | |||
| with torch.cuda.stream(transport_stream): | |||
| inputs = get_full_inputs(ctx.saved_tensors) | |||
| detached_inputs = detach_variable(inputs) | |||
| else: | |||
| inputs = ctx.saved_tensors | |||
| detached_inputs = detach_variable(inputs) | |||
| # Store the current states. | |||
| bwd_cpu_rng_state = torch.get_rng_state() | |||
| bwd_cuda_rng_state = torch.cuda.get_rng_state() | |||
| bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() | |||
| # Set the states to what it used to be before the forward pass. | |||
| torch.set_rng_state(ctx.fwd_cpu_rng_state) | |||
| _set_cuda_rng_state(ctx.fwd_cuda_rng_state) | |||
| get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) | |||
| if PARTITION_ACTIVATIONS: | |||
| current_stream = torch.cuda.current_stream() | |||
| current_stream.wait_stream(transport_stream) | |||
| with torch.enable_grad(): | |||
| outputs = ctx.run_function(*detached_inputs) | |||
| # Set the states back to what it was at the start of this function. | |||
| torch.set_rng_state(bwd_cpu_rng_state) | |||
| _set_cuda_rng_state(bwd_cuda_rng_state) | |||
| get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) | |||
| if isinstance(outputs, torch.Tensor): | |||
| outputs = (outputs, ) | |||
| torch.autograd.backward(outputs, args) | |||
| return (None, ) + tuple(inp.grad for inp in detached_inputs) | |||
| def checkpoint(function, *args): | |||
| """Checkpoint a model or part of the model. | |||
| This has been directly copied from torch.utils.checkpoint.""" | |||
| return CheckpointFunction.apply(function, *args) | |||
| def partition_activations_in_checkpoint(partition_activation): | |||
| global PARTITION_ACTIVATIONS | |||
| PARTITION_ACTIVATIONS = partition_activation | |||
| if dist.get_rank() == 0: | |||
| print( | |||
| f'**************Partition Activations {PARTITION_ACTIVATIONS}************' | |||
| ) | |||
| @@ -0,0 +1,86 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import argparse | |||
| import os | |||
| import random | |||
| import mpu | |||
| import numpy | |||
| import torch | |||
| class IdentityLayer(torch.nn.Module): | |||
| def __init__(self, size, scale=1.0): | |||
| super(IdentityLayer, self).__init__() | |||
| self.weight = torch.nn.Parameter(scale * torch.randn(size)) | |||
| def forward(self): | |||
| return self.weight | |||
| def set_random_seed(seed): | |||
| """Set random seed for reproducability.""" | |||
| random.seed(seed) | |||
| numpy.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| mpu.model_parallel_cuda_manual_seed(seed) | |||
| def initialize_distributed(backend='nccl'): | |||
| """Initialize torch.distributed.""" | |||
| # Get local rank in case it is provided. | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument( | |||
| '--local_rank', | |||
| type=int, | |||
| default=None, | |||
| help='local rank passed from distributed launcher') | |||
| args = parser.parse_args() | |||
| local_rank = args.local_rank | |||
| # Get rank and world size. | |||
| rank = int(os.getenv('RANK', '0')) | |||
| world_size = int(os.getenv('WORLD_SIZE', '1')) | |||
| print('> initializing torch.distributed with local rank: {}, ' | |||
| 'rank: {}, world size: {}'.format(local_rank, rank, world_size)) | |||
| # Set the device id. | |||
| device = rank % torch.cuda.device_count() | |||
| if local_rank is not None: | |||
| device = local_rank | |||
| torch.cuda.set_device(device) | |||
| # Call the init process. | |||
| init_method = 'tcp://' | |||
| master_ip = os.getenv('MASTER_ADDR', 'localhost') | |||
| master_port = os.getenv('MASTER_PORT', '6000') | |||
| init_method += master_ip + ':' + master_port | |||
| torch.distributed.init_process_group( | |||
| backend=backend, | |||
| world_size=world_size, | |||
| rank=rank, | |||
| init_method=init_method) | |||
| def print_separator(message): | |||
| torch.distributed.barrier() | |||
| filler_len = (78 - len(message)) // 2 | |||
| filler = '-' * filler_len | |||
| string = '\n' + filler + ' {} '.format(message) + filler | |||
| if torch.distributed.get_rank() == 0: | |||
| print(string, flush=True) | |||
| torch.distributed.barrier() | |||
| @@ -0,0 +1,106 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import random | |||
| import sys | |||
| import mpu | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from commons import (IdentityLayer, initialize_distributed, print_separator, | |||
| set_random_seed) | |||
| from mpu.cross_entropy import vocab_parallel_cross_entropy | |||
| sys.path.append('../..') | |||
| def torch_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, | |||
| seed): | |||
| set_random_seed(seed) | |||
| identity = IdentityLayer((batch_size, seq_length, vocab_size), | |||
| scale=logits_scale).cuda() | |||
| logits = identity() | |||
| target = torch.cuda.LongTensor(size=(batch_size, | |||
| seq_length)).random_(0, vocab_size) | |||
| loss = F.cross_entropy( | |||
| logits.view(-1, | |||
| logits.size()[-1]), target.view(-1), | |||
| reduction='none').view_as(target).mean() | |||
| loss.backward() | |||
| return loss, identity.weight.grad | |||
| def mpu_cross_entropy(batch_size, seq_length, vocab_size, logits_scale, seed): | |||
| set_random_seed(seed) | |||
| identity = IdentityLayer((batch_size, seq_length, vocab_size), | |||
| scale=logits_scale).cuda() | |||
| logits = identity() | |||
| logits_parallel = mpu.scatter_to_model_parallel_region(logits) | |||
| target = torch.cuda.LongTensor(size=(batch_size, | |||
| seq_length)).random_(0, vocab_size) | |||
| loss = vocab_parallel_cross_entropy(logits_parallel, target).mean() | |||
| loss.backward() | |||
| return loss, identity.weight.grad | |||
| def test_cross_entropy(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing cross entropy with model parallel size {} ...'.format( | |||
| model_parallel_size)) | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| batch_size = 13 | |||
| seq_length = 17 | |||
| vocab_size_per_partition = 11 | |||
| logits_scale = 1000.0 | |||
| vocab_size = vocab_size_per_partition * model_parallel_size | |||
| seed = 1234 | |||
| loss_torch, grad_torch = torch_cross_entropy(batch_size, seq_length, | |||
| vocab_size, logits_scale, | |||
| seed) | |||
| loss_mpu, grad_mpu = mpu_cross_entropy(batch_size, seq_length, vocab_size, | |||
| logits_scale, seed) | |||
| error = loss_torch.sub_(loss_mpu).abs().max() | |||
| print(' max error in loss on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| error = grad_torch.sub_(grad_mpu).abs().max() | |||
| print(' max error in grad on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print('>> passed the test :-)') | |||
| if __name__ == '__main__': | |||
| initialize_distributed() | |||
| world_size = torch.distributed.get_world_size() | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| print_separator('test cross entropy') | |||
| test_cross_entropy(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| @@ -0,0 +1,91 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import functools | |||
| import operator | |||
| import sys | |||
| import mpu | |||
| import torch | |||
| from commons import initialize_distributed, print_separator | |||
| from mpu import data as data_utils | |||
| sys.path.append('../..') | |||
| def test_boradcast_data(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print( | |||
| '> testing boradcast_data with model parallel size {} ...'.format( | |||
| model_parallel_size)) | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| torch.manual_seed(1234 + mpu.get_data_parallel_rank()) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| key_size_t = { | |||
| 'key1': [7, 11], | |||
| 'key2': [8, 2, 1], | |||
| 'key3': [13], | |||
| 'key4': [5, 1, 2], | |||
| 'key5': [5, 12] | |||
| } | |||
| keys = list(key_size_t.keys()) | |||
| data = {} | |||
| data_t = {} | |||
| for key in key_size_t: | |||
| data[key] = torch.LongTensor(size=key_size_t[key]).random_(0, 1000) | |||
| data_t[key] = data[key].clone() | |||
| data['keyX'] = torch.FloatTensor(size=(5, )).random_(0, 1000) | |||
| data_t['keyX'] = data['keyX'].clone() | |||
| if mpu.get_model_parallel_rank() != 0: | |||
| data = None | |||
| data_utils._check_data_types(keys, data_t, torch.int64) | |||
| key_size, key_numel, \ | |||
| total_numel = data_utils._build_key_size_numel_dictionaries(keys, data) | |||
| for key in keys: | |||
| assert key_size[key] == key_size_t[key] | |||
| total_numel_t = 0 | |||
| for key in keys: | |||
| target_size = functools.reduce(operator.mul, key_size_t[key], 1) | |||
| assert key_numel[key] == target_size | |||
| total_numel_t += target_size | |||
| assert total_numel == total_numel_t | |||
| data_b = data_utils.broadcast_data(keys, data, torch.int64) | |||
| for key in keys: | |||
| tensor = data_t[key].cuda() | |||
| assert data_b[key].sub(tensor).abs().max() == 0 | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print('>> passed the test :-)') | |||
| if __name__ == '__main__': | |||
| initialize_distributed() | |||
| world_size = torch.distributed.get_world_size() | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| print_separator('test test boradcast data') | |||
| test_boradcast_data(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| @@ -0,0 +1,95 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import sys | |||
| import mpu | |||
| import torch | |||
| from commons import initialize_distributed, print_separator | |||
| sys.path.append('../..') | |||
| def test_initialize_model_parallel(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing initialize_model_parallel with size {} ...'.format( | |||
| model_parallel_size)) | |||
| model_parallel_size_ = min(model_parallel_size, | |||
| torch.distributed.get_world_size()) | |||
| assert not mpu.model_parallel_is_initialized() | |||
| mpu.initialize_model_parallel(model_parallel_size_) | |||
| assert mpu.model_parallel_is_initialized() | |||
| # Checks. | |||
| def check(group, world_size, rank): | |||
| assert world_size == torch.distributed.get_world_size(group=group) | |||
| assert rank == torch.distributed.get_rank(group=group) | |||
| # Model parallel. | |||
| world_size = model_parallel_size_ | |||
| rank = torch.distributed.get_rank() % model_parallel_size_ | |||
| assert world_size == mpu.get_model_parallel_world_size() | |||
| assert rank == mpu.get_model_parallel_rank() | |||
| check(mpu.get_model_parallel_group(), world_size, rank) | |||
| # Data parallel. | |||
| world_size = torch.distributed.get_world_size() // model_parallel_size_ | |||
| rank = torch.distributed.get_rank() // model_parallel_size | |||
| assert world_size == mpu.get_data_parallel_world_size() | |||
| assert rank == mpu.get_data_parallel_rank() | |||
| check(mpu.get_data_parallel_group(), world_size, rank) | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print('>> passed the test :-)') | |||
| def test_get_model_parallel_src_rank(model_parallel_size_): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing get_model_parallel_src_rank with size {} ...'.format( | |||
| model_parallel_size_)) | |||
| model_parallel_size = min(model_parallel_size_, | |||
| torch.distributed.get_world_size()) | |||
| assert not mpu.model_parallel_is_initialized() | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| assert mpu.model_parallel_is_initialized() | |||
| # Checks | |||
| src_rank = torch.distributed.get_rank() - mpu.get_model_parallel_rank() | |||
| assert mpu.get_model_parallel_src_rank() == src_rank | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print('>> passed the test :-)') | |||
| if __name__ == '__main__': | |||
| initialize_distributed() | |||
| world_size = torch.distributed.get_world_size() | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| print_separator('test initialize model parallel') | |||
| test_initialize_model_parallel(model_parallel_size) | |||
| print_separator('test model parallel source rank') | |||
| test_get_model_parallel_src_rank(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| @@ -0,0 +1,533 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import random | |||
| import sys | |||
| import mpu | |||
| import torch | |||
| import torch.nn.init as init | |||
| from commons import initialize_distributed, print_separator, set_random_seed | |||
| from mpu import layers | |||
| from torch.nn.parameter import Parameter | |||
| sys.path.append('../..') | |||
| def test_parallel_embedding(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing parallel embedding with model parallel size {} ...'. | |||
| format(model_parallel_size)) | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| batch_size = 17 | |||
| seq_length = 23 | |||
| vocab_size = 48 | |||
| hidden_size = 16 | |||
| seed = 1236 | |||
| set_random_seed(123) | |||
| input_data = torch.LongTensor(size=(batch_size, seq_length)).random_( | |||
| 0, vocab_size).cuda() | |||
| loss_weight = torch.randn([batch_size, seq_length, hidden_size]).cuda() | |||
| set_random_seed(seed) | |||
| embedding_original = torch.nn.Embedding(vocab_size, hidden_size).cuda() | |||
| output = embedding_original(input_data) | |||
| loss_original = torch.mul(output, loss_weight).sum() | |||
| loss_original.backward() | |||
| set_random_seed(seed) | |||
| embedding_parallel = layers.ParallelEmbedding( | |||
| vocab_size, hidden_size, init_method=init.normal_).cuda() | |||
| output = embedding_parallel(input_data) | |||
| loss_parallel = torch.mul(output, loss_weight).sum() | |||
| loss_parallel.backward() | |||
| set_random_seed(seed) | |||
| embedding_vocab_parallel = layers.VocabParallelEmbedding( | |||
| vocab_size, hidden_size, init_method=init.normal_).cuda() | |||
| output = embedding_vocab_parallel(input_data) | |||
| loss_vocab_parallel = torch.mul(output, loss_weight).sum() | |||
| loss_vocab_parallel.backward() | |||
| torch.distributed.barrier() | |||
| error = loss_parallel.sub(loss_original).abs() | |||
| print(' error in loss (parallel) on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-12, 'error: {}'.format(error) | |||
| torch.distributed.barrier() | |||
| error = loss_vocab_parallel.sub(loss_original).abs() | |||
| print(' error in loss (vocab parallel) on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-12, 'error: {}'.format(error) | |||
| weight_grad_orig = torch.split(embedding_original.weight.grad, | |||
| hidden_size // model_parallel_size, | |||
| 1)[mpu.get_model_parallel_rank()] | |||
| error = embedding_parallel.weight.grad.sub(weight_grad_orig).abs().max() | |||
| print(' error in grad (parallel) on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-12, 'error: {}'.format(error) | |||
| weight_grad_orig = torch.split(embedding_original.weight.grad, | |||
| vocab_size // model_parallel_size, | |||
| 0)[mpu.get_model_parallel_rank()] | |||
| error = embedding_vocab_parallel.weight.grad.sub( | |||
| weight_grad_orig).abs().max() | |||
| print(' error in grad (vocab parallel) on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-12, 'error: {}'.format(error) | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print('>> passed the test :-)') | |||
| def test_initialize_affine_weight(model_parallel_size): | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing initialize_affine_weight with model parallel ' | |||
| 'size: {}'.format(model_parallel_size)) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| seed = 12345 | |||
| input_size_coeff = 13 | |||
| input_size = input_size_coeff * model_parallel_size | |||
| output_size_coeff = 17 | |||
| output_size = output_size_coeff * model_parallel_size | |||
| # --------------- | |||
| # Column parallel | |||
| # --------------- | |||
| weight = torch.empty(output_size_coeff, input_size) | |||
| set_random_seed(seed) | |||
| layers._initialize_affine_weight(weight, output_size, input_size, | |||
| output_size_coeff, 0, | |||
| torch.nn.init.normal_) | |||
| # Target. | |||
| set_random_seed(seed) | |||
| master_weight = torch.empty(output_size, input_size) | |||
| torch.nn.init.normal_(master_weight) | |||
| rank = mpu.get_model_parallel_rank() | |||
| my_weight = torch.split( | |||
| master_weight, output_size_coeff, dim=0)[rank].contiguous().clone() | |||
| # Compare. | |||
| error = weight.sub(my_weight).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' column parallel max error (should be zero) on global rank ' | |||
| '{}: {}'.format(torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| # ------------ | |||
| # Row parallel | |||
| # ------------ | |||
| weight = torch.empty(output_size, input_size_coeff) | |||
| set_random_seed(seed) | |||
| mpu.layers._initialize_affine_weight(weight, output_size, input_size, | |||
| input_size_coeff, 1, | |||
| torch.nn.init.normal_) | |||
| # Target. | |||
| set_random_seed(seed) | |||
| master_weight = torch.empty(output_size, input_size) | |||
| torch.nn.init.normal_(master_weight) | |||
| rank = mpu.get_model_parallel_rank() | |||
| my_weight = torch.split( | |||
| master_weight, input_size_coeff, dim=1)[rank].contiguous().clone() | |||
| # Compare. | |||
| error = weight.sub(my_weight).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' row parallel max error (should be zero) on global rank ' | |||
| '{}: {}'.format(torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print(' >> passed the test :-)') | |||
| class IdentityLayer2D(torch.nn.Module): | |||
| def __init__(self, m, n): | |||
| super(IdentityLayer2D, self).__init__() | |||
| self.weight = Parameter(torch.Tensor(m, n)) | |||
| torch.nn.init.xavier_normal_(self.weight) | |||
| def forward(self): | |||
| return self.weight | |||
| def test_column_parallel_linear(model_parallel_size): | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing ColumnParallelLinear with model parallel ' | |||
| 'size: {}'.format(model_parallel_size)) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| seed = 12345 | |||
| set_random_seed(seed) | |||
| input_size_coeff = 13 | |||
| input_size = input_size_coeff * model_parallel_size | |||
| output_size_coeff = 17 | |||
| output_size = output_size_coeff * model_parallel_size | |||
| batch_size = 7 | |||
| # Network | |||
| identity_layer = IdentityLayer2D(batch_size, input_size).cuda() | |||
| linear_layer = mpu.ColumnParallelLinear( | |||
| input_size, output_size, keep_master_weight_for_test=True).cuda() | |||
| loss_weight = torch.randn([batch_size, output_size]).cuda() | |||
| # Forward | |||
| input_ = identity_layer() | |||
| output = linear_layer(input_) | |||
| loss = torch.mul(output, loss_weight).sum() | |||
| # Backward | |||
| loss.backward() | |||
| # Values. | |||
| dLdY = loss_weight | |||
| X = identity_layer.weight | |||
| A = linear_layer.master_weight.cuda() | |||
| dLdA = torch.matmul(dLdY.t(), X) | |||
| dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) | |||
| dLdX = torch.matmul(dLdY, A) | |||
| rank = mpu.get_model_parallel_rank() | |||
| my_dLdA = torch.split( | |||
| dLdA, output_size_coeff, dim=0)[rank].contiguous().clone() | |||
| error = my_dLdA.sub(linear_layer.weight.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' error in dLdA on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| my_dLdb = torch.split( | |||
| dLdb, output_size_coeff, dim=0)[rank].contiguous().clone() | |||
| error = my_dLdb.sub(linear_layer.bias.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' error in dLdb on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| error = dLdX.sub(identity_layer.weight.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' error in dLdX on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print(' >> passed the test :-)') | |||
| def test_row_parallel_linear(model_parallel_size): | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing RowParallelLinear with model parallel ' | |||
| 'size: {}'.format(model_parallel_size)) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| seed = 12345 | |||
| set_random_seed(seed) | |||
| input_size_coeff = 13 | |||
| input_size = input_size_coeff * model_parallel_size | |||
| output_size_coeff = 17 | |||
| output_size = output_size_coeff * model_parallel_size | |||
| batch_size = 7 | |||
| # Network | |||
| identity_layer = IdentityLayer2D(batch_size, input_size).cuda() | |||
| linear_layer = mpu.RowParallelLinear( | |||
| input_size, output_size, keep_master_weight_for_test=True).cuda() | |||
| loss_weight = torch.randn([batch_size, output_size]).cuda() | |||
| # Forward | |||
| input_ = identity_layer() | |||
| output = linear_layer(input_) | |||
| loss = torch.mul(output, loss_weight).sum() | |||
| # Backward | |||
| loss.backward() | |||
| # Values. | |||
| dLdY = loss_weight | |||
| X = identity_layer.weight | |||
| A = linear_layer.master_weight.cuda() | |||
| dLdA = torch.matmul(dLdY.t(), X) | |||
| dLdb = torch.matmul(torch.ones(batch_size, 1).cuda().t(), dLdY).view(-1) | |||
| dLdX = torch.matmul(dLdY, A) | |||
| rank = mpu.get_model_parallel_rank() | |||
| my_dLdA = torch.split( | |||
| dLdA, input_size_coeff, dim=1)[rank].contiguous().clone() | |||
| error = my_dLdA.sub(linear_layer.weight.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' error in dLdA on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| error = dLdb.sub(linear_layer.bias.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' error in dLdb on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| error = dLdX.sub(identity_layer.weight.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' error in dLdX on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print(' >> passed the test :-)') | |||
| class IdentityLayer3D(torch.nn.Module): | |||
| def __init__(self, m, n, k): | |||
| super(IdentityLayer3D, self).__init__() | |||
| self.weight = Parameter(torch.Tensor(m, n, k)) | |||
| torch.nn.init.xavier_normal_(self.weight) | |||
| def forward(self): | |||
| return self.weight | |||
| def parallel_self_attention(model_parallel_size, num_att_heads_per_partition, | |||
| hidden_size_per_att_head, dropout_prob, batch_size, | |||
| sequence_length): | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| seed = 12345 | |||
| set_random_seed(seed) | |||
| num_att_heads = num_att_heads_per_partition * torch.distributed.get_world_size( | |||
| ) # noqa | |||
| hidden_size = hidden_size_per_att_head * num_att_heads | |||
| # Network | |||
| identity_layer = IdentityLayer3D(batch_size, sequence_length, | |||
| hidden_size).cuda() | |||
| attention_layer = mpu.BertParallelSelfAttention(hidden_size, num_att_heads, | |||
| dropout_prob).cuda() | |||
| loss_weight = torch.randn([batch_size, sequence_length, | |||
| hidden_size]).cuda() | |||
| attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() | |||
| # Forward | |||
| input_ = identity_layer() | |||
| output = attention_layer(input_, attention_mask) | |||
| loss = torch.mul(output, loss_weight).sum() | |||
| # Backward | |||
| loss.backward() | |||
| rank = mpu.get_model_parallel_rank() | |||
| mpu.destroy_model_parallel() | |||
| return rank, hidden_size, model_parallel_size, loss, \ | |||
| attention_layer, identity_layer | |||
| def test_parallel_self_attention(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing ParallelSelfAttention with model parallel ' | |||
| 'size: {}'.format(model_parallel_size)) | |||
| num_att_heads_per_partition = 3 | |||
| hidden_size_per_att_head = 7 | |||
| dropout_prob = 0.0 # has to be zero | |||
| batch_size = 5 | |||
| sequence_length = 13 | |||
| rank_1, hideen_size_1, model_parallel_size_1, loss_1, \ | |||
| attention_layer_1, identity_layer_1 = parallel_self_attention( | |||
| 1, num_att_heads_per_partition, | |||
| hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) | |||
| rank, hidden_size, model_parallel_size, loss, \ | |||
| attention_layer, identity_layer = parallel_self_attention( | |||
| model_parallel_size, num_att_heads_per_partition, | |||
| hidden_size_per_att_head, dropout_prob, batch_size, sequence_length) | |||
| assert hideen_size_1 == hidden_size | |||
| error = loss_1.sub(loss).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' loss error on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 5.0e-6 | |||
| my_lin_grad_list = torch.split( | |||
| attention_layer_1.query_key_value.weight.grad, | |||
| hidden_size // model_parallel_size, 0)[rank::model_parallel_size] | |||
| my_lin_grad = torch.cat(my_lin_grad_list, dim=0) | |||
| error = my_lin_grad.sub( | |||
| attention_layer.query_key_value.weight.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' weight gradient error on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 5.0e-6 | |||
| error = identity_layer_1.weight.grad.sub( | |||
| identity_layer.weight.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' input gradient error on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 5.0e-6 | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print(' >> passed the test :-)') | |||
| def parallel_transformer(model_parallel_size, num_att_heads_per_partition, | |||
| hidden_size_per_att_head, batch_size, | |||
| sequence_length): | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| seed = 12345 | |||
| set_random_seed(seed) | |||
| num_att_heads = num_att_heads_per_partition * torch.distributed.get_world_size( | |||
| ) | |||
| hidden_size = hidden_size_per_att_head * num_att_heads | |||
| intermediate_size = 4 * hidden_size | |||
| # Network | |||
| identity_layer = IdentityLayer3D(batch_size, sequence_length, | |||
| hidden_size).cuda() | |||
| transformer_layer = mpu.BertParallelTransformerLayer( | |||
| hidden_size, intermediate_size, num_att_heads, 0.0, 0.0, | |||
| torch.nn.functional.relu, 1.0e-5).cuda() | |||
| loss_weight = torch.randn([batch_size, sequence_length, | |||
| hidden_size]).cuda() | |||
| attention_mask = torch.randn([batch_size, 1, 1, sequence_length]).cuda() | |||
| # Forward | |||
| input_ = identity_layer() | |||
| output = transformer_layer(input_, attention_mask) | |||
| loss = torch.mul(output, loss_weight).sum() | |||
| # Backward | |||
| loss.backward() | |||
| rank = mpu.get_model_parallel_rank() | |||
| mpu.destroy_model_parallel() | |||
| return rank, hidden_size, model_parallel_size, loss, \ | |||
| transformer_layer, identity_layer | |||
| def test_parallel_transformer_layer(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing ParallelTransformerLayer with model parallel ' | |||
| 'size: {}'.format(model_parallel_size)) | |||
| num_att_heads_per_partition = 3 | |||
| hidden_size_per_att_head = 7 | |||
| batch_size = 5 | |||
| sequence_length = 13 | |||
| rank_1, hidden_size_1, model_parallel_size_1, loss_1, \ | |||
| transformer_layer_1, identity_layer_1 = parallel_transformer( | |||
| 1, num_att_heads_per_partition, | |||
| hidden_size_per_att_head, batch_size, sequence_length) | |||
| rank, hidden_size, model_parallel_size, loss, \ | |||
| transformer_layer, identity_layer = parallel_transformer( | |||
| model_parallel_size, num_att_heads_per_partition, | |||
| hidden_size_per_att_head, batch_size, sequence_length) | |||
| error = loss_1.sub(loss).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' loss error on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 5.0e-5, 'error: {}'.format(error) | |||
| error = identity_layer_1.weight.grad.sub( | |||
| identity_layer.weight.grad).abs().max() | |||
| torch.distributed.barrier() | |||
| print(' input gradient error on global rank {}: {}'.format( | |||
| torch.distributed.get_rank(), error)) | |||
| assert error < 5.0e-5, 'error: {}'.format(error) | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print(' >> passed the test :-)') | |||
| if __name__ == '__main__': | |||
| torch.backends.cudnn.deterministic = True | |||
| torch.backends.cudnn.benchmark = False | |||
| initialize_distributed() | |||
| world_size = torch.distributed.get_world_size() | |||
| print_separator('test initialize affine weight') | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| test_initialize_affine_weight(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| print_separator('test parallel embedding') | |||
| test_parallel_embedding(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| print_separator('test column-parallel linear') | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| test_column_parallel_linear(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| print_separator('test row-parallel linear') | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| test_row_parallel_linear(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| print_separator('test parallel self-attention') | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| test_parallel_self_attention(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| print_separator('test parallel transformer') | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| test_parallel_transformer_layer(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| @@ -0,0 +1,206 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import sys | |||
| import mpu | |||
| import torch | |||
| from commons import initialize_distributed, print_separator | |||
| sys.path.append('../..') | |||
| def test_set_cuda_rng_state(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing set_rng_state with size {} ...'.format( | |||
| model_parallel_size)) | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| size = 123 | |||
| seed = 1234 | |||
| torch.cuda.manual_seed(seed) | |||
| tensor = torch.cuda.FloatTensor(size) | |||
| # Get the state | |||
| rng_state = torch.cuda.get_rng_state() | |||
| rng_state_copy = rng_state.clone() | |||
| # Do some stuff. | |||
| for _ in range(5): | |||
| torch.randn(size, out=tensor) | |||
| result_1 = tensor.clone() | |||
| assert rng_state.sub(rng_state_copy).max() == 0 | |||
| assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0 | |||
| # State should be different. | |||
| new_rng_state = torch.cuda.get_rng_state() | |||
| max_diff = new_rng_state.sub(rng_state).max() | |||
| print( | |||
| ' max diff in rng state (should be non-zero) on global rank {}: {}'. | |||
| format(torch.distributed.get_rank(), max_diff)) | |||
| assert max_diff > 0 | |||
| # Reset the rng state and do the same stuff. | |||
| mpu.random._set_cuda_rng_state(rng_state) | |||
| for _ in range(5): | |||
| torch.randn(size, out=tensor) | |||
| mpu.random._set_cuda_rng_state(rng_state) | |||
| for _ in range(5): | |||
| torch.randn(size, out=tensor) | |||
| result_2 = tensor.clone() | |||
| # Results should be the same | |||
| error = result_2.sub(result_1).abs().max() | |||
| print(' max error in generated tensors (should be zero) on ' | |||
| 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| # Input state should have remained intact. | |||
| error = rng_state.sub(rng_state_copy).max() | |||
| print(' max error in rng state (should be zero) on global rank {}: {}'. | |||
| format(torch.distributed.get_rank(), error)) | |||
| assert error == 0 | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print('>> passed the test :-)') | |||
| def test_cuda_rng_tracker(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing cuda rng tracker with size {} ...'.format( | |||
| model_parallel_size)) | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| seed_1 = 1234 | |||
| seed_2 = 4321 | |||
| size = [12, 21] | |||
| tensor = torch.cuda.FloatTensor(size) | |||
| # Set to seed_1 and generate two tensors. | |||
| torch.cuda.manual_seed(seed_1) | |||
| torch.randn(size, out=tensor) | |||
| target_11 = tensor.clone() | |||
| torch.randn(size, out=tensor) | |||
| target_12 = tensor.clone() | |||
| # Set to seed_2 and generate two tensors. | |||
| torch.cuda.manual_seed(seed_2) | |||
| torch.randn(size, out=tensor) | |||
| target_21 = tensor.clone() | |||
| torch.randn(size, out=tensor) | |||
| target_22 = tensor.clone() | |||
| # Now if we interleave seed_1 and seed_2, | |||
| # we should still get the same tensors | |||
| torch.cuda.manual_seed(seed_1) | |||
| mpu.get_cuda_rng_tracker().add('test', seed_2) | |||
| torch.randn(size, out=tensor) | |||
| result_11 = tensor.clone() | |||
| with mpu.get_cuda_rng_tracker().fork('test'): | |||
| torch.randn(size, out=tensor) | |||
| result_21 = tensor.clone() | |||
| torch.randn(size, out=tensor) | |||
| result_12 = tensor.clone() | |||
| with mpu.get_cuda_rng_tracker().fork('test'): | |||
| torch.randn(size, out=tensor) | |||
| result_22 = tensor.clone() | |||
| diff = result_11.sub(result_21).abs().max() | |||
| diff = min(diff, result_12.sub(result_22).abs().max()) | |||
| print(' max diff in generated tensors (should be non-zero) on ' | |||
| 'global rank {}: {}'.format(torch.distributed.get_rank(), diff)) | |||
| assert diff > 1.0e-6 | |||
| error = max( | |||
| result_11.sub(target_11).abs().max(), | |||
| result_12.sub(target_12).abs().max()) | |||
| error = max(error, result_21.sub(target_21).abs().max()) | |||
| error = max(error, result_22.sub(target_22).abs().max()) | |||
| print(' max error in generated tensors (should be zero) on ' | |||
| 'global rank {}: {}'.format(torch.distributed.get_rank(), error)) | |||
| assert error < 1.0e-6 | |||
| # Reset the tracker | |||
| mpu.get_cuda_rng_tracker().reset() | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print('>> passed the test :-)') | |||
| def test_model_parallel_cuda_manual_seed(model_parallel_size): | |||
| if torch.distributed.get_rank() == 0: | |||
| print('> testing model parallel cuda manual seed with size {} ...'. | |||
| format(model_parallel_size)) | |||
| mpu.initialize_model_parallel(model_parallel_size) | |||
| model_parallel_size = mpu.get_model_parallel_world_size() | |||
| mpu.model_parallel_cuda_manual_seed(12345) | |||
| assert torch.cuda.initial_seed() == 12345 | |||
| with mpu.get_cuda_rng_tracker().fork(): | |||
| assert torch.cuda.initial_seed() == (12345 + 2718 | |||
| + mpu.get_model_parallel_rank()) | |||
| # Reset the tracker | |||
| mpu.get_cuda_rng_tracker().reset() | |||
| # Reset groups | |||
| mpu.destroy_model_parallel() | |||
| torch.distributed.barrier() | |||
| if torch.distributed.get_rank() == 0: | |||
| print('>> passed the test :-)') | |||
| if __name__ == '__main__': | |||
| initialize_distributed() | |||
| world_size = torch.distributed.get_world_size() | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| print_separator('test set rng state') | |||
| test_set_cuda_rng_state(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| print_separator('test cuda rng tracker') | |||
| test_cuda_rng_tracker(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| model_parallel_size = 1 | |||
| while model_parallel_size <= world_size: | |||
| print_separator('test model parallel cuda manual seed') | |||
| test_model_parallel_cuda_manual_seed(model_parallel_size) | |||
| model_parallel_size *= 2 | |||
| @@ -0,0 +1,70 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import torch | |||
| def ensure_divisibility(numerator, denominator): | |||
| """Ensure that numerator is divisible by the denominator.""" | |||
| assert numerator % denominator == 0, '{} is not divisible by {}'.format( | |||
| numerator, denominator) | |||
| def divide(numerator, denominator): | |||
| """Ensure that numerator is divisible by the denominator and return | |||
| the division value.""" | |||
| ensure_divisibility(numerator, denominator) | |||
| return numerator // denominator | |||
| def split_tensor_along_last_dim(tensor, | |||
| num_partitions, | |||
| contiguous_split_chunks=False): | |||
| """Split a tensor along its last dimension. | |||
| Arguments: | |||
| tensor: input tensor. | |||
| num_partitions: number of partitions to split the tensor | |||
| contiguous_split_chunks: If True, make each chunk contiguous | |||
| in memory. | |||
| """ | |||
| # Get the size and dimension. | |||
| last_dim = tensor.dim() - 1 | |||
| last_dim_size = divide(tensor.size()[last_dim], num_partitions) | |||
| # Split. | |||
| tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) | |||
| # Note: torch.split does not create contiguous tensors by default. | |||
| if contiguous_split_chunks: | |||
| return tuple(chunk.contiguous() for chunk in tensor_list) | |||
| return tensor_list | |||
| class VocabUtility: | |||
| """Split the vocabulary into `world_size` chunks amd return the | |||
| first and last index of the vocabulary belonging to the `rank` | |||
| partition: Note that indecies in [fist, last)""" | |||
| @staticmethod | |||
| def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, | |||
| rank, world_size): | |||
| index_f = rank * per_partition_vocab_size | |||
| index_l = index_f + per_partition_vocab_size | |||
| return index_f, index_l | |||
| @staticmethod | |||
| def vocab_range_from_global_vocab_size(global_vocab_size, rank, | |||
| world_size): | |||
| per_partition_vocab_size = divide(global_vocab_size, world_size) | |||
| return VocabUtility.vocab_range_from_per_partition_vocab_size( | |||
| per_partition_vocab_size, rank, world_size) | |||
| @@ -0,0 +1,61 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import glob | |||
| import os | |||
| import statistics | |||
| import sys | |||
| import json | |||
| path_pattern = sys.argv[1] | |||
| target_type = sys.argv[2] | |||
| best_value, best_result, best_name = None, None, None | |||
| mean_result = {} | |||
| print(path_pattern) | |||
| for dir_path in glob.glob(path_pattern, recursive=True): | |||
| entry = os.path.basename(dir_path) | |||
| valid_result = None | |||
| test_found = os.path.exists(os.path.join(dir_path, 'test_results.json')) | |||
| valid_path = os.path.join(dir_path, 'results.json') | |||
| if os.path.exists(valid_path): | |||
| print(entry) | |||
| with open(valid_path) as file: | |||
| valid_result = json.load(file) | |||
| else: | |||
| print(f'{entry} no validation results') | |||
| continue | |||
| if not test_found: | |||
| print(f'{entry} not tested yet') | |||
| if target_type == 'max': | |||
| metric = sys.argv[3] | |||
| metric_value = valid_result[metric] | |||
| if best_value is None or metric_value > best_value: | |||
| best_value = metric_value | |||
| best_result = valid_result | |||
| best_name = entry | |||
| elif target_type == 'mean' or target_type == 'median': | |||
| if mean_result: | |||
| for metric, value in valid_result.items(): | |||
| if metric not in ['type', 'epoch']: | |||
| mean_result[metric].append(value) | |||
| else: | |||
| mean_result = { | |||
| metric: [value] | |||
| for metric, value in valid_result.items() | |||
| if metric not in ['type', 'epoch'] | |||
| } | |||
| if target_type == 'max': | |||
| print(f'Best result found at {best_name}: {best_result}') | |||
| elif target_type == 'mean': | |||
| mean_result = { | |||
| metric: sum(value) / len(value) | |||
| for metric, value in mean_result.items() | |||
| } | |||
| print(f'Mean result {mean_result}') | |||
| elif target_type == 'median': | |||
| mean_result = { | |||
| metric: statistics.median(value) | |||
| for metric, value in mean_result.items() | |||
| } | |||
| print(f'Mean result {mean_result}') | |||
| @@ -0,0 +1,22 @@ | |||
| boto3 | |||
| botocore | |||
| deepspeed | |||
| fasttext | |||
| filelock | |||
| ftfy | |||
| langdetect | |||
| lsh | |||
| matplotlib | |||
| mpi4py | |||
| nltk | |||
| pandas | |||
| regex | |||
| requests | |||
| rouge_score | |||
| scikit_learn | |||
| scipy | |||
| sentencepiece | |||
| termcolor | |||
| tldextract | |||
| tqdm | |||
| transformers | |||
| @@ -0,0 +1,10 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import sys | |||
| if sys.argv[1] == 'block': | |||
| from test.test_block import main | |||
| main() | |||
| elif sys.argv[1] == 'rel_shift': | |||
| from test.test_rel_shift import main | |||
| main() | |||
| @@ -0,0 +1,389 @@ | |||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """ Tasks data utility.""" | |||
| import copy | |||
| import pickle | |||
| import re | |||
| from typing import Dict, List, Optional | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| import torch.utils.data | |||
| from torch.utils.data.dataloader import default_collate | |||
| from modelscope.models.nlp.mglm import mpu | |||
| def clean_text(text): | |||
| """Remove new lines and multiple spaces and adjust end of sentence dot.""" | |||
| text = text.replace('\n', ' ') | |||
| text = re.sub(r'\s+', ' ', text) | |||
| for _ in range(3): | |||
| text = text.replace(' . ', '. ') | |||
| return text | |||
| class InputExample(object): | |||
| """A raw input example consisting of one or two segments of text and a label""" | |||
| def __init__(self, | |||
| guid, | |||
| text_a, | |||
| text_b=None, | |||
| label=None, | |||
| logits=None, | |||
| meta: Optional[Dict] = None, | |||
| idx=-1, | |||
| num_choices=1): | |||
| """ | |||
| Create a new InputExample. | |||
| :param guid: a unique textual identifier | |||
| :param text_a: the sequence of text | |||
| :param text_b: an optional, second sequence of text | |||
| :param label: an optional label | |||
| :param logits: an optional list of per-class logits | |||
| :param meta: an optional dictionary to store arbitrary meta information | |||
| :param idx: an optional numeric index | |||
| """ | |||
| self.guid = guid | |||
| self.text_a = text_a | |||
| self.text_b = text_b | |||
| self.label = label | |||
| self.logits = logits | |||
| self.idx = idx | |||
| self.num_choices = num_choices | |||
| self.meta = meta if meta else {} | |||
| def __repr__(self): | |||
| return str(self.to_json_string()) | |||
| def to_dict(self): | |||
| """Serialize this instance to a Python dictionary.""" | |||
| output = copy.deepcopy(self.__dict__) | |||
| return output | |||
| def to_json_string(self): | |||
| """Serialize this instance to a JSON string.""" | |||
| return json.dumps(self.to_dict(), indent=2, sort_keys=True) + '\n' | |||
| @staticmethod | |||
| def load_examples(path: str) -> List['InputExample']: | |||
| """Load a set of input examples from a file""" | |||
| with open(path, 'rb') as fh: | |||
| return pickle.load(fh) | |||
| @staticmethod | |||
| def save_examples(examples: List['InputExample'], path: str) -> None: | |||
| """Save a set of input examples to a file""" | |||
| with open(path, 'wb') as fh: | |||
| pickle.dump(examples, fh) | |||
| def num_special_tokens_to_add(text_a_ids, | |||
| text_b_ids, | |||
| answer_ids, | |||
| add_cls, | |||
| add_sep, | |||
| add_piece, | |||
| add_eos=True): | |||
| num_tokens = 0 | |||
| if add_cls: | |||
| num_tokens += 1 | |||
| if text_b_ids and add_sep: | |||
| num_tokens += 1 | |||
| if add_eos: | |||
| num_tokens += 1 | |||
| if not answer_ids and add_piece: | |||
| num_tokens += 1 | |||
| return num_tokens | |||
| def build_input_from_ids(text_a_ids, | |||
| text_b_ids, | |||
| answer_ids, | |||
| max_seq_length, | |||
| tokenizer, | |||
| args=None, | |||
| add_cls=True, | |||
| add_sep=False, | |||
| add_piece=False, | |||
| add_eos=True, | |||
| mask_id=None): | |||
| if mask_id is None: | |||
| mask_id = tokenizer.get_command('MASK').Id | |||
| eos_id = tokenizer.get_command('eos').Id | |||
| cls_id = tokenizer.get_command('ENC').Id | |||
| sep_id = tokenizer.get_command('sep').Id | |||
| ids = [] | |||
| types = [] | |||
| paddings = [] | |||
| # CLS | |||
| if add_cls: | |||
| ids.append(cls_id) | |||
| types.append(0) | |||
| paddings.append(1) | |||
| # A | |||
| len_text_a = len(text_a_ids) | |||
| ids.extend(text_a_ids) | |||
| types.extend([0] * len_text_a) | |||
| paddings.extend([1] * len_text_a) | |||
| # B | |||
| if text_b_ids is not None: | |||
| # SEP | |||
| if add_sep: | |||
| ids.append(sep_id) | |||
| types.append(0) | |||
| paddings.append(1) | |||
| len_text_b = len(text_b_ids) | |||
| ids.extend(text_b_ids) | |||
| types.extend([1] * len_text_b) | |||
| paddings.extend([1] * len_text_b) | |||
| eos_length = 1 if add_eos else 0 | |||
| # Cap the size. | |||
| if len(ids) >= max_seq_length - eos_length: | |||
| max_seq_length_m1 = max_seq_length - 1 | |||
| ids = ids[0:max_seq_length_m1] | |||
| types = types[0:max_seq_length_m1] | |||
| paddings = paddings[0:max_seq_length_m1] | |||
| end_type = 0 if text_b_ids is None else 1 | |||
| if add_eos: | |||
| ids.append(eos_id) | |||
| types.append(end_type) | |||
| paddings.append(1) | |||
| sep = len(ids) | |||
| target_ids = [0] * len(ids) | |||
| loss_masks = [0] * len(ids) | |||
| position_ids = list(range(len(ids))) | |||
| block_position_ids = [0] * len(ids) | |||
| # Piece | |||
| if add_piece or answer_ids is not None: | |||
| sop_id = tokenizer.get_command('sop').Id | |||
| mask_position = ids.index( | |||
| mask_id | |||
| ) if not args.sentinel_token else args.max_position_embeddings | |||
| ids.append(sop_id) | |||
| types.append(end_type) | |||
| paddings.append(1) | |||
| position_ids.append(mask_position) | |||
| block_position_ids.append(1) | |||
| if answer_ids is not None: | |||
| len_answer = len(answer_ids) | |||
| ids.extend(answer_ids[:-1]) | |||
| types.extend([end_type] * (len_answer - 1)) | |||
| paddings.extend([1] * (len_answer - 1)) | |||
| position_ids.extend([mask_position] * (len_answer - 1)) | |||
| if not args.no_block_position: | |||
| block_position_ids.extend(range(2, len(answer_ids) + 1)) | |||
| else: | |||
| block_position_ids.extend([1] * (len(answer_ids) - 1)) | |||
| target_ids.extend(answer_ids) | |||
| loss_masks.extend([1] * len(answer_ids)) | |||
| else: | |||
| target_ids.append(0) | |||
| loss_masks.append(1) | |||
| # Padding. | |||
| padding_length = max_seq_length - len(ids) | |||
| if padding_length > 0: | |||
| ids.extend([eos_id] * padding_length) | |||
| types.extend([eos_id] * padding_length) | |||
| paddings.extend([0] * padding_length) | |||
| position_ids.extend([0] * padding_length) | |||
| block_position_ids.extend([0] * padding_length) | |||
| target_ids.extend([0] * padding_length) | |||
| loss_masks.extend([0] * padding_length) | |||
| if not args.masked_lm: | |||
| position_ids = [position_ids, block_position_ids] | |||
| return ids, types, paddings, position_ids, sep, target_ids, loss_masks | |||
| def build_decoder_input(enc_ids, answer_ids, max_seq_length, | |||
| max_dec_seq_length, tokenizer): | |||
| mask_id = tokenizer.get_command('MASK').Id | |||
| eos_id = tokenizer.get_command('eos').Id | |||
| sop_id = tokenizer.get_command('sop').Id | |||
| enc_len = len(enc_ids) # noqa | |||
| masks = [] | |||
| # TODO: it probably takes too much memory | |||
| # for i in range(max_dec_seq_length): | |||
| # m = [1]*enc_len + [0]*(max_seq_length - enc_len) + [1]*(i+1) + [0]*(max_dec_seq_length-1-i) | |||
| # masks.append(m) | |||
| mask_position = enc_ids.index(mask_id) | |||
| len_answer = len(answer_ids) | |||
| ids = [sop_id] + answer_ids[:-1] | |||
| types = [0] * len_answer # not used | |||
| paddings = [1] * len_answer | |||
| position_ids = [mask_position] * len_answer | |||
| block_position_ids = list(range(1, len_answer + 1)) | |||
| target_ids = answer_ids | |||
| loss_masks = [1] * len_answer | |||
| # Padding. | |||
| padding_length = max_dec_seq_length - len(ids) | |||
| if padding_length > 0: | |||
| ids.extend([eos_id] * padding_length) | |||
| types.extend([0] * padding_length) | |||
| paddings.extend([0] * padding_length) | |||
| position_ids.extend([0] * padding_length) | |||
| block_position_ids.extend([0] * padding_length) | |||
| target_ids.extend([0] * padding_length) | |||
| loss_masks.extend([0] * padding_length) | |||
| position_ids = [position_ids, block_position_ids] | |||
| return ids, types, paddings, position_ids, masks, target_ids, loss_masks | |||
| def build_sample(ids, | |||
| types=None, | |||
| paddings=None, | |||
| positions=None, | |||
| masks=None, | |||
| label=None, | |||
| unique_id=None, | |||
| target=None, | |||
| logit_mask=None, | |||
| segment_ids=None, | |||
| prompt_ids=None): | |||
| """Convert to numpy and return a sample consumed by the batch producer.""" | |||
| ids_np = np.array(ids, dtype=np.int64) | |||
| sample = {'text': ids_np, 'label': int(label)} | |||
| if types is not None: | |||
| types_np = np.array(types, dtype=np.int64) | |||
| sample['types'] = types_np | |||
| if paddings is not None: | |||
| paddings_np = np.array(paddings, dtype=np.int64) | |||
| sample['padding_mask'] = paddings_np | |||
| if positions is not None: | |||
| positions_np = np.array(positions, dtype=np.int64) | |||
| sample['position'] = positions_np | |||
| if masks is not None: | |||
| masks_np = np.array(masks, dtype=np.int64) | |||
| sample['mask'] = masks_np | |||
| if target is not None: | |||
| target_np = np.array(target, dtype=np.int64) | |||
| sample['target'] = target_np | |||
| if logit_mask is not None: | |||
| logit_mask_np = np.array(logit_mask, dtype=np.int64) | |||
| sample['logit_mask'] = logit_mask_np | |||
| if segment_ids is not None: | |||
| segment_ids = np.array(segment_ids, dtype=np.int64) | |||
| sample['segment_id'] = segment_ids | |||
| if prompt_ids is not None: | |||
| prompt_ids = np.array(prompt_ids, dtype=np.int64) | |||
| sample['prompt_pos'] = prompt_ids | |||
| if unique_id is not None: | |||
| sample['uid'] = unique_id | |||
| return sample | |||
| def build_decoder_sample(sample, dec_ids, dec_position, dec_masks, dec_target, | |||
| dec_logit_mask): | |||
| sample['dec_text'] = np.array(dec_ids) | |||
| sample['dec_position'] = np.array(dec_position) | |||
| sample['dec_mask'] = np.array(dec_masks) | |||
| sample['dec_target'] = np.array(dec_target) | |||
| sample['dec_logit_mask'] = np.array(dec_logit_mask) | |||
| return sample | |||
| def my_collate(batch): | |||
| new_batch = [{key: value | |||
| for key, value in sample.items() if key != 'uid'} | |||
| for sample in batch] | |||
| text_list = [sample['text'] for sample in batch] | |||
| def pad_choice_dim(data, choice_num): | |||
| if len(data) < choice_num: | |||
| data = np.concatenate([data] | |||
| + [data[0:1]] * (choice_num - len(data))) | |||
| return data | |||
| if len(text_list[0].shape) == 2: | |||
| choice_nums = list(map(len, text_list)) | |||
| max_choice_num = max(choice_nums) | |||
| for i, sample in enumerate(new_batch): | |||
| for key, value in sample.items(): | |||
| if key != 'label': | |||
| sample[key] = pad_choice_dim(value, max_choice_num) | |||
| else: | |||
| sample[key] = value | |||
| sample['loss_mask'] = np.array( | |||
| [1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]), | |||
| dtype=np.int64) | |||
| if 'dec_text' in new_batch[0]: | |||
| choice_nums = [len(sample['dec_text']) for sample in new_batch] | |||
| if choice_nums.count(choice_nums[0]) != len(choice_nums): | |||
| max_choice_num = max(choice_nums) | |||
| for i, sample in enumerate(new_batch): | |||
| for key, value in sample.items(): | |||
| if key.startswith('dec_'): | |||
| sample[key] = pad_choice_dim(value, max_choice_num) | |||
| sample['loss_mask'] = np.array( | |||
| [1] * choice_nums[i] + [0] * # noqa | |||
| (max_choice_num - choice_nums[i]), | |||
| dtype=np.int64) | |||
| new_batch = default_collate(new_batch) | |||
| if 'uid' in batch[0]: | |||
| uid_list = [sample['uid'] for sample in batch] | |||
| new_batch['uid'] = uid_list | |||
| return new_batch | |||
| class FakeDataloader: | |||
| def __init__(self, num_iters): | |||
| self.num_iters = num_iters | |||
| def __iter__(self): | |||
| if self.num_iters is not None: | |||
| for _ in range(self.num_iters): | |||
| yield None | |||
| else: | |||
| while True: | |||
| yield None | |||
| def build_data_loader(dataset, | |||
| batch_size, | |||
| num_workers, | |||
| drop_last, | |||
| shuffle=True, | |||
| only_rank0=False): | |||
| """Data loader. Note that batch-size is the local (per GPU) batch-size.""" | |||
| # Sampler. | |||
| if only_rank0: | |||
| rank, world_size = 0, 1 | |||
| else: | |||
| world_size = mpu.get_data_parallel_world_size() | |||
| rank = mpu.get_data_parallel_rank() | |||
| sampler = torch.utils.data.distributed.DistributedSampler( | |||
| dataset, num_replicas=world_size, rank=rank, shuffle=shuffle) | |||
| # Data loader. Note that batch size is the per GPU batch size. | |||
| data_loader = torch.utils.data.DataLoader( | |||
| dataset, | |||
| batch_size=batch_size, | |||
| sampler=sampler, | |||
| shuffle=False, | |||
| num_workers=num_workers, | |||
| drop_last=drop_last, | |||
| pin_memory=True, | |||
| collate_fn=my_collate) | |||
| return data_loader | |||
| @@ -0,0 +1,249 @@ | |||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Evaluation utilities.""" | |||
| import datetime | |||
| import os | |||
| import random | |||
| import time | |||
| from collections import OrderedDict | |||
| from typing import List | |||
| import mpu | |||
| import torch | |||
| from finetune_glm import process_batch | |||
| from sklearn.metrics import f1_score | |||
| from tasks.data_utils import InputExample, build_data_loader | |||
| from utils import debug_finetune_data, get_spare_port, print_rank_0 | |||
| def accuracy_metric(predictions, labels, examples): | |||
| count = 0 | |||
| num_predictions = max(len(predictions), 1) | |||
| assert len(predictions) == len(labels) | |||
| for prediction, label in zip(predictions, labels): | |||
| count += prediction == label | |||
| return count * 100.0 / num_predictions | |||
| def f1_metric(predictions, labels, examples): | |||
| return f1_score(labels, predictions) | |||
| def f1_macro_metric(predictions, labels, examples): | |||
| return f1_score(labels, predictions, average='macro') | |||
| global_tokenizer = None | |||
| def accuracy_func_provider(single_dataset_provider, | |||
| metric_dict, | |||
| args, | |||
| is_test=False, | |||
| eval_func=None, | |||
| output_func=None, | |||
| only_rank0=True, | |||
| tokenizer=None): | |||
| """Provide function that calculates accuracies.""" | |||
| # Build dataloaders. | |||
| global global_tokenizer | |||
| global_tokenizer = tokenizer | |||
| if only_rank0 and torch.distributed.is_initialized( | |||
| ) and torch.distributed.get_rank() != 0: | |||
| return None | |||
| if is_test and not args.eval_valid: | |||
| datapaths = args.test_data if args.test_data is not None else ['test'] | |||
| else: | |||
| datapaths = args.valid_data if args.valid_data is not None else ['dev'] | |||
| if eval_func is None: | |||
| eval_func = multichoice_evaluate | |||
| dataloaders = [] | |||
| eval_batch_size = args.eval_batch_size if args.eval_batch_size else args.batch_size | |||
| for datapath in datapaths: | |||
| dataset = single_dataset_provider(datapath) | |||
| dataloader = build_data_loader( | |||
| dataset, | |||
| eval_batch_size, | |||
| num_workers=args.num_workers, | |||
| drop_last=False, | |||
| shuffle=False, | |||
| only_rank0=only_rank0) | |||
| dataloaders.append((dataset.dataset_name, dataloader)) | |||
| def metrics_func(model, | |||
| epoch, | |||
| output_predictions=False, | |||
| summary_writer=None): | |||
| print_rank_0('calculating metrics ...') | |||
| score_dict = OrderedDict([(key, 0.0) for key in metric_dict | |||
| ]) if isinstance(metric_dict, dict) else { | |||
| metric_dict: 0.0 | |||
| } # noqa | |||
| total = 0 | |||
| for name, dataloader in dataloaders: | |||
| example_dict = None | |||
| if hasattr(dataloader.dataset, 'examples'): | |||
| example_dict = dataloader.dataset.examples | |||
| start_time = time.time() | |||
| predictions, labels, examples = eval_func(model, dataloader, | |||
| example_dict, args) | |||
| elapsed_time = time.time() - start_time | |||
| if output_predictions and torch.distributed.get_rank() == 0: | |||
| filename = os.path.join(args.log_dir, name + '.jsonl') | |||
| output_func(predictions, examples, filename) | |||
| total_count = len(predictions) | |||
| single_dict = { | |||
| key: metric(predictions, labels, examples) | |||
| for key, metric in metric_dict.items() | |||
| } | |||
| output_str = ' > |epoch: {}| metrics for {}: total {}'.format( | |||
| epoch, name, total_count) | |||
| for key, value in single_dict.items(): | |||
| output_str += ' {} = {:.4f} %'.format(key, value) | |||
| if summary_writer is not None and epoch >= 0 and not is_test and len( | |||
| dataloaders) > 1: | |||
| summary_writer.add_scalar(f'Train/valid_{name}_{key}', | |||
| value, epoch) | |||
| output_str += ' elapsed time (sec): {:.3f}'.format(elapsed_time) | |||
| if len(dataloaders) > 1: | |||
| print_rank_0(output_str) | |||
| for key in score_dict: | |||
| score_dict[key] += single_dict[key] * total_count | |||
| total += total_count | |||
| score_dict = { | |||
| key: score / float(total) | |||
| for key, score in score_dict.items() | |||
| } | |||
| output_str = ' >> |epoch: {}| overall: total = {}'.format(epoch, total) | |||
| for key, score in score_dict.items(): | |||
| output_str += ' {} = {:.4f}'.format(key, score) | |||
| if summary_writer is not None and epoch >= 0 and not is_test: | |||
| summary_writer.add_scalar(f'Train/valid_{key}', score, epoch) | |||
| print_rank_0(output_str) | |||
| return score_dict | |||
| return metrics_func | |||
| segment_length = 10 | |||
| def multichoice_evaluate(model, dataloader, example_dict, args): | |||
| """Calculate correct over total answers and return prediction if the | |||
| `output_predictions` is true.""" | |||
| model.eval() | |||
| port = get_spare_port(args) | |||
| print_rank_0(f'Using port {port}') | |||
| store = torch.distributed.TCPStore(args.master_ip, port, | |||
| torch.distributed.get_world_size(), | |||
| torch.distributed.get_rank() == 0, | |||
| datetime.timedelta(seconds=30)) | |||
| # file_path = os.path.join("/cache", args.experiment_name + "_store") | |||
| # print_rank_0(f"Using file store at {file_path}") | |||
| # store = torch.distributed.FileStore(file_path, torch.distributed.get_world_size()) | |||
| with torch.no_grad(): | |||
| # For all the batches in the dataset. | |||
| for _, batch in enumerate(dataloader): | |||
| # Run the model forward. | |||
| data = process_batch(batch, args) | |||
| if args.pretrained_bert: | |||
| tokens, types, labels_, attention_mask = data['text'], data[ | |||
| 'types'], data['label'], data['padding_mask'] | |||
| inputs = [tokens, types, attention_mask] | |||
| elif args.cloze_eval: | |||
| tokens, labels_, position_ids = data['text'], data[ | |||
| 'label'], data['position'] | |||
| attention_mask, target_ids, logit_mask = data['mask'], data[ | |||
| 'target'], data['logit_mask'] | |||
| if not args.fast_decode: | |||
| inputs = [ | |||
| tokens, position_ids, attention_mask, target_ids, | |||
| logit_mask | |||
| ] | |||
| if args.continuous_prompt: | |||
| prompt_pos = data['prompt_pos'] | |||
| inputs.append(prompt_pos) | |||
| else: | |||
| dec_input_ids, dec_position_ids, dec_attention_mask = data[ | |||
| 'dec_text'], data['dec_position'], data['dec_mask'] | |||
| dec_target_ids, dec_logit_mask = data['dec_target'], data[ | |||
| 'dec_logit_mask'] | |||
| inputs = [ | |||
| tokens, position_ids, attention_mask, dec_input_ids, | |||
| dec_position_ids, dec_attention_mask, dec_target_ids, | |||
| dec_logit_mask | |||
| ] | |||
| else: | |||
| tokens, labels_, position_ids, attention_mask = data[ | |||
| 'text'], data['label'], data['position'], data['mask'] | |||
| inputs = [tokens, position_ids, attention_mask] | |||
| if len(inputs[0].shape | |||
| ) == 3 and inputs[0].size(1) > segment_length: | |||
| logit_list = [] | |||
| for i in range((inputs[0].size(1) - 1) // segment_length + 1): | |||
| input_batch = [ | |||
| arg[:, i * segment_length:(i + 1) * segment_length] | |||
| for arg in inputs | |||
| ] | |||
| if args.pretrained_bert: | |||
| logits = model(*input_batch) | |||
| else: | |||
| logits, *mems = model(*input_batch) | |||
| logit_list.append(logits) | |||
| logits = torch.cat(logit_list, dim=1) | |||
| elif args.cloze_eval and args.fast_decode: | |||
| logit_list = [] | |||
| num_choices = inputs[3].size(1) | |||
| for i in range((num_choices - 1) // segment_length + 1): | |||
| input_batch = inputs[:3] + [ | |||
| arg[:, i * segment_length:(i + 1) * segment_length] | |||
| for arg in inputs[3:] | |||
| ] | |||
| logits, *mems = model(*input_batch) | |||
| logit_list.append(logits) | |||
| logits = torch.cat(logit_list, dim=1) | |||
| else: | |||
| if args.pretrained_bert: | |||
| logits = model(*inputs) | |||
| else: | |||
| logits, *mems = model(*inputs) | |||
| if 'segment_id' in data: | |||
| from torch_scatter import scatter_sum | |||
| if 'loss_mask' in data: | |||
| logits = logits * data['loss_mask'] | |||
| logits = scatter_sum(logits, data['segment_id'], dim=1) | |||
| elif 'loss_mask' in data: | |||
| loss_mask = data['loss_mask'] | |||
| logits = logits * loss_mask - 10000.0 * (1.0 - loss_mask) | |||
| uid_list = batch['uid'] | |||
| if isinstance(uid_list, torch.Tensor): | |||
| uid_list = uid_list.cpu().numpy().tolist() | |||
| predicted = torch.argmax(logits, dim=-1).tolist() | |||
| labels = labels_.tolist() | |||
| if args.task.lower() == 'wsc': | |||
| predicted = [1 if pred == 0 else 0 for pred in predicted] | |||
| if mpu.get_model_parallel_rank() == 0: | |||
| for uid, prediction, label in zip(uid_list, predicted, labels): | |||
| store.set(uid, str((prediction, label))) | |||
| model.train() | |||
| torch.distributed.barrier() | |||
| predictions, labels, examples = [], [], [] | |||
| for uid, example in example_dict.items(): | |||
| prediction, label = eval(store.get(uid)) | |||
| predictions.append(prediction) | |||
| labels.append(label) | |||
| examples.append(example) | |||
| torch.distributed.barrier() | |||
| return predictions, labels, examples | |||
| @@ -0,0 +1,249 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import math | |||
| from bisect import bisect_right | |||
| from itertools import accumulate | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from tasks.data_utils import build_input_from_ids, num_special_tokens_to_add | |||
| from tasks.language_model.detokenizer import get_detokenizer | |||
| from utils import print_rank_0 | |||
| class LMDataset(torch.utils.data.Dataset): | |||
| def __init__(self, args, documents, tokenizer, num_original_tokens, | |||
| num_tokenized_tokens): | |||
| self.args = args | |||
| self.documents = documents | |||
| self.max_seq_len = args.seq_length - 1 | |||
| self.tokenizer = tokenizer | |||
| self.overalapping_eval = args.overlapping_eval | |||
| if self.overalapping_eval is None: | |||
| self.overalapping_eval = self.max_seq_len | |||
| self.overalapping_eval = max(1, self.overalapping_eval) | |||
| self.num_original_tokens = num_original_tokens | |||
| self.num_tokenized_tokens = num_tokenized_tokens | |||
| # remove first sequence tokens | |||
| targets = [ | |||
| max(len(tokens) - self.max_seq_len, 0) for tokens in self.documents | |||
| ] | |||
| self.num_sequences = [ | |||
| max(math.ceil(target / self.overalapping_eval) + 1, 1) | |||
| for target in targets | |||
| ] | |||
| self.weights = list(accumulate(self.num_sequences)) | |||
| self.left_weights = [0] + self.weights[:-1] | |||
| self.unidirectional = args.unidirectional | |||
| self.block_lm = args.block_lm | |||
| mask_token = 'gMASK' if args.task_mask else 'MASK' | |||
| self.mask_id = self.tokenizer.get_command(mask_token).Id | |||
| def __len__(self): | |||
| return sum(self.num_sequences) | |||
| def __getitem__(self, idx): | |||
| document_idx = bisect_right(self.weights, idx) | |||
| idx = idx - self.left_weights[document_idx] | |||
| start_idx = idx * self.overalapping_eval | |||
| end_idx = start_idx + self.max_seq_len | |||
| tokens = self.documents[document_idx][start_idx:end_idx] | |||
| if self.block_lm: | |||
| if idx == 0 or self.unidirectional: | |||
| prompt, text = tokens[:1], tokens[1:] | |||
| else: | |||
| prompt_length = self.max_seq_len - self.overalapping_eval | |||
| prompt, text = tokens[:prompt_length], tokens[prompt_length:] | |||
| prompt = prompt + [self.mask_id] | |||
| num_special_tokens = num_special_tokens_to_add( | |||
| prompt, | |||
| None, | |||
| text, | |||
| add_cls=True, | |||
| add_sep=False, | |||
| add_piece=True, | |||
| add_eos=False) | |||
| data = build_input_from_ids( | |||
| prompt, | |||
| None, | |||
| text, | |||
| self.max_seq_len + num_special_tokens + 1, | |||
| self.tokenizer, | |||
| args=self.args, | |||
| add_cls=True, | |||
| add_sep=False, | |||
| add_piece=True, | |||
| add_eos=False, | |||
| mask_id=self.mask_id) | |||
| ids, types, paddings, position_ids, sep, target_ids, loss_masks = data | |||
| if idx != 0 and self.unidirectional: | |||
| loss_masks = np.array(loss_masks, dtype=np.int64) | |||
| loss_masks[:-self.overalapping_eval] = 0 | |||
| return { | |||
| 'text': np.array(ids, dtype=np.int64), | |||
| 'target': np.array(target_ids, dtype=np.int64), | |||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||
| 'loss_mask': np.array(loss_masks, dtype=np.int64), | |||
| 'position_id': np.array(position_ids, dtype=np.int64) | |||
| } | |||
| else: | |||
| loss_masks = [1] * len(tokens) | |||
| if len(tokens) < self.max_seq_len: | |||
| tokens = tokens + [0] * (self.max_seq_len - len(tokens)) | |||
| loss_masks = loss_masks + [0] * ( | |||
| self.max_seq_len - len(loss_masks)) | |||
| if idx != 0: | |||
| loss_masks = np.array(loss_masks, dtype=np.int64) | |||
| loss_masks[:-self.overalapping_eval] = 0 | |||
| return { | |||
| 'text': np.array(tokens, dtype=np.int64), | |||
| 'loss_mask': np.array(loss_masks, dtype=np.int64) | |||
| } | |||
| class LambadaDataset(torch.utils.data.Dataset): | |||
| def __init__(self, args, tokenizer, strict=True): | |||
| data_path = args.valid_data[0] | |||
| print_rank_0( | |||
| '> building lambada dataset from {} ...'.format(data_path)) | |||
| self.args = args | |||
| self.max_seq_length = args.seq_length | |||
| self.tokenizer = tokenizer | |||
| self.pad_idx = tokenizer.get_command('pad').Id | |||
| self.strict = strict | |||
| self.block_lm = args.block_lm | |||
| self.unidirectional = args.unidirectional | |||
| mask_token = 'gMASK' if args.task_mask else 'MASK' | |||
| self.mask_id = self.tokenizer.get_command(mask_token).Id | |||
| self.tokens = [] | |||
| self.labels = [] | |||
| with open(data_path, 'r') as f: | |||
| for line in f.readlines(): | |||
| text = json.loads(line)['text'] | |||
| tokens, labels = self.get_tokens(text) | |||
| self.tokens.append(tokens) | |||
| self.labels.append(labels) | |||
| def get_tokens(self, text): | |||
| if not self.strict: | |||
| tokens = self.tokenizer.EncodeAsIds(text).tokenization | |||
| return tokens[:-1], [tokens[-1]] | |||
| last_token = text.split()[-1] | |||
| start_idx = text.rfind(last_token) | |||
| beginning_tokens = self.tokenizer.EncodeAsIds( | |||
| text[:start_idx].strip()).tokenization | |||
| last_token = self.tokenizer.EncodeAsIds(' ' + last_token).tokenization | |||
| return beginning_tokens, last_token | |||
| def __len__(self): | |||
| return len(self.tokens) | |||
| def __getitem__(self, idx): | |||
| tokens, answer = self.tokens[idx], self.labels[idx] | |||
| if self.block_lm: | |||
| if self.unidirectional: | |||
| tokens, answer_tokens = tokens[:1], tokens[1:] + answer | |||
| else: | |||
| answer_tokens = answer | |||
| tokens = tokens + [self.mask_id] | |||
| num_special_tokens = num_special_tokens_to_add( | |||
| tokens, | |||
| None, | |||
| answer_tokens, | |||
| add_cls=True, | |||
| add_sep=False, | |||
| add_piece=True) | |||
| left_shift = len(tokens) + len( | |||
| answer_tokens) + num_special_tokens - self.max_seq_length | |||
| if left_shift > 0: | |||
| tokens = tokens[left_shift:] | |||
| data = build_input_from_ids( | |||
| tokens, | |||
| None, | |||
| answer_tokens, | |||
| self.max_seq_length, | |||
| self.tokenizer, | |||
| args=self.args, | |||
| add_cls=True, | |||
| add_sep=False, | |||
| add_piece=True, | |||
| mask_id=self.mask_id) | |||
| ids, types, paddings, position_ids, sep, target_ids, loss_masks = data | |||
| if self.unidirectional: | |||
| loss_masks = np.array(loss_masks, dtype=np.int64) | |||
| last_index = len(loss_masks) | |||
| while loss_masks[last_index - 1] == 0: | |||
| last_index -= 1 | |||
| loss_masks[:last_index - len(answer)] = 0 | |||
| return { | |||
| 'text': np.array(ids, dtype=np.int64), | |||
| 'target': np.array(target_ids, dtype=np.int64), | |||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||
| 'loss_mask': np.array(loss_masks, dtype=np.int64), | |||
| 'position_id': np.array(position_ids, dtype=np.int64) | |||
| } | |||
| else: | |||
| left_shift = len(tokens) - self.max_seq_length | |||
| if left_shift > 0: | |||
| tokens = tokens[left_shift:] | |||
| ids = tokens + answer | |||
| if len(ids) < self.max_seq_length: | |||
| ids = ids + [0] * (self.max_seq_length - len(ids)) | |||
| loss_masks = [0] * len(tokens) + [1] * len(answer) | |||
| if len(loss_masks) < self.max_seq_length: | |||
| loss_masks = loss_masks + [0] * ( | |||
| self.max_seq_length - len(loss_masks)) | |||
| return { | |||
| 'text': np.array(ids, dtype=np.int64), | |||
| 'loss_mask': np.array(loss_masks, dtype=np.int64) | |||
| } | |||
| def build_lambada_dataset(tokenizer, args): | |||
| """Build lambada dataset.""" | |||
| assert len(args.valid_data) == 1 | |||
| val_dataset = LambadaDataset(args, tokenizer, strict=True) | |||
| print_rank_0(' > found {} samples, {} label tokens.'.format( | |||
| len(val_dataset), sum(map(len, val_dataset.labels)))) | |||
| return val_dataset | |||
| def build_lm_dataset(tokenizer, args): | |||
| documents = [] | |||
| num_tokens, num_original_tokens = 0, 0 | |||
| with open(args.valid_data[0], encoding='utf-8') as file: | |||
| for line in file: | |||
| tokens = tokenizer.EncodeAsIds(line.strip()).tokenization | |||
| num_tokens += len(tokens) | |||
| num_original_tokens += len(line.strip().split(' ')) | |||
| documents.append(tokens) | |||
| val_dataset = LMDataset(args, documents, tokenizer, num_original_tokens, | |||
| num_tokens) | |||
| print_rank_0( | |||
| ' > number of document: {}, number of original tokens {}, number of detokenized tokens: {}' | |||
| .format(len(documents), num_original_tokens, num_tokens)) | |||
| return val_dataset | |||
| def build_wikitext103_dataset(tokenizer, args): | |||
| """""" | |||
| assert len(args.valid_data) == 1 | |||
| with open(args.valid_data[0], 'rb') as reader: | |||
| entire_data = reader.read().decode('utf-8') | |||
| num_original_tokens = len(entire_data.strip().split(' ')) | |||
| entire_data = get_detokenizer('wikitext')(entire_data) | |||
| print_rank_0(entire_data[:1024]) | |||
| tokenized_data = tokenizer.EncodeAsIds(entire_data).tokenization | |||
| num_tokenized_tokens = len(tokenized_data) | |||
| val_dataset = LMDataset(args, [tokenized_data], tokenizer, | |||
| num_original_tokens, num_tokenized_tokens) | |||
| print_rank_0(' > number of original tokens: {}, number of detokenized ' | |||
| 'tokens: {}'.format(num_original_tokens, | |||
| num_tokenized_tokens)) | |||
| return val_dataset | |||
| @@ -0,0 +1,63 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import re | |||
| def ptb_detokenizer(string): | |||
| string = string.replace(" '", "'") | |||
| string = string.replace(' \n', '\n') | |||
| string = string.replace('\n ', '\n') | |||
| string = string.replace(" n't", "n't") | |||
| string = string.replace(' N ', '1 ') | |||
| string = string.replace('$ 1', '$1') | |||
| string = string.replace('# 1', '#1') | |||
| return string | |||
| def wikitext_detokenizer(string): | |||
| # contractions | |||
| string = string.replace("s '", "s'") | |||
| string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) | |||
| # number separators | |||
| string = string.replace(' @-@ ', '-') | |||
| string = string.replace(' @,@ ', ',') | |||
| string = string.replace(' @.@ ', '.') | |||
| # punctuation | |||
| string = string.replace(' : ', ': ') | |||
| string = string.replace(' ; ', '; ') | |||
| string = string.replace(' . ', '. ') | |||
| string = string.replace(' ! ', '! ') | |||
| string = string.replace(' ? ', '? ') | |||
| string = string.replace(' , ', ', ') | |||
| # double brackets | |||
| string = re.sub(r'\(\s*([^\)]*?)\s*\)', r'(\1)', string) | |||
| string = re.sub(r'\[\s*([^\]]*?)\s*\]', r'[\1]', string) | |||
| string = re.sub(r'{\s*([^}]*?)\s*}', r'{\1}', string) | |||
| string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) | |||
| string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) | |||
| # miscellaneous | |||
| string = string.replace('= = = =', '====') | |||
| string = string.replace('= = =', '===') | |||
| string = string.replace('= =', '==') | |||
| string = string.replace(' ' + chr(176) + ' ', chr(176)) | |||
| string = string.replace(' \n', '\n') | |||
| string = string.replace('\n ', '\n') | |||
| string = string.replace(' N ', ' 1 ') | |||
| string = string.replace(" 's", "'s") | |||
| return string | |||
| def lambada_detokenizer(string): | |||
| return string | |||
| def get_detokenizer(dataset): | |||
| return DETOKENIZERS[dataset] | |||
| DETOKENIZERS = { | |||
| 'ptb': ptb_detokenizer, | |||
| 'wikitext': wikitext_detokenizer, | |||
| 'lambada': lambada_detokenizer, | |||
| } | |||
| @@ -0,0 +1,254 @@ | |||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """GPT2 zero-shot evaluation.""" | |||
| import functools | |||
| import math | |||
| import mpu | |||
| import torch | |||
| from finetune_glm import finetune | |||
| from pretrain_glm import get_batch | |||
| from tasks.data_utils import build_data_loader | |||
| from tasks.language_model.dataset import (build_lambada_dataset, | |||
| build_lm_dataset, | |||
| build_wikitext103_dataset) | |||
| from utils import print_rank_0 | |||
| global_tokenizer = None | |||
| def lm_forward_step(data, model, args, timers, mems, eval_metric=None): | |||
| """Forward step.""" | |||
| # Get the batch. | |||
| if timers is not None: | |||
| timers('batch generator').start() | |||
| if 'mask' in data: | |||
| data['attention_mask'] = data.pop('mask') | |||
| tokens, labels, loss_mask, attention_mask, position_ids = get_batch( | |||
| data, args) | |||
| if timers is not None: | |||
| timers('batch generator').stop() | |||
| def print_masked_text(batch_id): | |||
| block_position_ids = position_ids[:, 1] | |||
| position_ids_ = position_ids[:, 0] | |||
| output_tokens = [] | |||
| sep = attention_mask[batch_id].item() | |||
| for i, token in enumerate(tokens[batch_id, :sep].tolist()): | |||
| if global_tokenizer is not None: | |||
| token = global_tokenizer.IdToToken(token) | |||
| if token.startswith('[MASK'): | |||
| token = f'[{position_ids_[batch_id, i].item()}, {token}]' | |||
| if token.startswith('##') and len( | |||
| output_tokens) > 0 and not output_tokens[-1].endswith( | |||
| ']'): | |||
| output_tokens[-1] += token[2:] | |||
| else: | |||
| output_tokens.append(token) | |||
| else: | |||
| output_tokens.append(str(token)) | |||
| print(' '.join(output_tokens)) | |||
| last_index = None | |||
| for i in range(sep, tokens.size(1)): | |||
| if global_tokenizer.IdToToken( | |||
| tokens[batch_id, i].item()).startswith('<|startofpiece'): | |||
| if last_index is not None: | |||
| print( | |||
| global_tokenizer.DecodeIds( | |||
| tokens[batch_id, last_index:i].tolist()), '|', | |||
| global_tokenizer.DecodeIds( | |||
| labels[batch_id, last_index:i].tolist())), | |||
| print(position_ids_[batch_id, last_index:i].tolist(), | |||
| block_position_ids[batch_id, last_index:i].tolist()) | |||
| last_index = i | |||
| if last_index is not None: | |||
| print( | |||
| global_tokenizer.DecodeIds(tokens[batch_id, | |||
| last_index:].tolist()), '|', | |||
| global_tokenizer.DecodeIds(labels[batch_id, | |||
| last_index:].tolist())) | |||
| print(position_ids_[batch_id, last_index:].tolist(), | |||
| block_position_ids[batch_id, last_index:].tolist()) | |||
| # Forward model. | |||
| if args.continuous_prompt: | |||
| prompt_pos = data['prompt_pos'].long().cuda() | |||
| logits, *mems = model( | |||
| tokens, position_ids, attention_mask, *mems, prompt_pos=prompt_pos) | |||
| else: | |||
| logits, *mems = model(tokens, position_ids, attention_mask, *mems) | |||
| if eval_metric is None or eval_metric == 'loss': | |||
| losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), | |||
| labels) | |||
| loss_mask = loss_mask.view(-1) | |||
| # The loss is not normalized for fair comparison | |||
| loss = torch.sum(losses.view(-1) * loss_mask) | |||
| if eval_metric is None: | |||
| loss = loss / loss_mask.sum() | |||
| return loss, mems, 'bert' | |||
| elif eval_metric == 'accuracy' or eval_metric == 'classify': | |||
| logits = mpu.gather_from_model_parallel_region(logits) | |||
| outputs = torch.argmax(logits, -1) | |||
| correct = (outputs == labels).float() | |||
| correct[(1 - loss_mask).bool()] = 1 | |||
| correct = correct.prod(-1) | |||
| if eval_metric == 'accuracy': | |||
| correct = correct.sum() | |||
| return correct, mems, 'bert' | |||
| else: | |||
| raise NotImplementedError( | |||
| 'Metric {} not implemented'.format(eval_metric)) | |||
| def classify_evaluate(model, dataloader, example_dict, args): | |||
| """Evaluation.""" | |||
| # Turn on evaluation mode which disables dropout. | |||
| model.eval() | |||
| predictions, labels, examples = [], [], [] | |||
| with torch.no_grad(): | |||
| # For all the batches in the dataset. | |||
| for iteration, batch in enumerate(dataloader): | |||
| # Forward evaluation. | |||
| output, _, _ = lm_forward_step( | |||
| batch, model, args, None, [], eval_metric='classify') | |||
| uid_list = batch['uid'] | |||
| example_batch = [example_dict[uid] for uid in uid_list] | |||
| predictions.extend(output.long().tolist()) | |||
| label = batch['label'].tolist() | |||
| labels.extend(label) | |||
| examples.extend(example_batch) | |||
| return predictions, labels, examples | |||
| def evaluate(model, dataloader, eval_metric, args): | |||
| """Evaluation.""" | |||
| # Turn on evaluation mode which disables dropout. | |||
| model.eval() | |||
| total_output, total_count = 0.0, 0 | |||
| total_tokens = 0 | |||
| with torch.no_grad(): | |||
| # For all the batches in the dataset. | |||
| for iteration, batch in enumerate(dataloader): | |||
| if (iteration + 1) % args.log_interval == 0: | |||
| print_rank_0('> working on iteration: {}'.format(iteration)) | |||
| # Forward evaluation. | |||
| output, _, _ = lm_forward_step( | |||
| batch, model, args, None, [], eval_metric=eval_metric) | |||
| count = batch['text'].size(0) | |||
| count = torch.cuda.LongTensor([count]) | |||
| # Reduce across processes. | |||
| torch.distributed.all_reduce( | |||
| output, group=mpu.get_data_parallel_group()) | |||
| torch.distributed.all_reduce( | |||
| count, group=mpu.get_data_parallel_group()) | |||
| total_output += output.item() | |||
| total_count += count.item() | |||
| total_tokens += batch['loss_mask'].sum().item() | |||
| totals = torch.cuda.FloatTensor([total_output, total_tokens]) | |||
| torch.distributed.all_reduce(totals, group=mpu.get_data_parallel_group()) | |||
| total_output, total_tokens = totals.tolist() | |||
| print(total_tokens) | |||
| return {eval_metric: total_output}, total_count | |||
| def evaluate_and_print_results(data_loader, model, eval_metric, args): | |||
| """Evaluate and print results on screen.""" | |||
| # Evaluate and get results. | |||
| output, _ = evaluate(model, data_loader, eval_metric, args) | |||
| string = '' | |||
| if eval_metric == 'loss': | |||
| output = output['loss'] | |||
| num_tokenized_tokens = data_loader.dataset.num_tokenized_tokens | |||
| num_original_tokens = data_loader.dataset.num_original_tokens | |||
| val_loss = output / (num_tokenized_tokens - 1) | |||
| ppl = math.exp(min(20, val_loss)) | |||
| token_ratio = (num_tokenized_tokens - 1) / (num_original_tokens - 1) | |||
| adjusted_ppl = math.exp(min(20, val_loss * token_ratio)) | |||
| string += 'avg loss: {:.4E} | '.format(val_loss) | |||
| string += 'ppl: {:.4E} | '.format(ppl) | |||
| string += 'adjusted ppl: {:.4E} | '.format(adjusted_ppl) | |||
| string += 'token ratio: {} |'.format(token_ratio) | |||
| score_dict = { | |||
| 'avg loss': val_loss, | |||
| 'ppl': ppl, | |||
| 'adjusted ppl': adjusted_ppl | |||
| } | |||
| elif eval_metric == 'accuracy': | |||
| output = output['accuracy'] | |||
| num_examples = len(data_loader.dataset) | |||
| acc = output / num_examples * 100 | |||
| string += 'number correct: {} | '.format(output) | |||
| string += 'total examples: {} | '.format(num_examples) | |||
| string += 'avg accuracy: {:.2f}'.format(acc) | |||
| score_dict = {'accuracy': acc} | |||
| else: | |||
| raise NotImplementedError('evaluation method for {} metric is not ' | |||
| 'implemented yet.'.format(eval_metric)) | |||
| length = len(string) + 1 | |||
| print_rank_0('-' * length) | |||
| print_rank_0(string) | |||
| print_rank_0('-' * length) | |||
| return score_dict | |||
| def metrics_func_provider(args, tokenizer, is_test): | |||
| """Privde metrics callback function.""" | |||
| if args.task.lower() == 'lambda': | |||
| eval_metric = 'accuracy' | |||
| dataset = build_lambada_dataset(tokenizer, args) | |||
| elif args.task == 'wikitext': | |||
| eval_metric = 'loss' | |||
| dataset = build_wikitext103_dataset(tokenizer, args) | |||
| elif args.task == 'language_model': | |||
| eval_metric = 'loss' | |||
| dataset = build_lm_dataset(tokenizer, args) | |||
| else: | |||
| raise NotImplementedError('{} task is not implemented.'.format( | |||
| args.task)) | |||
| # Data stuff | |||
| dataloader = build_data_loader( | |||
| dataset, | |||
| args.eval_batch_size, | |||
| args.num_workers, | |||
| drop_last=False, | |||
| shuffle=False) | |||
| def metrics_func(model, | |||
| epoch, | |||
| output_predictions=False, | |||
| summary_writer=None): | |||
| return evaluate_and_print_results( | |||
| dataloader, model, eval_metric=eval_metric, args=args) | |||
| global global_tokenizer | |||
| global_tokenizer = tokenizer | |||
| return metrics_func | |||
| def main(args): | |||
| """Main program.""" | |||
| finetune( | |||
| args, | |||
| None, {}, | |||
| end_of_epoch_callback_provider=metrics_func_provider, | |||
| forward_step=lm_forward_step) | |||
| @@ -0,0 +1,667 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import os | |||
| import random | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| import torch.utils.data | |||
| from data_utils.corpora import punctuation_standardization | |||
| from tasks.data_utils import InputExample | |||
| from tqdm import tqdm | |||
| from utils import print_rank_0 | |||
| def gigaword_detokenize(string, is_target=False): | |||
| _tok_dict = { | |||
| '(': '-lrb-', | |||
| ')': '-rrb-', | |||
| '[': '-lsb-', | |||
| ']': '-rsb-', | |||
| '{': '-lcb-', | |||
| '}': '-rcb-', | |||
| '&': '&', | |||
| '<': '<', | |||
| '>': '>' | |||
| } | |||
| string = string.replace('UNK', '[UNK]') | |||
| string = string.replace('<unk>', '[UNK]') | |||
| for key, value in _tok_dict.items(): | |||
| string = string.replace(value, key) | |||
| # string = string.replace("''", "\"") | |||
| # string = string.replace("``", "\"") | |||
| # string = string.replace("`", "'") | |||
| # string = string.replace(" n't", "n't") | |||
| # string = string.replace(" 's", "'s") | |||
| # string = string.replace(" 'd", "'d") | |||
| # string = string.replace(" 'll", "'ll") | |||
| return string | |||
| def cnndm_detokenize(string, is_target=False): | |||
| _tok_dict = { | |||
| '(': '-LRB-', | |||
| ')': '-RRB-', | |||
| '[': '-LSB-', | |||
| ']': '-RSB-', | |||
| '{': '-LCB-', | |||
| '}': '-RCB-' | |||
| } | |||
| if not is_target: | |||
| string = string.replace('<S_SEP>', '') | |||
| else: | |||
| string = string.replace('<S_SEP>', '[SEP]') | |||
| for key, value in _tok_dict.items(): | |||
| string = string.replace(value, key) | |||
| string = string.replace("''", "\"") | |||
| string = string.replace('``', "\"") | |||
| string = string.replace('`', "'") | |||
| string = string.replace(" n't", "n't") | |||
| string = string.replace(" 's", "'s") | |||
| string = string.replace(" 'd", "'d") | |||
| string = string.replace(" 'll", "'ll") | |||
| return string | |||
| def blanklm_detokenize(string, is_target=False): | |||
| string = string.replace('_UNK', '[UNK]') | |||
| string = string.replace('<blank>', '[MASK]') | |||
| return string | |||
| class SummmaryProcessor: | |||
| def __init__(self, task, data_dir, tokenizer): | |||
| self.task = task | |||
| self.data_dir = data_dir | |||
| self.tokenizer = tokenizer | |||
| def create_examples(self, split): | |||
| if split == 'train': | |||
| filename = 'train' | |||
| elif split == 'dev': | |||
| filename = 'val' | |||
| elif split == 'test': | |||
| filename = 'test' | |||
| else: | |||
| raise NotImplementedError(split) | |||
| print_rank_0( | |||
| f'Creating {self.task}-{split} dataset from {self.data_dir}') | |||
| if self.task == 'gigaword': | |||
| detokenizer = gigaword_detokenize | |||
| elif self.task == 'cnn_dm': | |||
| detokenizer = cnndm_detokenize | |||
| else: | |||
| detokenizer = None | |||
| source_texts, target_texts = [], [] | |||
| with open( | |||
| os.path.join(self.data_dir, f'{filename}.source'), | |||
| encoding='utf-8') as file: | |||
| for line in file: | |||
| line = line.strip() | |||
| line = punctuation_standardization(line) | |||
| line = detokenizer(line) if detokenizer else line | |||
| source_texts.append(line) | |||
| with open( | |||
| os.path.join(self.data_dir, f'{filename}.target'), | |||
| encoding='utf-8') as file: | |||
| for line in file: | |||
| line = line.strip() | |||
| line = punctuation_standardization(line) | |||
| line = detokenizer( | |||
| line, is_target=True) if detokenizer else line | |||
| target_texts.append(line) | |||
| assert len(source_texts) == len(target_texts) | |||
| example_list = [] | |||
| for idx, (source_text, | |||
| target_text) in enumerate(zip(source_texts, target_texts)): | |||
| if (idx + 1) % 20000 == 0: | |||
| print_rank_0(f'Complete {idx + 1} examples') | |||
| guid = '%s-%s' % (split, idx) | |||
| meta = { | |||
| 'ref': | |||
| self.tokenizer.DecodeIds( | |||
| self.tokenizer.EncodeAsIds(target_text).tokenization) | |||
| } | |||
| example = InputExample( | |||
| guid=guid, text_a=source_text, text_b=target_text, meta=meta) | |||
| if idx < 10: | |||
| print_rank_0( | |||
| (source_text.encode('utf-8'), target_text.encode('utf-8'), | |||
| meta['ref'].encode('utf-8'))) | |||
| example_list.append(example) | |||
| return example_list | |||
| class SQuADProcessor: | |||
| def __init__(self, data_dir, tokenizer): | |||
| self.data_dir = data_dir | |||
| self.tokenizer = tokenizer | |||
| def create_examples(self, split): | |||
| if split == 'train': | |||
| filename = 'train.json' | |||
| elif split == 'dev': | |||
| filename = 'dev.json' | |||
| elif split == 'test': | |||
| filename = 'test.json' | |||
| else: | |||
| raise NotImplementedError(split) | |||
| print_rank_0(f'Creating SQuAD-{split} dataset from {self.data_dir}') | |||
| example_list = [] | |||
| idx = 0 | |||
| with open( | |||
| os.path.join(self.data_dir, filename), | |||
| encoding='utf-8') as file: | |||
| dataset = json.load(file) | |||
| for paragraphs in dataset: | |||
| for paragraph in paragraphs['paragraphs']: | |||
| context = paragraph['context'] | |||
| for qa in paragraph['qas']: | |||
| question = qa['question'] | |||
| answers = {answer['text'] for answer in qa['answers']} | |||
| answer_starts = { | |||
| answer['text']: answer['answer_start'] | |||
| for answer in qa['answers'] | |||
| } | |||
| for answer in answers: | |||
| guid = '%s-%s' % (split, idx) | |||
| meta = { | |||
| 'answer_start': | |||
| answer_starts[answer], | |||
| 'answer': | |||
| answer, | |||
| 'question': | |||
| question, | |||
| 'ref': | |||
| self.tokenizer.DecodeIds( | |||
| self.tokenizer.EncodeAsIds( | |||
| question).tokenization) | |||
| } | |||
| example = InputExample( | |||
| guid=guid, text_a=context, meta=meta) | |||
| if idx < 10: | |||
| print_rank_0((context.encode('utf-8'), | |||
| answer.encode('utf-8'), | |||
| meta['ref'].encode('utf-8'))) | |||
| example_list.append(example) | |||
| idx += 1 | |||
| print_rank_0(f'Creating {len(example_list)} examples for {split}') | |||
| return example_list | |||
| class XSumProcessor: | |||
| def __init__(self, data_dir, tokenizer): | |||
| self.data_dir = data_dir | |||
| self.tokenizer = tokenizer | |||
| def create_examples(self, split): | |||
| if split == 'train': | |||
| key = 'train' | |||
| elif split == 'dev': | |||
| key = 'validation' | |||
| elif split == 'test': | |||
| key = 'test' | |||
| else: | |||
| raise NotImplementedError(split) | |||
| print_rank_0(f'Creating XSUM-{split} dataset from {self.data_dir}') | |||
| with open( | |||
| os.path.join( | |||
| self.data_dir, | |||
| 'XSum-TRAINING-DEV-TEST-SPLIT-90-5-5.json')) as file: | |||
| id_list = json.load(file) | |||
| id_list = id_list[key] | |||
| source_texts, target_texts = [], [] | |||
| for i, idx in enumerate(id_list): | |||
| with open(os.path.join(self.data_dir, f'{idx}.summary')) as file: | |||
| key, sentences = None, [] | |||
| source_text, target_text = None, None | |||
| for line in file: | |||
| line = line.strip() | |||
| if line.startswith('[SN]'): | |||
| if key is not None: | |||
| if key == 'RESTBODY': | |||
| source_text = ' '.join(sentences) | |||
| elif key == 'FIRST-SENTENCE': | |||
| target_text = ' '.join(sentences) | |||
| key = line[4:-4] | |||
| sentences = [] | |||
| elif line: | |||
| sentences.append(line) | |||
| if key is not None: | |||
| if key == 'RESTBODY': | |||
| source_text = ' '.join(sentences) | |||
| elif key == 'FIRST-SENTENCE': | |||
| target_text = ' '.join(sentences) | |||
| source_texts.append(source_text) | |||
| target_texts.append(target_text) | |||
| if (i + 1) % 1000 == 0: | |||
| print_rank_0(f'Complete {i + 1} examples') | |||
| assert len(source_texts) == len(target_texts) | |||
| example_list = [] | |||
| for idx, (source_text, | |||
| target_text) in enumerate(zip(source_texts, target_texts)): | |||
| if (idx + 1) % 20000 == 0: | |||
| print_rank_0(f'Complete {idx + 1} examples') | |||
| guid = '%s-%s' % (split, idx) | |||
| meta = { | |||
| 'ref': | |||
| self.tokenizer.DecodeIds( | |||
| self.tokenizer.EncodeAsIds(target_text).tokenization) | |||
| } | |||
| example = InputExample( | |||
| guid=guid, text_a=source_text, text_b=target_text, meta=meta) | |||
| if idx < 10: | |||
| print_rank_0( | |||
| (source_text.encode('utf-8'), target_text.encode('utf-8'), | |||
| meta['ref'].encode('utf-8'))) | |||
| example_list.append(example) | |||
| return example_list | |||
| class Seq2SeqDataset(torch.utils.data.Dataset): | |||
| def __init__(self, args, split, tokenizer): | |||
| self.args = args | |||
| self.task, self.data_dir = args.task.lower(), args.data_dir | |||
| self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length | |||
| self.split = split | |||
| self.tokenizer = tokenizer | |||
| self.dataset_name = split | |||
| if self.task in ['gigaword', 'cnn_dm', 'cnn_dm_original']: | |||
| self.processor = SummmaryProcessor(self.task, self.data_dir, | |||
| tokenizer) | |||
| elif self.task in ['xsum']: | |||
| self.processor = XSumProcessor(self.data_dir, tokenizer) | |||
| elif self.task in ['squad_generation']: | |||
| self.processor = SQuADProcessor(self.data_dir, tokenizer) | |||
| else: | |||
| raise NotImplementedError | |||
| example_list = self.processor.create_examples(split) | |||
| self.example_list = example_list | |||
| self.examples = {example.guid: example for example in example_list} | |||
| print_rank_0(f'Return {len(self.examples)} {split} examples') | |||
| def __len__(self): | |||
| return len(self.example_list) | |||
| def __getitem__(self, idx): | |||
| example = self.example_list[idx] | |||
| cls_id = self.tokenizer.get_command('ENC').Id | |||
| mask_token = 'sMASK' if self.args.task_mask else 'MASK' | |||
| mask_id = self.tokenizer.get_command(mask_token).Id | |||
| pad_id = self.tokenizer.get_command('pad').Id | |||
| sop_id = self.tokenizer.get_command('sop').Id | |||
| eop_id = self.tokenizer.get_command('eop').Id | |||
| if self.task in ['gigaword', 'cnn_dm', 'cnn_dm_original', 'xsum']: | |||
| source_text, target_text = example.text_a, example.text_b | |||
| source_tokens = self.tokenizer.EncodeAsIds( | |||
| ' ' + source_text).tokenization | |||
| prompt = [cls_id, mask_id | |||
| ] + self.tokenizer.EncodeAsIds(' Content:').tokenization | |||
| if len(source_tokens) > self.max_src_length - len(prompt): | |||
| source_tokens = source_tokens[:self.max_src_length | |||
| - len(prompt)] | |||
| source_tokens = prompt + source_tokens | |||
| elif self.task == 'squad_generation': | |||
| source_text = example.text_a | |||
| target_text, answer = example.meta['question'], example.meta[ | |||
| 'answer'] | |||
| source_tokens = self.tokenizer.EncodeAsIds( | |||
| source_text.rstrip() + ' Question:').tokenization | |||
| answer_tokens = self.tokenizer.EncodeAsIds(' Answer: ' | |||
| + answer).tokenization | |||
| if len(source_tokens | |||
| ) > self.max_src_length - len(answer_tokens) - 2: | |||
| max_src_length = self.max_src_length - len(answer_tokens) - 2 | |||
| answer_pattern = self.tokenizer.EncodeAsIds( | |||
| ' ' + answer).tokenization | |||
| def sub_finder(mylist, pattern): | |||
| matches = [] | |||
| for i in range(len(mylist)): | |||
| if mylist[i] == pattern[0] and mylist[ | |||
| i:i + len(pattern)] == pattern: | |||
| matches.append(i) | |||
| return matches | |||
| answer_indices = sub_finder(source_tokens, answer_pattern) | |||
| if len(answer_indices) == 0: | |||
| print(f'Answer {answer} not exists in the source text') | |||
| source_tokens = source_tokens[:max_src_length] | |||
| else: | |||
| start_index = max(answer_indices[0] - max_src_length // 2, | |||
| 0) | |||
| source_tokens = source_tokens[start_index:start_index | |||
| + max_src_length] | |||
| source_tokens = [cls_id] + source_tokens + [mask_id | |||
| ] + answer_tokens | |||
| else: | |||
| raise NotImplementedError | |||
| if len(source_tokens) < self.max_src_length: | |||
| source_tokens = source_tokens + [pad_id] * ( | |||
| self.max_src_length - len(source_tokens)) | |||
| sep = len(source_tokens) | |||
| position_ids = list(range(len(source_tokens))) | |||
| block_position_ids = [0] * len(source_tokens) | |||
| mask_pos = source_tokens.index(mask_id) | |||
| if self.split == 'train': | |||
| target_tokens = self.tokenizer.EncodeAsIds( | |||
| ' ' + target_text).tokenization | |||
| target_tokens = target_tokens + [eop_id] | |||
| if len(target_tokens) > self.max_tgt_length: | |||
| target_tokens = target_tokens[:self.max_tgt_length] | |||
| loss_mask = [1] * len(target_tokens) | |||
| if len(target_tokens) < self.max_tgt_length: | |||
| loss_mask += [0] * (self.max_tgt_length - len(target_tokens)) | |||
| target_tokens += [pad_id] * ( | |||
| self.max_tgt_length - len(target_tokens)) | |||
| tokens = source_tokens + [sop_id] + target_tokens[:-1] | |||
| loss_mask = [0] * len(source_tokens) + loss_mask | |||
| target_ids = [0] * len(source_tokens) + target_tokens | |||
| position_ids += [mask_pos] * len(target_tokens) | |||
| if self.args.no_block_position: | |||
| block_position_ids += [1] * len(target_tokens) | |||
| else: | |||
| block_position_ids += list(range(1, len(target_tokens) + 1)) | |||
| position_ids = [position_ids, block_position_ids] | |||
| sample = { | |||
| 'text': np.array(tokens, dtype=np.int64), | |||
| 'target': np.array(target_ids, dtype=np.int64), | |||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||
| 'loss_mask': np.array(loss_mask, dtype=np.int64), | |||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||
| 'uid': example.guid | |||
| } | |||
| else: | |||
| tokens = source_tokens + [sop_id] | |||
| position_ids = position_ids + [mask_pos] | |||
| block_position_ids = block_position_ids + [1] | |||
| position_ids = [position_ids, block_position_ids] | |||
| sample = { | |||
| 'text': np.array(tokens, dtype=np.int64), | |||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||
| 'uid': example.guid | |||
| } | |||
| return sample | |||
| class ExtractionDataset(torch.utils.data.Dataset): | |||
| def __init__(self, args, split, tokenizer): | |||
| self.args = args | |||
| task, data_dir = args.task.lower(), args.data_dir | |||
| self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length | |||
| self.split = split | |||
| self.tokenizer = tokenizer | |||
| if split == 'train': | |||
| filename = 'train' | |||
| elif split == 'dev': | |||
| filename = 'valid' | |||
| elif split == 'test': | |||
| filename = 'test' | |||
| else: | |||
| raise NotImplementedError(split) | |||
| print_rank_0(f'Creating {task}-{split} dataset from {data_dir}') | |||
| self.dataset_name = split | |||
| source_texts, target_texts = [], [] | |||
| with open( | |||
| os.path.join(data_dir, f'{filename}.source'), | |||
| encoding='utf-8') as file: | |||
| for line in file: | |||
| line = line.strip() | |||
| source_texts.append(line) | |||
| with open( | |||
| os.path.join(data_dir, f'{filename}.target'), | |||
| encoding='utf-8') as file: | |||
| for line in file: | |||
| line = line.strip() | |||
| target_texts.append(line) | |||
| self.examples, self.example_list = {}, [] | |||
| for idx, (source_text, | |||
| target_text) in enumerate(zip(source_texts, target_texts)): | |||
| if (idx + 1) % 20000 == 0: | |||
| print_rank_0(f'Complete {idx + 1} examples') | |||
| guid = '%s-%s' % (split, idx) | |||
| meta = {'ref': target_text} | |||
| example = InputExample( | |||
| guid=guid, text_a=source_text, text_b=target_text, meta=meta) | |||
| self.examples[guid] = example | |||
| self.example_list.append(example) | |||
| print_rank_0(f'Return {len(self.examples)} {split} examples') | |||
| def __len__(self): | |||
| return len(self.example_list) | |||
| def __getitem__(self, idx): | |||
| example = self.example_list[idx] | |||
| source_text, target_text = example.text_a, example.text_b | |||
| mask_token = 'MASK' | |||
| mask_id = self.tokenizer.get_command(mask_token).Id | |||
| sop_id = self.tokenizer.get_command('sop').Id | |||
| eop_id = self.tokenizer.get_command('eop').Id | |||
| pad_id = self.tokenizer.get_command('pad').Id | |||
| def pad_to(text, max_len, pad_id): | |||
| if len(text) > max_len: | |||
| text = text[:max_len] | |||
| else: | |||
| text = text + [pad_id] * (max_len - len(text)) | |||
| return text | |||
| source_tokens = self.tokenizer.EncodeAsIds(source_text).tokenization | |||
| masked_tgt = target_text.split('|') | |||
| source_tokens = pad_to(source_tokens, self.max_src_length, pad_id) | |||
| sep = len(source_tokens) | |||
| position_ids = list(range(len(source_tokens))) | |||
| block_position_ids = [0] * len(source_tokens) | |||
| if self.split == 'train': | |||
| mask_positions = [ | |||
| i for i, x in enumerate(source_tokens) if x == mask_id | |||
| ] | |||
| assert len(mask_positions) <= len(masked_tgt) | |||
| tokens = source_tokens | |||
| target_ids = [0] * len(source_tokens) | |||
| loss_mask = [0] * len(source_tokens) | |||
| for i, mask_pos in enumerate(mask_positions): | |||
| tgt_text = masked_tgt[i] | |||
| tgt_tokens = self.tokenizer.EncodeAsIds( | |||
| ' ' + tgt_text).tokenization | |||
| tokens += [sop_id] + tgt_tokens | |||
| target_ids += tgt_tokens + [eop_id] | |||
| loss_mask += [1] * (len(tgt_tokens) + 1) | |||
| position_ids += [mask_pos] * (len(tgt_tokens) + 1) | |||
| block_position_ids += [ | |||
| i + 1 for i in range(len(tgt_tokens) + 1) | |||
| ] | |||
| tokens = pad_to(tokens, self.max_src_length + self.max_tgt_length, | |||
| pad_id) | |||
| target_ids = pad_to(target_ids, | |||
| self.max_src_length + self.max_tgt_length, | |||
| pad_id) | |||
| loss_mask = pad_to(loss_mask, | |||
| self.max_src_length + self.max_tgt_length, 0) | |||
| position_ids = pad_to(position_ids, | |||
| self.max_src_length + self.max_tgt_length, 0) | |||
| block_position_ids = pad_to( | |||
| block_position_ids, self.max_src_length + self.max_tgt_length, | |||
| 0) | |||
| position_ids = [position_ids, block_position_ids] | |||
| sample = { | |||
| 'text': np.array(tokens, dtype=np.int64), | |||
| 'target': np.array(target_ids, dtype=np.int64), | |||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||
| 'loss_mask': np.array(loss_mask, dtype=np.int64), | |||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||
| 'uid': example.guid | |||
| } | |||
| else: | |||
| tokens = source_tokens + [sop_id] | |||
| mask_pos = source_tokens.index(mask_id) | |||
| position_ids = position_ids + [mask_pos] | |||
| block_position_ids = block_position_ids + [1] | |||
| position_ids = [position_ids, block_position_ids] | |||
| sample = { | |||
| 'text': np.array(tokens, dtype=np.int64), | |||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||
| 'uid': example.guid | |||
| } | |||
| return sample | |||
| class BlankLMDataset(torch.utils.data.Dataset): | |||
| def __init__(self, args, split, tokenizer): | |||
| self.args = args | |||
| task, data_dir = args.task.lower(), args.data_dir | |||
| self.max_src_length, self.max_tgt_length = args.src_seq_length, args.tgt_seq_length | |||
| self.split = split | |||
| assert args.tokenizer_type == 'BertWordPieceTokenizer' | |||
| self.tokenizer = tokenizer | |||
| if split == 'train': | |||
| filename = 'train' | |||
| elif split == 'dev': | |||
| filename = 'valid' | |||
| elif split == 'test': | |||
| filename = 'test' | |||
| else: | |||
| raise NotImplementedError(split) | |||
| print_rank_0(f'Creating {task}-{split} dataset from {data_dir}') | |||
| self.dataset_name = split | |||
| detokenizer = blanklm_detokenize | |||
| source_texts, target_texts = [], [] | |||
| with open( | |||
| os.path.join(data_dir, f'{filename}.txt'), | |||
| encoding='utf-8') as file: | |||
| for line in file: | |||
| line = line.strip() | |||
| line = detokenizer(line) if detokenizer else line | |||
| target_texts.append(line) | |||
| if split == 'test': | |||
| with open( | |||
| os.path.join( | |||
| data_dir, | |||
| f'blank/test.maskratio{args.blank_maskratio:.1f}.blank' | |||
| ), | |||
| encoding='utf-8') as file: | |||
| for line in file: | |||
| line = line.strip() | |||
| line = detokenizer(line) if detokenizer else line | |||
| source_texts.append(line) | |||
| else: | |||
| source_texts = target_texts | |||
| self.examples, self.example_list = {}, [] | |||
| for idx, (source_text, | |||
| target_text) in enumerate(zip(source_texts, target_texts)): | |||
| # if idx > 10000: | |||
| # break | |||
| if (idx + 1) % 20000 == 0: | |||
| print_rank_0(f'Complete {idx + 1} examples') | |||
| guid = '%s-%s' % (split, idx) | |||
| meta = {'ref': target_text} | |||
| example = InputExample( | |||
| guid=guid, text_a=source_text, text_b=target_text, meta=meta) | |||
| self.examples[guid] = example | |||
| self.example_list.append(example) | |||
| print_rank_0(f'Return {len(self.examples)} {split} examples') | |||
| self.random = random.Random(args.seed) | |||
| def __len__(self): | |||
| return len(self.example_list) | |||
| def __getitem__(self, idx): | |||
| example = self.example_list[idx] | |||
| source_text, target_text = example.text_a, example.text_b # noqa | |||
| mask_token = 'gMASK' if self.args.task_mask else 'MASK' | |||
| mask_id = self.tokenizer.get_command(mask_token).Id | |||
| sop_id = self.tokenizer.get_command('sop').Id | |||
| eop_id = self.tokenizer.get_command('eop').Id | |||
| pad_id = self.tokenizer.get_command('pad').Id | |||
| if self.split in ['train', 'dev']: | |||
| masked_src, masked_tgt = self.mask_text(source_text) | |||
| source_text = masked_src | |||
| def pad_to(text, max_len, pad_id): | |||
| if len(text) > max_len: | |||
| text = text[:max_len] | |||
| else: | |||
| text = text + [pad_id] * (max_len - len(text)) | |||
| return text | |||
| source_tokens = self.tokenizer.EncodeAsIds(' ' | |||
| + source_text).tokenization | |||
| source_tokens = pad_to(source_tokens, self.max_src_length, pad_id) | |||
| sep = len(source_tokens) | |||
| position_ids = list(range(len(source_tokens))) | |||
| block_position_ids = [0] * len(source_tokens) | |||
| if self.split in ['train', 'dev']: | |||
| mask_positions = [ | |||
| i for i, x in enumerate(source_tokens) if x == mask_id | |||
| ] | |||
| assert len(mask_positions) <= len(masked_tgt) | |||
| tokens = source_tokens | |||
| target_ids = [0] * len(source_tokens) | |||
| loss_mask = [0] * len(source_tokens) | |||
| for i, mask_pos in enumerate(mask_positions): | |||
| tgt_text = masked_tgt[i] | |||
| tgt_tokens = self.tokenizer.EncodeAsIds( | |||
| ' ' + tgt_text).tokenization | |||
| tokens += [sop_id] + tgt_tokens | |||
| target_ids += tgt_tokens + [eop_id] | |||
| loss_mask += [1] * (len(tgt_tokens) + 1) | |||
| position_ids += [mask_pos] * (len(tgt_tokens) + 1) | |||
| block_position_ids += [ | |||
| i + 1 for i in range(len(tgt_tokens) + 1) | |||
| ] | |||
| max_length = self.max_src_length + int( | |||
| self.max_src_length * self.args.blank_maskratio) | |||
| tokens = pad_to(tokens, max_length, pad_id) | |||
| target_ids = pad_to(target_ids, max_length, pad_id) | |||
| loss_mask = pad_to(loss_mask, max_length, 0) | |||
| position_ids = pad_to(position_ids, max_length, 0) | |||
| block_position_ids = pad_to(block_position_ids, max_length, 0) | |||
| position_ids = [position_ids, block_position_ids] | |||
| sample = { | |||
| 'text': np.array(tokens, dtype=np.int64), | |||
| 'target': np.array(target_ids, dtype=np.int64), | |||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||
| 'loss_mask': np.array(loss_mask, dtype=np.int64), | |||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||
| 'uid': example.guid | |||
| } | |||
| else: | |||
| tokens = source_tokens + [sop_id] | |||
| mask_pos = source_tokens.index(mask_id) | |||
| position_ids = position_ids + [mask_pos] | |||
| block_position_ids = block_position_ids + [1] | |||
| position_ids = [position_ids, block_position_ids] | |||
| sample = { | |||
| 'text': np.array(tokens, dtype=np.int64), | |||
| 'attention_mask': np.array(sep, dtype=np.int64), | |||
| 'position_id': np.array(position_ids, dtype=np.int64), | |||
| 'uid': example.guid | |||
| } | |||
| return sample | |||
| def mask_text(self, text): | |||
| tokens = text.split() | |||
| mask_ratio = self.args.blank_maskratio | |||
| n = len(tokens) | |||
| indices = sorted(self.random.sample(range(n), int(n * mask_ratio))) | |||
| masked_src, masked_tgt = '', [] | |||
| for i, idx in enumerate(indices): | |||
| if i == 0 or idx != indices[i - 1] + 1: | |||
| masked_tgt.append('') | |||
| masked_tgt[-1] += ' ' + tokens[idx] | |||
| tokens[idx] = '[MASK]' | |||
| for i, token in enumerate(tokens): | |||
| if i != 0 and token == '[MASK]' and tokens[i - 1] == '[MASK]': | |||
| continue | |||
| masked_src += ' ' + token | |||
| return masked_src, masked_tgt | |||
| @@ -0,0 +1,538 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import datetime | |||
| import random | |||
| import string | |||
| import mpu | |||
| import torch | |||
| import torch.nn.functional as F | |||
| from generation_utils import (BeamSearchScorer, LogitsProcessorList, | |||
| MinLengthLogitsProcessor, | |||
| NoRepeatNGramLogitsProcessor) | |||
| from rouge_score import rouge_scorer | |||
| from utils import print_rank_0 | |||
| def _is_digit(w): | |||
| for ch in w: | |||
| if not (ch.isdigit() or ch == ','): | |||
| return False | |||
| return True | |||
| gigaword_tok_dict = { | |||
| '(': '-lrb-', | |||
| ')': '-rrb-', | |||
| '[': '-lsb-', | |||
| ']': '-rsb-', | |||
| '{': '-lcb-', | |||
| '}': '-rcb-', | |||
| '[UNK]': 'UNK', | |||
| '&': '&', | |||
| '<': '<', | |||
| '>': '>' | |||
| } | |||
| cnndm_tok_dict = { | |||
| '(': '-LRB-', | |||
| ')': '-RRB-', | |||
| '[': '-LSB-', | |||
| ']': '-RSB-', | |||
| '{': '-LCB-', | |||
| '}': '-RCB-' | |||
| } | |||
| def fix_tokenization(text, dataset): | |||
| if dataset == 'cnn_dm_org': | |||
| return text | |||
| if dataset == 'gigaword': | |||
| text = text.replace('[UNK]', 'UNK') | |||
| return text | |||
| input_tokens = text.split() | |||
| output_tokens = [] | |||
| has_left_quote = False | |||
| has_left_single_quote = False | |||
| i = 0 | |||
| prev_dash = False | |||
| while i < len(input_tokens): | |||
| tok = input_tokens[i] | |||
| flag_prev_dash = False | |||
| if tok == "\"": | |||
| if has_left_quote: | |||
| output_tokens.append("''") | |||
| else: | |||
| output_tokens.append('``') | |||
| has_left_quote = not has_left_quote | |||
| i += 1 | |||
| elif tok == "'" and len( | |||
| output_tokens) > 0 and output_tokens[-1].endswith( | |||
| 'n') and i < len(input_tokens) - 1 and input_tokens[ | |||
| i + 1] == 't': # noqa | |||
| output_tokens[-1] = output_tokens[-1][:-1] | |||
| output_tokens.append("n't") | |||
| i += 2 | |||
| elif tok == "'" and i < len(input_tokens) - 1 and input_tokens[ | |||
| i + 1] in ('s', 'd', 'll'): | |||
| output_tokens.append("'" + input_tokens[i + 1]) | |||
| i += 2 | |||
| elif tok == "'": | |||
| if has_left_single_quote: | |||
| output_tokens.append("'") | |||
| else: | |||
| output_tokens.append('`') | |||
| has_left_single_quote = not has_left_single_quote | |||
| i += 1 | |||
| elif tok == '.' and i < len(input_tokens) - 2 and input_tokens[ | |||
| i + 1] == '.' and input_tokens[i + 2] == '.': | |||
| output_tokens.append('...') | |||
| i += 3 | |||
| elif tok == ',' and len(output_tokens) > 0 and _is_digit( | |||
| output_tokens[-1]) and i < len(input_tokens) - 1 and _is_digit( | |||
| input_tokens[i + 1]): | |||
| # $ 3 , 000 -> $ 3,000 | |||
| output_tokens[-1] += ',' + input_tokens[i + 1] | |||
| i += 2 | |||
| elif tok == '.' and len(output_tokens) > 0 and output_tokens[-1].isdigit() and i < len(input_tokens) - 1 and \ | |||
| input_tokens[i + 1].isdigit(): | |||
| # 3 . 03 -> $ 3.03 | |||
| output_tokens[-1] += '.' + input_tokens[i + 1] | |||
| i += 2 | |||
| elif tok == '.' and len(output_tokens) > 0 and len( | |||
| output_tokens[-1]) == 1 and output_tokens[-1].isalpha( # noqa | |||
| ) and i < len(input_tokens) - 2 and len( # noqa | |||
| input_tokens[i + 1]) == 1 and input_tokens[ | |||
| i + 1].isalpha( # noqa | |||
| ) and input_tokens[i + 2] == '.': # noqa | |||
| # U . N . -> U.N. | |||
| k = i + 3 | |||
| while k + 2 < len(input_tokens): | |||
| if len(input_tokens[k + 1]) == 1 and input_tokens[ | |||
| k + 1].isalpha() and input_tokens[k + 2] == '.': | |||
| k += 2 | |||
| else: | |||
| break | |||
| output_tokens[-1] += ''.join(input_tokens[i:k]) | |||
| i = k | |||
| elif tok == '-': | |||
| if i < len(input_tokens) - 1 and input_tokens[i + 1] == '-': | |||
| output_tokens.append('--') | |||
| i += 2 | |||
| elif i == len(input_tokens) - 1 or i == 0: | |||
| output_tokens.append('-') | |||
| i += 1 | |||
| elif output_tokens[-1] not in string.punctuation and input_tokens[ | |||
| i + 1][0] not in string.punctuation: | |||
| output_tokens[-1] += '-' | |||
| i += 1 | |||
| flag_prev_dash = True | |||
| else: | |||
| output_tokens.append('-') | |||
| i += 1 | |||
| elif prev_dash and len( | |||
| output_tokens) > 0 and tok[0] not in string.punctuation: | |||
| output_tokens[-1] += tok | |||
| i += 1 | |||
| else: | |||
| output_tokens.append(tok) | |||
| i += 1 | |||
| prev_dash = flag_prev_dash | |||
| return ' '.join(output_tokens) | |||
| def count_tokens(tokens): | |||
| counter = {} | |||
| for t in tokens: | |||
| if t in counter.keys(): | |||
| counter[t] += 1 | |||
| else: | |||
| counter[t] = 1 | |||
| return counter | |||
| def get_f1(text_a, text_b): | |||
| tokens_a = text_a.lower().split() | |||
| tokens_b = text_b.lower().split() | |||
| if len(tokens_a) == 0 or len(tokens_b) == 0: | |||
| return 1 if len(tokens_a) == len(tokens_b) else 0 | |||
| set_a = count_tokens(tokens_a) | |||
| set_b = count_tokens(tokens_b) | |||
| match = 0 | |||
| for token in set_a.keys(): | |||
| if token in set_b.keys(): | |||
| match += min(set_a[token], set_b[token]) | |||
| p = match / len(tokens_a) | |||
| r = match / len(tokens_b) | |||
| return 2.0 * p * r / (p + r + 1e-5) | |||
| def remove_duplicate(l_list, duplicate_rate): | |||
| tk_list = [l.lower().split() for l in l_list] # noqa | |||
| r_list = [] | |||
| history_set = set() | |||
| for i, w_list in enumerate(tk_list): | |||
| w_set = set(w_list) | |||
| if len(w_set & history_set) / len(w_set) <= duplicate_rate: | |||
| r_list.append(l_list[i]) | |||
| history_set |= w_set | |||
| return r_list | |||
| def rouge_metric(predictions, | |||
| labels, | |||
| examples, | |||
| metric='rouge-1', | |||
| duplicate_rate=0.7, | |||
| dataset='cnn_dm'): | |||
| metric_dict = { | |||
| 'rouge-1': 'rouge1', | |||
| 'rouge-2': 'rouge2', | |||
| 'rouge-l': 'rougeLsum' | |||
| } | |||
| refs = [example.meta['ref'] for example in examples] | |||
| ref_list = [] | |||
| for ref in refs: | |||
| ref = ref.strip().split('[SEP]') | |||
| ref = [fix_tokenization(sentence, dataset=dataset) for sentence in ref] | |||
| ref = '\n'.join(ref) | |||
| ref_list.append(ref) | |||
| pred_list = [] | |||
| for prediction in predictions: | |||
| buf = [] | |||
| for sentence in prediction.strip().split('[SEP]'): | |||
| sentence = fix_tokenization(sentence, dataset=dataset) | |||
| if any(get_f1(sentence, s) > 1.0 for s in buf): | |||
| continue | |||
| s_len = len(sentence.split()) | |||
| if s_len <= 4: | |||
| continue | |||
| buf.append(sentence) | |||
| if duplicate_rate and duplicate_rate < 1: | |||
| buf = remove_duplicate(buf, duplicate_rate) | |||
| line = '\n'.join(buf) | |||
| pred_list.append(line) | |||
| if torch.distributed.get_rank() == 0: | |||
| import json | |||
| with open('./results.json', 'w') as output: | |||
| for ref, pred in zip(ref_list, pred_list): | |||
| output.write(json.dumps({'ref': ref, 'pred': pred}) + '\n') | |||
| scorer = rouge_scorer.RougeScorer([metric_dict[metric]], use_stemmer=True) | |||
| scores = [ | |||
| scorer.score(pred, ref) for pred, ref in zip(pred_list, ref_list) | |||
| ] | |||
| scores = [score[metric_dict[metric]].fmeasure for score in scores] | |||
| scores = sum(scores) / len(scores) | |||
| return scores | |||
| def process_batch(batch, args): | |||
| """Process batch and produce inputs for the model.""" | |||
| tokens = batch['text'].long().cuda() | |||
| attention_mask = batch['attention_mask'].long().cuda() | |||
| position_ids = batch['position_id'].long().cuda() | |||
| return tokens, attention_mask, position_ids | |||
| class DecoderEvaluater: | |||
| def __init__(self, args, tokenizer): | |||
| self.tokenizer = tokenizer | |||
| self.start_token = tokenizer.get_command('sop').Id | |||
| self.end_token = tokenizer.get_command('eop').Id | |||
| self.mask_token = tokenizer.get_command( | |||
| 'sMASK').Id if args.task_mask else tokenizer.get_command('MASK').Id | |||
| self.pad_token = tokenizer.get_command('pad').Id | |||
| self.processors = LogitsProcessorList() | |||
| if args.min_tgt_length > 0: | |||
| processor = MinLengthLogitsProcessor(args.min_tgt_length, | |||
| self.end_token) | |||
| self.processors.append(processor) | |||
| if args.no_repeat_ngram_size > 0: | |||
| processor = NoRepeatNGramLogitsProcessor(args.no_repeat_ngram_size) | |||
| self.processors.append(processor) | |||
| def evaluate(self, model, dataloader, example_dict, args): | |||
| """Calculate correct over total answers and return prediction if the | |||
| `output_predictions` is true.""" | |||
| model.eval() | |||
| store = torch.distributed.TCPStore(args.master_ip, | |||
| 18931 + random.randint(0, 10000), | |||
| mpu.get_data_parallel_world_size(), | |||
| torch.distributed.get_rank() == 0, | |||
| datetime.timedelta(seconds=30)) | |||
| print_rank_0('Distributed store created') | |||
| with torch.no_grad(): | |||
| # For all the batches in the dataset. | |||
| for idx, data in enumerate(dataloader): | |||
| tokens, attention_mask, position_ids = process_batch( | |||
| data, args) | |||
| batch_size = tokens.size(0) | |||
| beam_scorer = BeamSearchScorer( | |||
| batch_size=batch_size, | |||
| max_length=args.out_seq_length, | |||
| num_beams=args.num_beams, | |||
| device=tokens.device, | |||
| length_penalty=args.length_penalty, | |||
| do_early_stopping=False, | |||
| ) | |||
| beam_scores = torch.zeros((batch_size, args.num_beams), | |||
| dtype=torch.float, | |||
| device=tokens.device) | |||
| beam_scores[:, 1:] = -1e9 | |||
| beam_scores = beam_scores.view((batch_size * args.num_beams, )) | |||
| # Run the model forward. | |||
| counter = 0 | |||
| while counter < args.tgt_seq_length: | |||
| if counter == 0: | |||
| next_token_logits, *mems = model( | |||
| tokens, | |||
| position_ids, | |||
| attention_mask, | |||
| return_memory=True) | |||
| seq_length = next_token_logits.size(1) | |||
| next_token_logits = next_token_logits[:, -1] | |||
| next_token_logits = next_token_logits.unsqueeze( | |||
| 1).repeat(1, args.num_beams, | |||
| 1).view(batch_size * args.num_beams, -1) | |||
| mems = [ | |||
| mem.unsqueeze(1).repeat( | |||
| 1, args.num_beams, 1, | |||
| 1).view(batch_size * args.num_beams, | |||
| seq_length, -1) for mem in mems | |||
| ] | |||
| position_ids = tokens.new_ones(batch_size, | |||
| args.num_beams, 2, 1) | |||
| for i, text in enumerate(tokens.tolist()): | |||
| mask_pos = text.index(self.mask_token) | |||
| position_ids[i, :, 0] = mask_pos | |||
| position_ids = position_ids.reshape( | |||
| batch_size * args.num_beams, 2, 1) | |||
| tokens = tokens.new_zeros(batch_size * args.num_beams, | |||
| 0) | |||
| attention_mask = tokens.new_zeros( | |||
| [batch_size * args.num_beams]) | |||
| else: | |||
| if not args.no_block_position: | |||
| position_ids[:, 1] = counter + 1 | |||
| last_token = tokens[:, -1:] | |||
| next_token_logits, *mems = model( | |||
| last_token, | |||
| position_ids, | |||
| attention_mask, | |||
| *mems, | |||
| return_memory=True) | |||
| next_token_logits = next_token_logits[:, -1] | |||
| next_token_scores = F.log_softmax( | |||
| next_token_logits, dim=-1) | |||
| next_token_scores = self.processors( | |||
| tokens, next_token_scores) | |||
| next_token_scores = next_token_scores + beam_scores[:, None].expand_as( | |||
| next_token_scores) | |||
| vocab_size = next_token_scores.shape[-1] | |||
| next_token_scores = next_token_scores.view( | |||
| batch_size, args.num_beams * vocab_size) | |||
| probs = F.softmax(next_token_scores, dim=-1) | |||
| if args.select_topk: | |||
| _, next_tokens = torch.topk( | |||
| probs, k=2 * args.num_beams, dim=-1, largest=True) | |||
| else: | |||
| next_tokens = torch.multinomial( | |||
| probs, num_samples=2 * args.num_beams) | |||
| next_token_scores = torch.gather(next_token_scores, -1, | |||
| next_tokens) | |||
| next_token_scores, _indices = torch.sort( | |||
| next_token_scores, descending=True, dim=1) | |||
| next_tokens = torch.gather(next_tokens, -1, _indices) | |||
| next_indices = next_tokens // vocab_size | |||
| next_tokens = next_tokens % vocab_size | |||
| # stateless | |||
| beam_outputs = beam_scorer.process( | |||
| tokens, | |||
| next_token_scores, | |||
| next_tokens, | |||
| next_indices, | |||
| eos_token_id=self.end_token, | |||
| pad_token_id=self.pad_token) | |||
| beam_scores = beam_outputs['next_beam_scores'] | |||
| beam_next_tokens = beam_outputs['next_beam_tokens'] | |||
| beam_idx = beam_outputs['next_beam_indices'] | |||
| beam_next_tokens = beam_next_tokens.unsqueeze(-1) | |||
| tokens = torch.cat([tokens[beam_idx, :], beam_next_tokens], | |||
| dim=-1) | |||
| mems = [mem[beam_idx] for mem in mems] if mems else [] | |||
| if beam_scorer.is_done: | |||
| break | |||
| counter += 1 | |||
| tokens, _ = beam_scorer.finalize( | |||
| tokens, | |||
| beam_scores, | |||
| next_tokens, | |||
| next_indices, | |||
| eos_token_id=self.end_token, | |||
| pad_token_id=self.pad_token) | |||
| predictions = [] | |||
| for text in tokens.tolist(): | |||
| text = [ | |||
| token for token in text | |||
| if token not in [self.end_token, self.pad_token] | |||
| ] | |||
| text = self.tokenizer.DecodeIds(text) | |||
| predictions.append(text) | |||
| uid_list = data['uid'] | |||
| if isinstance(uid_list, torch.Tensor): | |||
| uid_list = uid_list.cpu().numpy().tolist() | |||
| for uid, prediction in zip(uid_list, predictions): | |||
| store.set(uid, prediction) | |||
| if (idx + 1) % args.log_interval == 0: | |||
| print_rank_0(f'Iteration {idx + 1} / {len(dataloader)}') | |||
| model.train() | |||
| torch.distributed.barrier() | |||
| print_rank_0('Evaluation completed') | |||
| predictions, examples = [], [] | |||
| for uid, example in example_dict.items(): | |||
| predictions.append(store.get(uid).decode('utf-8')) | |||
| examples.append(example) | |||
| torch.distributed.barrier() | |||
| return predictions, [], examples | |||
| def blanklm_fix_tokenization(text): | |||
| text = text.replace('` `', '``') | |||
| text = text.replace("\' \'", "\'\'") | |||
| text = text.replace("n \' t", "n\'t") | |||
| text = text.replace("\' s", "\'s") | |||
| text = text.replace("\' m", "\'m") | |||
| text = text.replace("\' re", "\'re") | |||
| text = text.replace('. . .', '...') | |||
| text = text.replace(' . .', ' ..') | |||
| text = text.replace('- -', '--') | |||
| text = text.replace('u . s .', 'u.s.') | |||
| text = text.replace('u . k .', 'u.k.') | |||
| text = text.replace('e . g .', 'e.g.') | |||
| return text | |||
| class BlankLMEvaluater(DecoderEvaluater): | |||
| def evaluate(self, model, dataloader, example_dict, args): | |||
| model.eval() | |||
| store = torch.distributed.TCPStore(args.master_ip, | |||
| 18931 + random.randint(0, 10000), | |||
| mpu.get_data_parallel_world_size(), | |||
| torch.distributed.get_rank() == 0, | |||
| datetime.timedelta(seconds=30)) | |||
| print_rank_0('Distributed store created') | |||
| with torch.no_grad(): | |||
| for idx, data in enumerate(dataloader): | |||
| tokens, attention_mask, position_ids = process_batch( | |||
| data, args) | |||
| src_tokens = tokens | |||
| batch_size = tokens.size(0) | |||
| mask_positions = [] | |||
| current_mask = [] | |||
| for text in tokens.tolist(): | |||
| mask_positions.append([ | |||
| i for i, x in enumerate(text) if x == self.mask_token | |||
| ]) | |||
| current_mask.append(0) | |||
| # print(self.tokenizer.DecodeIds(text)) | |||
| # print(mask_positions[-1]) | |||
| counter = 0 | |||
| done = [False] * batch_size | |||
| while counter < args.tgt_seq_length: | |||
| if counter == 0: | |||
| # print(tokens) | |||
| # print(position_ids) | |||
| next_token_logits, *mems = model( | |||
| tokens, | |||
| position_ids, | |||
| attention_mask, | |||
| return_memory=True) | |||
| next_token_logits = next_token_logits[:, -1] | |||
| position_ids = tokens.new_ones(batch_size, 2, 1) | |||
| for i, text in enumerate(tokens.tolist()): | |||
| mask_pos = mask_positions[i][current_mask[i]] | |||
| position_ids[i, 0] = mask_pos | |||
| tokens = tokens.new_zeros(batch_size, 0) | |||
| attention_mask = tokens.new_zeros(batch_size) | |||
| else: | |||
| position_ids[:, 1] = position_ids[:, 1] + 1 | |||
| last_token = tokens[:, -1:] | |||
| next_token_logits, *mems = model( | |||
| last_token, | |||
| position_ids, | |||
| attention_mask, | |||
| *mems, | |||
| return_memory=True) | |||
| next_token_logits = next_token_logits[:, -1] | |||
| next_token_scores = F.log_softmax( | |||
| next_token_logits, dim=-1) | |||
| next_token_scores = self.processors( | |||
| tokens, next_token_scores) | |||
| next_tokens = next_token_scores.max(dim=-1)[1] | |||
| # print(self.tokenizer.DecodeIds(next_tokens.tolist())) | |||
| for i, next_token in enumerate(next_tokens.tolist()): | |||
| if next_token == self.end_token: | |||
| if current_mask[i] + 1 < len(mask_positions[i]): | |||
| current_mask[i] += 1 | |||
| next_tokens[i] = self.start_token | |||
| position_ids[i, 0] = mask_positions[i][ | |||
| current_mask[i]] | |||
| position_ids[i, 1] = 0 | |||
| else: | |||
| done[i] = True | |||
| if done[i]: | |||
| next_tokens[i] = self.pad_token | |||
| if all(done): | |||
| break | |||
| tokens = torch.cat( | |||
| [tokens, next_tokens.unsqueeze(-1)], dim=-1) | |||
| counter += 1 | |||
| predictions = [] | |||
| for i, text in enumerate(tokens.tolist()): | |||
| text = [ | |||
| token for token in text | |||
| if token not in [self.end_token, self.pad_token] | |||
| ] | |||
| blanks = [[]] | |||
| for token in text: | |||
| if token == self.start_token: | |||
| blanks.append([]) | |||
| else: | |||
| blanks[-1].append(token) | |||
| output_tokens = [] | |||
| current_blank = 0 | |||
| for token in src_tokens[i].tolist(): | |||
| if token == self.mask_token: | |||
| if current_blank < len(blanks): | |||
| output_tokens += blanks[current_blank] | |||
| current_blank += 1 | |||
| else: | |||
| if token not in [self.pad_token]: | |||
| output_tokens.append(token) | |||
| text = self.tokenizer.DecodeIds(output_tokens[:-1]) | |||
| text = blanklm_fix_tokenization(text) | |||
| predictions.append(text) | |||
| # print(text) | |||
| uid_list = data['uid'] | |||
| if isinstance(uid_list, torch.Tensor): | |||
| uid_list = uid_list.cpu().numpy().tolist() | |||
| for uid, prediction in zip(uid_list, predictions): | |||
| store.set(uid, prediction) | |||
| if (idx + 1) % args.log_interval == 0: | |||
| print_rank_0(f'Iteration {idx + 1} / {len(dataloader)}') | |||
| model.train() | |||
| torch.distributed.barrier() | |||
| print_rank_0('Evaluation completed') | |||
| predictions, examples = [], [] | |||
| for uid, example in example_dict.items(): | |||
| predictions.append(store.get(uid).decode('utf-8')) | |||
| examples.append(example) | |||
| torch.distributed.barrier() | |||
| return predictions, [], examples | |||
| @@ -0,0 +1,151 @@ | |||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Race.""" | |||
| import functools | |||
| from collections import OrderedDict | |||
| import mpu | |||
| import torch | |||
| from finetune_glm import finetune | |||
| from pretrain_glm import get_batch | |||
| from tasks.eval_utils import accuracy_func_provider | |||
| from tasks.seq2seq.dataset import (BlankLMDataset, ExtractionDataset, | |||
| Seq2SeqDataset) | |||
| from tasks.seq2seq.evaluate import (BlankLMEvaluater, DecoderEvaluater, | |||
| rouge_metric) | |||
| global_tokenizer = None | |||
| def seq2seq_forward_step(data, model, args, timers, mems): | |||
| """Forward step.""" | |||
| # Get the batch. | |||
| if timers is not None: | |||
| timers('batch generator').start() | |||
| tokens, labels, loss_mask, attention_mask, position_ids = get_batch( | |||
| data, args) | |||
| if timers is not None: | |||
| timers('batch generator').stop() | |||
| # Forward model. | |||
| logits, *mems = model(tokens, position_ids, attention_mask, *mems) | |||
| # logits, loss_mask = logits[:, args.src_seq_length:], loss_mask[:, args.src_seq_length:] | |||
| # target_ids = target_ids[:, args.src_seq_length:] | |||
| losses = mpu.vocab_parallel_cross_entropy(logits.contiguous().float(), | |||
| labels) | |||
| if args.label_smoothing > 0.0: | |||
| epsilon = args.label_smoothing | |||
| smooth_loss = -torch.nn.functional.log_softmax( | |||
| logits, dim=-1).mean(dim=-1) | |||
| losses = (1 - epsilon) * losses + epsilon * smooth_loss | |||
| loss_mask = loss_mask.reshape(-1) | |||
| # The loss is not normalized for fair comparison | |||
| loss = torch.sum(losses.reshape(-1) * loss_mask) / loss_mask.sum() | |||
| return loss, mems, 'bert' | |||
| def train_valid_datasets_provider(args, tokenizer): | |||
| """Provide train and validation datasets.""" | |||
| if args.task.lower() == 'blank': | |||
| train_dataset = BlankLMDataset( | |||
| args, split='train', tokenizer=tokenizer) | |||
| valid_dataset = None | |||
| elif args.task.lower() == 'extraction': | |||
| train_dataset = ExtractionDataset( | |||
| args, split='train', tokenizer=tokenizer) | |||
| valid_dataset = None | |||
| else: | |||
| train_dataset = Seq2SeqDataset( | |||
| args, split='train', tokenizer=tokenizer) | |||
| valid_dataset = None | |||
| global global_tokenizer | |||
| global_tokenizer = tokenizer | |||
| return train_dataset, valid_dataset | |||
| def metrics_func_provider(args, tokenizer, is_test): | |||
| """Provide metrics callback function.""" | |||
| def single_dataset_provider(split): | |||
| if args.task.lower() == 'blank': | |||
| return BlankLMDataset(args, split=split, tokenizer=tokenizer) | |||
| elif args.task.lower() == 'extraction': | |||
| return ExtractionDataset(args, split=split, tokenizer=tokenizer) | |||
| else: | |||
| return Seq2SeqDataset(args, split=split, tokenizer=tokenizer) | |||
| if args.task.lower() in ['blank', 'extraction']: | |||
| evaluater = BlankLMEvaluater(args, tokenizer) | |||
| eval_func = evaluater.evaluate | |||
| metric_dict = {} | |||
| else: | |||
| evaluater = DecoderEvaluater(args, tokenizer) | |||
| eval_func = evaluater.evaluate | |||
| if args.tokenizer_type == 'BertWordPieceTokenizer': | |||
| dataset = 'cnn_dm' | |||
| elif args.task.lower() == 'gigaword': | |||
| dataset = 'gigaword' | |||
| else: | |||
| dataset = 'cnn_dm_org' | |||
| metric_dict = OrderedDict({ | |||
| 'rouge-1': | |||
| functools.partial(rouge_metric, metric='rouge-1', dataset=dataset), | |||
| 'rouge-2': | |||
| functools.partial(rouge_metric, metric='rouge-2', dataset=dataset), | |||
| 'rouge-l': | |||
| functools.partial(rouge_metric, metric='rouge-l', dataset=dataset) | |||
| }) | |||
| def output_func(predictions, examples, output_file): | |||
| with open(output_file + '.hyps', 'w', encoding='utf-8') as output: | |||
| for prediction in predictions: | |||
| output.write(prediction) | |||
| output.write('\n') | |||
| with open(output_file + '.refs', 'w', encoding='utf-8') as output: | |||
| for example in examples: | |||
| output.write(example.meta['ref']) | |||
| output.write('\n') | |||
| if args.task.lower() == 'squad_generation': | |||
| with open( | |||
| output_file + '.source', 'w', encoding='utf-8') as output: | |||
| for example in examples: | |||
| output.write( | |||
| example.text_a.replace('\n', ' ') + ' Answer: ' | |||
| + example.meta['answer']) | |||
| output.write('\n') | |||
| return accuracy_func_provider( | |||
| single_dataset_provider, | |||
| metric_dict, | |||
| args, | |||
| is_test=is_test, | |||
| eval_func=eval_func, | |||
| output_func=output_func, | |||
| only_rank0=False) | |||
| def main(args): | |||
| if args.src_seq_length > args.max_position_embeddings: | |||
| args.max_position_embeddings = args.src_seq_length | |||
| if args.task.lower() in [ | |||
| 'cnn_dm', 'cnn_dm_original', 'gigaword', 'blank', | |||
| 'squad_generation', 'xsum', 'extraction' | |||
| ]: | |||
| finetune( | |||
| args, | |||
| train_valid_datasets_provider, {}, | |||
| end_of_epoch_callback_provider=metrics_func_provider, | |||
| forward_step=seq2seq_forward_step) | |||
| else: | |||
| raise NotImplementedError(args.task) | |||
| @@ -0,0 +1,137 @@ | |||
| # Use GLM for your NLU tasks | |||
| To use GLM for your own NLU tasks, you should implement a subclass of `DataProcessor` in [tasks/superglue/dataset.py](dataset.py) and a subclass of `PVP` in [tasks/superglue/pvp.py](pvp.py). You should also specify the We will take the RTE and ReCoRD tasks in SuperGLUE as an example. | |||
| ## 1. Design your patterns | |||
| RTE is an NLI task in which the model is required to predict text entailment between a premise and a hypothesis. The label can be `entailment` or `not_entailment` One sample from the training set is | |||
| ``` | |||
| premise: No Weapons of Mass Destruction Found in Iraq Yet. | |||
| hypothesis: Weapons of Mass Destruction Found in Iraq. | |||
| label: not_entailment | |||
| ``` | |||
| We design the pattern as | |||
| ``` | |||
| "`hypothesis`"?, [MASK], "`premise`" | |||
| ``` | |||
| GLM predicts "Yes" for `entailment` and "No" for `not_entailment`. "Yes" and "No" are called verbalizers for `entailment` and `not_entailment`. | |||
| ReCoRD is a multi-choice QA task. Each example consists of a news article and a Cloze-style question about the article in which one entity is masked out. The system must predict the masked out entity from a list of possible entities in the provided passage. We directly adopt the cloze-style question as our pattern and use GLM to predict the masked entity. | |||
| ## 2. Implement subclass of `DataProcessor` | |||
| A subclass of `DataProcessor` should implement `get_train_examples`, `get_dev_examples` and `get_test_examples`, which return the examples of the train, dev, and test sets. The returned value is a list of `InputExample`. It should also implement `get_labels` to return the list of possible labels. Hete we take the `RTEProcessor` as an example: | |||
| ```python | |||
| class RteProcessor(DataProcessor): | |||
| """Processor for the RTE data set.""" | |||
| def get_train_examples(self, data_dir): | |||
| return self._create_examples(os.path.join(data_dir, "train.jsonl"), "train") | |||
| def get_dev_examples(self, data_dir, for_train=False): | |||
| return self._create_examples(os.path.join(data_dir, "val.jsonl"), "dev") | |||
| def get_test_examples(self, data_dir): | |||
| return self._create_examples(os.path.join(data_dir, "test.jsonl"), "test") | |||
| def get_unlabeled_examples(self, data_dir): | |||
| return self._create_examples(os.path.join(data_dir, "unlabeled.jsonl"), "unlabeled") | |||
| def get_labels(self): | |||
| return ["entailment", "not_entailment"] | |||
| def _create_examples(self, path: str, set_type: str, hypothesis_name: str = "hypothesis", | |||
| premise_name: str = "premise") -> List[InputExample]: | |||
| examples = [] | |||
| with open(path, encoding='utf8') as f: | |||
| for line_idx, line in enumerate(f): | |||
| example_json = json.loads(line) | |||
| idx = example_json['idx'] | |||
| if isinstance(idx, str): | |||
| try: | |||
| idx = int(idx) | |||
| except ValueError: | |||
| idx = line_idx | |||
| label = example_json.get('label') | |||
| guid = "%s-%s" % (set_type, idx) | |||
| text_a = example_json[premise_name] | |||
| text_b = example_json[hypothesis_name] | |||
| example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label, idx=idx) | |||
| examples.append(example) | |||
| return examples | |||
| ``` | |||
| After that, you should add the implemented class to ``PROCESSORS`` at the end of [tasks/superglue/dataset.py](dataset.py): | |||
| ```python | |||
| PROCESSORS = { | |||
| ... | |||
| "rte": RteProcessor | |||
| } | |||
| ``` | |||
| ## 3. Implement subclass of `PVP` | |||
| To implement a subclass of `PVP`, you should first decide your verbalizers is single-token or multi-token. The verbalizers in RTE, "Yes" and "No" are single-token. Instead, the verbalizers in ReCoRD are multi-token, as one entity can be tokenized into multiple tokens with WordPiece or BPE tokenizer. | |||
| For single-token task, you should set `is_multi_token=False` in the class definition. You should implement `get_parts` to return the inputs to GLM given an example and `verbalize` to return the verbalizer given a label. Take `RTEPVP` as an example: | |||
| ```python | |||
| class RtePVP(PVP): | |||
| is_multi_token = False | |||
| VERBALIZER = { | |||
| "not_entailment": [" No"], | |||
| "entailment": [" Yes"] | |||
| } | |||
| @property | |||
| def spell_length(self): | |||
| return self.pattern_id | |||
| def get_parts(self, example: InputExample) -> FilledPattern: | |||
| # switch text_a and text_b to get the correct order | |||
| text_a = example.text_a | |||
| text_b = example.text_b.rstrip(string.punctuation) | |||
| return ['"', self.shortenable(text_b), '" ?'], [[self.mask], ', "', self.shortenable(text_a), '"'] | |||
| def verbalize(self, label) -> List[str]: | |||
| return RtePVP.VERBALIZER[label] | |||
| ``` | |||
| We use `PvP.shortenable` to mark the segments that can be truncated when exceeding the maximum sequence length. | |||
| For multi-token task, you should set `is_multi_token=True` in the class definition. You should implement `get_parts` to return the inputs to GLM given an example and `get_answers` to return the candidates. Take `ReCoRDPVP` as an example: | |||
| ```python | |||
| class RecordPVP(PVP): | |||
| is_multi_token = True | |||
| def get_answers(self, example: InputExample): | |||
| choices = example.meta['candidates'] | |||
| choices = [" " + choice for choice in choices] | |||
| return choices | |||
| def get_parts(self, example: InputExample) -> FilledPattern: | |||
| premise = self.shortenable(example.text_a) | |||
| assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token' | |||
| question_a, question_b = example.text_b.split('@placeholder') | |||
| return [premise, " " + question_a.rstrip(), [self.mask], question_b], [] | |||
| ``` | |||
| After that, you should implement the class to `PVPS` at the end of [tasks/superglue/pvp.py](pvp.py): | |||
| ```python | |||
| PVPS = { | |||
| ... | |||
| 'rte': RtePVP, | |||
| 'record': RecordPVP | |||
| } | |||
| ``` | |||
| ## 4. Run the experiment | |||
| To run the experiment for your new task, you should create a config file like [config_tasks/task_rte.sh](/config_tasks/task_rte.sh). You should also specify the evaluation metrics for the task in `DEFAULT_METRICS` of [tasks/superglue/finetune.py](finetune.py): | |||
| ```python | |||
| DEFAULT_METRICS = { | |||
| ... | |||
| "record": [("EM", qa_exact_match), ("F1", qa_f1)], | |||
| "rte": [("accuracy", accuracy_metric)] | |||
| } | |||
| ``` | |||
| Then you can run the experiment with [finetune_superglue.sh](/scripts/finetune_superglue.sh): | |||
| ```shell | |||
| bash scripts/finetune_superglue.sh \ | |||
| config_tasks/model_blocklm_large.sh \ | |||
| config_tasks/task_rte.sh | |||
| ``` | |||
| @@ -0,0 +1,101 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| """ | |||
| Official evaluation script for ReCoRD v1.0. | |||
| (Some functions are adopted from the SQuAD evaluation script.) | |||
| """ | |||
| from __future__ import print_function | |||
| import functools | |||
| import re | |||
| import string | |||
| from collections import Counter, defaultdict | |||
| from typing import List | |||
| from tasks.data_utils import InputExample | |||
| def normalize_answer(s): | |||
| """Lower text and remove punctuation, articles and extra whitespace.""" | |||
| def remove_articles(text): | |||
| return re.sub(r'\b(a|an|the)\b', ' ', text) | |||
| def white_space_fix(text): | |||
| return ' '.join(text.split()) | |||
| def remove_punc(text): | |||
| exclude = set(string.punctuation) | |||
| return ''.join(ch for ch in text if ch not in exclude) | |||
| def lower(text): | |||
| return text.lower() | |||
| return white_space_fix(remove_articles(remove_punc(lower(s)))) | |||
| def f1_score(prediction, ground_truth): | |||
| prediction_tokens = normalize_answer(prediction).split() | |||
| ground_truth_tokens = normalize_answer(ground_truth).split() | |||
| common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |||
| num_same = sum(common.values()) | |||
| if num_same == 0: | |||
| return 0 | |||
| precision = 1.0 * num_same / len(prediction_tokens) | |||
| recall = 1.0 * num_same / len(ground_truth_tokens) | |||
| f1 = (2 * precision * recall) / (precision + recall) | |||
| return f1 | |||
| def exact_match_score(prediction, ground_truth): | |||
| return normalize_answer(prediction) == normalize_answer(ground_truth) | |||
| def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): | |||
| if not ground_truths: | |||
| return 0.0 | |||
| scores_for_ground_truths = [] | |||
| for ground_truth in ground_truths: | |||
| score = metric_fn(prediction, ground_truth) | |||
| scores_for_ground_truths.append(score) | |||
| return max(scores_for_ground_truths) | |||
| def qa_evaluate(predictions, labels, examples: List[InputExample], metric): | |||
| assert len(examples) == len(predictions) | |||
| score = 0.0 | |||
| for example, prediction in zip(examples, predictions): | |||
| ground_truths = example.meta['answers'] | |||
| prediction = example.meta['candidates'][prediction] | |||
| if ground_truths: | |||
| score += metric_max_over_ground_truths(metric, prediction, | |||
| ground_truths) | |||
| score = 100.0 * score / len(predictions) | |||
| return score | |||
| def multirc_em(predictions, labels, examples: List[InputExample]): | |||
| """Compute the exact match (EM) for a sequence of predictions and actual labels""" | |||
| question_ids = [example.meta['question_idx'] for example in examples] | |||
| unique_questions = set(question_ids) | |||
| q_actuals = list(zip(question_ids, labels)) | |||
| q_predictions = list(zip(question_ids, predictions)) | |||
| actuals_per_question = defaultdict(list) | |||
| predictions_per_question = defaultdict(list) | |||
| for qid, val in q_actuals: | |||
| actuals_per_question[qid].append(val) | |||
| for qid, val in q_predictions: | |||
| predictions_per_question[qid].append(val) | |||
| em = 0 | |||
| for qid in unique_questions: | |||
| if actuals_per_question[qid] == predictions_per_question[qid]: | |||
| em += 1 | |||
| em /= len(unique_questions) | |||
| return em | |||
| qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score) | |||
| qa_f1 = functools.partial(qa_evaluate, metric=f1_score) | |||
| @@ -0,0 +1,138 @@ | |||
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Race.""" | |||
| from collections import OrderedDict | |||
| from finetune_glm import finetune | |||
| from tasks.eval_utils import (accuracy_func_provider, accuracy_metric, | |||
| f1_macro_metric, f1_metric) | |||
| from tasks.superglue.dataset import (CLASSIFICATION_DATASETS, | |||
| MULTI_CHOICE_DATASETS, PROCESSORS, | |||
| SuperGlueDataset, get_output_func) | |||
| from tasks.superglue.evaluate import multirc_em, qa_exact_match, qa_f1 | |||
| from tasks.superglue.pvp import PVPS | |||
| DEFAULT_METRICS = { | |||
| 'record': [('EM', qa_exact_match), ('F1', qa_f1)], | |||
| 'copa': [('accuracy', accuracy_metric)], | |||
| 'rte': [('accuracy', accuracy_metric)], | |||
| 'boolq': [('accuracy', accuracy_metric)], | |||
| 'wic': [('accuracy', accuracy_metric)], | |||
| 'wsc': [('accuracy', accuracy_metric)], | |||
| 'cb': [('accuracy', accuracy_metric), ('f1-macro', f1_macro_metric)], | |||
| 'multirc': [('f1a', f1_metric), ('em', multirc_em), | |||
| ('acc', accuracy_metric)], | |||
| 'mnli': [('accuracy', accuracy_metric)], | |||
| 'sst2': [('accuracy', accuracy_metric)], | |||
| 'qnli': [('accuracy', accuracy_metric)], | |||
| 'qqp': [('accuracy', accuracy_metric)], | |||
| 'mrpc': [('accuracy', accuracy_metric)], | |||
| 'cola': [('accuracy', accuracy_metric)], | |||
| 'squad': [('accuracy', accuracy_metric)], | |||
| } | |||
| def train_valid_datasets_provider(args, tokenizer, pattern_text=False): | |||
| """Provide train and validation datasets.""" | |||
| task_name = args.task.lower() | |||
| data_dir = args.data_dir | |||
| train_dataset = SuperGlueDataset( | |||
| args, | |||
| task_name, | |||
| data_dir, | |||
| args.seq_length, | |||
| 'train', | |||
| tokenizer, | |||
| pattern_text=pattern_text) | |||
| valid_dataset = SuperGlueDataset( | |||
| args, | |||
| task_name, | |||
| data_dir, | |||
| args.seq_length, | |||
| 'dev', | |||
| tokenizer, | |||
| for_train=True, | |||
| pattern_text=pattern_text) | |||
| return train_dataset, valid_dataset | |||
| def metrics_func_provider(args, tokenizer, is_test): | |||
| """Privde metrics callback function.""" | |||
| def single_dataset_provider(split): | |||
| return SuperGlueDataset(args, args.task.lower(), args.data_dir, | |||
| args.seq_length, split, tokenizer) | |||
| output_func = get_output_func(args.task.lower(), args) | |||
| eval_func = None | |||
| if args.task.lower() in ['wsc', 'squad' | |||
| ] and args.cloze_eval and not args.wsc_negative: | |||
| from tasks.language_model.finetune import classify_evaluate | |||
| eval_func = classify_evaluate | |||
| metric_dict = OrderedDict(DEFAULT_METRICS[args.task.lower()]) | |||
| return accuracy_func_provider( | |||
| single_dataset_provider, | |||
| metric_dict, | |||
| args, | |||
| is_test=is_test, | |||
| eval_func=eval_func, | |||
| output_func=output_func, | |||
| only_rank0=False, | |||
| tokenizer=tokenizer) | |||
| def main(args): | |||
| model_kwargs = {} | |||
| processor = PROCESSORS[args.task.lower()](args) | |||
| pvp = PVPS[args.task.lower()]( | |||
| args, | |||
| None, | |||
| processor.get_labels(), | |||
| args.seq_length, | |||
| pattern_id=args.pattern_id, | |||
| is_multi_token=args.multi_token, | |||
| num_prompt_tokens=args.num_prompt_tokens) | |||
| if args.continuous_prompt: | |||
| model_kwargs['spell_length'] = pvp.spell_length | |||
| if args.task.lower() in ['wsc', 'squad' | |||
| ] and args.cloze_eval and not args.wsc_negative: | |||
| from tasks.language_model.finetune import lm_forward_step | |||
| finetune( | |||
| args, | |||
| train_valid_datasets_provider, | |||
| model_kwargs, | |||
| end_of_epoch_callback_provider=metrics_func_provider, | |||
| forward_step=lm_forward_step) | |||
| else: | |||
| if args.cloze_eval: | |||
| multi_token = pvp.is_multi_token | |||
| else: | |||
| multi_token = args.task.lower() in MULTI_CHOICE_DATASETS | |||
| args.multi_token = multi_token | |||
| if not multi_token: | |||
| model_kwargs[ | |||
| 'model_type'] = 'multiple_choice' if args.cloze_eval else 'classification' | |||
| model_kwargs['multi_token'] = False | |||
| model_kwargs['num_labels'] = len(processor.get_labels()) | |||
| else: | |||
| model_kwargs['model_type'] = 'multiple_choice' | |||
| model_kwargs['multi_token'] = True | |||
| model_kwargs['num_labels'] = 1 | |||
| finetune( | |||
| args, | |||
| train_valid_datasets_provider, | |||
| model_kwargs, | |||
| end_of_epoch_callback_provider=metrics_func_provider) | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import random | |||
| from argparse import Namespace | |||
| import numpy as np | |||
| from blocklm_utils import ConstructBlockStrategy | |||
| # rng = random.Random() | |||
| # span_lengths = [2, 3, 4, 2, 3, 4] | |||
| # length = 100 | |||
| # | |||
| # counts = np.array([0] * length) | |||
| # for _ in range(10000): | |||
| # rng.shuffle(span_lengths) | |||
| # spans = ConstructBlockStrategy.sample_spans(span_lengths, length, rng) | |||
| # for start, end in spans: | |||
| # counts[start: end] += 1 | |||
| # print(counts) | |||
| def main(): | |||
| args = Namespace() | |||
| args.seq_length = 10 | |||
| args.eod_token = 0 | |||
| strategy = ConstructBlockStrategy( | |||
| args, None, bert_ratio=0.4, max_seq_length=128) | |||
| counts = np.array([0] * 10) | |||
| for _ in range(10000): | |||
| spans = strategy.sample_span_in_document( | |||
| np.array([1, 2, 3, 0, 4, 5, 6, 7, 9, 0], dtype=np.long), [1, 1], | |||
| random.Random()) | |||
| for start, end in spans: | |||
| counts[start:end] += 1 | |||
| print(counts) | |||
| @@ -0,0 +1,27 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| from learning_rates import AnnealingLR | |||
| from torch.nn.modules import Linear | |||
| from torch.optim import Adam | |||
| def main(): | |||
| model = Linear(10, 10) | |||
| optimizer = Adam(model.parameters()) | |||
| lr_scheduler = AnnealingLR( | |||
| optimizer, | |||
| start_lr=0.00015, | |||
| warmup_iter=3000, | |||
| num_iters=300000, | |||
| decay_style='cosine', | |||
| decay_ratio=0.1) | |||
| steps = np.arange(0, 400000, 10, dtype=np.long) | |||
| rates = [] | |||
| for step in steps: | |||
| lr_scheduler.num_iters = step | |||
| rates.append(lr_scheduler.get_lr()) | |||
| print(rates) | |||
| plt.plot(steps, rates) | |||
| plt.savefig('lr.pdf', format='pdf') | |||
| @@ -0,0 +1,472 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import deepspeed | |||
| import torch | |||
| from apex.optimizers import FusedAdam as Adam | |||
| from torch import distributed as dist | |||
| from . import mpu | |||
| from .fp16 import DynamicLossScaler, FP16_Module, FP16_Optimizer | |||
| from .model import DistributedDataParallel as LocalDDP | |||
| from .model import (GLMForMultiTokenCloze, GLMForMultiTokenClozeFast, | |||
| GLMForSequenceClassification, GLMForSingleTokenCloze, | |||
| GLMModel) | |||
| from .model import PyTorchDistributedDataParallel as TorchDDP | |||
| from .model import glm_get_params_for_weight_decay_optimization | |||
| from .utils import get_checkpoint_iteration, get_checkpoint_name, print_rank_0 | |||
| def load_pretrained(model, checkpoint_path, args, task_tokens=None): | |||
| load_dir, tag, release, success = get_checkpoint_iteration(checkpoint_path) | |||
| checkpoint_name = get_checkpoint_name(load_dir, tag, release) | |||
| if mpu.get_data_parallel_rank() == 0: | |||
| print('global rank {} is loading pretrained model {}'.format( | |||
| torch.distributed.get_rank(), checkpoint_name)) | |||
| # Load the checkpoint. | |||
| sd = torch.load(checkpoint_name, map_location='cpu') | |||
| if args.deepspeed: | |||
| model = model.module | |||
| if isinstance(model, TorchDDP): | |||
| model = model.module | |||
| if isinstance(model, FP16_Module): | |||
| model = model.module | |||
| if hasattr(model, 'model'): | |||
| model = model.model | |||
| # Model. | |||
| def extend_embedding_weights(state_weights, model_weights): | |||
| original_length = state_weights.shape[0] | |||
| assert original_length <= args.max_position_embeddings + 1 | |||
| new_weights = model_weights.clone() | |||
| new_weights[:original_length] = state_weights | |||
| return new_weights | |||
| if args.block_lm: | |||
| if 'transformer.block_position_embeddings.weight' in sd['module']: | |||
| position_weights = sd['module'][ | |||
| 'transformer.position_embeddings.weight'] | |||
| if args.max_position_embeddings + 1 > position_weights.shape[0]: | |||
| sd['module'][ | |||
| 'transformer.position_embeddings.weight'] = extend_embedding_weights( | |||
| position_weights, | |||
| model.state_dict() | |||
| ['transformer.position_embeddings.weight'].data) | |||
| print_rank_0( | |||
| f'Extend position embedding to {args.max_position_embeddings + 1}' | |||
| ) | |||
| if 'transformer.block_position_embeddings.weight' in sd['module']: | |||
| block_position_weights = sd['module'][ | |||
| 'transformer.block_position_embeddings.weight'] | |||
| if args.max_position_embeddings + 1 > block_position_weights.shape[ | |||
| 0]: | |||
| sd['module'][ | |||
| 'transformer.block_position_embeddings.weight'] = extend_embedding_weights( | |||
| block_position_weights, | |||
| model.state_dict() | |||
| ['transformer.block_position_embeddings.weight'].data) | |||
| print_rank_0( | |||
| f'Extend block position embedding to {args.max_position_embeddings + 1}' | |||
| ) | |||
| for key in list(model.state_dict().keys()): | |||
| print(key) | |||
| model.state_dict()[key.replace( | |||
| 'mixins.block_position_embedding.block_position_embeddings.weight', | |||
| 'transformer.block_position_embeddings.weight').replace( | |||
| 'transformer.word_embeddings.weight', | |||
| 'word_embeddings.weight')] = model.state_dict().pop(key) | |||
| missing_keys, unexpected_keys = model.load_state_dict( | |||
| sd['module'], strict=False) | |||
| if missing_keys or unexpected_keys: | |||
| print_rank_0( | |||
| f'Missing keys {missing_keys}, unexpected keys {unexpected_keys}') | |||
| if args.continuous_prompt and args.prompt_init: | |||
| model.prompt_spell.init_embedding(model.word_embeddings.weight.data, | |||
| task_tokens) | |||
| def get_model(args, | |||
| model_type=None, | |||
| multi_token=True, | |||
| num_labels=None, | |||
| spell_length=None): | |||
| """Build the model.""" | |||
| print_rank_0('building GPT2 model ...') | |||
| if args.pretrained_bert: | |||
| if model_type == 'multiple_choice': | |||
| model = BertForMultipleChoice.from_pretrained( | |||
| args.tokenizer_model_type, | |||
| cache_dir=args.cache_dir, | |||
| fp32_layernorm=args.fp32_layernorm, | |||
| fp32_embedding=args.fp32_embedding, | |||
| layernorm_epsilon=args.layernorm_epsilon) | |||
| elif model_type == 'classification': | |||
| model = BertForSequenceClassification.from_pretrained( | |||
| args.tokenizer_model_type, | |||
| cache_dir=args.cache_dir, | |||
| fp32_layernorm=args.fp32_layernorm, | |||
| fp32_embedding=args.fp32_embedding, | |||
| layernorm_epsilon=args.layernorm_epsilon, | |||
| num_labels=num_labels) | |||
| else: | |||
| raise NotImplementedError | |||
| else: | |||
| output_predict, paralle_output = True, True | |||
| if (model_type == 'multiple_choice' | |||
| or model_type == 'classification') and not args.cloze_eval: | |||
| output_predict = False | |||
| if model_type is not None: | |||
| paralle_output = False | |||
| if spell_length is not None: | |||
| print_rank_0(f'Continuous spell length {spell_length}') | |||
| model = GLMModel( | |||
| num_layers=args.num_layers, | |||
| vocab_size=args.vocab_size, | |||
| hidden_size=args.hidden_size, | |||
| num_attention_heads=args.num_attention_heads, | |||
| embedding_dropout_prob=args.hidden_dropout, | |||
| attention_dropout_prob=args.attention_dropout, | |||
| output_dropout_prob=args.hidden_dropout, | |||
| max_sequence_length=args.max_position_embeddings, | |||
| max_memory_length=args.mem_length, | |||
| checkpoint_activations=args.checkpoint_activations, | |||
| checkpoint_num_layers=args.checkpoint_num_layers, | |||
| parallel_output=paralle_output, | |||
| relative_encoding=args.transformer_xl, | |||
| block_position_encoding=args.block_lm and not args.masked_lm, | |||
| output_predict=output_predict, | |||
| spell_length=spell_length, | |||
| spell_func=args.prompt_func, | |||
| attention_scale=args.attention_scale) | |||
| if args.freeze_transformer: | |||
| model.freeze_transformer( | |||
| tune_prefix_layers=args.tune_prefix_layers) | |||
| if model_type is not None: | |||
| if model_type == 'multiple_choice': | |||
| if args.cloze_eval: | |||
| if multi_token: | |||
| if args.fast_decode: | |||
| model = GLMForMultiTokenClozeFast( | |||
| model, length_penalty=args.length_penalty) | |||
| else: | |||
| model = GLMForMultiTokenCloze( | |||
| model, length_penalty=args.length_penalty) | |||
| else: | |||
| model = GLMForSingleTokenCloze( | |||
| model, take_softmax=args.adapet) | |||
| else: | |||
| model = GLMForSequenceClassification( | |||
| model, | |||
| args.hidden_size, | |||
| args.output_dropout, | |||
| args.pool_token, | |||
| num_class=num_labels) | |||
| elif model_type == 'classification': | |||
| model = GLMForSequenceClassification( | |||
| model, | |||
| args.hidden_size, | |||
| args.output_dropout, | |||
| args.pool_token, | |||
| num_class=num_labels) | |||
| elif model_type == 'generation': | |||
| pass | |||
| else: | |||
| raise NotImplementedError(model_type) | |||
| if mpu.get_data_parallel_rank() == 0: | |||
| print( | |||
| ' > number of parameters on model parallel rank {}: {}'.format( | |||
| mpu.get_model_parallel_rank(), | |||
| sum([p.nelement() for p in model.parameters()])), | |||
| flush=True) | |||
| # To prevent OOM for model sizes that cannot fit in GPU memory in full precision | |||
| if args.fp16: | |||
| model.half() | |||
| # GPU allocation. | |||
| model.cuda(torch.cuda.current_device()) | |||
| # Fp16 conversion. | |||
| if args.fp16: | |||
| model = FP16_Module(model) | |||
| # Wrap model for distributed training. | |||
| if not args.deepspeed and (args.train_iters or args.epochs): | |||
| if args.DDP_impl == 'torch': | |||
| i = torch.cuda.current_device() | |||
| model = TorchDDP( | |||
| model, | |||
| device_ids=[i], | |||
| output_device=i, | |||
| process_group=mpu.get_data_parallel_group()) | |||
| elif args.DDP_impl == 'local': | |||
| model = LocalDDP(model) | |||
| else: | |||
| print_rank_0('Skip DDP model') | |||
| return model | |||
| def get_optimizer_param_groups(model): | |||
| # Build parameter groups (weight decay and non-decay). | |||
| while isinstance(model, (LocalDDP, TorchDDP, FP16_Module)): | |||
| model = model.module | |||
| param_groups = glm_get_params_for_weight_decay_optimization(model) | |||
| # Add model parallel attribute if it is not set. | |||
| for param_group in param_groups: | |||
| # print('## param_group', len(param_group['params'])) | |||
| for param in param_group['params']: | |||
| if not hasattr(param, 'model_parallel'): | |||
| param.model_parallel = False | |||
| return param_groups | |||
| def get_optimizer(param_groups, args): | |||
| """Set up the optimizer.""" | |||
| if args.cpu_optimizer: | |||
| # Apex FusedAdam uses decoupled weight decay so use the same here | |||
| if args.cpu_torch_adam: | |||
| cpu_adam_optimizer = torch.optim.AdamW | |||
| else: | |||
| from deepspeed.ops.adam import DeepSpeedCPUAdam | |||
| cpu_adam_optimizer = DeepSpeedCPUAdam | |||
| optimizer = cpu_adam_optimizer( | |||
| param_groups, lr=args.lr, weight_decay=args.weight_decay) | |||
| else: | |||
| # Use FusedAdam. | |||
| if args.optimizer == 'adam': | |||
| optimizer = Adam( | |||
| param_groups, | |||
| lr=args.lr, | |||
| weight_decay=args.weight_decay, | |||
| betas=(args.adam_beta1, args.adam_beta2), | |||
| eps=args.adam_eps) | |||
| elif args.optimizer == 'adafactor': | |||
| from transformers import Adafactor | |||
| optimizer = Adafactor( | |||
| param_groups, | |||
| lr=args.lr, | |||
| relative_step=False, | |||
| warmup_init=False) | |||
| else: | |||
| raise NotImplementedError | |||
| print(f'Optimizer = {optimizer.__class__.__name__}') | |||
| if hasattr(args, 'deepspeed') and args.deepspeed: | |||
| raise NotImplementedError | |||
| # fp16 wrapper is not required for DeepSpeed. | |||
| # return optimizer | |||
| # Wrap into fp16 optimizer. | |||
| if args.fp16: | |||
| optimizer = FP16_Optimizer( | |||
| optimizer, | |||
| static_loss_scale=args.loss_scale, | |||
| dynamic_loss_scale=args.dynamic_loss_scale, | |||
| dynamic_loss_args={ | |||
| 'scale_window': args.loss_scale_window, | |||
| 'min_scale': args.min_scale, | |||
| 'delayed_shift': args.hysteresis | |||
| }) | |||
| return optimizer | |||
| def get_learning_rate_scheduler(optimizer, args): | |||
| """Build the learning rate scheduler.""" | |||
| # Add linear learning rate scheduler. | |||
| if args.lr_decay_iters is not None: | |||
| num_iters = args.lr_decay_iters | |||
| else: | |||
| num_iters = args.train_iters | |||
| if args.finetune: | |||
| num_iters = num_iters // args.gradient_accumulation_steps | |||
| num_iters = max(1, num_iters) | |||
| init_step = -1 | |||
| warmup_iter = args.warmup * num_iters | |||
| lr_scheduler = AnnealingLR( | |||
| optimizer, | |||
| start_lr=args.lr, | |||
| warmup_iter=warmup_iter, | |||
| num_iters=num_iters - warmup_iter, | |||
| decay_style=args.lr_decay_style, | |||
| last_iter=init_step, | |||
| decay_ratio=args.lr_decay_ratio) | |||
| return lr_scheduler | |||
| def setup_model_and_optimizer(args, | |||
| model_type=None, | |||
| multi_token=True, | |||
| num_labels=None, | |||
| spell_length=None): | |||
| """Setup model and optimizer.""" | |||
| model = get_model( | |||
| args, | |||
| model_type=model_type, | |||
| multi_token=multi_token, | |||
| num_labels=num_labels, | |||
| spell_length=spell_length) | |||
| param_groups = get_optimizer_param_groups(model) | |||
| if args.train_data is not None or args.data_dir is not None and ( | |||
| args.epochs > 0 or args.train_iters > 0): | |||
| if args.deepspeed: | |||
| print_rank_0('DeepSpeed is enabled.') | |||
| model, optimizer, _, _ = deepspeed.initialize( | |||
| model=model, | |||
| model_parameters=param_groups, | |||
| args=args, | |||
| mpu=mpu, | |||
| dist_init_required=False) | |||
| else: | |||
| optimizer = get_optimizer(param_groups, args) | |||
| lr_scheduler = get_learning_rate_scheduler(optimizer, args) | |||
| else: | |||
| optimizer, lr_scheduler = None, None | |||
| return model, optimizer, lr_scheduler | |||
| def backward_step(optimizer, model, lm_loss, args, timers): | |||
| """Backward step.""" | |||
| # Total loss. | |||
| loss = lm_loss | |||
| # Backward pass. | |||
| if args.deepspeed: | |||
| model.backward(loss) | |||
| else: | |||
| # optimizer.zero_grad() | |||
| if args.fp16: | |||
| optimizer.backward(loss, update_master_grads=False) | |||
| else: | |||
| loss.backward() | |||
| if args.deepspeed or args.DDP_impl == 'torch': | |||
| # DeepSpeed backward propagation already addressed all reduce communication. | |||
| # Reset the timer to avoid breaking timer logs below. | |||
| timers('allreduce').reset() | |||
| else: | |||
| timers('allreduce').start() | |||
| model.allreduce_params( | |||
| reduce_after=False, fp32_allreduce=args.fp32_allreduce) | |||
| timers('allreduce').stop() | |||
| # Update master gradients. | |||
| if not args.deepspeed: | |||
| if args.fp16: | |||
| optimizer.update_master_grads() | |||
| # Clipping gradients helps prevent the exploding gradient. | |||
| if args.clip_grad > 0: | |||
| if not args.fp16: | |||
| mpu.clip_grad_norm(model.parameters(), args.clip_grad) | |||
| else: | |||
| optimizer.clip_master_grads(args.clip_grad) | |||
| return lm_loss | |||
| def see_memory_usage(message, force=False): | |||
| if not force: | |||
| return | |||
| dist.barrier() | |||
| if dist.get_rank() == 0: | |||
| print(message) | |||
| print('Memory Allocated ', | |||
| torch.cuda.memory_allocated() / (1024 * 1024 * 1024), | |||
| 'GigaBytes') | |||
| print('Max Memory Allocated ', | |||
| torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024), | |||
| 'GigaBytes') | |||
| print('Cache Allocated ', | |||
| torch.cuda.memory_cached() / (1024 * 1024 * 1024), 'GigaBytes') | |||
| print('Max cache Allocated ', | |||
| torch.cuda.max_memory_cached() / (1024 * 1024 * 1024), | |||
| 'GigaBytes') | |||
| print(' ') | |||
| # input("Press Any Key To Continue ..") | |||
| def train_step(data_iterator, | |||
| model, | |||
| optimizer, | |||
| lr_scheduler, | |||
| args, | |||
| timers, | |||
| forward_step_func, | |||
| mems=None, | |||
| single_step=False): | |||
| """Single training step.""" | |||
| lm_loss_total, count = 0.0, 0 | |||
| mems = [] if mems is None else mems | |||
| if not args.deepspeed: | |||
| optimizer.zero_grad() | |||
| while True: | |||
| skipped_iter, complete = 0, False | |||
| # Forward model for one step. | |||
| timers('forward').start() | |||
| lm_loss, mems, _ = forward_step_func(data_iterator, model, args, | |||
| timers, mems) | |||
| timers('forward').stop() | |||
| # print_rank_0("Forward step") | |||
| if not args.deepspeed: | |||
| lm_loss /= args.gradient_accumulation_steps | |||
| reduced_loss = lm_loss.detach().clone().view(1) | |||
| torch.distributed.all_reduce( | |||
| reduced_loss.data, group=mpu.get_data_parallel_group()) | |||
| reduced_loss.data = reduced_loss.data / ( | |||
| args.world_size / args.model_parallel_size) | |||
| if not DynamicLossScaler._has_inf_or_nan(reduced_loss): | |||
| lm_loss_total += reduced_loss | |||
| count += 1 | |||
| # Calculate gradients, reduce across processes, and clip. | |||
| timers('backward').start() | |||
| backward_step(optimizer, model, lm_loss, args, timers) | |||
| timers('backward').stop() | |||
| # print_rank_0("Backward step") | |||
| # Update parameters. | |||
| timers('optimizer').start() | |||
| if args.deepspeed: | |||
| if model.is_gradient_accumulation_boundary(): | |||
| model.step() | |||
| complete = True | |||
| if not (args.fp16 and optimizer.overflow): | |||
| lr_scheduler.step() | |||
| else: | |||
| skipped_iter = 1 | |||
| else: | |||
| model.step() | |||
| else: | |||
| if count == args.gradient_accumulation_steps: | |||
| optimizer.step() | |||
| complete = True | |||
| # Update learning rate. | |||
| if not (args.fp16 and optimizer.overflow): | |||
| lr_scheduler.step() | |||
| else: | |||
| skipped_iter = 1 | |||
| # print_rank_0("Optimizer step") | |||
| timers('optimizer').stop() | |||
| if complete: | |||
| break | |||
| else: | |||
| print_rank_0('Found NaN loss, skip backward') | |||
| del lm_loss, reduced_loss | |||
| mems = [] | |||
| if single_step: | |||
| break | |||
| if args.deepspeed: | |||
| lm_loss_total = lm_loss_total / count | |||
| return lm_loss_total, skipped_iter, mems | |||
| @@ -0,0 +1,529 @@ | |||
| # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Utilities for logging and serialization""" | |||
| import os | |||
| import random | |||
| import subprocess | |||
| import time | |||
| import json | |||
| import numpy as np | |||
| import torch | |||
| from . import mpu | |||
| from .fp16 import FP16_Optimizer | |||
| SUMMARY_WRITER_DIR_NAME = 'runs' | |||
| def get_log_dir(name, base): | |||
| return os.path.join(base, SUMMARY_WRITER_DIR_NAME, name) | |||
| def print_rank_0(message): | |||
| if torch.distributed.is_initialized(): | |||
| if torch.distributed.get_rank() == 0: | |||
| print(message, flush=True) | |||
| else: | |||
| print(message, flush=True) | |||
| def get_hostname(): | |||
| hostname_cmd = ['hostname -I'] | |||
| result = subprocess.check_output(hostname_cmd, shell=True) | |||
| master_addr = result.decode('utf-8').split()[0] | |||
| return master_addr | |||
| def get_spare_port(args): | |||
| if torch.distributed.get_rank() == 0: | |||
| port = subprocess.check_output(['shuf -n 1 -i 10000-65535'], | |||
| shell=True) | |||
| port = int(port.strip()) | |||
| if port == args.master_port: | |||
| port = subprocess.check_output(['shuf -n 1 -i 10000-65535'], | |||
| shell=True) | |||
| port = int(port.strip()) | |||
| port = torch.cuda.LongTensor([port]) | |||
| else: | |||
| port = torch.cuda.LongTensor([0]) | |||
| torch.distributed.broadcast(port, 0) | |||
| port = port.item() | |||
| return port | |||
| def print_and_save_args(args, verbose=True, log_dir=None): | |||
| """Print arguments.""" | |||
| if verbose: | |||
| print('arguments:', flush=True) | |||
| for arg in vars(args): | |||
| dots = '.' * (29 - len(arg)) | |||
| print( | |||
| ' {} {} {}'.format(arg, dots, getattr(args, arg)), flush=True) | |||
| if log_dir is not None: | |||
| json_file = os.path.join(log_dir, 'config.json') | |||
| with open(json_file, 'w') as output: | |||
| json.dump(vars(args), output, sort_keys=True) | |||
| if args.deepspeed and args.deepspeed_config is not None: | |||
| with open(args.deepspeed_config) as file: | |||
| deepspeed_config = json.load(file) | |||
| deepspeed_json_file = os.path.join(log_dir, | |||
| 'config_gpt_large.json') | |||
| with open(deepspeed_json_file, 'w') as output: | |||
| json.dump(deepspeed_config, output) | |||
| def print_params_min_max_norm(optimizer, iteration): | |||
| """Print min, max, and norm of all parameters.""" | |||
| index = 0 | |||
| rank = torch.distributed.get_rank() | |||
| string = 'iteration, rank, index, model-parallel,min, max, norm\n' | |||
| optimizer_ = optimizer | |||
| if isinstance(optimizer, FP16_Optimizer): | |||
| optimizer_ = optimizer.optimizer | |||
| for param_group in optimizer_.param_groups: | |||
| for param in param_group['params']: | |||
| index += 1 | |||
| min_ = param.data.min() | |||
| max_ = param.data.max() | |||
| norm = param.data.norm() | |||
| string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( | |||
| iteration, rank, index, int(param.model_parallel)) | |||
| string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) | |||
| print(string, flush=True) | |||
| class Timers: | |||
| """Group of timers.""" | |||
| class Timer: | |||
| """Timer.""" | |||
| def __init__(self, name): | |||
| self.name_ = name | |||
| self.elapsed_ = 0.0 | |||
| self.started_ = False | |||
| self.start_time = time.time() | |||
| def start(self): | |||
| """Start the timer.""" | |||
| assert not self.started_, 'timer has already been started' | |||
| torch.cuda.synchronize() | |||
| self.start_time = time.time() | |||
| self.started_ = True | |||
| def stop(self): | |||
| """Stop the timer.""" | |||
| assert self.started_, 'timer is not started' | |||
| torch.cuda.synchronize() | |||
| self.elapsed_ += (time.time() - self.start_time) | |||
| self.started_ = False | |||
| def reset(self): | |||
| """Reset timer.""" | |||
| self.elapsed_ = 0.0 | |||
| self.started_ = False | |||
| def elapsed(self, reset=True): | |||
| """Calculate the elapsed time.""" | |||
| started_ = self.started_ | |||
| # If the timing in progress, end it first. | |||
| if self.started_: | |||
| self.stop() | |||
| # Get the elapsed time. | |||
| elapsed_ = self.elapsed_ | |||
| # Reset the elapsed time | |||
| if reset: | |||
| self.reset() | |||
| # If timing was in progress, set it back. | |||
| if started_: | |||
| self.start() | |||
| return elapsed_ | |||
| def __init__(self): | |||
| self.timers = {} | |||
| def __call__(self, name): | |||
| if name not in self.timers: | |||
| self.timers[name] = self.Timer(name) | |||
| return self.timers[name] | |||
| def log(self, names, normalizer=1.0, reset=True): | |||
| """Log a group of timers.""" | |||
| assert normalizer > 0.0 | |||
| string = 'time (ms)' | |||
| for name in names: | |||
| elapsed_time = self.timers[name].elapsed( | |||
| reset=reset) * 1000.0 / normalizer | |||
| string += ' | {}: {:.2f}'.format(name, elapsed_time) | |||
| print_rank_0(string) | |||
| def report_memory(name): | |||
| """Simple GPU memory report.""" | |||
| mega_bytes = 1024.0 * 1024.0 | |||
| string = name + ' memory (MB)' | |||
| string += ' | allocated: {}'.format(torch.cuda.memory_allocated() | |||
| / mega_bytes) | |||
| string += ' | max allocated: {}'.format(torch.cuda.max_memory_allocated() | |||
| / mega_bytes) | |||
| string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes) | |||
| string += ' | max cached: {}'.format(torch.cuda.memory_reserved() | |||
| / mega_bytes) | |||
| print_rank_0(string) | |||
| def get_checkpoint_name(checkpoints_path, | |||
| iteration, | |||
| release=False, | |||
| zero=False): | |||
| if release: | |||
| d = 'release' | |||
| else: | |||
| d = '{}'.format(iteration) | |||
| if zero: | |||
| dp_rank = mpu.get_data_parallel_rank() | |||
| d += '_zero_dp_rank_{}'.format(dp_rank) | |||
| return os.path.join( | |||
| checkpoints_path, d, | |||
| 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) | |||
| def ensure_directory_exists(filename): | |||
| dirname = os.path.dirname(filename) | |||
| if not os.path.exists(dirname): | |||
| os.makedirs(dirname, exist_ok=True) | |||
| def get_checkpoint_tracker_filename(checkpoints_path): | |||
| return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') | |||
| def save_zero_checkpoint(args, iteration, optimizer): | |||
| zero_sd = { | |||
| 'iteration': iteration, | |||
| 'optimizer_state_dict': optimizer.state_dict() | |||
| } | |||
| zero_checkpoint_name = get_checkpoint_name(args.save, iteration, zero=True) | |||
| ensure_directory_exists(zero_checkpoint_name) | |||
| torch.save(zero_sd, zero_checkpoint_name) | |||
| print(' successfully saved {}'.format(zero_checkpoint_name)) | |||
| def save_checkpoint(iteration, | |||
| model, | |||
| optimizer, | |||
| lr_scheduler, | |||
| args, | |||
| tag=None, | |||
| barrier=True, | |||
| only_changed_parameters=False, | |||
| no_deepspeed=False, | |||
| no_save_optim=False): | |||
| """Save a model checkpoint.""" | |||
| if tag is None: | |||
| tag = str(iteration) | |||
| if args.deepspeed and not no_deepspeed: | |||
| save_ds_checkpoint(iteration, model, lr_scheduler, args, tag=tag) | |||
| else: | |||
| # Only rank zer0 of the data parallel writes to the disk. | |||
| if mpu.get_data_parallel_rank() == 0: | |||
| checkpoint_name = get_checkpoint_name(args.save, tag) | |||
| print( | |||
| 'global rank {} is saving checkpoint at iteration {:7d} to {}'. | |||
| format(torch.distributed.get_rank(), iteration, | |||
| checkpoint_name)) | |||
| sd = {'iteration': iteration} | |||
| if args.deepspeed: | |||
| model = model.module | |||
| state_dict = model.state_dict() | |||
| if only_changed_parameters: | |||
| requires_grad_dict = {} | |||
| for name, parameter in model.named_parameters(): | |||
| requires_grad_dict[name] = parameter.requires_grad | |||
| state_dict = { | |||
| key: value | |||
| for key, value in state_dict.items() | |||
| if requires_grad_dict[key] | |||
| } | |||
| sd['module'] = state_dict | |||
| # Optimizer stuff. | |||
| if not args.no_save_optim and not no_save_optim: | |||
| if optimizer is not None: | |||
| sd['optimizer'] = optimizer.state_dict() | |||
| if lr_scheduler is not None: | |||
| sd['lr_scheduler'] = lr_scheduler.state_dict() | |||
| # rng states. | |||
| if not args.no_save_rng: | |||
| sd['random_rng_state'] = random.getstate() | |||
| sd['np_rng_state'] = np.random.get_state() | |||
| sd['torch_rng_state'] = torch.get_rng_state() | |||
| sd['cuda_rng_state'] = torch.cuda.get_rng_state() | |||
| sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker( | |||
| ).get_states() | |||
| ensure_directory_exists(checkpoint_name) | |||
| torch.save(sd, checkpoint_name) | |||
| print(' successfully saved {}'.format(checkpoint_name)) | |||
| # Wait so everyone is done (necessary) | |||
| if barrier: | |||
| torch.distributed.barrier() | |||
| # And update the latest iteration | |||
| if torch.distributed.get_rank() == 0: | |||
| tracker_filename = get_checkpoint_tracker_filename(args.save) | |||
| with open(tracker_filename, 'w') as f: | |||
| f.write(tag) | |||
| def save_ds_checkpoint(iteration, model, lr_scheduler, args, tag): | |||
| """Save a model checkpoint.""" | |||
| sd = {} | |||
| sd['iteration'] = iteration | |||
| if lr_scheduler is not None: | |||
| sd['client_lr_scheduler'] = lr_scheduler.state_dict() | |||
| # rng states. | |||
| if not args.no_save_rng: | |||
| sd['random_rng_state'] = random.getstate() | |||
| sd['np_rng_state'] = np.random.get_state() | |||
| sd['torch_rng_state'] = torch.get_rng_state() | |||
| sd['cuda_rng_state'] = torch.cuda.get_rng_state() | |||
| sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states() | |||
| model.save_checkpoint(args.save, tag, client_state=sd) | |||
| def get_checkpoint_iteration(load_path): | |||
| # Read the tracker file and set the iteration. | |||
| tracker_filename = get_checkpoint_tracker_filename(load_path) | |||
| if not os.path.isfile(tracker_filename): | |||
| print_rank_0('WARNING: could not find the metadata file {} '.format( | |||
| tracker_filename)) | |||
| if os.path.isdir(load_path): | |||
| path = os.path.normpath(load_path) | |||
| load_dir, tag = os.path.split(path) | |||
| print_rank_0( | |||
| 'Try to directly load the checkpoint from the directory') | |||
| return load_dir, tag, False, True | |||
| print_rank_0(' will not load any checkpoints and will start from ' | |||
| 'random') | |||
| return load_path, 0, False, False | |||
| with open(tracker_filename, 'r') as f: | |||
| metastring = f.read().strip() | |||
| release = metastring == 'release' | |||
| # try: | |||
| # iteration = int(metastring) | |||
| # except ValueError: | |||
| # release = metastring == 'release' | |||
| # if not release: | |||
| # print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( | |||
| # tracker_filename)) | |||
| # exit() | |||
| # assert iteration > 0 or release, 'error parsing metadata file {}'.format( | |||
| # tracker_filename) | |||
| return load_path, metastring, release, True | |||
| def load_checkpoint(model, | |||
| optimizer, | |||
| lr_scheduler, | |||
| args, | |||
| no_deepspeed=False, | |||
| no_load_optim=False): | |||
| """Load a model checkpoint.""" | |||
| load_dir, tag, release, success = get_checkpoint_iteration(args.load) | |||
| if not success: | |||
| return 0 | |||
| if args.deepspeed and not no_deepspeed: | |||
| checkpoint_name, sd = model.load_checkpoint( | |||
| load_dir, | |||
| tag, | |||
| load_optimizer_states=not args.no_load_optim and not no_load_optim, | |||
| load_lr_scheduler_states=not args.no_load_lr_scheduler) | |||
| if not args.no_load_lr_scheduler and 'client_lr_scheduler' in sd: | |||
| lr_scheduler.load_state_dict(sd['client_lr_scheduler']) | |||
| print_rank_0('Load lr scheduler state') | |||
| if checkpoint_name is None: | |||
| if mpu.get_data_parallel_rank() == 0: | |||
| print('Unable to load checkpoint.') | |||
| return tag | |||
| else: | |||
| # Checkpoint. | |||
| checkpoint_name = get_checkpoint_name(load_dir, tag, release) | |||
| if mpu.get_data_parallel_rank() == 0: | |||
| print('global rank {} is loading checkpoint {}'.format( | |||
| torch.distributed.get_rank(), checkpoint_name)) | |||
| # Load the checkpoint. | |||
| sd = torch.load(checkpoint_name, map_location='cpu') | |||
| # Model. | |||
| if args.deepspeed: | |||
| model = model.module | |||
| missing_keys, unexpected_keys = model.load_state_dict( | |||
| sd['module'], strict=False) | |||
| if missing_keys or unexpected_keys: | |||
| print_rank_0( | |||
| f'Missing keys {missing_keys}, unexpected keys {unexpected_keys}' | |||
| ) | |||
| # Optimizer. | |||
| if not release and not args.finetune and not args.no_load_optim and not no_load_optim: | |||
| try: | |||
| if optimizer is not None: | |||
| optimizer.load_state_dict(sd['optimizer']) | |||
| if lr_scheduler is not None: | |||
| lr_scheduler.load_state_dict(sd['lr_scheduler']) | |||
| except KeyError: | |||
| print_rank_0( | |||
| 'Unable to load optimizer from checkpoint {}, exiting. ' | |||
| 'Specify --no-load-optim or --finetune to prevent ' | |||
| 'attempting to load the optimizer ' | |||
| 'state.'.format(checkpoint_name)) | |||
| # Iterations. | |||
| if args.finetune or release: | |||
| iteration = 0 | |||
| else: | |||
| try: | |||
| iteration = sd['iteration'] | |||
| except KeyError: | |||
| try: # Backward compatible with older checkpoints | |||
| iteration = sd['total_iters'] | |||
| except KeyError: | |||
| print_rank_0( | |||
| 'A metadata file exists but Unable to load iteration ' | |||
| ' from checkpoint {}, starting from 0 iteration'.format( | |||
| checkpoint_name)) | |||
| iteration = 0 | |||
| # rng states. | |||
| if not release and not args.finetune and not args.no_load_rng: | |||
| try: | |||
| random.setstate(sd['random_rng_state']) | |||
| np.random.set_state(sd['np_rng_state']) | |||
| torch.set_rng_state(sd['torch_rng_state']) | |||
| torch.cuda.set_rng_state(sd['cuda_rng_state']) | |||
| mpu.get_cuda_rng_tracker().set_states(sd['rng_tracker_states']) | |||
| except KeyError: | |||
| print_rank_0( | |||
| 'Unable to load random state from checkpoint {}, exiting. ' | |||
| 'Specify --no-load-rng or --finetune to prevent ' | |||
| 'attempting to load the random ' | |||
| 'state.'.format(checkpoint_name)) | |||
| if mpu.get_data_parallel_rank() == 0: | |||
| print(' successfully loaded {}'.format(checkpoint_name)) | |||
| return iteration | |||
| def load_weights(src, dst, dst2src=False): | |||
| """ | |||
| Loads weights from src to dst via in place copy. | |||
| src is a huggingface gpt2model, while dst is one of our models. | |||
| dst2src=True loads parameters from our models into huggingface's. | |||
| ^dst2src is still untested | |||
| """ | |||
| conv_layer = 'Conv1D' in str(type(src)) | |||
| for n, p in src.named_parameters(): | |||
| if dst2src: | |||
| data = dst._parameters[n].data | |||
| load = p.data | |||
| else: | |||
| data = p.data | |||
| load = dst._parameters[n].data | |||
| if conv_layer and 'weight' in n: | |||
| data = data.t().contiguous() | |||
| load.copy_(data) | |||
| # dst._parameters[n].data.copy_(data) | |||
| def load_mlp(our, oai, dst2src=False): | |||
| load_weights(oai.c_fc, our.dense_h_to_4h, dst2src) | |||
| load_weights(oai.c_proj, our.dense_4h_to_h, dst2src) | |||
| def load_attention(our, oai, dst2src=False): | |||
| load_weights(oai.c_attn, our.query_key_value, dst2src) | |||
| load_weights(oai.c_proj, our.dense, dst2src) | |||
| def load_transformer_layer(our, oai, dst2src=False): | |||
| load_weights(oai.ln_1, our.input_layernorm, dst2src) | |||
| load_weights(oai.ln_2, our.post_attention_layernorm, dst2src) | |||
| load_mlp(our.mlp, oai.mlp, dst2src) | |||
| load_attention(our.attention, oai.attn, dst2src) | |||
| def move_weights(our, oai, dst2src=False): | |||
| """ | |||
| Loads weights from `oai` to `our` via in place copy. | |||
| `oai` is a huggingface gpt2model, while `our` is one of our models. | |||
| dst2src=True loads parameters from our models into huggingface's. | |||
| ^dst2src=True is still untested | |||
| """ | |||
| # while isinstance(our, (torchDDP, model.distributed.DistributedDataParallel, FP16_Module)): | |||
| # our=our.module | |||
| transformer_model = oai.transformer | |||
| load_weights(transformer_model.ln_f, our.transformer.final_layernorm, | |||
| dst2src) | |||
| load_weights(transformer_model.wte, our.word_embeddings, dst2src) | |||
| load_weights(transformer_model.wpe, our.position_embeddings, dst2src) | |||
| for our_layer, oai_layer in zip(our.transformer.layers, oai.transformer.h): | |||
| load_transformer_layer(our_layer, oai_layer, dst2src) | |||
| def debug_finetune_data(local_vars, batch_id, tokenizer): | |||
| tokens, target_ids = local_vars['tokens'], local_vars['target_ids'] | |||
| attention_mask, logit_mask, position_ids = local_vars[ | |||
| 'attention_mask'], local_vars['logit_mask'], local_vars['position_ids'] | |||
| output_tokens = [] | |||
| sep = attention_mask[batch_id].item() | |||
| for i, token in enumerate(tokens[batch_id][:sep].tolist()): | |||
| token = tokenizer.IdToToken(token) | |||
| if token == '[MASK]': | |||
| token = f'[{position_ids[batch_id][0, i].item()}]' | |||
| output_tokens.append(token) | |||
| print(' '.join(output_tokens)) | |||
| target_positions = [] | |||
| for i in range(sep, tokens.size(-1)): | |||
| if logit_mask[batch_id][i]: | |||
| target_positions.append(i) | |||
| print(target_positions) | |||
| print(tokenizer.DecodeIds(tokens[batch_id][target_positions].tolist())) | |||
| if len(target_ids.shape) > 2: | |||
| print( | |||
| tokenizer.DecodeIds( | |||
| target_ids[batch_id][target_positions].tolist())) | |||
| else: | |||
| print(tokenizer.DecodeIds(target_ids[batch_id].tolist())) | |||
| print(position_ids[batch_id][:, target_positions]) | |||
| @@ -516,6 +516,12 @@ TASK_OUTPUTS = { | |||
| # } | |||
| Tasks.text_generation: [OutputKeys.TEXT], | |||
| # summarization result for single sample | |||
| # { | |||
| # "text": "this is the text generated by a model." | |||
| # } | |||
| Tasks.text_summarization: [OutputKeys.TEXT], | |||
| # text generation result for single sample | |||
| # { | |||
| # "text": "北京" | |||
| @@ -31,6 +31,7 @@ if TYPE_CHECKING: | |||
| from .translation_pipeline import TranslationPipeline | |||
| from .word_segmentation_pipeline import WordSegmentationPipeline | |||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
| from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | |||
| from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | |||
| WordSegmentationThaiPipeline | |||
| @@ -71,6 +72,7 @@ else: | |||
| 'word_segmentation_pipeline': ['WordSegmentationPipeline'], | |||
| 'zero_shot_classification_pipeline': | |||
| ['ZeroShotClassificationPipeline'], | |||
| 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | |||
| 'multilingual_word_segmentation_pipeline': [ | |||
| 'MultilingualWordSegmentationPipeline', | |||
| 'WordSegmentationThaiPipeline' | |||
| @@ -0,0 +1,43 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| from typing import Any, Dict, Optional, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.base import Model | |||
| from modelscope.models.nlp import MGLMForTextSummarization | |||
| from modelscope.pipelines.base import Pipeline, Tensor | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import (MGLMSummarizationPreprocessor, | |||
| Preprocessor) | |||
| from modelscope.utils.constant import Tasks | |||
| __all__ = ['MGLMTextSummarizationPipeline'] | |||
| @PIPELINES.register_module( | |||
| group_key=Tasks.text_summarization, | |||
| module_name=Pipelines.mglm_text_summarization) | |||
| class MGLMTextSummarizationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[MGLMForTextSummarization, str], | |||
| preprocessor: [Preprocessor] = None, | |||
| *args, | |||
| **kwargs): | |||
| model = MGLMForTextSummarization(model) if isinstance(model, | |||
| str) else model | |||
| self.model = model | |||
| self.model.eval() | |||
| if preprocessor is None: | |||
| preprocessor = MGLMSummarizationPreprocessor() | |||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||
| # define the forward pass | |||
| def forward(self, inputs: Union[Dict, str], | |||
| **forward_params) -> Dict[str, Any]: | |||
| inputs = {'text': inputs} if isinstance(inputs, str) else inputs | |||
| return self.model.generate(inputs) | |||
| # format the outputs from pipeline | |||
| def postprocess(self, input, **kwargs) -> Dict[str, Any]: | |||
| return input | |||
| @@ -18,16 +18,16 @@ if TYPE_CHECKING: | |||
| from .nlp import ( | |||
| DocumentSegmentationPreprocessor, FaqQuestionAnsweringPreprocessor, | |||
| FillMaskPoNetPreprocessor, NLPPreprocessor, | |||
| NLPTokenizerPreprocessorBase, TextRankingPreprocessor, | |||
| RelationExtractionPreprocessor, SentenceEmbeddingPreprocessor, | |||
| SequenceClassificationPreprocessor, TokenClassificationPreprocessor, | |||
| TextErrorCorrectionPreprocessor, TextGenerationPreprocessor, | |||
| Text2TextGenerationPreprocessor, Tokenize, | |||
| NLPTokenizerPreprocessorBase, PassageRankingPreprocessor, | |||
| TextRankingPreprocessor, RelationExtractionPreprocessor, | |||
| SentenceEmbeddingPreprocessor, SequenceClassificationPreprocessor, | |||
| TokenClassificationPreprocessor, TextErrorCorrectionPreprocessor, | |||
| TextGenerationPreprocessor, Text2TextGenerationPreprocessor, Tokenize, | |||
| WordSegmentationBlankSetToLabelPreprocessor, | |||
| ZeroShotClassificationPreprocessor, TextGenerationJiebaPreprocessor, | |||
| SentencePiecePreprocessor, DialogIntentPredictionPreprocessor, | |||
| DialogModelingPreprocessor, DialogStateTrackingPreprocessor, | |||
| ConversationalTextToSqlPreprocessor, | |||
| MGLMSummarizationPreprocessor, ZeroShotClassificationPreprocessor, | |||
| TextGenerationJiebaPreprocessor, SentencePiecePreprocessor, | |||
| DialogIntentPredictionPreprocessor, DialogModelingPreprocessor, | |||
| DialogStateTrackingPreprocessor, ConversationalTextToSqlPreprocessor, | |||
| TableQuestionAnsweringPreprocessor, NERPreprocessorViet, | |||
| NERPreprocessorThai, WordSegmentationPreprocessorThai) | |||
| from .video import ReadVideoData, MovieSceneSegmentationPreprocessor | |||
| @@ -57,6 +57,7 @@ else: | |||
| 'TextErrorCorrectionPreprocessor', 'TextGenerationPreprocessor', | |||
| 'Tokenize', 'Text2TextGenerationPreprocessor', | |||
| 'WordSegmentationBlankSetToLabelPreprocessor', | |||
| 'MGLMSummarizationPreprocessor', | |||
| 'ZeroShotClassificationPreprocessor', | |||
| 'TextGenerationJiebaPreprocessor', 'SentencePiecePreprocessor', | |||
| 'NERPreprocessorViet', 'NERPreprocessorThai', | |||
| @@ -29,6 +29,7 @@ if TYPE_CHECKING: | |||
| MultiWOZBPETextField, IntentBPETextField) | |||
| from .space_T_en import ConversationalTextToSqlPreprocessor | |||
| from .space_T_cn import TableQuestionAnsweringPreprocessor | |||
| from .mglm_summarization_preprocessor import MGLMSummarizationPreprocessor | |||
| else: | |||
| _import_structure = { | |||
| 'nlp_base': [ | |||
| @@ -62,6 +63,7 @@ else: | |||
| 'text_error_correction': [ | |||
| 'TextErrorCorrectionPreprocessor', | |||
| ], | |||
| 'mglm_summarization_preprocessor': ['MGLMSummarizationPreprocessor'], | |||
| 'token_classification_thai_preprocessor': [ | |||
| 'NERPreprocessorThai', | |||
| 'WordSegmentationPreprocessorThai', | |||
| @@ -0,0 +1,32 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import os.path as osp | |||
| import re | |||
| from typing import Any, Dict, Iterable, Optional, Tuple, Union | |||
| from modelscope.metainfo import Models, Preprocessors | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.preprocessors.base import Preprocessor | |||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||
| from modelscope.utils.config import Config, ConfigFields | |||
| from modelscope.utils.constant import Fields, InputFields, ModeKeys, ModelFile | |||
| from modelscope.utils.hub import get_model_type, parse_label_mapping | |||
| from modelscope.utils.logger import get_logger | |||
| from modelscope.utils.nlp import import_external_nltk_data | |||
| from modelscope.utils.type_assert import type_assert | |||
| @PREPROCESSORS.register_module( | |||
| Fields.nlp, module_name=Preprocessors.mglm_summarization) | |||
| class MGLMSummarizationPreprocessor(Preprocessor): | |||
| def __init__(self, *args, **kwargs): | |||
| """preprocess the data | |||
| Args: | |||
| model_dir (str): model path | |||
| """ | |||
| super().__init__(*args, **kwargs) | |||
| @type_assert(object, (str, tuple, Dict)) | |||
| def __call__(self, data: Union[str, tuple, Dict]) -> Dict[str, Any]: | |||
| return data | |||
| @@ -1,18 +1,25 @@ | |||
| boto3 | |||
| en_core_web_sm>=2.3.5 | |||
| fasttext | |||
| filelock | |||
| ftfy | |||
| jieba>=0.42.1 | |||
| megatron_util | |||
| matplotlib | |||
| nltk | |||
| pai-easynlp | |||
| pandas | |||
| # protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged. | |||
| protobuf>=3.19.0,<3.21.0 | |||
| pythainlp | |||
| pyvi | |||
| # rough-score was just recently updated from 0.0.4 to 0.0.7 | |||
| # which introduced compatability issues that are being investigated | |||
| rouge_score<=0.0.4 | |||
| regex | |||
| sacremoses>=0.0.41 | |||
| scikit_learn | |||
| sentencepiece | |||
| seqeval | |||
| spacy>=2.3.5 | |||
| subword_nmt>=0.3.8 | |||
| termcolor | |||
| text2sql_lgesql | |||
| tokenizers | |||
| transformers>=4.12.0 | |||
| @@ -0,0 +1,47 @@ | |||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||
| import os | |||
| import unittest | |||
| from modelscope.models import Model | |||
| from modelscope.pipelines import pipeline | |||
| from modelscope.preprocessors import MGLMSummarizationPreprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||
| from modelscope.utils.test_utils import test_level | |||
| class mGLMTest(unittest.TestCase, DemoCompatibilityCheck): | |||
| def setUp(self) -> None: | |||
| self.output_dir = 'unittest_output' | |||
| os.makedirs(self.output_dir, exist_ok=True) | |||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | |||
| def test_run_with_mglm_with_name(self): | |||
| model = 'ZhipuAI/Multilingual-GLM-Summarization-zh' | |||
| preprocessor = MGLMSummarizationPreprocessor() | |||
| pipe = pipeline( | |||
| task=Tasks.text_summarization, | |||
| model=model, | |||
| preprocessor=preprocessor, | |||
| ) | |||
| result = pipe( | |||
| '据中国载人航天工程办公室消息,北京时间2022年10月25日,梦天实验舱与长征五号B遥四运载火箭组合体已转运至发射区。后续将按计划开展发射前各项功能检查和联合测试等工作,计划于近日择机实施发射。目前,文昌航天发射场设施设备状态良好,参试各单位正在加紧开展任务准备,全力以赴确保空间站建造任务决战决胜。' # noqa | |||
| ) | |||
| print(result) | |||
| model = 'ZhipuAI/Multilingual-GLM-Summarization-en' | |||
| preprocessor = MGLMSummarizationPreprocessor() | |||
| pipe = pipeline( | |||
| task=Tasks.text_summarization, | |||
| model=model, | |||
| preprocessor=preprocessor, | |||
| ) | |||
| result = pipe( | |||
| '据中国载人航天工程办公室消息,北京时间2022年10月25日,梦天实验舱与长征五号B遥四运载火箭组合体已转运至发射区。后续将按计划开展发射前各项功能检查和联合测试等工作,计划于近日择机实施发射。目前,文昌航天发射场设施设备状态良好,参试各单位正在加紧开展任务准备,全力以赴确保空间站建造任务决战决胜。' # noqa | |||
| ) | |||
| print(result) | |||
| if __name__ == '__main__': | |||
| unittest.main() | |||