Browse Source

update

tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
9aad8dff6e
3 changed files with 87 additions and 4 deletions
  1. +78
    -0
      fastNLP/core/action.py
  2. +2
    -2
      fastNLP/core/trainer.py
  3. +7
    -2
      fastNLP/loader/dataset_loader.py

+ 78
- 0
fastNLP/core/action.py View File

@@ -1,3 +1,5 @@
from collections import Counter

import numpy as np


@@ -10,6 +12,63 @@ class Action(object):
super(Action, self).__init__()


def k_means_1d(x, k, max_iter=100):
"""

:param x: list of int, representing points in 1-D.
:param k: the number of clusters required.
:param max_iter: maximum iteration
:return centroids: numpy array, centroids of the k clusters
assignment: numpy array, 1-D, the bucket id assigned to each example.
"""
sorted_x = sorted(list(set(x)))
if len(sorted_x) < k:
raise ValueError("too few buckets")
gap = len(sorted_x) / k

centroids = np.array([sorted_x[int(x * gap)] for x in range(k)])
assign = None

for i in range(max_iter):
# Cluster Assignment step
assign = np.array([np.argmin([np.absolute(x_i - x) for x in centroids]) for x_i in x])
# Move centroids step
new_centroids = np.array([x[assign == k].mean() for k in range(k)])
if (new_centroids == centroids).all():
centroids = new_centroids
break
centroids = new_centroids
return np.array(centroids), assign


def k_means_bucketing(all_inst, buckets):
"""
:param all_inst: 3-level list
[
[[word_11, word_12, word_13], [label_11. label_12]], # sample 1
[[word_21, word_22, word_23], [label_21. label_22]], # sample 2
...
]
:param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length
threshold for each bucket (This is usually None.).
:return data: 2-level list
[
[index_11, index_12, ...], # bucket 1
[index_21, index_22, ...], # bucket 2
...
]
"""
bucket_data = [[] for _ in buckets]
num_buckets = len(buckets)
lengths = np.array([len(inst[0]) for inst in all_inst])
_, assignments = k_means_1d(lengths, num_buckets)

for idx, bucket_id in enumerate(assignments):
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
bucket_data[bucket_id].append(idx)
return bucket_data


class BaseSampler(object):
"""
Base class for all samplers.
@@ -49,6 +108,25 @@ class RandomSampler(BaseSampler):
return iter(np.random.permutation(self.data_set_length))


class BucketSampler(BaseSampler):
"""
Partition all samples into multiple buckets, each of which contains sentences of approximately the same length.
In sampling, first random choose a bucket. Then sample data from it.
The number of buckets is decided dynamically by the variance of sentence lengths.
"""

def __init__(self, data_set):
super(BucketSampler, self).__init__(data_set)
BUCKETS = ([None] * 10)
self.length_freq = dict(Counter([len(example) for example in data_set]))
self.buckets = k_means_bucketing(data_set, BUCKETS)

def __iter__(self):
bucket_samples = self.buckets[np.random.randint(0, len(self.buckets) + 1)]
np.random.shuffle(bucket_samples)
return iter(bucket_samples)


class Batchifier(object):
"""
Wrap random or sequential sampler to generate a mini-batch.


+ 2
- 2
fastNLP/core/trainer.py View File

@@ -8,7 +8,7 @@ import torch
import torch.nn as nn

from fastNLP.core.action import Action
from fastNLP.core.action import RandomSampler, Batchifier
from fastNLP.core.action import RandomSampler, Batchifier, BucketSampler
from fastNLP.core.tester import POSTester
from fastNLP.saver.model_saver import ModelSaver

@@ -89,7 +89,7 @@ class BaseTrainer(Action):

# turn on network training mode; define optimizer; prepare batch iterator
self.mode(test=False)
self.iterator = iter(Batchifier(RandomSampler(data_train), self.batch_size, drop_last=True))
self.iterator = iter(Batchifier(BucketSampler(data_train), self.batch_size, drop_last=True))

# training iterations in one epoch
for step in range(iterations):


+ 7
- 2
fastNLP/loader/dataset_loader.py View File

@@ -86,7 +86,7 @@ class TokenizeDatasetLoader(DatasetLoader):
def __init__(self, data_name, data_path):
super(TokenizeDatasetLoader, self).__init__(data_name, data_path)

def load_pku(self):
def load_pku(self, max_seq_len=64):
"""
load pku dataset for Chinese word segmentation
CWS (Chinese Word Segmentation) pku training dataset format:
@@ -98,8 +98,11 @@ class TokenizeDatasetLoader(DatasetLoader):
E: ending of a word
S: single character

:param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into
several sequences.
:return: three-level lists
"""
assert isinstance(max_seq_len, int) and max_seq_len > 0
with open(self.data_path, "r", encoding="utf-8") as f:
sentences = f.readlines()
data = []
@@ -107,7 +110,9 @@ class TokenizeDatasetLoader(DatasetLoader):
words = []
labels = []
tokens = sent.strip().split()
for token in tokens:
for start in range(len(tokens) // max_seq_len):

for token in token_seq:
if len(token) == 1:
words.append(token)
labels.append("S")


Loading…
Cancel
Save