Browse Source

- fix Dataset & Trainer

- update CNNText model
tags/v0.2.0
yunfan 6 years ago
parent
commit
74a697651e
3 changed files with 32 additions and 33 deletions
  1. +7
    -8
      fastNLP/core/dataset.py
  2. +7
    -10
      fastNLP/core/trainer.py
  3. +18
    -15
      fastNLP/models/cnn_text_classification.py

+ 7
- 8
fastNLP/core/dataset.py View File

@@ -254,19 +254,18 @@ class DataSet(object):
:param str new_field_name: If not None, results of the function will be stored as a new field.
:return results: returned values of the function over all instances.
"""
results = []
for ins in self:
results.append(func(ins))
results = [func(ins) for ins in self]
if new_field_name is not None:
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name]
padding_val = old_field.padding_val
need_tensor = old_field.need_tensor
is_target = old_field.is_target
self.add_field(new_field_name, results, padding_val, need_tensor, is_target)
self.add_field(name=new_field_name,
fields=results,
padding_val=old_field.padding_val,
is_input=old_field.is_input,
is_target=old_field.is_target)
else:
self.add_field(new_field_name, results)
self.add_field(name=new_field_name, fields=results)
else:
return results



+ 7
- 10
fastNLP/core/trainer.py View File

@@ -1,10 +1,6 @@
import time
from datetime import timedelta
from datetime import datetime

import warnings
from collections import defaultdict

rom datetime import timedelta, datetime
import os
import torch
from tensorboardX import SummaryWriter

@@ -28,7 +24,7 @@ class Trainer(object):

"""

def __init__(self, train_data, model, n_epochs=1, batch_size=32, print_every=-1,
def __init__(self, train_data, model, n_epochs, batch_size, n_print=1,
dev_data=None, use_cuda=False, loss=Loss(None), save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0),
evaluator=Evaluator(),
@@ -56,7 +52,7 @@ class Trainer(object):
for k, v in kwargs.items():
setattr(self, k, v)

self._summary_writer = SummaryWriter(self.save_path + 'tensorboard_logs')
self._summary_writer = SummaryWriter(os.path.join(self.save_path, 'tensorboard_logs'))
self._graph_summaried = False
self.step = 0
self.start_time = None # start timestamp
@@ -112,9 +108,9 @@ class Trainer(object):
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if self.print_every > 0 and self.step % self.print_every == 0:
if n_print > 0 and self.step % n_print == 0:
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff)
print(print_output)
@@ -177,6 +173,7 @@ class Trainer(object):
return self.loss_func(predict, truth)

def save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)
else:


+ 18
- 15
fastNLP/models/cnn_text_classification.py View File

@@ -15,25 +15,25 @@ class CNNText(torch.nn.Module):
Classification.'
"""

def __init__(self, args):
def __init__(self, embed_num,
embed_dim,
num_classes,
kernel_nums=(3,4,5),
kernel_sizes=(3,4,5),
padding=0,
dropout=0.5):
super(CNNText, self).__init__()

num_classes = args["num_classes"]
kernel_nums = [100, 100, 100]
kernel_sizes = [3, 4, 5]
vocab_size = args["vocab_size"]
embed_dim = 300
pretrained_embed = None
drop_prob = 0.5

# no support for pre-trained embedding currently
self.embed = encoder.embedding.Embedding(vocab_size, embed_dim)
self.conv_pool = encoder.conv_maxpool.ConvMaxpool(
self.embed = encoder.Embedding(embed_num, embed_dim)
self.conv_pool = encoder.ConvMaxpool(
in_channels=embed_dim,
out_channels=kernel_nums,
kernel_sizes=kernel_sizes)
self.dropout = nn.Dropout(drop_prob)
self.fc = encoder.linear.Linear(sum(kernel_nums), num_classes)
kernel_sizes=kernel_sizes,
padding=padding)
self.dropout = nn.Dropout(dropout)
self.fc = encoder.Linear(sum(kernel_nums), num_classes)
self._loss = nn.CrossEntropyLoss()

def forward(self, word_seq):
"""
@@ -44,4 +44,7 @@ class CNNText(torch.nn.Module):
x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]
return x
return {'output':x}

def loss(self, output, label_seq):
return self._loss(output, label_seq)

Loading…
Cancel
Save