@@ -24,5 +24,5 @@ from .optimizer import Optimizer, SGD, Adam | |||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
from .tester import Tester | from .tester import Tester | ||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .utils import cache_results, seq_len_to_mask | |||||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary |
@@ -376,7 +376,7 @@ class Vocabulary(object): | |||||
:return: bool | :return: bool | ||||
""" | """ | ||||
return word in self._no_create_word | return word in self._no_create_word | ||||
def to_index(self, w): | def to_index(self, w): | ||||
""" | """ | ||||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | ||||
@@ -68,6 +68,10 @@ class BertEmbedding(ContextualEmbedding): | |||||
else: | else: | ||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | raise ValueError(f"Cannot recognize {model_dir_or_name}.") | ||||
self._word_sep_index = None | |||||
if '[SEP]' in vocab: | |||||
self._word_sep_index = vocab['[SEP]'] | |||||
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | ||||
pool_method=pool_method, include_cls_sep=include_cls_sep, | pool_method=pool_method, include_cls_sep=include_cls_sep, | ||||
pooled_cls=pooled_cls, auto_truncate=auto_truncate) | pooled_cls=pooled_cls, auto_truncate=auto_truncate) | ||||
@@ -86,7 +90,11 @@ class BertEmbedding(ContextualEmbedding): | |||||
:param torch.LongTensor words: [batch_size, max_len] | :param torch.LongTensor words: [batch_size, max_len] | ||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | ||||
""" | """ | ||||
if self._word_sep_index: # 不能drop sep | |||||
sep_mask = words.eq(self._word_sep_index) | |||||
words = self.drop_word(words) | words = self.drop_word(words) | ||||
if self._word_sep_index: | |||||
words.masked_fill_(sep_mask, self._word_sep_index) | |||||
outputs = self._get_sent_reprs(words) | outputs = self._get_sent_reprs(words) | ||||
if outputs is not None: | if outputs is not None: | ||||
return self.dropout(words) | return self.dropout(words) | ||||
@@ -74,14 +74,10 @@ class StaticEmbedding(TokenEmbedding): | |||||
if lower: | if lower: | ||||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | ||||
for word, index in vocab: | for word, index in vocab: | ||||
if not vocab._is_word_no_create_entry(word): | |||||
if vocab._is_word_no_create_entry(word): | |||||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | |||||
else: | |||||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | ||||
for word in vocab._no_create_word.keys(): # 不需要创建entry的 | |||||
if word in vocab: | |||||
lowered_word = word.lower() | |||||
if lowered_word not in lowered_vocab.word_count: | |||||
lowered_vocab.add_word(lowered_word) | |||||
lowered_vocab._no_create_word[lowered_word] += 1 | |||||
print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " | print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} " | ||||
f"words, {len(lowered_vocab)} unique lowered words.") | f"words, {len(lowered_vocab)} unique lowered words.") | ||||
if model_path: | if model_path: | ||||
@@ -90,7 +86,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | ||||
# 需要适配一下 | # 需要适配一下 | ||||
if not hasattr(self, 'words_to_words'): | if not hasattr(self, 'words_to_words'): | ||||
self.words_to_words = torch.arange(len(lowered_vocab, )).long() | |||||
self.words_to_words = torch.arange(len(lowered_vocab)).long() | |||||
if lowered_vocab.unknown: | if lowered_vocab.unknown: | ||||
unknown_idx = lowered_vocab.unknown_idx | unknown_idx = lowered_vocab.unknown_idx | ||||
else: | else: | ||||
@@ -100,10 +96,11 @@ class StaticEmbedding(TokenEmbedding): | |||||
for word, index in vocab: | for word, index in vocab: | ||||
if word not in lowered_vocab: | if word not in lowered_vocab: | ||||
word = word.lower() | word = word.lower() | ||||
if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了 | |||||
continue | |||||
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): | |||||
continue # 如果不需要创建entry,已经默认unknown了 | |||||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | ||||
self.words_to_words = words_to_words | self.words_to_words = words_to_words | ||||
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index | |||||
else: | else: | ||||
if model_path: | if model_path: | ||||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | ||||
@@ -211,12 +208,14 @@ class StaticEmbedding(TokenEmbedding): | |||||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | ||||
for word, index in vocab: | for word, index in vocab: | ||||
if index not in matrix and not vocab._is_word_no_create_entry(word): | if index not in matrix and not vocab._is_word_no_create_entry(word): | ||||
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 | |||||
if vocab.padding_idx == index: | |||||
matrix[index] = torch.zeros(dim) | |||||
elif vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 | |||||
matrix[index] = matrix[vocab.unknown_idx] | matrix[index] = matrix[vocab.unknown_idx] | ||||
else: | else: | ||||
matrix[index] = None | matrix[index] = None | ||||
vectors = self._randomly_init_embed(len(matrix), dim, init_method) | |||||
vectors = self._randomly_init_embed(len(vocab), dim, init_method) | |||||
if vocab._no_create_word_length>0: | if vocab._no_create_word_length>0: | ||||
if vocab.unknown is None: # 创建一个专门的unknown | if vocab.unknown is None: # 创建一个专门的unknown | ||||
@@ -226,10 +225,13 @@ class StaticEmbedding(TokenEmbedding): | |||||
unknown_idx = vocab.unknown_idx | unknown_idx = vocab.unknown_idx | ||||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | ||||
requires_grad=False) | requires_grad=False) | ||||
for order, (index, vec) in enumerate(matrix.items()): | |||||
for word, index in vocab: | |||||
vec = matrix.get(index, None) | |||||
if vec is not None: | if vec is not None: | ||||
vectors[order] = vec | |||||
words_to_words[index] = order | |||||
vectors[index] = vec | |||||
words_to_words[index] = index | |||||
else: | |||||
vectors[index] = vectors[unknown_idx] | |||||
self.words_to_words = words_to_words | self.words_to_words = words_to_words | ||||
else: | else: | ||||
for index, vec in matrix.items(): | for index, vec in matrix.items(): | ||||
@@ -144,7 +144,7 @@ class DataBundle: | |||||
""" | """ | ||||
self.datasets[name] = dataset | self.datasets[name] = dataset | ||||
def get_dataset(self, name:str): | |||||
def get_dataset(self, name:str)->DataSet: | |||||
""" | """ | ||||
获取名为name的dataset | 获取名为name的dataset | ||||
@@ -153,7 +153,7 @@ class DataBundle: | |||||
""" | """ | ||||
return self.datasets[name] | return self.datasets[name] | ||||
def get_vocab(self, field_name:str): | |||||
def get_vocab(self, field_name:str)->Vocabulary: | |||||
""" | """ | ||||
获取field名为field_name对应的vocab | 获取field名为field_name对应的vocab | ||||
@@ -78,6 +78,17 @@ DATASET_DIR = { | |||||
"rte": "RTE.zip" | "rte": "RTE.zip" | ||||
} | } | ||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||||
"bert": PRETRAINED_BERT_MODEL_DIR, | |||||
"static": PRETRAIN_STATIC_FILES} | |||||
# 用于扩展fastNLP的下载 | |||||
FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' | |||||
FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | |||||
'bert':'fastnlp_bert_url.txt', | |||||
'static': 'fastnlp_static_url.txt' | |||||
} | |||||
def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: | def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: | ||||
""" | """ | ||||
@@ -97,7 +108,7 @@ def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: | |||||
:return: | :return: | ||||
""" | """ | ||||
if cache_dir is None: | if cache_dir is None: | ||||
data_cache = Path(get_default_cache_path()) | |||||
data_cache = Path(get_cache_path()) | |||||
else: | else: | ||||
data_cache = cache_dir | data_cache = cache_dir | ||||
@@ -146,7 +157,7 @@ def get_filepath(filepath): | |||||
raise FileNotFoundError(f"{filepath} is not a valid file or directory.") | raise FileNotFoundError(f"{filepath} is not a valid file or directory.") | ||||
def get_default_cache_path(): | |||||
def get_cache_path(): | |||||
""" | """ | ||||
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | ||||
@@ -188,27 +199,51 @@ def _get_base_url(name): | |||||
return URLS[name.lower()] | return URLS[name.lower()] | ||||
def _get_embedding_url(type, name): | |||||
def _get_embedding_url(embed_type, name): | |||||
""" | """ | ||||
给定embedding类似和名称,返回下载url | 给定embedding类似和名称,返回下载url | ||||
:param str type: 支持static, bert, elmo。即embedding的类型 | |||||
:param str embed_type: 支持static, bert, elmo。即embedding的类型 | |||||
:param str name: embedding的名称, 例如en, cn, based等 | :param str name: embedding的名称, 例如en, cn, based等 | ||||
:return: str, 下载的url地址 | :return: str, 下载的url地址 | ||||
""" | """ | ||||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||||
"bert": PRETRAINED_BERT_MODEL_DIR, | |||||
"static":PRETRAIN_STATIC_FILES} | |||||
map = PRETRAIN_MAP.get(type, None) | |||||
# 从扩展中寻找下载的url | |||||
_filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None) | |||||
if _filename: | |||||
url = _read_extend_url_file(_filename, name) | |||||
if url: | |||||
return url | |||||
map = PRETRAIN_MAP.get(embed_type, None) | |||||
if map: | if map: | ||||
filename = map.get(name, None) | filename = map.get(name, None) | ||||
if filename: | if filename: | ||||
url = _get_base_url('embedding') + filename | url = _get_base_url('embedding') + filename | ||||
return url | return url | ||||
raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys()))) | raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys()))) | ||||
else: | else: | ||||
raise KeyError(f"There is no {type}. Only supports bert, elmo, static") | |||||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") | |||||
def _read_extend_url_file(filename, name)->str: | |||||
""" | |||||
filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 | |||||
:param str filename: 在默认的路径下寻找file这个文件 | |||||
:param str name: 需要寻找的资源的名称 | |||||
:return: str or None | |||||
""" | |||||
cache_dir = get_cache_path() | |||||
filepath = os.path.join(cache_dir, filename) | |||||
if os.path.exists(filepath): | |||||
with open(filepath, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
parts = line.split('\t') | |||||
if len(parts) == 2: | |||||
if name == parts[0]: | |||||
return parts[1] | |||||
return None | |||||
def _get_dataset_url(name): | def _get_dataset_url(name): | ||||
""" | """ | ||||
@@ -217,6 +252,11 @@ def _get_dataset_url(name): | |||||
:param str name: 给定dataset的名称,比如imdb, sst-2等 | :param str name: 给定dataset的名称,比如imdb, sst-2等 | ||||
:return: str | :return: str | ||||
""" | """ | ||||
# 从扩展中寻找下载的url | |||||
url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) | |||||
if url: | |||||
return url | |||||
filename = DATASET_DIR.get(name, None) | filename = DATASET_DIR.get(name, None) | ||||
if filename: | if filename: | ||||
url = _get_base_url('dataset') + filename | url = _get_base_url('dataset') + filename | ||||
@@ -3,7 +3,7 @@ from .. import DataBundle | |||||
from ..utils import check_loader_paths | from ..utils import check_loader_paths | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
import os | import os | ||||
from ..file_utils import _get_dataset_url, get_default_cache_path, cached_path | |||||
from ..file_utils import _get_dataset_url, get_cache_path, cached_path | |||||
class Loader: | class Loader: | ||||
def __init__(self): | def __init__(self): | ||||
@@ -66,7 +66,7 @@ class Loader: | |||||
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | ||||
""" | """ | ||||
default_cache_path = get_default_cache_path() | |||||
default_cache_path = get_cache_path() | |||||
url = _get_dataset_url(dataset_name) | url = _get_dataset_url(dataset_name) | ||||
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | ||||
@@ -24,7 +24,7 @@ class _NERPipe(Pipe): | |||||
if encoding_type == 'bio': | if encoding_type == 'bio': | ||||
self.convert_tag = iob2 | self.convert_tag = iob2 | ||||
else: | else: | ||||
self.convert_tag = iob2bioes | |||||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||||
self.lower = lower | self.lower = lower | ||||
self.target_pad_val = int(target_pad_val) | self.target_pad_val = int(target_pad_val) | ||||
@@ -57,7 +57,7 @@ class MatchingBertPipe(Pipe): | |||||
dataset[Const.INPUTS(0)].lower() | dataset[Const.INPUTS(0)].lower() | ||||
dataset[Const.INPUTS(1)].lower() | dataset[Const.INPUTS(1)].lower() | ||||
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)], | |||||
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], | |||||
[Const.INPUTS(0), Const.INPUTS(1)]) | [Const.INPUTS(0), Const.INPUTS(1)]) | ||||
# concat两个words | # concat两个words | ||||
@@ -61,7 +61,7 @@ def get_tokenizer(tokenizer:str, lang='en'): | |||||
if tokenizer == 'spacy': | if tokenizer == 'spacy': | ||||
import spacy | import spacy | ||||
spacy.prefer_gpu() | spacy.prefer_gpu() | ||||
if lang!='en': | |||||
if lang != 'en': | |||||
raise RuntimeError("Spacy only supports en right right.") | raise RuntimeError("Spacy only supports en right right.") | ||||
en = spacy.load(lang) | en = spacy.load(lang) | ||||
tokenizer = lambda x: [w.text for w in en.tokenizer(x)] | tokenizer = lambda x: [w.text for w in en.tokenizer(x)] | ||||