Browse Source

- fix Dataset & Trainer

- update CNNText model
tags/v0.2.0
yunfan 6 years ago
parent
commit
74a697651e
3 changed files with 32 additions and 33 deletions
  1. +7
    -8
      fastNLP/core/dataset.py
  2. +7
    -10
      fastNLP/core/trainer.py
  3. +18
    -15
      fastNLP/models/cnn_text_classification.py

+ 7
- 8
fastNLP/core/dataset.py View File

@@ -254,19 +254,18 @@ class DataSet(object):
:param str new_field_name: If not None, results of the function will be stored as a new field. :param str new_field_name: If not None, results of the function will be stored as a new field.
:return results: returned values of the function over all instances. :return results: returned values of the function over all instances.
""" """
results = []
for ins in self:
results.append(func(ins))
results = [func(ins) for ins in self]
if new_field_name is not None: if new_field_name is not None:
if new_field_name in self.field_arrays: if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes # overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name] old_field = self.field_arrays[new_field_name]
padding_val = old_field.padding_val
need_tensor = old_field.need_tensor
is_target = old_field.is_target
self.add_field(new_field_name, results, padding_val, need_tensor, is_target)
self.add_field(name=new_field_name,
fields=results,
padding_val=old_field.padding_val,
is_input=old_field.is_input,
is_target=old_field.is_target)
else: else:
self.add_field(new_field_name, results)
self.add_field(name=new_field_name, fields=results)
else: else:
return results return results




+ 7
- 10
fastNLP/core/trainer.py View File

@@ -1,10 +1,6 @@
import time import time
from datetime import timedelta
from datetime import datetime

import warnings
from collections import defaultdict

rom datetime import timedelta, datetime
import os
import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter


@@ -28,7 +24,7 @@ class Trainer(object):


""" """


def __init__(self, train_data, model, n_epochs=1, batch_size=32, print_every=-1,
def __init__(self, train_data, model, n_epochs, batch_size, n_print=1,
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save", dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
evaluator=Evaluator(), evaluator=Evaluator(),
@@ -56,7 +52,7 @@ class Trainer(object):
for k, v in kwargs.items(): for k, v in kwargs.items():
setattr(self, k, v) setattr(self, k, v)


self._summary_writer = SummaryWriter(self.save_path + 'tensorboard_logs')
self._summary_writer = SummaryWriter(os.path.join(self.save_path, 'tensorboard_logs'))
self._graph_summaried = False self._graph_summaried = False
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp
@@ -112,9 +108,9 @@ class Trainer(object):
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if self.print_every > 0 and self.step % self.print_every == 0:
if n_print > 0 and self.step % n_print == 0:
end = time.time() end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff) epoch, self.step, loss.data, diff)
print(print_output) print(print_output)
@@ -177,6 +173,7 @@ class Trainer(object):
return self.loss_func(predict, truth) return self.loss_func(predict, truth)


def save_model(self, model, model_name, only_param=False): def save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)
if only_param: if only_param:
torch.save(model.state_dict(), model_name) torch.save(model.state_dict(), model_name)
else: else:


+ 18
- 15
fastNLP/models/cnn_text_classification.py View File

@@ -15,25 +15,25 @@ class CNNText(torch.nn.Module):
Classification.' Classification.'
""" """


def __init__(self, args):
def __init__(self, embed_num,
embed_dim,
num_classes,
kernel_nums=(3,4,5),
kernel_sizes=(3,4,5),
padding=0,
dropout=0.5):
super(CNNText, self).__init__() super(CNNText, self).__init__()


num_classes = args["num_classes"]
kernel_nums = [100, 100, 100]
kernel_sizes = [3, 4, 5]
vocab_size = args["vocab_size"]
embed_dim = 300
pretrained_embed = None
drop_prob = 0.5

# no support for pre-trained embedding currently # no support for pre-trained embedding currently
self.embed = encoder.embedding.Embedding(vocab_size, embed_dim)
self.conv_pool = encoder.conv_maxpool.ConvMaxpool(
self.embed = encoder.Embedding(embed_num, embed_dim)
self.conv_pool = encoder.ConvMaxpool(
in_channels=embed_dim, in_channels=embed_dim,
out_channels=kernel_nums, out_channels=kernel_nums,
kernel_sizes=kernel_sizes)
self.dropout = nn.Dropout(drop_prob)
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes)
kernel_sizes=kernel_sizes,
padding=padding)
self.dropout = nn.Dropout(dropout)
self.fc = encoder.Linear(sum(kernel_nums), num_classes)
self._loss = nn.CrossEntropyLoss()


def forward(self, word_seq): def forward(self, word_seq):
""" """
@@ -44,4 +44,7 @@ class CNNText(torch.nn.Module):
x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x) x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class] x = self.fc(x) # [N,C] -> [N, N_class]
return x
return {'output':x}

def loss(self, output, label_seq):
return self._loss(output, label_seq)

Loading…
Cancel
Save