diff --git a/fastNLP/core/action.py b/fastNLP/core/action.py index d0dce5a6..f3fe5f41 100644 --- a/fastNLP/core/action.py +++ b/fastNLP/core/action.py @@ -22,9 +22,7 @@ class Action(object): @staticmethod def make_batch(iterator, data, output_length=True): - """ - 1. Perform batching from data and produce a batch of training data. - 2. Add padding. + """Batch and Pad data. :param iterator: an iterator, (object that implements __next__ method) which returns the next sample. :param data: list. Each entry is a sample, which is also a list of features and label(s). E.g. @@ -41,17 +39,17 @@ class Action(object): return batch_x and batch_y, if output_length is False """ - indices = next(iterator) - batch = [data[idx] for idx in indices] - batch_x = [sample[0] for sample in batch] - batch_y = [sample[1] for sample in batch] - batch_x_pad = Action.pad(batch_x) - batch_y_pad = Action.pad(batch_y) - if output_length: - seq_len = [len(x) for x in batch_x] - return (batch_x_pad, seq_len), batch_y_pad - else: - return batch_x_pad, batch_y_pad + for indices in iterator: + batch = [data[idx] for idx in indices] + batch_x = [sample[0] for sample in batch] + batch_y = [sample[1] for sample in batch] + batch_x_pad = Action.pad(batch_x) + batch_y_pad = Action.pad(batch_y) + if output_length: + seq_len = [len(x) for x in batch_x] + yield (batch_x_pad, seq_len), batch_y_pad + else: + yield batch_x_pad, batch_y_pad @staticmethod def pad(batch, fill=0): @@ -208,11 +206,10 @@ class Batchifier(object): def __iter__(self): batch = [] - while True: - for idx in self.sampler: - batch.append(idx) - if len(batch) == self.batch_size: - yield batch - batch = [] - if 0 < len(batch) < self.batch_size and self.drop_last is False: + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: yield batch + batch = [] + if 0 < len(batch) < self.batch_size and self.drop_last is False: + yield batch diff --git a/fastNLP/core/inference.py b/fastNLP/core/inference.py index 1bbcaf3a..d1d63b78 100644 --- a/fastNLP/core/inference.py +++ b/fastNLP/core/inference.py @@ -11,7 +11,7 @@ class Inference(object): This is an interface focusing on predicting output based on trained models. It does not care about evaluations of the model, which is different from Tester. This is a high-level model wrapper to be called by FastNLP. - + This class does not share any operations with Trainer and Tester. """ def __init__(self, pickle_path): diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 27225cdb..d0208dfd 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -48,10 +48,7 @@ class BaseTester(Action): iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=True)) - num_iter = len(dev_data) // self.batch_size - - for step in range(num_iter): - batch_x, batch_y = self.action.make_batch(iterator, dev_data) + for batch_x, batch_y in self.action.make_batch(iterator, dev_data): prediction = self.data_forward(network, batch_x) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d941536b..7024029f 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -72,28 +72,32 @@ class BaseTrainer(Action): else: self.model = network - data_train = self.prepare_input(self.pickle_path) + data_train, data_dev, data_test, embedding = self.prepare_input(self.pickle_path) # define tester over dev data - # TODO: more flexible - default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, - "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, - "use_cuda": self.use_cuda} - validator = POSTester(default_valid_args, self.action) + if self.validate: + default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, + "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, + "use_cuda": self.use_cuda} + validator = POSTester(default_valid_args, self.action) - # main training epochs - iterations = len(data_train) // self.batch_size self.define_optimizer() + # main training epochs + start = time() + n_samples = len(data_train) + n_batches = n_samples // self.batch_size + n_print = 1 + for epoch in range(1, self.n_epochs + 1): - # turn on network training mode; define optimizer; prepare batch iterator - self.action.mode(self.model, test=False) - iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True)) + # turn on network training mode; prepare batch iterator + self.action.mode(network, test=False) + iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=False)) # training iterations in one epoch - for step in range(iterations): - batch_x, batch_y = self.action.make_batch(iterator, data_train) + step = 0 + for batch_x, batch_y in self.action.make_batch(iterator, data_train, output_length=True): prediction = self.data_forward(network, batch_x) @@ -101,8 +105,12 @@ class BaseTrainer(Action): self.grad_backward(loss) self.update() - if step % 10 == 0: - print("[epoch {} step {}] train loss={:.2f}".format(epoch, step, loss.data)) + if step % n_print == 0: + end = time() + diff = timedelta(seconds=round(end - start)) + print("[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( + epoch, step, loss.data, diff)) + step += 1 if self.validate: validator.test(network) @@ -114,15 +122,25 @@ class BaseTrainer(Action): print("[epoch {}]".format(epoch), end=" ") print(validator.show_matrices()) - # finish training - def prepare_input(self, pickle_path): """ - This is reserved for task-specific processing. - :param data_path: - :return: + For task-specific processing. + :param pickle_path: + :return data_train, data_dev, data_test, embedding: """ - return _pickle.load(open(pickle_path + "/data_train.pkl", "rb")) + names = [ + "data_train.pkl", "data_dev.pkl", + "data_test.pkl", "embedding.pkl"] + files = [] + for name in names: + file_path = os.path.join(pickle_path, name) + if os.path.exists(file_path): + with open(file_path, 'rb') as f: + data = _pickle.load(f) + else: + data = [] + files.append(data) + return tuple(files) def define_optimizer(self): """ @@ -222,14 +240,14 @@ class ToyTrainer(BaseTrainer): self.optimizer.step() -class POSTrainer(BaseTrainer): +class SeqLabelTrainer(BaseTrainer): """ Trainer for Sequence Modeling """ def __init__(self, train_args, action=None): - super(POSTrainer, self).__init__(train_args, action) + super(SeqLabelTrainer, self).__init__(train_args, action) self.vocab_size = train_args["vocab_size"] self.num_classes = train_args["num_classes"] self.max_len = None @@ -251,11 +269,12 @@ class POSTrainer(BaseTrainer): raise RuntimeError("[fastnlp] output_length must be true for sequence modeling.") # unpack the returned value from make_batch x, seq_len = inputs[0], inputs[1] + x = torch.Tensor(x).long() + batch_size, max_len = x.size(0), x.size(1) mask = utils.seq_mask(seq_len, max_len) mask = mask.byte().view(batch_size, max_len) - x = torch.Tensor(x).long() if torch.cuda.is_available() and self.use_cuda: x = x.cuda() mask = mask.cuda() @@ -304,8 +323,8 @@ class LanguageModelTrainer(BaseTrainer): class ClassTrainer(BaseTrainer): """Trainer for classification.""" - def __init__(self, train_args): - # super(ClassTrainer, self).__init__(train_args) + def __init__(self, train_args, action=None): + super(ClassTrainer, self).__init__(train_args, action) self.n_epochs = train_args["epochs"] self.batch_size = train_args["batch_size"] self.pickle_path = train_args["pickle_path"] @@ -332,117 +351,8 @@ class ClassTrainer(BaseTrainer): self.loss_func = None self.optimizer = None - def train(self, network): - """General Training Steps - :param network: a model - - The method is framework independent. - Work by calling the following methods: - - prepare_input - - mode - - define_optimizer - - data_forward - - get_loss - - grad_backward - - update - Subclasses must implement these methods with a specific framework. - """ - # prepare model and data, transfer model to gpu if available - if torch.cuda.is_available() and self.use_cuda: - self.model = network.cuda() - else: - self.model = network - data_train, data_dev, data_test, embedding = self.prepare_input( - self.pickle_path) - - # define tester over dev data - # valid_args = { - # "save_output": True, "validate_in_training": True, - # "save_dev_input": True, "save_loss": True, - # "batch_size": self.batch_size, "pickle_path": self.pickle_path} - # validator = POSTester(valid_args) - - # urn on network training mode, define loss and optimizer - self.define_loss() - self.define_optimizer() - self.mode(test=False) - - # main training epochs - start = time() - n_samples = len(data_train) - n_batches = n_samples // self.batch_size - n_print = n_batches // 10 - for epoch in range(self.n_epochs): - # prepare batch iterator - self.iterator = iter(Batchifier( - RandomSampler(data_train), self.batch_size, drop_last=False)) - - # training iterations in one epoch - step = 0 - for batch_x, batch_y in self.make_batch(data_train): - prediction = self.data_forward(network, batch_x) - - loss = self.get_loss(prediction, batch_y) - self.grad_backward(loss) - self.update() - - if step % n_print == 0: - acc = self.get_acc(prediction, batch_y) - end = time() - diff = timedelta(seconds=round(end - start)) - print("epoch: {:>3} step: {:>4} loss: {:>4.2}" - " train acc: {:>5.1%} time: {}".format( - epoch, step, loss, acc, diff)) - - step += 1 - - # if self.validate: - # if data_dev is None: - # raise RuntimeError("No validation data provided.") - # validator.test(network) - # print("[epoch {}]".format(epoch), end=" ") - # print(validator.show_matrices()) - - # finish training - - def prepare_input(self, data_path): - - names = [ - "data_train.pkl", "data_dev.pkl", - "data_test.pkl", "embedding.pkl"] - - files = [] - for name in names: - file_path = os.path.join(data_path, name) - if os.path.exists(file_path): - with open(file_path, 'rb') as f: - data = _pickle.load(f) - else: - data = [] - files.append(data) - - return tuple(files) - - def mode(self, test=False): - """ - Tell the network to be trained or not. - :param test: bool - """ - if test: - self.model.eval() - else: - self.model.train() - def define_loss(self): - """ - Assign an instance of loss function to self.loss_func - E.g. self.loss_func = nn.CrossEntropyLoss() - """ - if self.loss_func is None: - if hasattr(self.model, "loss"): - self.loss_func = self.model.loss - else: - self.loss_func = nn.CrossEntropyLoss() + self.loss_func = nn.CrossEntropyLoss() def define_optimizer(self): """ @@ -455,13 +365,12 @@ class ClassTrainer(BaseTrainer): def data_forward(self, network, x): """Forward through network.""" + x = torch.Tensor(x).long() + if torch.cuda.is_available() and self.use_cuda: + x = x.cuda() logits = network(x) return logits - def get_loss(self, predict, truth): - """Calculate loss.""" - return self.loss_func(predict, truth) - def grad_backward(self, loss): """Compute gradient backward.""" self.model.zero_grad() @@ -471,21 +380,22 @@ class ClassTrainer(BaseTrainer): """Apply gradient.""" self.optimizer.step() + """ def make_batch(self, data): - """Batch and pad data.""" for indices in self.iterator: batch = [data[idx] for idx in indices] batch_x = [sample[0] for sample in batch] batch_y = [sample[1] for sample in batch] batch_x = self.pad(batch_x) - batch_x = torch.tensor(batch_x, dtype=torch.long) - batch_y = torch.tensor(batch_y, dtype=torch.long) + batch_x = torch.Tensor(batch_x).long() + batch_y = torch.Tensor(batch_y).long() if torch.cuda.is_available() and self.use_cuda: batch_x = batch_x.cuda() batch_y = batch_y.cuda() yield batch_x, batch_y + """ def get_acc(self, y_logit, y_true): """Compute accuracy.""" diff --git a/reproduction/chinese_word_seg/cws_train.py b/reproduction/chinese_word_seg/cws_train.py index ff549eb9..3378a8a5 100644 --- a/reproduction/chinese_word_seg/cws_train.py +++ b/reproduction/chinese_word_seg/cws_train.py @@ -3,7 +3,7 @@ import sys sys.path.append("..") from fastNLP.loader.config_loader import ConfigLoader, ConfigSection -from fastNLP.core.trainer import POSTrainer +from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader from fastNLP.loader.preprocess import POSPreprocess, load_pickle from fastNLP.saver.model_saver import ModelSaver @@ -64,7 +64,7 @@ def train(): train_args["num_classes"] = p.num_classes # Trainer - trainer = POSTrainer(train_args) + trainer = SeqLabelTrainer(train_args) # Model model = SeqLabeling(train_args) diff --git a/test/data_for_tests/people.txt b/test/data_for_tests/people.txt index f34c85cb..e4909679 100644 --- a/test/data_for_tests/people.txt +++ b/test/data_for_tests/people.txt @@ -64,4 +64,90 @@ 3 B-t 1 M-t 日 E-t -, S-w \ No newline at end of file +, S-w +迈 B-v +向 E-v +充 B-v +满 E-v +希 B-n +望 E-n +的 S-u +新 S-a +世 B-n +纪 E-n +— B-w +— E-w +一 B-t +九 M-t +九 M-t +八 M-t +年 E-t +新 B-t +年 E-t +讲 B-n +话 E-n +( S-w +附 S-v +图 B-n +片 E-n +1 S-m +张 S-q +) S-w + +迈 B-v +向 E-v +充 B-v +满 E-v +希 B-n +望 E-n +的 S-u +新 S-a +世 B-n +纪 E-n +— B-w +— E-w +一 B-t +九 M-t +九 M-t +八 M-t +年 E-t +新 B-t +年 E-t +讲 B-n +话 E-n +( S-w +附 S-v +图 B-n +片 E-n +1 S-m +张 S-q +) S-w + +迈 B-v +向 E-v +充 B-v +满 E-v +希 B-n +望 E-n +的 S-u +新 S-a +世 B-n +纪 E-n +— B-w +— E-w +一 B-t +九 M-t +九 M-t +八 M-t +年 E-t +新 B-t +年 E-t +讲 B-n +话 E-n +( S-w +附 S-v +图 B-n +片 E-n +1 S-m +张 S-q +) S-w \ No newline at end of file diff --git a/test/data_for_tests/text_classify.txt b/test/data_for_tests/text_classify.txt new file mode 100644 index 00000000..24a51ce9 --- /dev/null +++ b/test/data_for_tests/text_classify.txt @@ -0,0 +1,100 @@ +entertainment 台 媒 预 测 周 冬 雨 金 马 奖 封 后 , 大 气 的 倪 妮 却 佳 作 难 出 +food 农 村 就 是 好 , 能 吃 到 纯 天 然 无 添 加 的 野 生 蜂 蜜 , 营 养 又 健 康 +fashion 1 4 款 知 性 美 装 , 时 尚 惊 艳 搁 浅 的 阳 光 轻 熟 的 优 雅 +history 火 焰 喷 射 器 1 0 0 0 度 火 焰 烧 死 鬼 子 4 连 拍 +society 1 8 岁 青 年 砍 死 8 8 岁 老 兵 +fashion 醋 洗 脸 的 正 确 方 法 洗 对 了 不 仅 美 容 肌 肤 还 能 收 缩 毛 孔 +game 大 家 都 说 说 除 了 这 1 0 个 英 雄 , L O L 还 有 哪 些 英 雄 可 以 单 挑 男 爵 +sports 王 仕 鹏 退 役 担 任 N B A 总 决 赛 现 场 解 说 嘉 宾 +regimen 天 天 吃 “ 洋 快 餐 ” , 5 岁 女 童 患 上 肝 炎 +food 汤 里 的 蛋 花 怎 样 才 能 如 花 朵 般 漂 亮 , 注 意 这 一 点 即 可 ! +tech 英 退 休 人 士 把 谷 歌 当 活 人 以 礼 貌 搜 索 请 求 征 服 整 个 互 联 网 +discovery N A S A 探 测 器 拍 摄 地 球 、 火 星 和 冥 王 星 合 影 +society 当 骗 子 遇 上 撒 贝 宁 ! 几 句 话 过 后 骗 子 赔 礼 道 歉 . . . . . +history 红 军 长 征 在 中 国 革 命 史 上 的 地 位 +world 实 拍 神 秘 之 国 , 带 你 走 进 真 实 的 朝 鲜 +tech 逼 格 爆 表 ! 古 文 版 2 0 1 6 网 络 热 词 : 燃 尽 洪 荒 之 力 +story 因 为 一 样 东 西 这 个 后 娘 竟 然 给 孩 子 磕 头 +game L O L : 皮 肤 对 操 作 没 影 响 ? 细 数 那 些 有 加 成 效 果 的 皮 肤 +fashion 冬 天 想 穿 裙 子 又 怕 冷 ? 学 了 这 些 搭 配 就 能 好 看 又 温 暖 ! +entertainment 贾 建 军 少 林 三 光 剑 视 频 +food 再 也 不 用 出 去 吃 羊 肉 串 , 自 己 做 又 卫 生 又 健 康 +regimen 男 人 多 吃 这 几 道 菜 , 效 果 胜 “ 伟 哥 ” +baby 宝 贝 厨 房 丨 肉 类 辅 食 第 一 步 宝 宝 的 生 长 发 育 每 天 都 离 不 开 它 ! +travel 近 8 0 亿 的 顶 级 豪 华 邮 轮 上 到 底 有 什 么 ? +sports 厄 齐 尔 心 中 最 想 签 约 的 三 个 人 +food 东 北 的 粘 豆 包 啊 , 想 死 你 们 了 ! +military 强 军 足 音 +sports 奥 运 赛 场 上 , 被 喷 子 痛 批 的 十 大 知 名 运 动 员 +game 老 玩 家 分 享 对 2 0 1 6 L P L 夏 季 赛 R N G 的 分 析 +military 揭 秘 : 关 于 战 争 的 五 大 真 相 , 不 要 再 被 影 视 所 欺 骗 了 ! +food 小 丫 厨 房 : 夏 天 怎 么 吃 辣 不 长 痘 ? 告 诉 你 火 锅 鸡 、 香 辣 鱼 的 正 确 做 法 +travel 中 国 首 个 内 陆 城 市 群 上 的 9 座 城 市 , 看 看 有 你 的 家 乡 吗 +fashion 李 小 璐 做 榜 样 接 亲 吻 脚 大 流 行 新 娘 玉 足 怎 样 才 有 好 味 道 ? +game 黄 金 吊 打 钻 石 ? L O L 最 强 刷 钱 毒 瘤 打 法 诞 生 +history 奇 事 ! 上 万 只 青 蛙 拦 路 告 状 , 竟 然 牵 扯 出 一 桩 命 案 +baby 奶 奶 , 你 为 什 么 不 让 我 用 尿 不 湿 +game L O L 当 5 个 大 发 明 家 炮 台 围 住 泉 水 的 时 候 : 这 是 真 虐 泉 ! +essay 文 友 忠 告 暖 人 心 : 人 到 中 年 “ 不 交 五 友 ” +travel 这 一 年 , 我 们 去 日 本 +food 好 吃 早 饭 近 似 吃 补 药 +fashion 夏 天 太 热 , 唇 膏 化 了 如 何 办 ? +society 厂 里 面 的 9 0 后 打 工 妹 , 辛 苦 来 之 不 易 +history 罕 见 老 照 片 展 示 美 国 大 萧 条 时 期 景 象 +world 美 国 总 统 奥 巴 马 , 是 童 心 未 泯 的 温 情 奥 大 大 , 还 是 个 超 级 老 顽 童 +finance 脱 欧 公 投 前 一 天 抛 售 英 镑 这 一 次 索 罗 斯 也 被 “ 打 败 ” 了 . . . +history 翻 越 长 征 路 上 第 一 座 大 山 +world 朝 鲜 批 奥 巴 马 涉 朝 言 论 , 称 只 要 核 威 胁 存 在 将 继 续 强 化 核 武 力 量 +game 《 巫 师 3 : 狂 猎 》 不 良 因 素 解 析 攻 略 +travel 在 郑 州 有 个 地 方 , 时 光 仿 佛 在 那 儿 停 下 脚 步 +history 它 号 称 “ 天 下 第 一 团 ” , 走 出 过 1 4 位 共 和 国 将 军 以 及 一 位 著 名 作 家 +car 煤 老 板 去 黄 江 买 车 , 以 为 占 了 便 宜 没 想 被 坑 了 1 0 0 多 万 +society “ 试 管 婴 儿 之 母 ” 张 丽 珠 遗 体 告 别 仪 式 8 日 举 行 +sports 东 京 奥 运 会 , 中 国 女 排 卫 冕 的 几 率 有 多 大 ? +travel 成 都 我 们 永 远 依 恋 的 城 市 +tech 雷 布 斯 除 了 小 米 还 有 这 些 秘 密 , 你 知 道 吗 ? +world “ 仲 裁 庭 损 害 国 际 法 体 系 公 正 性 ” — — 访 武 汉 大 学 中 国 边 界 与 海 洋 研 究 院 首 席 专 家 易 显 河 +entertainment 上 海 观 众 和 欧 洲 三 大 影 展 之 间 的 距 离 : 零 时 差 +essay 关 系 好 , 一 切 便 好 +baby 刚 出 生 不 到 1 小 时 的 白 鲸 宝 宝 被 冲 上 岸 , 被 救 后 对 恩 人 露 出 微 笑 +tech 赚 足 眼 球 , 诺 基 亚 五 边 形 W i n 1 0 M o b i l e 概 念 手 机 : 棱 镜 +essay 2 4 句 经 典 语 录 : 穷 三 年 可 以 怨 命 , 穷 十 年 就 得 自 省 +food 这 道 菜 真 下 饭 ! 做 法 简 单 , 防 辐 射 、 抗 衰 老 , 关 键 还 便 宜 +entertainment 《 继 承 者 们 》 要 拍 中 国 版 , 众 角 色 你 期 待 谁 来 演 ? +game D N F 暴 走 改 版 后 怎 么 样 D N F 暴 走 改 版 红 眼 变 弱 了 吗 +entertainment 郑 佩 佩 自 曝 与 李 小 龙 的 过 去 他 是 个 “ 疯 子 ” +baby 女 性 只 有 8 4 次 最 佳 受 孕 机 会 +travel 月 初 一 个 人 去 了 日 本 . . +military 不 为 人 知 的 8 0 万 苏 联 女 兵 ! 最 后 一 张 很 美 ! +tech 网 络 商 家 提 供 小 米 5 运 存 升 级 服 务 : 3 G B 秒 变 6 G B +history 宋 太 祖 、 宋 太 宗 凌 辱 亡 国 皇 后 , 徽 钦 二 帝 后 宫 被 金 人 凌 辱 +history 人 有 三 面 最 “ 难 吃 ” ! 黑 帮 大 佬 杜 月 笙 论 江 湖 规 矩 ! 一 生 只 怕 这 一 人 +game 来 了 ! 索 尼 P S 4 独 占 大 作 《 战 神 4 》 正 式 公 布 +discovery 延 时 视 频 显 示 珊 瑚 如 何 “ 驱 逐 ” 共 生 藻 类 +car 传 祺 G A 8 和 东 风 A 9 谁 才 是 自 主 “ 豪 车 ” 大 佬 +fashion 娶 老 婆 就 要 娶 这 种 ! 蒋 欣 这 样 微 胖 的 女 人 好 看 又 实 用 +sports 黄 山 姑 娘 吕 秀 芝 勇 夺 奥 运 铜 牌 数 百 父 老 彻 夜 为 她 加 油 +military [ 每 日 军 图 ] 土 豪 补 仓 ! 沙 特 再 次 购 买 上 百 辆 美 国 M 1 A 2 主 战 坦 克 +military 美 军 这 款 武 器 号 称 能 让 半 个 中 国 陷 入 黑 暗 , 解 放 军 少 将 : 我 们 也 有 +world 邓 小 平 与 日 本 天 皇 的 历 史 性 会 谈 , 对 中 日 两 国 都 具 有 深 远 的 意 义 啊 ! +baby 为 什 么 有 人 上 个 厕 所 都 能 生 出 孩 子 ? +fashion 欣 宜 举 行 首 次 个 唱 十 万 颗 宝 仕 奥 莎 仿 水 晶 闪 耀 全 场 +food 小 两 口 上 周 的 晚 餐 +society 在 北 京 就 要 守 规 矩 +entertainment 知 情 人 曝 翰 爽 分 手 内 幕 : 郑 爽 想 结 婚 却 被 一 直 拖 着 +military 中 国 反 舰 导 弹 世 界 第 一 远 远 超 过 美 国 但 为 何 却 还 不 如 俄 罗 斯 ? +entertainment 他 除 了 是 《 我 歌 》 音 乐 总 监 , 还 曾 组 乐 队 玩 摇 滚 , 是 黄 家 驹 旧 日 知 己 +baby 长 鹅 口 疮 的 孩 子 怎 么 照 顾 ? 不 要 再 说 拿 他 没 办 法 了 ! +discovery 微 重 力 不 需 使 用 肌 肉 , 太 空 人 返 回 地 球 后 脊 椎 旁 肌 肉 萎 缩 约 1 9 % +regimen 这 6 种 人 将 来 会 得 老 年 痴 呆 ! 预 防 老 年 痴 呆 症 , 这 些 办 法 被 全 世 界 公 认 +society 2 0 1 6 年 上 海 即 将 发 生 哪 些 大 事 件 。 。 。 。 +car 北 汽 自 主 品 牌 亏 损 3 3 . 4 1 亿 额 外 促 销 成 主 因 +car 在 那 山 的 那 边 海 的 那 边 , 有 一 群 自 由 侠 +history 一 个 小 城 就 屠 杀 了 4 0 0 0 苏 军 战 俘 , 希 特 勒 死 神 战 队 的 崛 起 与 覆 灭 +baby 给 孩 子 洗 澡 时 , 这 些 部 位 再 脏 也 不 要 碰 ! +essay 好 久 不 见 , 你 还 好 么 +baby 被 娃 误 伤 的 9 种 痛 , 数 一 数 你 中 了 几 枪 ? +food 初 秋 的 小 炖 品 放 冰 糖 就 比 较 滋 润 , 放 红 糖 就 补 血 又 不 燥 热 +game 佩 服 佩 服 ! 羊 驼 D e f t 单 排 重 回 韩 服 最 强 王 者 第 一 名 ! +game 三 个 时 代 的 标 志 炉 石 传 说 三 大 远 古 毒 瘤 卡 组 +discovery 2 0 世 纪 最 伟 大 科 学 发 现 — — 魔 术 般 的 超 导 材 料 ! \ No newline at end of file diff --git a/test/seq_labeling.py b/test/seq_labeling.py index ce31f0e8..904720ca 100644 --- a/test/seq_labeling.py +++ b/test/seq_labeling.py @@ -3,7 +3,7 @@ import sys sys.path.append("..") from fastNLP.loader.config_loader import ConfigLoader, ConfigSection -from fastNLP.core.trainer import POSTrainer +from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.loader.dataset_loader import POSDatasetLoader, BaseLoader from fastNLP.loader.preprocess import POSPreprocess, load_pickle from fastNLP.saver.model_saver import ModelSaver @@ -73,7 +73,7 @@ def train_and_test(): train_args["num_classes"] = p.num_classes # Trainer - trainer = POSTrainer(train_args) + trainer = SeqLabelTrainer(train_args) # Model model = SeqLabeling(train_args) @@ -112,5 +112,5 @@ def train_and_test(): if __name__ == "__main__": - # train_and_test() + train_and_test() infer() diff --git a/test/test_cws.py b/test/test_cws.py index 8f6c1211..5d0dc3c2 100644 --- a/test/test_cws.py +++ b/test/test_cws.py @@ -3,7 +3,7 @@ import sys sys.path.append("..") from fastNLP.loader.config_loader import ConfigLoader, ConfigSection -from fastNLP.core.trainer import POSTrainer +from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader from fastNLP.loader.preprocess import POSPreprocess, load_pickle from fastNLP.saver.model_saver import ModelSaver @@ -73,7 +73,7 @@ def train_test(): train_args["num_classes"] = p.num_classes # Trainer - trainer = POSTrainer(train_args) + trainer = SeqLabelTrainer(train_args) # Model model = SeqLabeling(train_args) @@ -113,4 +113,4 @@ def train_test(): if __name__ == "__main__": train_test() - #infer() + infer() diff --git a/test/text_classify.py b/test/text_classify.py new file mode 100644 index 00000000..88d11b02 --- /dev/null +++ b/test/text_classify.py @@ -0,0 +1,44 @@ +# Python: 3.5 +# encoding: utf-8 + +import os + +from fastNLP.core.trainer import ClassTrainer +from fastNLP.loader.dataset_loader import ClassDatasetLoader +from fastNLP.loader.preprocess import ClassPreprocess +from fastNLP.models.cnn_text_classification import CNNText + +if __name__ == "__main__": + data_dir = "./data_for_tests/" + train_file = 'text_classify.txt' + model_name = "model_class.pkl" + + # load dataset + print("Loading data...") + ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file)) + data = ds_loader.load() + print(data[0]) + + # pre-process data + pre = ClassPreprocess(data_dir) + vocab_size, n_classes = pre.process(data, "data_train.pkl") + print("vocabulary size:", vocab_size) + print("number of classes:", n_classes) + + # construct model + print("Building model...") + cnn = CNNText(class_num=n_classes, embed_num=vocab_size) + + # train + print("Training...") + train_args = { + "epochs": 1, + "batch_size": 10, + "pickle_path": data_dir, + "validate": False, + "save_best_dev": False, + "model_saved_path": "./data_for_tests/", + "use_cuda": True + } + trainer = ClassTrainer(train_args) + trainer.train(cnn)