Browse Source

1.git add fastNLP/io/loader/loader.pygit add fastNLP/io/loader/loader.py重

tags/v0.4.10
yh 6 years ago
parent
commit
015376d235
10 changed files with 83 additions and 33 deletions
  1. +1
    -1
      fastNLP/core/__init__.py
  2. +1
    -1
      fastNLP/core/vocabulary.py
  3. +8
    -0
      fastNLP/embeddings/bert_embedding.py
  4. +17
    -15
      fastNLP/embeddings/static_embedding.py
  5. +2
    -2
      fastNLP/io/base_loader.py
  6. +49
    -9
      fastNLP/io/file_utils.py
  7. +2
    -2
      fastNLP/io/loader/loader.py
  8. +1
    -1
      fastNLP/io/pipe/conll.py
  9. +1
    -1
      fastNLP/io/pipe/matching.py
  10. +1
    -1
      fastNLP/io/pipe/utils.py

+ 1
- 1
fastNLP/core/__init__.py View File

@@ -24,5 +24,5 @@ from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .tester import Tester
from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask
from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary

+ 1
- 1
fastNLP/core/vocabulary.py View File

@@ -376,7 +376,7 @@ class Vocabulary(object):
:return: bool
"""
return word in self._no_create_word
def to_index(self, w):
"""
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``::


+ 8
- 0
fastNLP/embeddings/bert_embedding.py View File

@@ -68,6 +68,10 @@ class BertEmbedding(ContextualEmbedding):
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

self._word_sep_index = None
if '[SEP]' in vocab:
self._word_sep_index = vocab['[SEP]']

self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls, auto_truncate=auto_truncate)
@@ -86,7 +90,11 @@ class BertEmbedding(ContextualEmbedding):
:param torch.LongTensor words: [batch_size, max_len]
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._word_sep_index)
words = self.drop_word(words)
if self._word_sep_index:
words.masked_fill_(sep_mask, self._word_sep_index)
outputs = self._get_sent_reprs(words)
if outputs is not None:
return self.dropout(words)


+ 17
- 15
fastNLP/embeddings/static_embedding.py View File

@@ -74,14 +74,10 @@ class StaticEmbedding(TokenEmbedding):
if lower:
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
for word, index in vocab:
if not vocab._is_word_no_create_entry(word):
if vocab._is_word_no_create_entry(word):
lowered_vocab.add_word(word.lower(), no_create_entry=True)
else:
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的
for word in vocab._no_create_word.keys(): # 不需要创建entry的
if word in vocab:
lowered_word = word.lower()
if lowered_word not in lowered_vocab.word_count:
lowered_vocab.add_word(lowered_word)
lowered_vocab._no_create_word[lowered_word] += 1
print(f"All word in the vocab have been lowered before finding pretrained vectors. There are {len(vocab)} "
f"words, {len(lowered_vocab)} unique lowered words.")
if model_path:
@@ -90,7 +86,7 @@ class StaticEmbedding(TokenEmbedding):
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
# 需要适配一下
if not hasattr(self, 'words_to_words'):
self.words_to_words = torch.arange(len(lowered_vocab, )).long()
self.words_to_words = torch.arange(len(lowered_vocab)).long()
if lowered_vocab.unknown:
unknown_idx = lowered_vocab.unknown_idx
else:
@@ -100,10 +96,11 @@ class StaticEmbedding(TokenEmbedding):
for word, index in vocab:
if word not in lowered_vocab:
word = word.lower()
if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了
continue
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word):
continue # 如果不需要创建entry,已经默认unknown了
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
self.words_to_words = words_to_words
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index
else:
if model_path:
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method)
@@ -211,12 +208,14 @@ class StaticEmbedding(TokenEmbedding):
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
for word, index in vocab:
if index not in matrix and not vocab._is_word_no_create_entry(word):
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化
if vocab.padding_idx == index:
matrix[index] = torch.zeros(dim)
elif vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化
matrix[index] = matrix[vocab.unknown_idx]
else:
matrix[index] = None

vectors = self._randomly_init_embed(len(matrix), dim, init_method)
vectors = self._randomly_init_embed(len(vocab), dim, init_method)

if vocab._no_create_word_length>0:
if vocab.unknown is None: # 创建一个专门的unknown
@@ -226,10 +225,13 @@ class StaticEmbedding(TokenEmbedding):
unknown_idx = vocab.unknown_idx
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
requires_grad=False)
for order, (index, vec) in enumerate(matrix.items()):
for word, index in vocab:
vec = matrix.get(index, None)
if vec is not None:
vectors[order] = vec
words_to_words[index] = order
vectors[index] = vec
words_to_words[index] = index
else:
vectors[index] = vectors[unknown_idx]
self.words_to_words = words_to_words
else:
for index, vec in matrix.items():


+ 2
- 2
fastNLP/io/base_loader.py View File

