Browse Source

1.修复MNLILoader中的bug; 2.修复field中的tensor warning

tags/v0.4.10
yh_cc 5 years ago
parent
commit
1756e3ffdf
3 changed files with 6 additions and 6 deletions
  1. +3
    -3
      fastNLP/core/field.py
  2. +2
    -2
      fastNLP/core/vocabulary.py
  3. +1
    -1
      fastNLP/io/loader/matching.py

+ 3
- 3
fastNLP/core/field.py View File

@@ -595,7 +595,7 @@ class AutoPadder(Padder):
max_len = max(map(len, contents))
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
tensor[i, :len(content_i)] = torch.tensor(content_i)
tensor[i, :len(content_i)] = content_i.clone().detach()
elif dim == 2:
max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
@@ -604,7 +604,7 @@ class AutoPadder(Padder):
dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
for j, content_ii in enumerate(content_i):
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii)
tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
else:
shapes = set([np.shape(content_i) for content_i in contents])
if len(shapes) > 1:
@@ -615,7 +615,7 @@ class AutoPadder(Padder):
tensor = torch.full([len(contents)] + list(shape), fill_value=self.pad_val,
dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype)
tensor[i] = content_i.clone().detach().to(field_ele_dtype)
else:
raise RuntimeError(
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")


+ 2
- 2
fastNLP/core/vocabulary.py View File

@@ -253,7 +253,7 @@ class Vocabulary(object):
if self.unknown is not None:
return self.word2idx[self.unknown]
else:
raise ValueError("word {} not in vocabulary".format(w))
raise ValueError("word `{}` not in vocabulary".format(w))
@_check_build_vocab
def index_dataset(self, *datasets, field_name, new_field_name=None):
@@ -360,7 +360,7 @@ class Vocabulary(object):
try:
dataset.apply(construct_vocab)
except BaseException as e:
log("When processing the `{}` dataset, the following error occurred:".format(idx))
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx))
raise e
else:
raise TypeError("Only DataSet type is allowed.")


+ 1
- 1
fastNLP/io/loader/matching.py View File

@@ -41,7 +41,7 @@ class MNLILoader(Loader):
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test.tsv"):
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'):
warnings.warn("RTE's test file has no target.")
for line in f:
line = line.strip()


Loading…
Cancel
Save