Browse Source

fix trainer & dataset

tags/v0.2.0
yunfan 6 years ago
parent
commit
4f587f7561
3 changed files with 17 additions and 6 deletions
  1. +8
    -1
      fastNLP/core/dataset.py
  2. +9
    -3
      fastNLP/core/trainer.py
  3. +0
    -2
      fastNLP/modules/encoder/conv_maxpool.py

+ 8
- 1
fastNLP/core/dataset.py View File

@@ -1,4 +1,5 @@
import numpy as np import numpy as np
from copy import copy


from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
@@ -37,7 +38,7 @@ class DataSet(object):
self.idx += 1 self.idx += 1
if self.idx >= len(self.dataset): if self.idx >= len(self.dataset):
raise StopIteration raise StopIteration
return self
return copy(self)


def add_field(self, field_name, field): def add_field(self, field_name, field):
"""Add a new field to the instance. """Add a new field to the instance.
@@ -270,6 +271,12 @@ class DataSet(object):
else: else:
return results return results


def drop(self, func):
results = [ins for ins in self if not func(ins)]
for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results]
# print(self.field_arrays[name])

def split(self, dev_ratio): def split(self, dev_ratio):
"""Split the dataset into training and development(validation) set. """Split the dataset into training and development(validation) set.




+ 9
- 3
fastNLP/core/trainer.py View File

@@ -27,7 +27,7 @@ class Trainer(object):
""" """
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save", dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), need_check_code=True,
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs): **kwargs):
super(Trainer, self).__init__() super(Trainer, self).__init__()


@@ -84,7 +84,14 @@ class Trainer(object):
start = time.time() start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time) print("training epochs started " + self.start_time)
if self.save_path is not None:
if self.save_path is None:
class psudoSW:
def __getattr__(self, item):
def pass_func(*args, **kwargs):
pass
return pass_func
self._summary_writer = psudoSW()
else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)


@@ -98,7 +105,6 @@ class Trainer(object):
# validate_every override validation at end of epochs # validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0: if self.dev_data and self.validate_every <= 0:
self.do_validation() self.do_validation()
self.save_model(self.model, 'training_model_' + self.start_time)
epoch += 1 epoch += 1
finally: finally:
self._summary_writer.close() self._summary_writer.close()


+ 0
- 2
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -34,8 +34,6 @@ class ConvMaxpool(nn.Module):
bias=bias) bias=bias)
for oc, ks in zip(out_channels, kernel_sizes)]) for oc, ks in zip(out_channels, kernel_sizes)])


for conv in self.convs:
xavier_uniform_(conv.weight) # weight initialization
else: else:
raise Exception( raise Exception(
'Incorrect kernel sizes: should be list, tuple or int') 'Incorrect kernel sizes: should be list, tuple or int')


Loading…
Cancel
Save