@@ -144,7 +144,7 @@ class DataBundle:
"""
self.datasets[name] = dataset

def get_dataset(self, name:str):
def get_dataset(self, name:str)->DataSet:
"""
获取名为name的dataset

@@ -153,7 +153,7 @@ class DataBundle:
"""
return self.datasets[name]

def get_vocab(self, field_name:str):
def get_vocab(self, field_name:str)->Vocabulary:
"""
获取field名为field_name对应的vocab



+ 49
- 9
fastNLP/io/file_utils.py View File

@@ -78,6 +78,17 @@ DATASET_DIR = {
"rte": "RTE.zip"
}

PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,
"bert": PRETRAINED_BERT_MODEL_DIR,
"static": PRETRAIN_STATIC_FILES}

# 用于扩展fastNLP的下载
FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt'
FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt',
'bert':'fastnlp_bert_url.txt',
'static': 'fastnlp_static_url.txt'
}


def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path:
"""
@@ -97,7 +108,7 @@ def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path:
:return:
"""
if cache_dir is None:
data_cache = Path(get_default_cache_path())
data_cache = Path(get_cache_path())
else:
data_cache = cache_dir

@@ -146,7 +157,7 @@ def get_filepath(filepath):
raise FileNotFoundError(f"{filepath} is not a valid file or directory.")


def get_default_cache_path():
def get_cache_path():
"""
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。

@@ -188,27 +199,51 @@ def _get_base_url(name):
return URLS[name.lower()]


def _get_embedding_url(type, name):
def _get_embedding_url(embed_type, name):
"""
给定embedding类似和名称,返回下载url

:param str type: 支持static, bert, elmo。即embedding的类型
:param str embed_type: 支持static, bert, elmo。即embedding的类型
:param str name: embedding的名称, 例如en, cn, based等
:return: str, 下载的url地址
"""
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,
"bert": PRETRAINED_BERT_MODEL_DIR,
"static":PRETRAIN_STATIC_FILES}
map = PRETRAIN_MAP.get(type, None)
# 从扩展中寻找下载的url
_filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None)
if _filename:
url = _read_extend_url_file(_filename, name)
if url:
return url
map = PRETRAIN_MAP.get(embed_type, None)
if map:

filename = map.get(name, None)
if filename:
url = _get_base_url('embedding') + filename
return url
raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys())))
else:
raise KeyError(f"There is no {type}. Only supports bert, elmo, static")
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static")

def _read_extend_url_file(filename, name)->str:
"""
filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址

:param str filename: 在默认的路径下寻找file这个文件
:param str name: 需要寻找的资源的名称
:return: str or None
"""
cache_dir = get_cache_path()
filepath = os.path.join(cache_dir, filename)
if os.path.exists(filepath):
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
parts = line.split('\t')
if len(parts) == 2:
if name == parts[0]:
return parts[1]
return None

def _get_dataset_url(name):
"""
@@ -217,6 +252,11 @@ def _get_dataset_url(name):
:param str name: 给定dataset的名称,比如imdb, sst-2等
:return: str
"""
# 从扩展中寻找下载的url
url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name)
if url:
return url

filename = DATASET_DIR.get(name, None)
if filename:
url = _get_base_url('dataset') + filename


+ 2
- 2
fastNLP/io/loader/loader.py View File

@@ -3,7 +3,7 @@ from .. import DataBundle
from ..utils import check_loader_paths
from typing import Union, Dict
import os
from ..file_utils import _get_dataset_url, get_default_cache_path, cached_path
from ..file_utils import _get_dataset_url, get_cache_path, cached_path

class Loader:
def __init__(self):
@@ -66,7 +66,7 @@ class Loader:
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。
"""

default_cache_path = get_default_cache_path()
default_cache_path = get_cache_path()
url = _get_dataset_url(dataset_name)
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset')



+ 1
- 1
fastNLP/io/pipe/conll.py View File

@@ -24,7 +24,7 @@ class _NERPipe(Pipe):
if encoding_type == 'bio':
self.convert_tag = iob2
else:
self.convert_tag = iob2bioes
self.convert_tag = lambda words: iob2bioes(iob2(words))
self.lower = lower
self.target_pad_val = int(target_pad_val)



+ 1
- 1
fastNLP/io/pipe/matching.py View File

@@ -57,7 +57,7 @@ class MatchingBertPipe(Pipe):
dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower()

data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUT(1)],
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)],
[Const.INPUTS(0), Const.INPUTS(1)])

# concat两个words


+ 1
- 1
fastNLP/io/pipe/utils.py View File

@@ -61,7 +61,7 @@ def get_tokenizer(tokenizer:str, lang='en'):
if tokenizer == 'spacy':
import spacy
spacy.prefer_gpu()
if lang!='en':
if lang != 'en':
raise RuntimeError("Spacy only supports en right right.")
en = spacy.load(lang)
tokenizer = lambda x: [w.text for w in en.tokenizer(x)]


Loading…
Cancel
Save