- [tester][trainer] add cuda support - [preprocess] fix label2index for padding label seq - update README.md - [test] add test_tester.py - rename "action" to "core"tags/v0.1.0
@@ -1,58 +1,92 @@ | |||
# FastNLP | |||
``` | |||
FastNLP | |||
│ LICENSE | |||
│ README.md | |||
│ requirements.txt | |||
│ setup.py | |||
├── docs | |||
│ └── quick_tutorial.md | |||
├── fastNLP | |||
│ ├── action | |||
│ │ ├── action.py | |||
│ │ ├── inference.py | |||
│ │ ├── __init__.py | |||
│ │ ├── metrics.py | |||
│ │ ├── optimizer.py | |||
│ │ ├── README.md | |||
│ │ ├── tester.py | |||
│ │ └── trainer.py | |||
│ ├── fastnlp.py | |||
│ ├── __init__.py | |||
│ ├── loader | |||
│ │ ├── base_loader.py | |||
│ │ ├── config_loader.py | |||
│ │ ├── dataset_loader.py | |||
│ │ ├── embed_loader.py | |||
│ │ ├── __init__.py | |||
│ │ ├── model_loader.py | |||
│ │ └── preprocess.py | |||
│ ├── models | |||
│ │ ├── base_model.py | |||
│ │ ├── char_language_model.py | |||
│ │ ├── cnn_text_classification.py | |||
│ │ ├── __init__.py | |||
│ │ └── sequence_modeling.py | |||
│ ├── modules | |||
│ │ ├── aggregation | |||
│ │ │ ├── attention.py | |||
│ │ │ ├── avg_pool.py | |||
│ │ │ ├── __init__.py | |||
│ │ │ ├── kmax_pool.py | |||
│ │ │ ├── max_pool.py | |||
│ │ │ └── self_attention.py | |||
│ │ ├── decoder | |||
│ │ │ ├── CRF.py | |||
│ │ │ └── __init__.py | |||
│ │ ├── encoder | |||
│ │ │ ├── char_embedding.py | |||
│ │ │ ├── conv_maxpool.py | |||
│ │ │ ├── conv.py | |||
│ │ │ ├── embedding.py | |||
│ │ │ ├── __init__.py | |||
│ │ │ ├── linear.py | |||
│ │ │ ├── lstm.py | |||
│ │ │ ├── masked_rnn.py | |||
│ │ │ └── variational_rnn.py | |||
│ │ ├── __init__.py | |||
│ │ ├── interaction | |||
│ │ │ └── __init__.py | |||
│ │ ├── other_modules.py | |||
│ │ └── utils.py | |||
│ └── saver | |||
│ ├── base_saver.py | |||
│ ├── __init__.py | |||
│ ├── logger.py | |||
│ └── model_saver.py | |||
├── LICENSE | |||
├── README.md | |||
├── reproduction | |||
│ ├── Char-aware_NLM | |||
│ │ | |||
│ ├── CNN-sentence_classification | |||
│ │ | |||
│ ├── HAN-document_classification | |||
│ │ | |||
│ └── LSTM+self_attention_sentiment_analysis | |||
| | |||
├─docs (documentation) | |||
| | |||
└─tests (unit tests, intergrating tests, system tests) | |||
| │ test_charlm.py | |||
| │ test_loader.py | |||
| │ test_trainer.py | |||
| │ test_word_seg.py | |||
| │ | |||
| └─data_for_tests (test data used by models) | |||
| charlm.txt | |||
| cws_test | |||
| cws_train | |||
| | |||
└─fastNLP | |||
├─action (model independent process) | |||
│ │ action.py (base class) | |||
│ │ README.md | |||
│ │ tester.py (model testing, for deployment and validation) | |||
│ │ trainer.py (main logic for model training) | |||
│ │ __init__.py | |||
│ │ | |||
| | |||
│ | |||
├─loader (file loader for all loading operations) | |||
│ | base_loader.py (base class) | |||
│ | config_loader.py (model-specific configuration/parameter loader) | |||
│ | dataset_loader.py (data set loader, base class) | |||
│ | embed_loader.py (embedding loader, base class) | |||
│ | __init__.py | |||
│ | |||
├─model (definitions of PyTorch models) | |||
│ │ base_model.py (base class, abstract) | |||
│ │ char_language_model.py (derived class, to implement abstract methods) | |||
│ │ word_seg_model.py | |||
│ │ __init__.py | |||
│ │ | |||
│ | |||
├─reproduction (code library for paper reproduction) | |||
│ ├─Char-aware_NLM | |||
│ │ | |||
│ ├─CNN-sentence_classification | |||
│ │ | |||
│ └─HAN-document_classification | |||
│ | |||
├─saver (file saver for all saving operations) | |||
│ base_saver.py | |||
│ logger.py | |||
│ model_saver.py | |||
│ | |||
├── requirements.txt | |||
├── setup.py | |||
└── test | |||
├── data_for_tests | |||
│ ├── charlm.txt | |||
│ ├── config | |||
│ ├── cws_test | |||
│ ├── cws_train | |||
│ ├── people_infer.txt | |||
│ └── people.txt | |||
├── test_charlm.py | |||
├── test_cws.py | |||
├── test_fastNLP.py | |||
├── test_loader.py | |||
├── test_seq_labeling.py | |||
├── test_tester.py | |||
└── test_trainer.py | |||
``` |
@@ -1,6 +1,6 @@ | |||
import torch | |||
from fastNLP.action.action import Batchifier, SequentialSampler | |||
from fastNLP.core.action import Batchifier, SequentialSampler | |||
from fastNLP.loader.preprocess import load_pickle, DEFAULT_UNKNOWN_LABEL | |||
@@ -55,7 +55,7 @@ class Inference(object): | |||
def data_forward(self, network, x): | |||
""" | |||
This is only for sequence labeling with CRF decoder. To do: more general ? | |||
This is only for sequence labeling with CRF decoder. TODO: more general ? | |||
:param network: | |||
:param x: | |||
:return: |
@@ -4,8 +4,8 @@ import os | |||
import numpy as np | |||
import torch | |||
from fastNLP.action.action import Action | |||
from fastNLP.action.action import RandomSampler, Batchifier | |||
from fastNLP.core.action import Action | |||
from fastNLP.core.action import RandomSampler, Batchifier | |||
class BaseTester(Action): | |||
@@ -25,14 +25,17 @@ class BaseTester(Action): | |||
self.batch_size = test_args["batch_size"] | |||
self.pickle_path = test_args["pickle_path"] | |||
self.iterator = None | |||
self.use_cuda = test_args["use_cuda"] | |||
self.model = None | |||
self.eval_history = [] | |||
self.batch_output = [] | |||
def test(self, network): | |||
# print("--------------testing----------------") | |||
self.model = network | |||
if torch.cuda.is_available() and self.use_cuda: | |||
self.model = network.cuda() | |||
else: | |||
self.model = network | |||
# turn on the testing mode; clean up the history | |||
self.mode(network, test=True) | |||
@@ -44,7 +47,7 @@ class BaseTester(Action): | |||
num_iter = len(dev_data) // self.batch_size | |||
for step in range(num_iter): | |||
batch_x, batch_y = self.batchify(dev_data) | |||
batch_x, batch_y = self.make_batch(dev_data) | |||
prediction = self.data_forward(network, batch_x) | |||
eval_results = self.evaluate(prediction, batch_y) | |||
@@ -65,7 +68,7 @@ class BaseTester(Action): | |||
self.save_dev_data = data_dev | |||
return self.save_dev_data | |||
def batchify(self, data): | |||
def make_batch(self, data, output_length=True): | |||
""" | |||
1. Perform batching from data and produce a batch of training data. | |||
2. Add padding. | |||
@@ -83,8 +86,13 @@ class BaseTester(Action): | |||
batch = [data[idx] for idx in indices] | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x = self.pad(batch_x) | |||
return batch_x, batch_y | |||
batch_x_pad = self.pad(batch_x) | |||
batch_y_pad = self.pad(batch_y) | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
return (batch_x_pad, seq_len), batch_y_pad | |||
else: | |||
return batch_x_pad, batch_y_pad | |||
@staticmethod | |||
def pad(batch, fill=0): | |||
@@ -97,7 +105,7 @@ class BaseTester(Action): | |||
max_length = max([len(x) for x in batch]) | |||
for idx, sample in enumerate(batch): | |||
if len(sample) < max_length: | |||
batch[idx] = sample + [fill * (max_length - len(sample))] | |||
batch[idx] = sample + ([fill] * (max_length - len(sample))) | |||
return batch | |||
def data_forward(self, network, data): | |||
@@ -111,7 +119,7 @@ class BaseTester(Action): | |||
raise NotImplementedError | |||
def mode(self, model, test=True): | |||
"""To do: combine this function with Trainer ?? """ | |||
"""TODO: combine this function with Trainer ?? """ | |||
if test: | |||
model.eval() | |||
else: | |||
@@ -140,26 +148,37 @@ class POSTester(BaseTester): | |||
self.mask = None | |||
self.batch_result = None | |||
def data_forward(self, network, x): | |||
"""To Do: combine with Trainer | |||
def data_forward(self, network, inputs): | |||
"""TODO: combine with Trainer | |||
:param network: the PyTorch model | |||
:param x: list of list, [batch_size, max_len] | |||
:return y: [batch_size, num_classes] | |||
""" | |||
self.seq_len = [len(seq) for seq in x] | |||
# unpack the returned value from make_batch | |||
if isinstance(inputs, tuple): | |||
x = inputs[0] | |||
self.seq_len = inputs[1] | |||
else: | |||
x = inputs | |||
x = torch.Tensor(x).long() | |||
if torch.cuda.is_available() and self.use_cuda: | |||
x = x.cuda() | |||
self.batch_size = x.size(0) | |||
self.max_len = x.size(1) | |||
# self.mask = seq_mask(seq_len, self.max_len) | |||
y = network(x) | |||
return y | |||
def evaluate(self, predict, truth): | |||
truth = torch.Tensor(truth) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
truth = truth.cuda() | |||
loss = self.model.loss(predict, truth, self.seq_len) | |||
prediction = self.model.prediction(predict, self.seq_len) | |||
results = torch.Tensor(prediction).view(-1,) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
results = results.cuda() | |||
accuracy = float(torch.sum(results == truth.view((-1,)))) / results.shape[0] | |||
return [loss.data, accuracy] | |||
@@ -256,7 +275,7 @@ class ClassTester(BaseTester): | |||
n_batches = len(data_test) // self.batch_size | |||
n_print = n_batches // 10 | |||
step = 0 | |||
for batch_x, batch_y in self.batchify(data_test, max_len=self.max_len): | |||
for batch_x, batch_y in self.make_batch(data_test, max_len=self.max_len): | |||
prediction = self.data_forward(network, batch_x) | |||
eval_results = self.evaluate(prediction, batch_y) | |||
@@ -277,7 +296,7 @@ class ClassTester(BaseTester): | |||
data = _pickle.load(f) | |||
return data | |||
def batchify(self, data, max_len=None): | |||
def make_batch(self, data, max_len=None): | |||
"""Batch and pad data.""" | |||
for indices in self.iterator: | |||
# generate batch and pad | |||
@@ -319,7 +338,7 @@ class ClassTester(BaseTester): | |||
return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc | |||
def mode(self, model, test=True): | |||
"""To do: combine this function with Trainer ?? """ | |||
"""TODO: combine this function with Trainer ?? """ | |||
if test: | |||
model.eval() | |||
else: |
@@ -7,9 +7,9 @@ import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from fastNLP.action.action import Action | |||
from fastNLP.action.action import RandomSampler, Batchifier | |||
from fastNLP.action.tester import POSTester | |||
from fastNLP.core.action import Action | |||
from fastNLP.core.action import RandomSampler, Batchifier | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.saver.model_saver import ModelSaver | |||
@@ -44,6 +44,7 @@ class BaseTrainer(Action): | |||
self.validate = train_args["validate"] | |||
self.save_best_dev = train_args["save_best_dev"] | |||
self.model_saved_path = train_args["model_saved_path"] | |||
self.use_cuda = train_args["use_cuda"] | |||
self.model = None | |||
self.iterator = None | |||
@@ -65,13 +66,19 @@ class BaseTrainer(Action): | |||
- update | |||
Subclasses must implement these methods with a specific framework. | |||
""" | |||
# prepare model and data | |||
self.model = network | |||
# prepare model and data, transfer model to gpu if available | |||
if torch.cuda.is_available() and self.use_cuda: | |||
self.model = network.cuda() | |||
else: | |||
self.model = network | |||
data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) | |||
# define tester over dev data | |||
# TODO: more flexible | |||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path} | |||
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, | |||
"use_cuda": self.use_cuda} | |||
validator = POSTester(valid_args) | |||
# main training epochs | |||
@@ -109,9 +116,6 @@ class BaseTrainer(Action): | |||
# finish training | |||
def prepare_input(self, data_path): | |||
""" | |||
To do: Load pkl files of train/dev/test and embedding | |||
""" | |||
data_train = _pickle.load(open(data_path + "data_train.pkl", "rb")) | |||
data_dev = _pickle.load(open(data_path + "data_dev.pkl", "rb")) | |||
data_test = _pickle.load(open(data_path + "data_test.pkl", "rb")) | |||
@@ -203,11 +207,12 @@ class BaseTrainer(Action): | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x_pad = self.pad(batch_x) | |||
batch_y_pad = self.pad(batch_y) | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
return (batch_x_pad, seq_len), batch_y | |||
return (batch_x_pad, seq_len), batch_y_pad | |||
else: | |||
return batch_x_pad, batch_y | |||
return batch_x_pad, batch_y_pad | |||
@staticmethod | |||
def pad(batch, fill=0): | |||
@@ -288,9 +293,7 @@ class POSTrainer(BaseTrainer): | |||
self.best_accuracy = 0.0 | |||
def prepare_input(self, data_path): | |||
""" | |||
To do: Load pkl files of train/dev/test and embedding | |||
""" | |||
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb")) | |||
return data_train, data_dev, 0, 1 | |||
@@ -309,6 +312,8 @@ class POSTrainer(BaseTrainer): | |||
else: | |||
x = inputs | |||
x = torch.Tensor(x).long() | |||
if torch.cuda.is_available() and self.use_cuda: | |||
x = x.cuda() | |||
self.batch_size = x.size(0) | |||
self.max_len = x.size(1) | |||
@@ -339,6 +344,8 @@ class POSTrainer(BaseTrainer): | |||
:return: a scalar | |||
""" | |||
truth = torch.Tensor(truth) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
truth = truth.cuda() | |||
assert truth.shape == (self.batch_size, self.max_len) | |||
if self.loss_func is None: | |||
if hasattr(self.model, "loss"): | |||
@@ -380,11 +387,12 @@ class POSTrainer(BaseTrainer): | |||
batch_x = [sample[0] for sample in batch] | |||
batch_y = [sample[1] for sample in batch] | |||
batch_x_pad = self.pad(batch_x) | |||
batch_y_pad = self.pad(batch_y) | |||
if output_length: | |||
seq_len = [len(x) for x in batch_x] | |||
return (batch_x_pad, seq_len), batch_y | |||
return (batch_x_pad, seq_len), batch_y_pad | |||
else: | |||
return batch_x_pad, batch_y | |||
return batch_x_pad, batch_y_pad | |||
class LanguageModelTrainer(BaseTrainer): | |||
@@ -504,9 +512,6 @@ class ClassTrainer(BaseTrainer): | |||
# finish training | |||
def prepare_input(self, data_path): | |||
""" | |||
To do: Load pkl files of train/dev/test and embedding | |||
""" | |||
names = [ | |||
"data_train.pkl", "data_dev.pkl", |
@@ -1,4 +1,4 @@ | |||
from fastNLP.action.inference import Inference | |||
from fastNLP.core.inference import Inference | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.loader.model_loader import ModelLoader | |||
@@ -110,8 +110,9 @@ class POSPreprocess(BasePreprocess): | |||
:return word2index: dict of {str, int} | |||
label2index: dict of {str, int} | |||
""" | |||
label2index = {} | |||
word2index = DEFAULT_WORD_TO_INDEX | |||
# In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. | |||
label2index = DEFAULT_WORD_TO_INDEX.copy() | |||
word2index = DEFAULT_WORD_TO_INDEX.copy() | |||
for example in data: | |||
for word, label in zip(example[0], example[1]): | |||
if word not in word2index: | |||
@@ -3,7 +3,6 @@ import torch | |||
class BaseModel(torch.nn.Module): | |||
"""Base PyTorch model for all models. | |||
To do: add some useful common features | |||
""" | |||
def __init__(self): | |||
@@ -19,8 +19,6 @@ USE_GPU = True | |||
class CharLM(BaseModel): | |||
""" | |||
Controller of the Character-level Neural Language Model | |||
To do: | |||
- where the data goes, call data savers. | |||
""" | |||
def __init__(self, lstm_batch_size, lstm_seq_len): | |||
super(CharLM, self).__init__() | |||
@@ -51,6 +51,10 @@ class SeqLabeling(BaseModel): | |||
mask = utils.seq_mask(seq_length, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
# TODO: remove | |||
if torch.cuda.is_available(): | |||
mask = mask.cuda() | |||
# mask = x.new(batch_size, max_len) | |||
total_loss = self.Crf(x, y, mask) | |||
@@ -69,7 +73,10 @@ class SeqLabeling(BaseModel): | |||
mask = utils.seq_mask(seq_length, max_len) | |||
mask = mask.byte() | |||
# mask = x.new(batch_size, max_len) | |||
# TODO: remove | |||
if torch.cuda.is_available(): | |||
mask = mask.cuda() | |||
tag_seq = self.Crf.viterbi_decode(x, mask) | |||
@@ -18,7 +18,7 @@ MLP_HIDDEN = 2000 | |||
CLASSES_NUM = 5 | |||
from fastNLP.models.base_model import BaseModel | |||
from fastNLP.action.trainer import BaseTrainer | |||
from fastNLP.core.trainer import BaseTrainer | |||
class MyNet(BaseModel): | |||
@@ -66,6 +66,7 @@ rnn_bi_direction = true | |||
word_emb_dim = 100 | |||
dropout = 0.5 | |||
use_crf = true | |||
use_cuda = true | |||
[POS_test] | |||
save_output = true | |||
@@ -80,6 +81,7 @@ rnn_bi_direction = true | |||
word_emb_dim = 100 | |||
dropout = 0.5 | |||
use_crf = true | |||
use_cuda = true | |||
[POS_infer] | |||
pickle_path = "./data_for_tests/" | |||
@@ -1,31 +1,7 @@ | |||
from loader.base_loader import ToyLoader0 | |||
from model.char_language_model import CharLM | |||
from fastNLP.action import Tester | |||
from fastNLP.action.trainer import Trainer | |||
def test_charlm(): | |||
train_config = Trainer.TrainConfig(epochs=1, validate=True, save_when_better=True, | |||
log_per_step=10, log_validation=True, batch_size=160) | |||
trainer = Trainer(train_config) | |||
model = CharLM(lstm_batch_size=16, lstm_seq_len=10) | |||
train_data = ToyLoader0("load_train", "./data_for_tests/charlm.txt").load() | |||
valid_data = ToyLoader0("load_valid", "./data_for_tests/charlm.txt").load() | |||
trainer.train(model, train_data, valid_data) | |||
trainer.save_model(model) | |||
test_config = Tester.TestConfig(save_output=True, validate_in_training=True, | |||
save_dev_input=True, save_loss=True, batch_size=160) | |||
tester = Tester(test_config) | |||
test_data = ToyLoader0("load_test", "./data_for_tests/charlm.txt").load() | |||
tester.test(model, test_data) | |||
pass | |||
if __name__ == "__main__": | |||
@@ -3,14 +3,14 @@ import sys | |||
sys.path.append("..") | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.action.trainer import POSTrainer | |||
from fastNLP.core.trainer import POSTrainer | |||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader | |||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||
from fastNLP.saver.model_saver import ModelSaver | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.action.tester import POSTester | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.models.sequence_modeling import SeqLabeling | |||
from fastNLP.action.inference import Inference | |||
from fastNLP.core.inference import Inference | |||
data_name = "pku_training.utf8" | |||
cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" | |||
@@ -3,14 +3,14 @@ import sys | |||
sys.path.append("..") | |||
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection | |||
from fastNLP.action.trainer import POSTrainer | |||
from fastNLP.core.trainer import POSTrainer | |||
from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader | |||
from fastNLP.loader.preprocess import POSPreprocess, load_pickle | |||
from fastNLP.saver.model_saver import ModelSaver | |||
from fastNLP.loader.model_loader import ModelLoader | |||
from fastNLP.action.tester import POSTester | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.models.sequence_modeling import SeqLabeling | |||
from fastNLP.action.inference import Inference | |||
from fastNLP.core.inference import Inference | |||
data_name = "people.txt" | |||
data_path = "data_for_tests/people.txt" | |||
@@ -0,0 +1,35 @@ | |||
from fastNLP.core.tester import POSTester | |||
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader | |||
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader | |||
from fastNLP.loader.preprocess import POSPreprocess | |||
from fastNLP.models.sequence_modeling import SeqLabeling | |||
data_name = "pku_training.utf8" | |||
cws_data_path = "/home/zyfeng/Desktop/data/pku_training.utf8" | |||
pickle_path = "data_for_tests" | |||
def foo(): | |||
loader = TokenizeDatasetLoader(data_name, cws_data_path) | |||
train_data = loader.load_pku() | |||
train_args = ConfigSection() | |||
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS": train_args}) | |||
# Preprocessor | |||
p = POSPreprocess(train_data, pickle_path) | |||
train_args["vocab_size"] = p.vocab_size | |||
train_args["num_classes"] = p.num_classes | |||
model = SeqLabeling(train_args) | |||
valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, | |||
"save_loss": True, "batch_size": 8, "pickle_path": "./data_for_tests/", | |||
"use_cuda": True} | |||
validator = POSTester(valid_args) | |||
validator.test(model) | |||
validator.show_matrices() | |||
if __name__ == "__main__": | |||
foo() |
@@ -1,12 +1,5 @@ | |||
def test_trainer(): | |||
Config = namedtuple("config", ["epochs", "validate", "save_when_better"]) | |||
train_config = Config(epochs=5, validate=True, save_when_better=True) | |||
trainer = Trainer(train_config) | |||
net = ToyModel() | |||
data = np.random.rand(20, 6) | |||
dev_data = np.random.rand(20, 6) | |||
trainer.train(net, data, dev_data) | |||
pass | |||
if __name__ == "__main__": | |||