|
@@ -22,6 +22,8 @@ __all__ = [ |
|
|
'SSTLoader', |
|
|
'SSTLoader', |
|
|
'PeopleDailyCorpusLoader', |
|
|
'PeopleDailyCorpusLoader', |
|
|
'Conll2003Loader', |
|
|
'Conll2003Loader', |
|
|
|
|
|
'VocabularyOption', |
|
|
|
|
|
'EmbeddingOption', |
|
|
] |
|
|
] |
|
|
|
|
|
|
|
|
from nltk.tree import Tree |
|
|
from nltk.tree import Tree |
|
@@ -32,8 +34,8 @@ from .file_reader import _read_csv, _read_json, _read_conll |
|
|
from typing import Union, Dict, Iterable |
|
|
from typing import Union, Dict, Iterable |
|
|
import os |
|
|
import os |
|
|
from ..core.utils import Example |
|
|
from ..core.utils import Example |
|
|
from ..core import Vocabulary |
|
|
|
|
|
from ..io import EmbedLoader |
|
|
|
|
|
|
|
|
from ..core.vocabulary import Vocabulary |
|
|
|
|
|
from ..io.embed_loader import EmbedLoader |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -448,14 +450,13 @@ class SSTLoader(DataSetLoader): |
|
|
info = DataInfo(datasets=self.load(paths)) |
|
|
info = DataInfo(datasets=self.load(paths)) |
|
|
_train_ds = [info.datasets[name] |
|
|
_train_ds = [info.datasets[name] |
|
|
for name in train_ds] if train_ds else info.datasets.values() |
|
|
for name in train_ds] if train_ds else info.datasets.values() |
|
|
|
|
|
|
|
|
src_vocab.from_dataset(_train_ds, field_name=input_name) |
|
|
|
|
|
tgt_vocab.from_dataset(_train_ds, field_name=target_name) |
|
|
|
|
|
|
|
|
src_vocab.from_dataset(*_train_ds, field_name=input_name) |
|
|
|
|
|
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) |
|
|
src_vocab.index_dataset( |
|
|
src_vocab.index_dataset( |
|
|
info.datasets.values(), |
|
|
|
|
|
|
|
|
*info.datasets.values(), |
|
|
field_name=input_name, new_field_name=input_name) |
|
|
field_name=input_name, new_field_name=input_name) |
|
|
tgt_vocab.index_dataset( |
|
|
tgt_vocab.index_dataset( |
|
|
info.datasets.values(), |
|
|
|
|
|
|
|
|
*info.datasets.values(), |
|
|
field_name=target_name, new_field_name=target_name) |
|
|
field_name=target_name, new_field_name=target_name) |
|
|
info.vocabs = { |
|
|
info.vocabs = { |
|
|
input_name: src_vocab, |
|
|
input_name: src_vocab, |
|
|