Browse Source

- update dataloader

tags/v0.4.10
yunfan 5 years ago
parent
commit
e206cae45c
1 changed files with 8 additions and 7 deletions
  1. +8
    -7
      fastNLP/io/dataset_loader.py

+ 8
- 7
fastNLP/io/dataset_loader.py View File

@@ -22,6 +22,8 @@ __all__ = [
'SSTLoader', 'SSTLoader',
'PeopleDailyCorpusLoader', 'PeopleDailyCorpusLoader',
'Conll2003Loader', 'Conll2003Loader',
'VocabularyOption',
'EmbeddingOption',
] ]


from nltk.tree import Tree from nltk.tree import Tree
@@ -32,8 +34,8 @@ from .file_reader import _read_csv, _read_json, _read_conll
from typing import Union, Dict, Iterable from typing import Union, Dict, Iterable
import os import os
from ..core.utils import Example from ..core.utils import Example
from ..core import Vocabulary
from ..io import EmbedLoader
from ..core.vocabulary import Vocabulary
from ..io.embed_loader import EmbedLoader
import numpy as np import numpy as np




@@ -448,14 +450,13 @@ class SSTLoader(DataSetLoader):
info = DataInfo(datasets=self.load(paths)) info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name] _train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values() for name in train_ds] if train_ds else info.datasets.values()

src_vocab.from_dataset(_train_ds, field_name=input_name)
tgt_vocab.from_dataset(_train_ds, field_name=target_name)
src_vocab.from_dataset(*_train_ds, field_name=input_name)
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
src_vocab.index_dataset( src_vocab.index_dataset(
info.datasets.values(),
*info.datasets.values(),
field_name=input_name, new_field_name=input_name) field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset( tgt_vocab.index_dataset(
info.datasets.values(),
*info.datasets.values(),
field_name=target_name, new_field_name=target_name) field_name=target_name, new_field_name=target_name)
info.vocabs = { info.vocabs = {
input_name: src_vocab, input_name: src_vocab,


Loading…
Cancel
Save