@@ -24,5 +24,5 @@ from .optimizer import Optimizer, SGD, Adam | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results, seq_len_to_mask | |||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
from .vocabulary import Vocabulary |
@@ -376,7 +376,7 @@ class Vocabulary(object): | |||
:return: bool | |||
""" | |||
return word in self._no_create_word | |||
def to_index(self, w): | |||
""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | |||
@@ -68,6 +68,10 @@ class BertEmbedding(ContextualEmbedding): | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self._word_sep_index = None | |||
if '[SEP]' in vocab: | |||
self._word_sep_index = vocab['[SEP]'] | |||
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate) | |||
@@ -86,7 +90,11 @@ class BertEmbedding(ContextualEmbedding): | |||
:param torch.LongTensor words: [batch_size, max_len] | |||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||
""" | |||
if self._word_sep_index: # 不能drop sep | |||
sep_mask = words.eq(self._word_sep_index) | |||
words = self.drop_word(words) | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._word_sep_index) | |||
outputs = self._get_sent_reprs(words) | |||
if outputs is not None: | |||
return self.dropout(words) | |||
@@ -74,14 +74,10 @@ class StaticEmbedding(TokenEmbedding): | |||
if lower: | |||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | |||
for word, index in vocab: | |||
if not vocab._is_word_no_create_entry(word): | |||
if vocab._is_word_no_create_entry(word): | |||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | |||
else: | |||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | |||
for word in vocab._no_create_word.keys(): # 不需要创建entry的 | |||
if word in vocab: | |||
lowered_word = word.lower() | |||
if lowered_word not in lowered_vocab.word_count: | |||
lowered_vocab.add_word(lowered_word) | |||
lowered_vocab._no_create_word[lowered_word] += 1 | |||
print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " | |||
f"words, {len(lowered_vocab)} unique lowered words.") | |||
if model_path: | |||
@@ -90,7 +86,7 @@ class StaticEmbedding(TokenEmbedding): | |||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||
# 需要适配一下 | |||
if not hasattr(self, 'words_to_words'): | |||
self.words_to_words = torch.arange(len(lowered_vocab, )).long() | |||
self.words_to_words = torch.arange(len(lowered_vocab)).long() | |||
if lowered_vocab.unknown: | |||
unknown_idx = lowered_vocab.unknown_idx | |||
else: | |||
@@ -100,10 +96,11 @@ class StaticEmbedding(TokenEmbedding): | |||
for word, index in vocab: | |||
if word not in lowered_vocab: | |||
word = word.lower() | |||
if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了 | |||
continue | |||
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): | |||
continue # 如果不需要创建entry,已经默认unknown了 | |||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | |||
self.words_to_words = words_to_words | |||
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index | |||
else: | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | |||
@@ -211,12 +208,14 @@ class StaticEmbedding(TokenEmbedding): | |||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
for word, index in vocab: | |||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 | |||
if vocab.padding_idx == index: | |||
matrix[index] = torch.zeros(dim) | |||
elif vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 | |||
matrix[index] = matrix[vocab.unknown_idx] | |||
else: | |||
matrix[index] = None | |||
vectors = self._randomly_init_embed(len(matrix), dim, init_method) | |||
vectors = self._randomly_init_embed(len(vocab), dim, init_method) | |||
if vocab._no_create_word_length>0: | |||
if vocab.unknown is None: # 创建一个专门的unknown | |||
@@ -226,10 +225,13 @@ class StaticEmbedding(TokenEmbedding): | |||
unknown_idx = vocab.unknown_idx | |||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||
requires_grad=False) | |||
for order, (index, vec) in enumerate(matrix.items()): | |||
for word, index in vocab: | |||
vec = matrix.get(index, None) | |||
if vec is not None: | |||
vectors[order] = vec | |||
words_to_words[index] = order | |||
vectors[index] = vec | |||
words_to_words[index] = index | |||
else: | |||
vectors[index] = vectors[unknown_idx] | |||
self.words_to_words = words_to_words | |||
else: | |||
for index, vec in matrix.items(): | |||
@@ -144,7 +144,7 @@ class DataBundle: | |||
""" | |||
self.datasets[name] = dataset | |||
def get_dataset(self, name:str): | |||
def get_dataset(self, name:str)->DataSet: | |||
""" | |||
获取名为name的dataset | |||
@@ -153,7 +153,7 @@ class DataBundle: | |||
""" | |||
return self.datasets[name] | |||
def get_vocab(self, field_name:str): | |||
def get_vocab(self, field_name:str)->Vocabulary: | |||
""" | |||
获取field名为field_name对应的vocab | |||
@@ -78,6 +78,17 @@ DATASET_DIR = { | |||
"rte": "RTE.zip" | |||
} | |||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||
"bert": PRETRAINED_BERT_MODEL_DIR, | |||
"static": PRETRAIN_STATIC_FILES} | |||
# 用于扩展fastNLP的下载 | |||
FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' | |||
FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | |||
'bert':'fastnlp_bert_url.txt', | |||
'static': 'fastnlp_static_url.txt' | |||
} | |||
def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: | |||
""" | |||
@@ -97,7 +108,7 @@ def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: | |||
:return: | |||
""" | |||
if cache_dir is None: | |||
data_cache = Path(get_default_cache_path()) | |||
data_cache = Path(get_cache_path()) | |||
else: | |||
data_cache = cache_dir | |||
@@ -146,7 +157,7 @@ def get_filepath(filepath): | |||
raise FileNotFoundError(f"{filepath} is not a valid file or directory.") | |||
def get_default_cache_path(): | |||
def get_cache_path(): | |||
""" | |||
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | |||
@@ -188,27 +199,51 @@ def _get_base_url(name): | |||
return URLS[name.lower()] | |||
def _get_embedding_url(type, name): | |||
def _get_embedding_url(embed_type, name): | |||
""" | |||
给定embedding类似和名称,返回下载url | |||
:param str type: 支持static, bert, elmo。即embedding的类型 | |||
:param str embed_type: 支持static, bert, elmo。即embedding的类型 | |||
:param str name: embedding的名称, 例如en, cn, based等 | |||
:return: str, 下载的url地址 | |||
""" | |||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||
"bert": PRETRAINED_BERT_MODEL_DIR, | |||
"static":PRETRAIN_STATIC_FILES} | |||
map = PRETRAIN_MAP.get(type, None) | |||
# 从扩展中寻找下载的url | |||
_filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None) | |||
if _filename: | |||
url = _read_extend_url_file(_filename, name) | |||
if url: | |||
return url | |||
map = PRETRAIN_MAP.get(embed_type, None) | |||
if map: | |||
filename = map.get(name, None) | |||
if filename: | |||
url = _get_base_url('embedding') + filename | |||
return url | |||
raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys()))) | |||
else: | |||
raise KeyError(f"There is no {type}. Only supports bert, elmo, static") | |||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") | |||
def _read_extend_url_file(filename, name)->str: | |||
""" | |||
filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 | |||
:param str filename: 在默认的路径下寻找file这个文件 | |||
:param str name: 需要寻找的资源的名称 | |||
:return: str or None | |||
""" | |||
cache_dir = get_cache_path() | |||
filepath = os.path.join(cache_dir, filename) | |||
if os.path.exists(filepath): | |||
with open(filepath, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
if len(parts) == 2: | |||
if name == parts[0]: | |||
return parts[1] | |||
return None | |||
def _get_dataset_url(name): | |||
""" | |||
@@ -217,6 +252,11 @@ def _get_dataset_url(name): | |||
:param str name: 给定dataset的名称,比如imdb, sst-2等 | |||
:return: str | |||
""" | |||
# 从扩展中寻找下载的url | |||
url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) | |||
if url: | |||
return url | |||
filename = DATASET_DIR.get(name, None) | |||
if filename: | |||
url = _get_base_url('dataset') + filename | |||
@@ -3,7 +3,7 @@ from .. import DataBundle | |||
from ..utils import check_loader_paths | |||
from typing import Union, Dict | |||
import os | |||
from ..file_utils import _get_dataset_url, get_default_cache_path, cached_path | |||
from ..file_utils import _get_dataset_url, get_cache_path, cached_path | |||
class Loader: | |||
def __init__(self): | |||
@@ -66,7 +66,7 @@ class Loader: | |||
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | |||
""" | |||
default_cache_path = get_default_cache_path() | |||
default_cache_path = get_cache_path() | |||
url = _get_dataset_url(dataset_name) | |||
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | |||
@@ -24,7 +24,7 @@ class _NERPipe(Pipe): | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
else: | |||
self.convert_tag = iob2bioes | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
self.lower = lower | |||
self.target_pad_val = int(target_pad_val) | |||
@@ -57,7 +57,7 @@ class MatchingBertPipe(Pipe): | |||
dataset[Const.INPUTS(0)].lower() | |||
dataset[Const.INPUTS(1)].lower() | |||
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)], | |||
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], | |||
[Const.INPUTS(0), Const.INPUTS(1)]) | |||
# concat两个words | |||
@@ -61,7 +61,7 @@ def get_tokenizer(tokenizer:str, lang='en'): | |||
if tokenizer == 'spacy': | |||
import spacy | |||
spacy.prefer_gpu() | |||
if lang!='en': | |||
if lang != 'en': | |||
raise RuntimeError("Spacy only supports en right right.") | |||
en = spacy.load(lang) | |||
tokenizer = lambda x: [w.text for w in en.tokenizer(x)] | |||