- improve code comments - BaseLoader & its subclasses does not need a data name any more - update file tree - add setup.pytags/v0.1.0
@@ -65,7 +65,7 @@ data_dir = 'save/' # directory to save data and model | |||||
train_path = './data_for_tests/text_classify.txt' # training set file | train_path = './data_for_tests/text_classify.txt' # training set file | ||||
# load dataset | # load dataset | ||||
ds_loader = ClassDatasetLoader("train", train_path) | |||||
ds_loader = ClassDatasetLoader(train_path) | |||||
data = ds_loader.load() | data = ds_loader.load() | ||||
# pre-process dataset | # pre-process dataset | ||||
@@ -135,14 +135,15 @@ pip3 install torch torchvision | |||||
``` | ``` | ||||
FastNLP | FastNLP | ||||
├── docs | ├── docs | ||||
│ └── quick_tutorial.md | |||||
├── fastNLP | ├── fastNLP | ||||
│ ├── action | |||||
│ ├── core | |||||
│ │ ├── action.py | │ │ ├── action.py | ||||
│ │ ├── inference.py | |||||
│ │ ├── __init__.py | │ │ ├── __init__.py | ||||
│ │ ├── loss.py | |||||
│ │ ├── metrics.py | │ │ ├── metrics.py | ||||
│ │ ├── optimizer.py | │ │ ├── optimizer.py | ||||
│ │ ├── predictor.py | |||||
│ │ ├── preprocess.py | |||||
│ │ ├── README.md | │ │ ├── README.md | ||||
│ │ ├── tester.py | │ │ ├── tester.py | ||||
│ │ └── trainer.py | │ │ └── trainer.py | ||||
@@ -154,71 +155,28 @@ FastNLP | |||||
│ │ ├── dataset_loader.py | │ │ ├── dataset_loader.py | ||||
│ │ ├── embed_loader.py | │ │ ├── embed_loader.py | ||||
│ │ ├── __init__.py | │ │ ├── __init__.py | ||||
│ │ ├── model_loader.py | |||||
│ │ └── preprocess.py | |||||
│ │ └── model_loader.py | |||||
│ ├── models | │ ├── models | ||||
│ │ ├── base_model.py | |||||
│ │ ├── char_language_model.py | |||||
│ │ ├── cnn_text_classification.py | |||||
│ │ ├── __init__.py | |||||
│ │ └── sequence_modeling.py | |||||
│ ├── modules | │ ├── modules | ||||
│ │ ├── aggregation | │ │ ├── aggregation | ||||
│ │ │ ├── attention.py | |||||
│ │ │ ├── avg_pool.py | |||||
│ │ │ ├── __init__.py | |||||
│ │ │ ├── kmax_pool.py | |||||
│ │ │ ├── max_pool.py | |||||
│ │ │ └── self_attention.py | |||||
│ │ ├── decoder | │ │ ├── decoder | ||||
│ │ │ ├── CRF.py | |||||
│ │ │ └── __init__.py | |||||
│ │ ├── encoder | │ │ ├── encoder | ||||
│ │ │ ├── char_embedding.py | |||||
│ │ │ ├── conv_maxpool.py | |||||
│ │ │ ├── conv.py | |||||
│ │ │ ├── embedding.py | |||||
│ │ │ ├── __init__.py | |||||
│ │ │ ├── linear.py | |||||
│ │ │ ├── lstm.py | |||||
│ │ │ ├── masked_rnn.py | |||||
│ │ │ └── variational_rnn.py | |||||
│ │ ├── __init__.py | │ │ ├── __init__.py | ||||
│ │ ├── interaction | │ │ ├── interaction | ||||
│ │ │ └── __init__.py | |||||
│ │ ├── other_modules.py | │ │ ├── other_modules.py | ||||
│ │ └── utils.py | │ │ └── utils.py | ||||
│ └── saver | │ └── saver | ||||
│ ├── base_saver.py | |||||
│ ├── __init__.py | |||||
│ ├── logger.py | |||||
│ └── model_saver.py | |||||
├── LICENSE | ├── LICENSE | ||||
├── README.md | ├── README.md | ||||
├── reproduction | ├── reproduction | ||||
│ ├── Char-aware_NLM | |||||
│ │ | |||||
│ ├── CNN-sentence_classification | |||||
│ │ | |||||
│ ├── HAN-document_classification | |||||
│ │ | |||||
│ └── LSTM+self_attention_sentiment_analysis | |||||
| | |||||
├── requirements.txt | ├── requirements.txt | ||||
├── setup.py | ├── setup.py | ||||
└── test | └── test | ||||
├── core | |||||
├── data_for_tests | ├── data_for_tests | ||||
│ ├── charlm.txt | |||||
│ ├── config | |||||
│ ├── cws_test | |||||
│ ├── cws_train | |||||
│ ├── people_infer.txt | |||||
│ └── people.txt | |||||
├── test_charlm.py | |||||
├── test_cws.py | |||||
├── test_fastNLP.py | |||||
├── test_loader.py | |||||
├── test_seq_labeling.py | |||||
├── test_tester.py | |||||
└── test_trainer.py | |||||
├── __init__.py | |||||
├── loader | |||||
├── modules | |||||
└── readme_example.py | |||||
``` | ``` |
@@ -9,7 +9,7 @@ class Loss(object): | |||||
def __init__(self, args): | def __init__(self, args): | ||||
if args is None: | if args is None: | ||||
# this is useful when | |||||
# this is useful when Trainer.__init__ performs type check | |||||
self._loss = None | self._loss = None | ||||
elif isinstance(args, str): | elif isinstance(args, str): | ||||
self._loss = self._borrow_from_pytorch(args) | self._loss = self._borrow_from_pytorch(args) | ||||
@@ -70,7 +70,7 @@ class Predictor(object): | |||||
def predict(self, network, data): | def predict(self, network, data): | ||||
"""Perform inference using the trained model. | """Perform inference using the trained model. | ||||
:param network: a PyTorch model | |||||
:param network: a PyTorch model (cpu) | |||||
:param data: list of list of strings | :param data: list of list of strings | ||||
:return: list of list of strings, [num_examples, tag_seq_length] | :return: list of list of strings, [num_examples, tag_seq_length] | ||||
""" | """ | ||||
@@ -38,7 +38,7 @@ class BaseTester(object): | |||||
Obviously, "required_args" is the subset of "default_args". | Obviously, "required_args" is the subset of "default_args". | ||||
The value in "default_args" to the keys in "required_args" is simply for type check. | The value in "default_args" to the keys in "required_args" is simply for type check. | ||||
""" | """ | ||||
# TODO: required arguments | |||||
# add required arguments here | |||||
required_args = {} | required_args = {} | ||||
for req_key in required_args: | for req_key in required_args: | ||||
@@ -56,7 +56,7 @@ class BaseTester(object): | |||||
logger.error(msg) | logger.error(msg) | ||||
raise ValueError(msg) | raise ValueError(msg) | ||||
else: | else: | ||||
# BeseTester doesn't care about extra arguments | |||||
# BaseTester doesn't care about extra arguments | |||||
pass | pass | ||||
print(default_args) | print(default_args) | ||||
@@ -69,8 +69,8 @@ class BaseTester(object): | |||||
self.print_every_step = default_args["print_every_step"] | self.print_every_step = default_args["print_every_step"] | ||||
self._model = None | self._model = None | ||||
self.eval_history = [] | |||||
self.batch_output = [] | |||||
self.eval_history = [] # evaluation results of all batches | |||||
self.batch_output = [] # outputs of all batches | |||||
def test(self, network, dev_data): | def test(self, network, dev_data): | ||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
@@ -83,7 +83,7 @@ class BaseTester(object): | |||||
self.eval_history.clear() | self.eval_history.clear() | ||||
self.batch_output.clear() | self.batch_output.clear() | ||||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) | |||||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=False)) | |||||
step = 0 | step = 0 | ||||
for batch_x, batch_y in self.make_batch(iterator): | for batch_x, batch_y in self.make_batch(iterator): | ||||
@@ -99,7 +99,7 @@ class BaseTester(object): | |||||
print_output = "[test step {}] {}".format(step, eval_results) | print_output = "[test step {}] {}".format(step, eval_results) | ||||
logger.info(print_output) | logger.info(print_output) | ||||
if self.print_every_step > 0 and step % self.print_every_step == 0: | if self.print_every_step > 0 and step % self.print_every_step == 0: | ||||
print(print_output) | |||||
print(self.make_eval_output(prediction, eval_results)) | |||||
step += 1 | step += 1 | ||||
def mode(self, model, test): | def mode(self, model, test): | ||||
@@ -115,16 +115,28 @@ class BaseTester(object): | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def evaluate(self, predict, truth): | def evaluate(self, predict, truth): | ||||
"""Compute evaluation metrics for the model. """ | |||||
"""Compute evaluation metrics. | |||||
:param predict: Tensor | |||||
:param truth: Tensor | |||||
:return eval_results: can be anything. It will be stored in self.eval_history | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
@property | @property | ||||
def metrics(self): | def metrics(self): | ||||
"""Return a list of metrics. """ | |||||
"""Compute and return metrics. | |||||
Use self.eval_history to compute metrics over the whole dev set. | |||||
Please refer to metrics.py for common metric functions. | |||||
:return : variable number of outputs | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def show_metrics(self): | def show_metrics(self): | ||||
"""This is called by Trainer to print evaluation results on dev set during training. | |||||
"""Customize evaluation outputs in Trainer. | |||||
Called by Trainer to print evaluation results on dev set during training. | |||||
Use self.metrics to fetch available metrics. | |||||
:return print_str: str | :return print_str: str | ||||
""" | """ | ||||
@@ -133,6 +145,14 @@ class BaseTester(object): | |||||
def make_batch(self, iterator): | def make_batch(self, iterator): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def make_eval_output(self, predictions, eval_results): | |||||
"""Customize Tester outputs. | |||||
:param predictions: Tensor | |||||
:param eval_results: Tensor | |||||
:return: str, to be printed. | |||||
""" | |||||
raise NotImplementedError | |||||
class SeqLabelTester(BaseTester): | class SeqLabelTester(BaseTester): | ||||
""" | """ | ||||
@@ -211,7 +231,7 @@ class ClassificationTester(BaseTester): | |||||
def __init__(self, **test_args): | def __init__(self, **test_args): | ||||
""" | """ | ||||
:param test_args: a dict-like object that has __getitem__ method, \ | |||||
:param test_args: a dict-like object that has __getitem__ method. | |||||
can be accessed by "test_args["key_str"]" | can be accessed by "test_args["key_str"]" | ||||
""" | """ | ||||
super(ClassificationTester, self).__init__(**test_args) | super(ClassificationTester, self).__init__(**test_args) | ||||
@@ -1,6 +1,4 @@ | |||||
import _pickle | |||||
import copy | import copy | ||||
import os | |||||
import time | import time | ||||
from datetime import timedelta | from datetime import timedelta | ||||
@@ -15,16 +13,12 @@ from fastNLP.modules import utils | |||||
from fastNLP.saver.logger import create_logger | from fastNLP.saver.logger import create_logger | ||||
from fastNLP.saver.model_saver import ModelSaver | from fastNLP.saver.model_saver import ModelSaver | ||||
DEFAULT_QUEUE_SIZE = 300 | |||||
logger = create_logger(__name__, "./train_test.log") | logger = create_logger(__name__, "./train_test.log") | ||||
class BaseTrainer(object): | class BaseTrainer(object): | ||||
"""Operations to train a model, including data loading, SGD, and validation. | |||||
"""Operations of training a model, including data loading, gradient descent, and validation. | |||||
Subclasses must implement the following abstract methods: | |||||
- grad_backward | |||||
- get_loss | |||||
""" | """ | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
@@ -47,7 +41,7 @@ class BaseTrainer(object): | |||||
""" | """ | ||||
default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", | ||||
"save_best_dev": True, "model_name": "default_model_name.pkl", "print_every_step": 1, | "save_best_dev": True, "model_name": "default_model_name.pkl", "print_every_step": 1, | ||||
"loss": Loss(None), | |||||
"loss": Loss(None), # used to pass type check | |||||
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) | ||||
} | } | ||||
""" | """ | ||||
@@ -56,7 +50,7 @@ class BaseTrainer(object): | |||||
Obviously, "required_args" is the subset of "default_args". | Obviously, "required_args" is the subset of "default_args". | ||||
The value in "default_args" to the keys in "required_args" is simply for type check. | The value in "default_args" to the keys in "required_args" is simply for type check. | ||||
""" | """ | ||||
# TODO: required arguments | |||||
# add required arguments here | |||||
required_args = {} | required_args = {} | ||||
for req_key in required_args: | for req_key in required_args: | ||||
@@ -198,21 +192,6 @@ class BaseTrainer(object): | |||||
network_copy = copy.deepcopy(network) | network_copy = copy.deepcopy(network) | ||||
self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | ||||
def load_train_data(self, pickle_path): | |||||
""" | |||||
For task-specific processing. | |||||
:param pickle_path: | |||||
:return data_train | |||||
""" | |||||
file_path = os.path.join(pickle_path, "data_train.pkl") | |||||
if os.path.exists(file_path): | |||||
with open(file_path, 'rb') as f: | |||||
data = _pickle.load(f) | |||||
else: | |||||
logger.error("cannot find training data {}. invalid input path for training data.".format(file_path)) | |||||
raise RuntimeError("cannot find training data {}".format(file_path)) | |||||
return data | |||||
def make_batch(self, iterator): | def make_batch(self, iterator): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -1,9 +1,8 @@ | |||||
class BaseLoader(object): | class BaseLoader(object): | ||||
"""docstring for BaseLoader""" | """docstring for BaseLoader""" | ||||
def __init__(self, data_name, data_path): | |||||
def __init__(self, data_path): | |||||
super(BaseLoader, self).__init__() | super(BaseLoader, self).__init__() | ||||
self.data_name = data_name | |||||
self.data_path = data_path | self.data_path = data_path | ||||
def load(self): | def load(self): | ||||
@@ -25,8 +24,8 @@ class ToyLoader0(BaseLoader): | |||||
For charLM | For charLM | ||||
""" | """ | ||||
def __init__(self, name, path): | |||||
super(ToyLoader0, self).__init__(name, path) | |||||
def __init__(self, data_path): | |||||
super(ToyLoader0, self).__init__(data_path) | |||||
def load(self): | def load(self): | ||||
with open(self.data_path, 'r') as f: | with open(self.data_path, 'r') as f: | ||||
@@ -9,7 +9,7 @@ class ConfigLoader(BaseLoader): | |||||
"""loader for configuration files""" | """loader for configuration files""" | ||||
def __int__(self, data_name, data_path): | def __int__(self, data_name, data_path): | ||||
super(ConfigLoader, self).__init__(data_name, data_path) | |||||
super(ConfigLoader, self).__init__(data_path) | |||||
self.config = self.parse(super(ConfigLoader, self).load()) | self.config = self.parse(super(ConfigLoader, self).load()) | ||||
@staticmethod | @staticmethod | ||||
@@ -100,7 +100,7 @@ class ConfigSection(object): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
config = ConfigLoader('configLoader', 'there is no data') | |||||
config = ConfigLoader('there is no data') | |||||
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()} | ||||
""" | """ | ||||
@@ -6,8 +6,8 @@ from fastNLP.loader.base_loader import BaseLoader | |||||
class DatasetLoader(BaseLoader): | class DatasetLoader(BaseLoader): | ||||
""""loader for data sets""" | """"loader for data sets""" | ||||
def __init__(self, data_name, data_path): | |||||
super(DatasetLoader, self).__init__(data_name, data_path) | |||||
def __init__(self, data_path): | |||||
super(DatasetLoader, self).__init__(data_path) | |||||
class POSDatasetLoader(DatasetLoader): | class POSDatasetLoader(DatasetLoader): | ||||
@@ -31,8 +31,8 @@ class POSDatasetLoader(DatasetLoader): | |||||
to label5. | to label5. | ||||
""" | """ | ||||
def __init__(self, data_name, data_path): | |||||
super(POSDatasetLoader, self).__init__(data_name, data_path) | |||||
def __init__(self, data_path): | |||||
super(POSDatasetLoader, self).__init__(data_path) | |||||
def load(self): | def load(self): | ||||
assert os.path.exists(self.data_path) | assert os.path.exists(self.data_path) | ||||
@@ -84,8 +84,8 @@ class TokenizeDatasetLoader(DatasetLoader): | |||||
Data set loader for tokenization data sets | Data set loader for tokenization data sets | ||||
""" | """ | ||||
def __init__(self, data_name, data_path): | |||||
super(TokenizeDatasetLoader, self).__init__(data_name, data_path) | |||||
def __init__(self, data_path): | |||||
super(TokenizeDatasetLoader, self).__init__(data_path) | |||||
def load_pku(self, max_seq_len=32): | def load_pku(self, max_seq_len=32): | ||||
""" | """ | ||||
@@ -138,8 +138,8 @@ class TokenizeDatasetLoader(DatasetLoader): | |||||
class ClassDatasetLoader(DatasetLoader): | class ClassDatasetLoader(DatasetLoader): | ||||
"""Loader for classification data sets""" | """Loader for classification data sets""" | ||||
def __init__(self, data_name, data_path): | |||||
super(ClassDatasetLoader, self).__init__(data_name, data_path) | |||||
def __init__(self, data_path): | |||||
super(ClassDatasetLoader, self).__init__(data_path) | |||||
def load(self): | def load(self): | ||||
assert os.path.exists(self.data_path) | assert os.path.exists(self.data_path) | ||||
@@ -177,7 +177,7 @@ class ConllLoader(DatasetLoader): | |||||
:param str data_name: the name of the conll data set | :param str data_name: the name of the conll data set | ||||
:param str data_path: the path to the conll data set | :param str data_path: the path to the conll data set | ||||
""" | """ | ||||
super(ConllLoader, self).__init__(data_name, data_path) | |||||
super(ConllLoader, self).__init__(data_path) | |||||
self.data_set = self.parse(self.load()) | self.data_set = self.parse(self.load()) | ||||
def load(self): | def load(self): | ||||
@@ -209,8 +209,8 @@ class ConllLoader(DatasetLoader): | |||||
class LMDatasetLoader(DatasetLoader): | class LMDatasetLoader(DatasetLoader): | ||||
def __init__(self, data_name, data_path): | |||||
super(LMDatasetLoader, self).__init__(data_name, data_path) | |||||
def __init__(self, data_path): | |||||
super(LMDatasetLoader, self).__init__(data_path) | |||||
def load(self): | def load(self): | ||||
if not os.path.exists(self.data_path): | if not os.path.exists(self.data_path): | ||||
@@ -226,7 +226,7 @@ class PeopleDailyCorpusLoader(DatasetLoader): | |||||
""" | """ | ||||
def __init__(self, data_path): | def __init__(self, data_path): | ||||
super(PeopleDailyCorpusLoader, self).__init__("people_daily_corpus", data_path) | |||||
super(PeopleDailyCorpusLoader, self).__init__(data_path) | |||||
def load(self): | def load(self): | ||||
with open(self.data_path, "r", encoding="utf-8") as f: | with open(self.data_path, "r", encoding="utf-8") as f: | ||||
@@ -270,7 +270,7 @@ class PeopleDailyCorpusLoader(DatasetLoader): | |||||
return pos_tag_examples, ner_examples | return pos_tag_examples, ner_examples | ||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
loader = PeopleDailyCorpusLoader("/home/zyfeng/data/CWS_POS_TAG_NER_people_daily.txt") | |||||
loader = PeopleDailyCorpusLoader("./") | |||||
pos, ner = loader.load() | pos, ner = loader.load() | ||||
print(pos[:10]) | print(pos[:10]) | ||||
print(ner[:10]) | print(ner[:10]) |
@@ -1,8 +1,50 @@ | |||||
import _pickle | |||||
import os | |||||
import numpy as np | |||||
from fastNLP.loader.base_loader import BaseLoader | from fastNLP.loader.base_loader import BaseLoader | ||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
"""docstring for EmbedLoader""" | """docstring for EmbedLoader""" | ||||
def __init__(self, data_name, data_path): | |||||
super(EmbedLoader, self).__init__(data_name, data_path) | |||||
def __init__(self, data_path): | |||||
super(EmbedLoader, self).__init__(data_path) | |||||
@staticmethod | |||||
def load_embedding(emb_dim, emb_file, word_dict, emb_pkl): | |||||
"""Load the pre-trained embedding and combine with the given dictionary. | |||||
:param emb_file: str, the pre-trained embedding. | |||||
The embedding file should have the following format: | |||||
Each line is a word embedding, where a word string is followed by multiple floats. | |||||
Floats are separated by space. The word and the first float are separated by space. | |||||
:param word_dict: dict, a mapping from word to index. | |||||
:param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding. | |||||
:param emb_pkl: str, the embedding pickle file. | |||||
:return embedding_np: numpy array of shape (len(word_dict), emb_dim) | |||||
TODO: fragile code | |||||
""" | |||||
# If the embedding pickle exists, load it and return. | |||||
if os.path.exists(emb_pkl): | |||||
with open(emb_pkl, "rb") as f: | |||||
embedding_np = _pickle.load(f) | |||||
return embedding_np | |||||
# Otherwise, load the pre-trained embedding. | |||||
with open(emb_file, "r", encoding="utf-8") as f: | |||||
# begin with a random embedding | |||||
embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim)) | |||||
for line in f: | |||||
line = line.strip().split() | |||||
if len(line) != emb_dim + 1: | |||||
# skip this line if two embedding dimension not match | |||||
continue | |||||
if line[0] in word_dict: | |||||
# find the word and replace its embedding with a pre-trained one | |||||
embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]] | |||||
# save and return the result | |||||
with open(emb_pkl, "wb") as f: | |||||
_pickle.dump(embedding_np, f) | |||||
return embedding_np |
@@ -8,8 +8,8 @@ class ModelLoader(BaseLoader): | |||||
Loader for models. | Loader for models. | ||||
""" | """ | ||||
def __init__(self, data_name, data_path): | |||||
super(ModelLoader, self).__init__(data_name, data_path) | |||||
def __init__(self, data_path): | |||||
super(ModelLoader, self).__init__(data_path) | |||||
@staticmethod | @staticmethod | ||||
def load_pytorch(empty_model, model_path): | def load_pytorch(empty_model, model_path): | ||||
@@ -27,7 +27,7 @@ data_infer_path = os.path.join(datadir, "infer.utf8") | |||||
def infer(): | def infer(): | ||||
# Config Loader | # Config Loader | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args}) | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
@@ -47,7 +47,7 @@ def infer(): | |||||
raise | raise | ||||
# Data Loader | # Data Loader | ||||
raw_data_loader = BaseLoader(data_name, data_infer_path) | |||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
print('data loaded') | print('data loaded') | ||||
@@ -63,10 +63,10 @@ def train(): | |||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("good_name", "good_path").load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
ConfigLoader("good_path").load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
# Data Loader | # Data Loader | ||||
loader = TokenizeDatasetLoader(data_name, cws_data_path) | |||||
loader = TokenizeDatasetLoader(cws_data_path) | |||||
train_data = loader.load_pku() | train_data = loader.load_pku() | ||||
# Preprocessor | # Preprocessor | ||||
@@ -100,7 +100,7 @@ def train(): | |||||
def test(): | def test(): | ||||
# Config Loader | # Config Loader | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args}) | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
@@ -28,7 +28,7 @@ data_infer_path = os.path.join(datadir, "infer.utf8") | |||||
def infer(): | def infer(): | ||||
# Config Loader | # Config Loader | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args}) | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
@@ -47,7 +47,7 @@ def infer(): | |||||
raise | raise | ||||
# Data Loader | # Data Loader | ||||
raw_data_loader = BaseLoader(data_name, data_infer_path) | |||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
print('data loaded') | print('data loaded') | ||||
@@ -63,7 +63,7 @@ def train(): | |||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("good_name", "good_path").load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
ConfigLoader("good_name").load_config(cfgfile, {"train": train_args, "test": test_args}) | |||||
# Data Loader | # Data Loader | ||||
loader = PeopleDailyCorpusLoader(pos_tag_data_path) | loader = PeopleDailyCorpusLoader(pos_tag_data_path) | ||||
@@ -100,7 +100,7 @@ def train(): | |||||
def test(): | def test(): | ||||
# Config Loader | # Config Loader | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args}) | |||||
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
@@ -0,0 +1,24 @@ | |||||
#!/usr/bin/env python | |||||
# coding=utf-8 | |||||
from setuptools import setup, find_packages | |||||
with open('README.md') as f: | |||||
readme = f.read() | |||||
with open('LICENSE') as f: | |||||
license = f.read() | |||||
with open('requirements.txt') as f: | |||||
reqs = f.read() | |||||
setup( | |||||
name='fastNLP', | |||||
version='1.0', | |||||
description=('fudan fastNLP '), | |||||
long_description=readme, | |||||
license=license, | |||||
author='fudanNLP', | |||||
python_requires='>=3.5', | |||||
packages=find_packages(), | |||||
install_requires=reqs.strip().split('\n'), | |||||
) |
@@ -1,13 +1,12 @@ | |||||
import os | |||||
import configparser | import configparser | ||||
import json | import json | ||||
import os | |||||
import unittest | import unittest | ||||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | ||||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader | from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, POSDatasetLoader, LMDatasetLoader | ||||
class TestConfigLoader(unittest.TestCase): | class TestConfigLoader(unittest.TestCase): | ||||
def test_case_ConfigLoader(self): | def test_case_ConfigLoader(self): | ||||
@@ -33,8 +32,8 @@ class TestConfigLoader(unittest.TestCase): | |||||
return dict | return dict | ||||
test_arg = ConfigSection() | test_arg = ConfigSection() | ||||
ConfigLoader("config", "").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
#ConfigLoader("config", "").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", | |||||
ConfigLoader("config").load_config(os.path.join("./test/loader", "config"), {"test": test_arg}) | |||||
# ConfigLoader("config").load_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", | |||||
# {"test": test_arg}) | # {"test": test_arg}) | ||||
#dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") | #dict = read_section_from_config("/home/ygxu/github/fastNLP_testing/fastNLP/test/loader/config", "test") | ||||
@@ -58,18 +57,18 @@ class TestConfigLoader(unittest.TestCase): | |||||
class TestDatasetLoader(unittest.TestCase): | class TestDatasetLoader(unittest.TestCase): | ||||
def test_case_TokenizeDatasetLoader(self): | def test_case_TokenizeDatasetLoader(self): | ||||
loader = TokenizeDatasetLoader("cws_pku_utf_8", "./test/data_for_tests/cws_pku_utf_8") | |||||
loader = TokenizeDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||||
data = loader.load_pku(max_seq_len=32) | data = loader.load_pku(max_seq_len=32) | ||||
print("pass TokenizeDatasetLoader test!") | print("pass TokenizeDatasetLoader test!") | ||||
def test_case_POSDatasetLoader(self): | def test_case_POSDatasetLoader(self): | ||||
loader = POSDatasetLoader("people", "./test/data_for_tests/people.txt") | |||||
loader = POSDatasetLoader("./test/data_for_tests/people.txt") | |||||
data = loader.load() | data = loader.load() | ||||
datas = loader.load_lines() | datas = loader.load_lines() | ||||
print("pass POSDatasetLoader test!") | print("pass POSDatasetLoader test!") | ||||
def test_case_LMDatasetLoader(self): | def test_case_LMDatasetLoader(self): | ||||
loader = LMDatasetLoader("cws_pku_utf_8", "./test/data_for_tests/cws_pku_utf_8") | |||||
loader = LMDatasetLoader("./test/data_for_tests/cws_pku_utf_8") | |||||
data = loader.load() | data = loader.load() | ||||
datas = loader.load_lines() | datas = loader.load_lines() | ||||
print("pass TokenizeDatasetLoader test!") | print("pass TokenizeDatasetLoader test!") |
@@ -1,138 +0,0 @@ | |||||
import _pickle | |||||
import os | |||||
import numpy as np | |||||
import torch | |||||
from fastNLP.core.preprocess import SeqLabelPreprocess | |||||
from fastNLP.core.tester import SeqLabelTester | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
class MyNERTrainer(SeqLabelTrainer): | |||||
def __init__(self, train_args): | |||||
super(MyNERTrainer, self).__init__(train_args) | |||||
self.scheduler = None | |||||
def define_optimizer(self): | |||||
""" | |||||
override | |||||
:return: | |||||
""" | |||||
self.optimizer = torch.optim.Adam(self._model.parameters(), lr=0.001) | |||||
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5) | |||||
def update(self): | |||||
""" | |||||
override | |||||
:return: | |||||
""" | |||||
self.optimizer.step() | |||||
self.scheduler.step() | |||||
def _create_validator(self, valid_args): | |||||
return MyNERTester(valid_args) | |||||
def best_eval_result(self, validator): | |||||
accuracy = validator.metrics() | |||||
if accuracy > self.best_accuracy: | |||||
self.best_accuracy = accuracy | |||||
return True | |||||
else: | |||||
return False | |||||
class MyNERTester(SeqLabelTester): | |||||
def __init__(self, test_args): | |||||
super(MyNERTester, self).__init__(test_args) | |||||
def _evaluate(self, prediction, batch_y, seq_len): | |||||
""" | |||||
:param prediction: [batch_size, seq_len, num_classes] | |||||
:param batch_y: [batch_size, seq_len] | |||||
:param seq_len: [batch_size] | |||||
:return: | |||||
""" | |||||
summ = 0 | |||||
correct = 0 | |||||
_, indices = torch.max(prediction, 2) | |||||
for p, y, l in zip(indices, batch_y, seq_len): | |||||
summ += l | |||||
correct += np.sum(p[:l].cpu().numpy() == y[:l].cpu().numpy()) | |||||
return float(correct / summ) | |||||
def evaluate(self, predict, truth): | |||||
return self._evaluate(predict, truth, self.seq_len) | |||||
def metrics(self): | |||||
return np.mean(self.eval_history) | |||||
def show_metrics(self): | |||||
return "dev accuracy={:.2f}".format(float(self.metrics())) | |||||
def embedding_process(emb_file, word_dict, emb_dim, emb_pkl): | |||||
if os.path.exists(emb_pkl): | |||||
with open(emb_pkl, "rb") as f: | |||||
embedding_np = _pickle.load(f) | |||||
return embedding_np | |||||
with open(emb_file, "r", encoding="utf-8") as f: | |||||
embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim)) | |||||
for line in f: | |||||
line = line.strip().split() | |||||
if len(line) != emb_dim + 1: | |||||
continue | |||||
if line[0] in word_dict: | |||||
embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]] | |||||
with open(emb_pkl, "wb") as f: | |||||
_pickle.dump(embedding_np, f) | |||||
return embedding_np | |||||
def data_load(data_file): | |||||
with open(data_file, "r", encoding="utf-8") as f: | |||||
all_data = [] | |||||
sent = [] | |||||
label = [] | |||||
for line in f: | |||||
line = line.strip().split() | |||||
if not len(line) <= 1: | |||||
sent.append(line[0]) | |||||
label.append(line[1]) | |||||
else: | |||||
all_data.append([sent, label]) | |||||
sent = [] | |||||
label = [] | |||||
return all_data | |||||
data_path = "data_for_tests/people.txt" | |||||
pick_path = "data_for_tests/" | |||||
emb_path = "data_for_tests/emb50.txt" | |||||
save_path = "data_for_tests/" | |||||
if __name__ == "__main__": | |||||
data = data_load(data_path) | |||||
preprocess = SeqLabelPreprocess() | |||||
data_train, data_dev = preprocess.run(data, pickle_path=pick_path, train_dev_split=0.3) | |||||
# emb = embedding_process(emb_path, p.word2index, 50, os.path.join(pick_path, "embedding.pkl")) | |||||
emb = None | |||||
args = {"epochs": 20, | |||||
"batch_size": 1, | |||||
"pickle_path": pick_path, | |||||
"validate": True, | |||||
"save_best_dev": True, | |||||
"model_saved_path": save_path, | |||||
"use_cuda": True, | |||||
"vocab_size": preprocess.vocab_size, | |||||
"num_classes": preprocess.num_classes, | |||||
"word_emb_dim": 50, | |||||
"rnn_hidden_units": 100 | |||||
} | |||||
# emb = torch.Tensor(emb).float().cuda() | |||||
networks = AdvSeqLabel(args, emb) | |||||
trainer = MyNERTrainer(args) | |||||
trainer.train(networks, data_train, data_dev) | |||||
print("Training finished!") |
@@ -1,129 +0,0 @@ | |||||
import _pickle | |||||
import os | |||||
import torch | |||||
from fastNLP.core.predictor import SeqLabelInfer | |||||
from fastNLP.core.trainer import SeqLabelTrainer | |||||
from fastNLP.loader.model_loader import ModelLoader | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
class Decode(SeqLabelTrainer): | |||||
def __init__(self, args): | |||||
super(Decode, self).__init__(args) | |||||
def decoder(self, network, sents, model_path): | |||||
self.model = network | |||||
self.model.load_state_dict(torch.load(model_path)) | |||||
out_put = [] | |||||
self.mode(network, test=True) | |||||
for batch_x in sents: | |||||
prediction = self.data_forward(self.model, batch_x) | |||||
seq_tag = self.model.prediction(prediction, batch_x[1]) | |||||
out_put.append(list(seq_tag)[0]) | |||||
return out_put | |||||
def process_sent(sents, word2id): | |||||
sents_num = [] | |||||
for s in sents: | |||||
sent_num = [] | |||||
for c in s: | |||||
if c in word2id: | |||||
sent_num.append(word2id[c]) | |||||
else: | |||||
sent_num.append(word2id["<unk>"]) | |||||
sents_num.append(([sent_num], [len(sent_num)])) # batch_size is 1 | |||||
return sents_num | |||||
def process_tag(sents, tags, id2class): | |||||
Tags = [] | |||||
for ttt in tags: | |||||
Tags.append([id2class[t] for t in ttt]) | |||||
Segs = [] | |||||
PosNers = [] | |||||
for sent, tag in zip(sents, tags): | |||||
word__ = [] | |||||
lll__ = [] | |||||
for c, t in zip(sent, tag): | |||||
t = id2class[t] | |||||
l = t.split("-") | |||||
split_ = l[0] | |||||
pn = l[1] | |||||
if split_ == "S": | |||||
word__.append(c) | |||||
lll__.append(pn) | |||||
word_1 = "" | |||||
elif split_ == "E": | |||||
word_1 += c | |||||
word__.append(word_1) | |||||
lll__.append(pn) | |||||
word_1 = "" | |||||
elif split_ == "B": | |||||
word_1 = "" | |||||
word_1 += c | |||||
else: | |||||
word_1 += c | |||||
Segs.append(word__) | |||||
PosNers.append(lll__) | |||||
return Segs, PosNers | |||||
pickle_path = "data_for_tests/" | |||||
model_path = "data_for_tests/model_best_dev.pkl" | |||||
if __name__ == "__main__": | |||||
with open(os.path.join(pickle_path, "id2word.pkl"), "rb") as f: | |||||
id2word = _pickle.load(f) | |||||
with open(os.path.join(pickle_path, "word2id.pkl"), "rb") as f: | |||||
word2id = _pickle.load(f) | |||||
with open(os.path.join(pickle_path, "id2class.pkl"), "rb") as f: | |||||
id2class = _pickle.load(f) | |||||
sent = ["中共中央总书记、国家主席江泽民", | |||||
"逆向处理输入序列并返回逆序后的序列"] # here is input | |||||
args = {"epochs": 1, | |||||
"batch_size": 1, | |||||
"pickle_path": "data_for_tests/", | |||||
"validate": True, | |||||
"save_best_dev": True, | |||||
"model_saved_path": "data_for_tests/", | |||||
"use_cuda": False, | |||||
"vocab_size": len(word2id), | |||||
"num_classes": len(id2class), | |||||
"word_emb_dim": 50, | |||||
"rnn_hidden_units": 100, | |||||
} | |||||
""" | |||||
network = AdvSeqLabel(args, None) | |||||
decoder_ = Decode(args) | |||||
tags_num = decoder_.decoder(network, process_sent(sent, word2id), model_path=model_path) | |||||
output_seg, output_pn = process_tag(sent, tags_num, id2class) # here is output | |||||
print(output_seg) | |||||
print(output_pn) | |||||
""" | |||||
# Define the same model | |||||
model = AdvSeqLabel(args, None) | |||||
# Dump trained parameters into the model | |||||
ModelLoader.load_pytorch(model, "./data_for_tests/model_best_dev.pkl") | |||||
print("model loaded!") | |||||
# Inference interface | |||||
infer = SeqLabelInfer(pickle_path) | |||||
sent = [[ch for ch in s] for s in sent] | |||||
results = infer.predict(model, sent) | |||||
for res in results: | |||||
print(res) | |||||
print("Inference finished!") |
@@ -36,7 +36,7 @@ data_dir = 'save/' # directory to save data and model | |||||
train_path = './data_for_tests/text_classify.txt' # training set file | train_path = './data_for_tests/text_classify.txt' # training set file | ||||
# load dataset | # load dataset | ||||
ds_loader = ClassDatasetLoader("train", train_path) | |||||
ds_loader = ClassDatasetLoader(train_path) | |||||
data = ds_loader.load() | data = ds_loader.load() | ||||
# pre-process dataset | # pre-process dataset | ||||
@@ -33,7 +33,7 @@ data_infer_path = args.infer | |||||
def infer(): | def infer(): | ||||
# Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config(config_dir, {"POS_infer": test_args}) | |||||
ConfigLoader("config.cfg").load_config(config_dir, {"POS_infer": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
@@ -49,7 +49,7 @@ def infer(): | |||||
print("model loaded!") | print("model loaded!") | ||||
# Data Loader | # Data Loader | ||||
raw_data_loader = BaseLoader("xxx", data_infer_path) | |||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
# Inference interface | # Inference interface | ||||
@@ -65,11 +65,11 @@ def train_and_test(): | |||||
# Config Loader | # Config Loader | ||||
trainer_args = ConfigSection() | trainer_args = ConfigSection() | ||||
model_args = ConfigSection() | model_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config(config_dir, { | |||||
ConfigLoader("config.cfg").load_config(config_dir, { | |||||
"test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | "test_seq_label_trainer": trainer_args, "test_seq_label_model": model_args}) | ||||
# Data Loader | # Data Loader | ||||
pos_loader = POSDatasetLoader("xxx", data_path) | |||||
pos_loader = POSDatasetLoader(data_path) | |||||
train_data = pos_loader.load_lines() | train_data = pos_loader.load_lines() | ||||
# Preprocessor | # Preprocessor | ||||
@@ -117,7 +117,7 @@ def train_and_test(): | |||||
# Load test configuration | # Load test configuration | ||||
tester_args = ConfigSection() | tester_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
ConfigLoader("config.cfg").load_config(config_dir, {"test_seq_label_tester": tester_args}) | |||||
# Tester | # Tester | ||||
tester = SeqLabelTester(save_output=False, | tester = SeqLabelTester(save_output=False, | ||||
@@ -139,5 +139,5 @@ def train_and_test(): | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
train_and_test() | |||||
# infer() | |||||
# train_and_test() | |||||
infer() |
@@ -22,7 +22,7 @@ data_infer_path = "data_for_tests/people_infer.txt" | |||||
def infer(): | def infer(): | ||||
# Load infer configuration, the same as test | # Load infer configuration, the same as test | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
# fetch dictionary size and number of labels from pickle files | # fetch dictionary size and number of labels from pickle files | ||||
word2index = load_pickle(pickle_path, "word2id.pkl") | word2index = load_pickle(pickle_path, "word2id.pkl") | ||||
@@ -38,7 +38,7 @@ def infer(): | |||||
print("model loaded!") | print("model loaded!") | ||||
# Data Loader | # Data Loader | ||||
raw_data_loader = BaseLoader(data_name, data_infer_path) | |||||
raw_data_loader = BaseLoader(data_infer_path) | |||||
infer_data = raw_data_loader.load_lines() | infer_data = raw_data_loader.load_lines() | ||||
""" | """ | ||||
Transform strings into list of list of strings. | Transform strings into list of list of strings. | ||||
@@ -61,10 +61,10 @@ def infer(): | |||||
def train_test(): | def train_test(): | ||||
# Config Loader | # Config Loader | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
# Data Loader | # Data Loader | ||||
loader = TokenizeDatasetLoader(data_name, cws_data_path) | |||||
loader = TokenizeDatasetLoader(cws_data_path) | |||||
train_data = loader.load_pku() | train_data = loader.load_pku() | ||||
# Preprocessor | # Preprocessor | ||||
@@ -74,7 +74,7 @@ def train_test(): | |||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
# Trainer | # Trainer | ||||
trainer = SeqLabelTrainer(train_args) | |||||
trainer = SeqLabelTrainer(**train_args.data) | |||||
# Model | # Model | ||||
model = SeqLabeling(train_args) | model = SeqLabeling(train_args) | ||||
@@ -99,10 +99,10 @@ def train_test(): | |||||
# Load test configuration | # Load test configuration | ||||
test_args = ConfigSection() | test_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||||
# Tester | # Tester | ||||
tester = SeqLabelTester(test_args) | |||||
tester = SeqLabelTester(**test_args.data) | |||||
# Start testing | # Start testing | ||||
tester.test(model, data_train) | tester.test(model, data_train) | ||||
@@ -9,15 +9,15 @@ pickle_path = "data_for_tests" | |||||
def foo(): | def foo(): | ||||
loader = TokenizeDatasetLoader(data_name, "./data_for_tests/cws_pku_utf_8") | |||||
loader = TokenizeDatasetLoader("./data_for_tests/cws_pku_utf_8") | |||||
train_data = loader.load_pku() | train_data = loader.load_pku() | ||||
train_args = ConfigSection() | train_args = ConfigSection() | ||||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS": train_args}) | |||||
# Preprocessor | # Preprocessor | ||||
p = SeqLabelPreprocess() | p = SeqLabelPreprocess() | ||||
p.run(train_data) | |||||
train_data = p.run(train_data) | |||||
train_args["vocab_size"] = p.vocab_size | train_args["vocab_size"] = p.vocab_size | ||||
train_args["num_classes"] = p.num_classes | train_args["num_classes"] = p.num_classes | ||||
@@ -26,10 +26,10 @@ def foo(): | |||||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | ||||
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | "save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | ||||
"use_cuda": True} | "use_cuda": True} | ||||
validator = SeqLabelTester(valid_args) | |||||
validator = SeqLabelTester(**valid_args) | |||||
print("start validation.") | print("start validation.") | ||||
validator.test(model) | |||||
validator.test(model, train_data) | |||||
print(validator.show_metrics()) | print(validator.show_metrics()) | ||||
@@ -34,7 +34,7 @@ config_dir = args.config | |||||
def infer(): | def infer(): | ||||
# load dataset | # load dataset | ||||
print("Loading data...") | print("Loading data...") | ||||
ds_loader = ClassDatasetLoader("train", train_data_dir) | |||||
ds_loader = ClassDatasetLoader(train_data_dir) | |||||
data = ds_loader.load() | data = ds_loader.load() | ||||
unlabeled_data = [x[0] for x in data] | unlabeled_data = [x[0] for x in data] | ||||
@@ -69,7 +69,7 @@ def train(): | |||||
# load dataset | # load dataset | ||||
print("Loading data...") | print("Loading data...") | ||||
ds_loader = ClassDatasetLoader("train", train_data_dir) | |||||
ds_loader = ClassDatasetLoader(train_data_dir) | |||||
data = ds_loader.load() | data = ds_loader.load() | ||||
print(data[0]) | print(data[0]) | ||||