Browse Source

- update dataloader

tags/v0.4.10
yunfan 5 years ago
parent
commit
e206cae45c
1 changed files with 8 additions and 7 deletions
  1. +8
    -7
      fastNLP/io/dataset_loader.py

+ 8
- 7
fastNLP/io/dataset_loader.py View File

@@ -22,6 +22,8 @@ __all__ = [
'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
'VocabularyOption',
'EmbeddingOption',
]

from nltk.tree import Tree
@@ -32,8 +34,8 @@ from .file_reader import _read_csv, _read_json, _read_conll
from typing import Union, Dict, Iterable
import os
from ..core.utils import Example
from ..core import Vocabulary
from ..io import EmbedLoader
from ..core.vocabulary import Vocabulary
from ..io.embed_loader import EmbedLoader
import numpy as np


@@ -448,14 +450,13 @@ class SSTLoader(DataSetLoader):
info = DataInfo(datasets=self.load(paths))
_train_ds = [info.datasets[name]
for name in train_ds] if train_ds else info.datasets.values()

src_vocab.from_dataset(_train_ds, field_name=input_name)
tgt_vocab.from_dataset(_train_ds, field_name=target_name)
src_vocab.from_dataset(*_train_ds, field_name=input_name)
tgt_vocab.from_dataset(*_train_ds, field_name=target_name)
src_vocab.index_dataset(
info.datasets.values(),
*info.datasets.values(),
field_name=input_name, new_field_name=input_name)
tgt_vocab.index_dataset(
info.datasets.values(),
*info.datasets.values(),
field_name=target_name, new_field_name=target_name)
info.vocabs = {
input_name: src_vocab,


Loading…
Cancel
Save