diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index b246c6a0..d92e8f62 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -24,5 +24,5 @@ from .optimizer import Optimizer, SGD, Adam from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .tester import Tester from .trainer import Trainer -from .utils import cache_results, seq_len_to_mask +from .utils import cache_results, seq_len_to_mask, get_seq_len from .vocabulary import Vocabulary diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index a51c3f92..330d73dd 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -376,7 +376,7 @@ class Vocabulary(object): :return: bool """ return word in self._no_create_word - + def to_index(self, w): """ 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index db50f9f4..fa56419b 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -68,6 +68,10 @@ class BertEmbedding(ContextualEmbedding): else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") + self._word_sep_index = None + if '[SEP]' in vocab: + self._word_sep_index = vocab['[SEP]'] + self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, pool_method=pool_method, include_cls_sep=include_cls_sep, pooled_cls=pooled_cls, auto_truncate=auto_truncate) @@ -86,7 +90,11 @@ class BertEmbedding(ContextualEmbedding): :param torch.LongTensor words: [batch_size, max_len] :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) """ + if self._word_sep_index: # 不能drop sep + sep_mask = words.eq(self._word_sep_index) words = self.drop_word(words) + if self._word_sep_index: + words.masked_fill_(sep_mask, self._word_sep_index) outputs = self._get_sent_reprs(words) if outputs is not None: return self.dropout(words) diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index d44d7087..78f615f6 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -74,14 +74,10 @@ class StaticEmbedding(TokenEmbedding): if lower: lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) for word, index in vocab: - if not vocab._is_word_no_create_entry(word): + if vocab._is_word_no_create_entry(word): + lowered_vocab.add_word(word.lower(), no_create_entry=True) + else: lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 - for word in vocab._no_create_word.keys(): # 不需要创建entry的 - if word in vocab: - lowered_word = word.lower() - if lowered_word not in lowered_vocab.word_count: - lowered_vocab.add_word(lowered_word) - lowered_vocab._no_create_word[lowered_word] += 1 print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " f"words, {len(lowered_vocab)} unique lowered words.") if model_path: @@ -90,7 +86,7 @@ class StaticEmbedding(TokenEmbedding): embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) # 需要适配一下 if not hasattr(self, 'words_to_words'): - self.words_to_words = torch.arange(len(lowered_vocab, )).long() + self.words_to_words = torch.arange(len(lowered_vocab)).long() if lowered_vocab.unknown: unknown_idx = lowered_vocab.unknown_idx else: @@ -100,10 +96,11 @@ class StaticEmbedding(TokenEmbedding): for word, index in vocab: if word not in lowered_vocab: word = word.lower() - if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了 - continue + if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): + continue # 如果不需要创建entry,已经默认unknown了 words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] self.words_to_words = words_to_words + self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index else: if model_path: embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) @@ -211,12 +208,14 @@ class StaticEmbedding(TokenEmbedding): print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) for word, index in vocab: if index not in matrix and not vocab._is_word_no_create_entry(word): - if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 + if vocab.padding_idx == index: + matrix[index] = torch.zeros(dim) + elif vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 matrix[index] = matrix[vocab.unknown_idx] else: matrix[index] = None - vectors = self._randomly_init_embed(len(matrix), dim, init_method) + vectors = self._randomly_init_embed(len(vocab), dim, init_method) if vocab._no_create_word_length>0: if vocab.unknown is None: # 创建一个专门的unknown @@ -226,10 +225,13 @@ class StaticEmbedding(TokenEmbedding): unknown_idx = vocab.unknown_idx words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), requires_grad=False) - for order, (index, vec) in enumerate(matrix.items()): + for word, index in vocab: + vec = matrix.get(index, None) if vec is not None: - vectors[order] = vec - words_to_words[index] = order + vectors[index] = vec + words_to_words[index] = index + else: + vectors[index] = vectors[unknown_idx] self.words_to_words = words_to_words else: for index, vec in matrix.items(): diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index 429a8406..5cbd5bb1 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -144,7 +144,7 @@ class DataBundle: """ self.datasets[name] = dataset - def get_dataset(self, name:str): + def get_dataset(self, name:str)->DataSet: """ 获取名为name的dataset @@ -153,7 +153,7 @@ class DataBundle: """ return self.datasets[name] - def get_vocab(self, field_name:str): + def get_vocab(self, field_name:str)->Vocabulary: """ 获取field名为field_name对应的vocab diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index 43fe2ab1..eb6dea1d 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -78,6 +78,17 @@ DATASET_DIR = { "rte": "RTE.zip" } +PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, + "bert": PRETRAINED_BERT_MODEL_DIR, + "static": PRETRAIN_STATIC_FILES} + +# 用于扩展fastNLP的下载 +FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' +FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', + 'bert':'fastnlp_bert_url.txt', + 'static': 'fastnlp_static_url.txt' +} + def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: """ @@ -97,7 +108,7 @@ def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: :return: """ if cache_dir is None: - data_cache = Path(get_default_cache_path()) + data_cache = Path(get_cache_path()) else: data_cache = cache_dir @@ -146,7 +157,7 @@ def get_filepath(filepath): raise FileNotFoundError(f"{filepath} is not a valid file or directory.") -def get_default_cache_path(): +def get_cache_path(): """ 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 @@ -188,27 +199,51 @@ def _get_base_url(name): return URLS[name.lower()] -def _get_embedding_url(type, name): +def _get_embedding_url(embed_type, name): """ 给定embedding类似和名称,返回下载url - :param str type: 支持static, bert, elmo。即embedding的类型 + :param str embed_type: 支持static, bert, elmo。即embedding的类型 :param str name: embedding的名称, 例如en, cn, based等 :return: str, 下载的url地址 """ - PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, - "bert": PRETRAINED_BERT_MODEL_DIR, - "static":PRETRAIN_STATIC_FILES} - map = PRETRAIN_MAP.get(type, None) + # 从扩展中寻找下载的url + _filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None) + if _filename: + url = _read_extend_url_file(_filename, name) + if url: + return url + map = PRETRAIN_MAP.get(embed_type, None) if map: + filename = map.get(name, None) if filename: url = _get_base_url('embedding') + filename return url raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys()))) else: - raise KeyError(f"There is no {type}. Only supports bert, elmo, static") + raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") +def _read_extend_url_file(filename, name)->str: + """ + filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 + + :param str filename: 在默认的路径下寻找file这个文件 + :param str name: 需要寻找的资源的名称 + :return: str or None + """ + cache_dir = get_cache_path() + filepath = os.path.join(cache_dir, filename) + if os.path.exists(filepath): + with open(filepath, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + parts = line.split('\t') + if len(parts) == 2: + if name == parts[0]: + return parts[1] + return None def _get_dataset_url(name): """ @@ -217,6 +252,11 @@ def _get_dataset_url(name): :param str name: 给定dataset的名称,比如imdb, sst-2等 :return: str """ + # 从扩展中寻找下载的url + url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) + if url: + return url + filename = DATASET_DIR.get(name, None) if filename: url = _get_base_url('dataset') + filename diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py index c59de29f..607d6920 100644 --- a/fastNLP/io/loader/loader.py +++ b/fastNLP/io/loader/loader.py @@ -3,7 +3,7 @@ from .. import DataBundle from ..utils import check_loader_paths from typing import Union, Dict import os -from ..file_utils import _get_dataset_url, get_default_cache_path, cached_path +from ..file_utils import _get_dataset_url, get_cache_path, cached_path class Loader: def __init__(self): @@ -66,7 +66,7 @@ class Loader: :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 """ - default_cache_path = get_default_cache_path() + default_cache_path = get_cache_path() url = _get_dataset_url(dataset_name) output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index b9007344..a49e68b1 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -24,7 +24,7 @@ class _NERPipe(Pipe): if encoding_type == 'bio': self.convert_tag = iob2 else: - self.convert_tag = iob2bioes + self.convert_tag = lambda words: iob2bioes(iob2(words)) self.lower = lower self.target_pad_val = int(target_pad_val) diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 93e854b1..76116345 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -57,7 +57,7 @@ class MatchingBertPipe(Pipe): dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(1)].lower() - data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)], + data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], [Const.INPUTS(0), Const.INPUTS(1)]) # concat两个words diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py index 5e9ff8dc..48454b67 100644 --- a/fastNLP/io/pipe/utils.py +++ b/fastNLP/io/pipe/utils.py @@ -61,7 +61,7 @@ def get_tokenizer(tokenizer:str, lang='en'): if tokenizer == 'spacy': import spacy spacy.prefer_gpu() - if lang!='en': + if lang != 'en': raise RuntimeError("Spacy only supports en right right.") en = spacy.load(lang) tokenizer = lambda x: [w.text for w in en.tokenizer(x)]