diff --git a/.travis.yml b/.travis.yml
index eb5cc5cd..11239eb4 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -5,7 +5,6 @@ python:
install:
- pip install --quiet -r requirements.txt
- pip install pytest pytest-cov
- - pip install -U scikit-learn
# command to run tests
script:
- pytest --cov=./
diff --git a/README.md b/README.md
index 84d658fd..8169520a 100644
--- a/README.md
+++ b/README.md
@@ -30,77 +30,36 @@ Run the following commands to install fastNLP package.
pip install fastNLP
```
-### Cloning From GitHub
-
-If you just want to use fastNLP, use:
-```shell
-git clone https://github.com/fastnlp/fastNLP
-cd fastNLP
-```
-
-### PyTorch Installation
-
-Visit the [PyTorch official website] for installation instructions based on your system. In general, you could use:
-```shell
-# using conda
-conda install pytorch torchvision -c pytorch
-# or using pip
-pip3 install torch torchvision
-```
-
-### TensorboardX Installation
-
-```shell
-pip3 install tensorboardX
-```
## Project Structure
-```
-FastNLP
-├── docs
-├── fastNLP
-│ ├── core
-│ │ ├── action.py
-│ │ ├── __init__.py
-│ │ ├── loss.py
-│ │ ├── metrics.py
-│ │ ├── optimizer.py
-│ │ ├── predictor.py
-│ │ ├── preprocess.py
-│ │ ├── README.md
-│ │ ├── tester.py
-│ │ └── trainer.py
-│ ├── fastnlp.py
-│ ├── __init__.py
-│ ├── loader
-│ │ ├── base_loader.py
-│ │ ├── config_loader.py
-│ │ ├── dataset_loader.py
-│ │ ├── embed_loader.py
-│ │ ├── __init__.py
-│ │ └── model_loader.py
-│ ├── models
-│ ├── modules
-│ │ ├── aggregation
-│ │ ├── decoder
-│ │ ├── encoder
-│ │ ├── __init__.py
-│ │ ├── interaction
-│ │ ├── other_modules.py
-│ │ └── utils.py
-│ └── saver
-├── LICENSE
-├── README.md
-├── reproduction
-├── requirements.txt
-├── setup.py
-└── test
- ├── core
- ├── data_for_tests
- ├── __init__.py
- ├── loader
- ├── modules
- └── readme_example.py
-
-```
+
+
+ fastNLP |
+ an open-source NLP library |
+
+
+ fastNLP.core |
+ trainer, tester, predictor |
+
+
+ fastNLP.loader |
+ all kinds of loaders/readers |
+
+
+ fastNLP.models |
+ a collection of NLP models |
+
+
+ fastNLP.modules |
+ a collection of PyTorch sub-models/components/wheels |
+
+
+ fastNLP.saver |
+ all kinds of savers/writers |
+
+
+ fastNLP.fastnlp |
+ a high-level interface for prediction |
+
+
\ No newline at end of file
diff --git a/docs/source/user/quickstart.rst b/docs/source/user/quickstart.rst
index 21f0855f..24c7363c 100644
--- a/docs/source/user/quickstart.rst
+++ b/docs/source/user/quickstart.rst
@@ -18,7 +18,7 @@ pre-processing data, constructing model and training model.
from fastNLP.modules import aggregation
from fastNLP.modules import decoder
- from fastNLP.loader.dataset_loader import ClassDatasetLoader
+ from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.loader.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.core.inference import ClassificationInfer
@@ -50,7 +50,7 @@ pre-processing data, constructing model and training model.
train_path = 'test/data_for_tests/text_classify.txt' # training set file
# load dataset
- ds_loader = ClassDatasetLoader("train", train_path)
+ ds_loader = ClassDataSetLoader("train", train_path)
data = ds_loader.load()
# pre-process dataset
diff --git a/examples/readme_example.py b/examples/readme_example.py
index 74e20c57..9da2787b 100644
--- a/examples/readme_example.py
+++ b/examples/readme_example.py
@@ -3,7 +3,7 @@ from fastNLP.core.optimizer import Optimizer
from fastNLP.core.predictor import ClassificationInfer
from fastNLP.core.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
-from fastNLP.loader.dataset_loader import ClassDatasetLoader
+from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import aggregator
from fastNLP.modules import decoder
@@ -36,7 +36,7 @@ data_dir = 'save/' # directory to save data and model
train_path = './data_for_tests/text_classify.txt' # training set file
# load dataset
-ds_loader = ClassDatasetLoader(train_path)
+ds_loader = ClassDataSetLoader()
data = ds_loader.load()
# pre-process dataset
diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py
index 98d7c8da..bf837d0f 100644
--- a/fastNLP/core/batch.py
+++ b/fastNLP/core/batch.py
@@ -17,7 +17,7 @@ class Batch(object):
:param dataset: a DataSet object
:param batch_size: int, the size of the batch
:param sampler: a Sampler object
- :param use_cuda: bool, whetjher to use GPU
+ :param use_cuda: bool, whether to use GPU
"""
self.dataset = dataset
@@ -37,15 +37,12 @@ class Batch(object):
"""
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
- batch_x also contains an item (str: list of int) about origin lengths,
- which means ("field_name_origin_len": origin lengths).
E.g.
::
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]})
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True.
- The names of fields are defined in preprocessor's convert_to_dataset method.
"""
if self.curidx >= len(self.idx_list):
@@ -54,10 +51,9 @@ class Batch(object):
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
padding_length = {field_name: max(field_length[self.curidx: endidx])
for field_name, field_length in self.lengths.items()}
- origin_lengths = {field_name: field_length[self.curidx: endidx]
- for field_name, field_length in self.lengths.items()}
-
batch_x, batch_y = defaultdict(list), defaultdict(list)
+
+ # transform index to tensor and do padding for sequences
for idx in range(self.curidx, endidx):
x, y = self.dataset.to_tensor(idx, padding_length)
for name, tensor in x.items():
@@ -65,8 +61,7 @@ class Batch(object):
for name, tensor in y.items():
batch_y[name].append(tensor)
- batch_origin_length = {}
- # combine instances into a batch
+ # combine instances to form a batch
for batch in (batch_x, batch_y):
for name, tensor_list in batch.items():
if self.use_cuda:
@@ -74,14 +69,6 @@ class Batch(object):
else:
batch[name] = torch.stack(tensor_list, dim=0)
- # add origin lengths in batch_x
- for name, tensor in batch_x.items():
- if self.use_cuda:
- batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda()
- else:
- batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name])
- batch_x.update(batch_origin_length)
-
self.curidx = endidx
return batch_x, batch_y
diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py
index bb1a1890..13370969 100644
--- a/fastNLP/core/dataset.py
+++ b/fastNLP/core/dataset.py
@@ -1,7 +1,12 @@
+import random
+import sys
from collections import defaultdict
+from copy import deepcopy
-from fastNLP.core.field import TextField
+from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
+from fastNLP.core.vocabulary import Vocabulary
+from fastNLP.loader.dataset_loader import POSDataSetLoader, ClassDataSetLoader
def create_dataset_from_lists(str_lists: list, word_vocab: dict, has_target: bool = False, label_vocab: dict = None):
@@ -65,17 +70,19 @@ class DataSet(list):
"""A DataSet object is a list of Instance objects.
"""
- def __init__(self, name="", instances=None):
+
+ def __init__(self, name="", instances=None, load_func=None):
"""
:param name: str, the name of the dataset. (default: "")
:param instances: list of Instance objects. (default: None)
-
+ :param load_func: a function that takes the dataset path (string) as input and returns multi-level lists.
"""
list.__init__([])
self.name = name
if instances is not None:
self.extend(instances)
+ self.data_set_load_func = load_func
def index_all(self, vocab):
for ins in self:
@@ -109,3 +116,191 @@ class DataSet(list):
for field_name, field_length in ins.get_length().items():
lengths[field_name].append(field_length)
return lengths
+
+ def convert(self, data):
+ """Convert lists of strings into Instances with Fields, creating Vocabulary for labeled data. Used in Training."""
+ raise NotImplementedError
+
+ def convert_with_vocabs(self, data, vocabs):
+ """Convert lists of strings into Instances with Fields, using existing Vocabulary, with labels. Used in Testing."""
+ raise NotImplementedError
+
+ def convert_for_infer(self, data, vocabs):
+ """Convert lists of strings into Instances with Fields, using existing Vocabulary, without labels. Used in predicting."""
+
+ def load(self, data_path, vocabs=None, infer=False):
+ """Load data from the given files.
+
+ :param data_path: str, the path to the data
+ :param infer: bool. If True, there is no label information in the data. Default: False.
+ :param vocabs: dict of (name: Vocabulary object), used to index data. If not provided, a new vocabulary will be constructed.
+
+ """
+ raw_data = self.data_set_load_func(data_path)
+ if infer is True:
+ self.convert_for_infer(raw_data, vocabs)
+ else:
+ if vocabs is not None:
+ self.convert_with_vocabs(raw_data, vocabs)
+ else:
+ self.convert(raw_data)
+
+ def load_raw(self, raw_data, vocabs):
+ """Load raw data without loader. Used in FastNLP class.
+
+ :param raw_data:
+ :param vocabs:
+ :return:
+ """
+ self.convert_for_infer(raw_data, vocabs)
+
+ def split(self, ratio, shuffle=True):
+ """Train/dev splitting
+
+ :param ratio: float, between 0 and 1. The ratio of development set in origin data set.
+ :param shuffle: bool, whether shuffle the data set before splitting. Default: True.
+ :return train_set: a DataSet object, representing the training set
+ dev_set: a DataSet object, representing the validation set
+
+ """
+ assert 0 < ratio < 1
+ if shuffle:
+ random.shuffle(self)
+ split_idx = int(len(self) * ratio)
+ dev_set = deepcopy(self)
+ train_set = deepcopy(self)
+ del train_set[:split_idx]
+ del dev_set[split_idx:]
+ return train_set, dev_set
+
+
+class SeqLabelDataSet(DataSet):
+ def __init__(self, instances=None, load_func=POSDataSetLoader().load):
+ super(SeqLabelDataSet, self).__init__(name="", instances=instances, load_func=load_func)
+ self.word_vocab = Vocabulary()
+ self.label_vocab = Vocabulary()
+
+ def convert(self, data):
+ """Convert lists of strings into Instances with Fields.
+
+ :param data: 3-level lists. Entries are strings.
+ """
+ bar = ProgressBar(total=len(data))
+ for example in data:
+ word_seq, label_seq = example[0], example[1]
+ # list, list
+ self.word_vocab.update(word_seq)
+ self.label_vocab.update(label_seq)
+ x = TextField(word_seq, is_target=False)
+ x_len = LabelField(len(word_seq), is_target=False)
+ y = TextField(label_seq, is_target=False)
+ instance = Instance()
+ instance.add_field("word_seq", x)
+ instance.add_field("truth", y)
+ instance.add_field("word_seq_origin_len", x_len)
+ self.append(instance)
+ bar.move()
+ self.index_field("word_seq", self.word_vocab)
+ self.index_field("truth", self.label_vocab)
+ # no need to index "word_seq_origin_len"
+
+ def convert_with_vocabs(self, data, vocabs):
+ for example in data:
+ word_seq, label_seq = example[0], example[1]
+ # list, list
+ x = TextField(word_seq, is_target=False)
+ x_len = LabelField(len(word_seq), is_target=False)
+ y = TextField(label_seq, is_target=False)
+ instance = Instance()
+ instance.add_field("word_seq", x)
+ instance.add_field("truth", y)
+ instance.add_field("word_seq_origin_len", x_len)
+ self.append(instance)
+ self.index_field("word_seq", vocabs["word_vocab"])
+ self.index_field("truth", vocabs["label_vocab"])
+ # no need to index "word_seq_origin_len"
+
+ def convert_for_infer(self, data, vocabs):
+ for word_seq in data:
+ # list
+ x = TextField(word_seq, is_target=False)
+ x_len = LabelField(len(word_seq), is_target=False)
+ instance = Instance()
+ instance.add_field("word_seq", x)
+ instance.add_field("word_seq_origin_len", x_len)
+ self.append(instance)
+ self.index_field("word_seq", vocabs["word_vocab"])
+ # no need to index "word_seq_origin_len"
+
+
+class TextClassifyDataSet(DataSet):
+ def __init__(self, instances=None, load_func=ClassDataSetLoader().load):
+ super(TextClassifyDataSet, self).__init__(name="", instances=instances, load_func=load_func)
+ self.word_vocab = Vocabulary()
+ self.label_vocab = Vocabulary(need_default=False)
+
+ def convert(self, data):
+ for example in data:
+ word_seq, label = example[0], example[1]
+ # list, str
+ self.word_vocab.update(word_seq)
+ self.label_vocab.update(label)
+ x = TextField(word_seq, is_target=False)
+ y = LabelField(label, is_target=True)
+ instance = Instance()
+ instance.add_field("word_seq", x)
+ instance.add_field("label", y)
+ self.append(instance)
+ self.index_field("word_seq", self.word_vocab)
+ self.index_field("label", self.label_vocab)
+
+ def convert_with_vocabs(self, data, vocabs):
+ for example in data:
+ word_seq, label = example[0], example[1]
+ # list, str
+ x = TextField(word_seq, is_target=False)
+ y = LabelField(label, is_target=True)
+ instance = Instance()
+ instance.add_field("word_seq", x)
+ instance.add_field("label", y)
+ self.append(instance)
+ self.index_field("word_seq", vocabs["word_vocab"])
+ self.index_field("label", vocabs["label_vocab"])
+
+ def convert_for_infer(self, data, vocabs):
+ for word_seq in data:
+ # list
+ x = TextField(word_seq, is_target=False)
+ instance = Instance()
+ instance.add_field("word_seq", x)
+ self.append(instance)
+ self.index_field("word_seq", vocabs["word_vocab"])
+
+
+def change_field_is_target(data_set, field_name, new_target):
+ """Change the flag of is_target in a field.
+
+ :param data_set: a DataSet object
+ :param field_name: str, the name of the field
+ :param new_target: one of (True, False, None), representing this field is batch_x / is batch_y / neither.
+
+ """
+ for inst in data_set:
+ inst.fields[field_name].is_target = new_target
+
+
+class ProgressBar:
+
+ def __init__(self, count=0, total=0, width=100):
+ self.count = count
+ self.total = total
+ self.width = width
+
+ def move(self):
+ self.count += 1
+ progress = self.width * self.count // self.total
+ sys.stdout.write('{0:3}/{1:3}: '.format(self.count, self.total))
+ sys.stdout.write('#' * progress + '-' * (self.width - progress) + '\r')
+ if progress == self.width:
+ sys.stdout.write('\n')
+ sys.stdout.flush()
diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py
index f5347bd6..b57b9bb6 100644
--- a/fastNLP/core/field.py
+++ b/fastNLP/core/field.py
@@ -59,6 +59,9 @@ class TextField(Field):
class LabelField(Field):
+ """The Field representing a single label. Can be a string or integer.
+
+ """
def __init__(self, label, is_target=True):
super(LabelField, self).__init__(is_target)
self.label = label
@@ -73,13 +76,14 @@ class LabelField(Field):
def index(self, vocab):
if self._index is None:
- self._index = vocab[self.label]
+ if isinstance(self.label, str):
+ self._index = vocab[self.label]
return self._index
def to_tensor(self, padding_length):
if self._index is None:
if isinstance(self.label, int):
- return torch.LongTensor([self.label])
+ return torch.tensor(self.label)
elif isinstance(self.label, str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py
index 32f95197..ebf01912 100644
--- a/fastNLP/core/instance.py
+++ b/fastNLP/core/instance.py
@@ -46,8 +46,11 @@ class Instance(object):
tensor_x = {}
tensor_y = {}
for name, field in self.fields.items():
- if field.is_target:
+ if field.is_target is True:
tensor_y[name] = field.to_tensor(padding_length[name])
- else:
+ elif field.is_target is False:
tensor_x[name] = field.to_tensor(padding_length[name])
+ else:
+ # is_target is None
+ continue
return tensor_x, tensor_y
diff --git a/fastNLP/core/loss.py b/fastNLP/core/loss.py
index 8a0eedd7..16b5eac2 100644
--- a/fastNLP/core/loss.py
+++ b/fastNLP/core/loss.py
@@ -33,10 +33,25 @@ class Loss(object):
"""Given a name of a loss function, return it from PyTorch.
:param loss_name: str, the name of a loss function
+
+ - cross_entropy: combines log softmax and nll loss in a single function.
+ - nll: negative log likelihood
+
:return loss: a PyTorch loss
"""
+
+ class InnerCrossEntropy:
+ """A simple wrapper to guarantee input shapes."""
+
+ def __init__(self):
+ self.f = torch.nn.CrossEntropyLoss()
+
+ def __call__(self, predict, truth):
+ truth = truth.view(-1, )
+ return self.f(predict, truth)
+
if loss_name == "cross_entropy":
- return torch.nn.CrossEntropyLoss()
+ return InnerCrossEntropy()
elif loss_name == 'nll':
return torch.nn.NLLLoss()
else:
diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py
index 7bf4b034..6eedd214 100644
--- a/fastNLP/core/metrics.py
+++ b/fastNLP/core/metrics.py
@@ -4,6 +4,59 @@ import numpy as np
import torch
+class Evaluator(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, predict, truth):
+ """
+
+ :param predict: list of tensors, the network outputs from all batches.
+ :param truth: list of dict, the ground truths from all batch_y.
+ :return:
+ """
+ raise NotImplementedError
+
+
+class ClassifyEvaluator(Evaluator):
+ def __init__(self):
+ super(ClassifyEvaluator, self).__init__()
+
+ def __call__(self, predict, truth):
+ y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
+ y_prob = torch.cat(y_prob, dim=0)
+ y_pred = torch.argmax(y_prob, dim=-1)
+ y_true = torch.cat(truth, dim=0)
+ acc = float(torch.sum(y_pred == y_true)) / len(y_true)
+ return {"accuracy": acc}
+
+
+class SeqLabelEvaluator(Evaluator):
+ def __init__(self):
+ super(SeqLabelEvaluator, self).__init__()
+
+ def __call__(self, predict, truth):
+ """
+
+ :param predict: list of List, the network outputs from all batches.
+ :param truth: list of dict, the ground truths from all batch_y.
+ :return accuracy:
+ """
+ truth = [item["truth"] for item in truth]
+ total_correct, total_count= 0., 0.
+ for x, y in zip(predict, truth):
+ x = torch.Tensor(x)
+ y = y.to(x) # make sure they are in the same device
+ mask = x.ge(1).float()
+ # correct = torch.sum(x * mask.float() == (y * mask.long()).float())
+ correct = torch.sum(x * mask == y * mask)
+ correct -= torch.sum(x.le(0))
+ total_correct += float(correct)
+ total_count += float(torch.sum(mask))
+ accuracy = total_correct / total_count
+ return {"accuracy": float(accuracy)}
+
+
def _conver_numpy(x):
"""convert input data to numpy array
diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py
index 6bbb1bee..14c4e8c1 100644
--- a/fastNLP/core/predictor.py
+++ b/fastNLP/core/predictor.py
@@ -16,43 +16,42 @@ class Predictor(object):
Currently, Predictor does not support GPU.
"""
- def __init__(self, pickle_path, task):
+ def __init__(self, pickle_path, post_processor):
"""
:param pickle_path: str, the path to the pickle files.
- :param task: str, specify which task the predictor will perform. One of ("seq_label", "text_classify").
+ :param post_processor: a function or callable object, that takes list of batch outputs as input
"""
self.batch_size = 1
self.batch_output = []
self.pickle_path = pickle_path
- self._task = task # one of ("seq_label", "text_classify")
- self.label_vocab = load_pickle(self.pickle_path, "class2id.pkl")
+ self._post_processor = post_processor
+ self.label_vocab = load_pickle(self.pickle_path, "label2id.pkl")
self.word_vocab = load_pickle(self.pickle_path, "word2id.pkl")
def predict(self, network, data):
"""Perform inference using the trained model.
:param network: a PyTorch model (cpu)
- :param data: list of list of strings, [num_examples, seq_len]
+ :param data: a DataSet object.
:return: list of list of strings, [num_examples, tag_seq_length]
"""
# transform strings into DataSet object
- data = self.prepare_input(data)
+ # data = self.prepare_input(data)
# turn on the testing mode; clean up the history
self.mode(network, test=True)
- self.batch_output.clear()
+ batch_output = []
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)
for batch_x, _ in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
+ batch_output.append(prediction)
- self.batch_output.append(prediction)
-
- return self.prepare_output(self.batch_output)
+ return self._post_processor(batch_output, self.label_vocab)
def mode(self, network, test=True):
if test:
@@ -62,13 +61,7 @@ class Predictor(object):
def data_forward(self, network, x):
"""Forward through network."""
- if self._task == "seq_label":
- y = network(x["word_seq"], x["word_seq_origin_len"])
- y = network.prediction(y)
- elif self._task == "text_classify":
- y = network(x["word_seq"])
- else:
- raise NotImplementedError("Unknown task type {}.".format(self._task))
+ y = network(**x)
return y
def prepare_input(self, data):
@@ -88,39 +81,32 @@ class Predictor(object):
assert isinstance(data, list)
return create_dataset_from_lists(data, self.word_vocab, has_target=False)
- def prepare_output(self, data):
- """Transform list of batch outputs into strings."""
- if self._task == "seq_label":
- return self._seq_label_prepare_output(data)
- elif self._task == "text_classify":
- return self._text_classify_prepare_output(data)
- else:
- raise NotImplementedError("Unknown task type {}".format(self._task))
-
- def _seq_label_prepare_output(self, batch_outputs):
- results = []
- for batch in batch_outputs:
- for example in np.array(batch):
- results.append([self.label_vocab.to_word(int(x)) for x in example])
- return results
-
- def _text_classify_prepare_output(self, batch_outputs):
- results = []
- for batch_out in batch_outputs:
- idx = np.argmax(batch_out.detach().numpy(), axis=-1)
- results.extend([self.label_vocab.to_word(i) for i in idx])
- return results
-
class SeqLabelInfer(Predictor):
def __init__(self, pickle_path):
print(
- "[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor with argument 'task'='seq_label'.")
- super(SeqLabelInfer, self).__init__(pickle_path, "seq_label")
+ "[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.")
+ super(SeqLabelInfer, self).__init__(pickle_path, seq_label_post_processor)
class ClassificationInfer(Predictor):
def __init__(self, pickle_path):
print(
- "[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor with argument 'task'='text_classify'.")
- super(ClassificationInfer, self).__init__(pickle_path, "text_classify")
+ "[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.")
+ super(ClassificationInfer, self).__init__(pickle_path, text_classify_post_processor)
+
+
+def seq_label_post_processor(batch_outputs, label_vocab):
+ results = []
+ for batch in batch_outputs:
+ for example in np.array(batch):
+ results.append([label_vocab.to_word(int(x)) for x in example])
+ return results
+
+
+def text_classify_post_processor(batch_outputs, label_vocab):
+ results = []
+ for batch_out in batch_outputs:
+ idx = np.argmax(batch_out.detach().numpy(), axis=-1)
+ results.extend([label_vocab.to_word(i) for i in idx])
+ return results
diff --git a/fastNLP/core/preprocess.py b/fastNLP/core/preprocess.py
index 619d3967..5e77649e 100644
--- a/fastNLP/core/preprocess.py
+++ b/fastNLP/core/preprocess.py
@@ -18,6 +18,9 @@ def save_pickle(obj, pickle_path, file_name):
:param pickle_path: str, the directory where the pickle file is to be saved
:param file_name: str, the name of the pickle file. In general, it should be ended by "pkl".
"""
+ if not os.path.exists(pickle_path):
+ os.mkdir(pickle_path)
+ print("make dir {} before saving pickle file".format(pickle_path))
with open(os.path.join(pickle_path, file_name), "wb") as f:
_pickle.dump(obj, f)
print("{} saved in {}".format(file_name, pickle_path))
@@ -66,14 +69,27 @@ class Preprocessor(object):
Preprocessors will check if those files are already in the directory and will reuse them in future calls.
"""
- def __init__(self, label_is_seq=False):
+ def __init__(self, label_is_seq=False, share_vocab=False, add_char_field=False):
"""
:param label_is_seq: bool, whether label is a sequence. If True, label vocabulary will preserve
several special tokens for sequence processing.
+ :param share_vocab: bool, whether word sequence and label sequence share the same vocabulary. Typically, this
+ is only available when label_is_seq is True. Default: False.
+ :param add_char_field: bool, whether to add character representations to all TextFields. Default: False.
"""
+ print("Preprocessor is about to deprecate. Please use DataSet class.")
self.data_vocab = Vocabulary()
- self.label_vocab = Vocabulary(need_default=label_is_seq)
+ if label_is_seq is True:
+ if share_vocab is True:
+ self.label_vocab = self.data_vocab
+ else:
+ self.label_vocab = Vocabulary()
+ else:
+ self.label_vocab = Vocabulary(need_default=False)
+
+ self.character_vocab = Vocabulary(need_default=False)
+ self.add_char_field = add_char_field
@property
def vocab_size(self):
@@ -83,6 +99,12 @@ class Preprocessor(object):
def num_classes(self):
return len(self.label_vocab)
+ @property
+ def char_vocab_size(self):
+ if self.character_vocab is None:
+ self.build_char_dict()
+ return len(self.character_vocab)
+
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
"""Main pre-processing pipeline.
@@ -96,7 +118,6 @@ class Preprocessor(object):
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset
is a list of DataSet objects; Otherwise, each dataset is a DataSet object.
"""
-
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"):
self.data_vocab = load_pickle(pickle_path, "word2id.pkl")
self.label_vocab = load_pickle(pickle_path, "class2id.pkl")
@@ -176,6 +197,16 @@ class Preprocessor(object):
self.label_vocab.update(label)
return self.data_vocab, self.label_vocab
+ def build_char_dict(self):
+ char_collection = set()
+ for word in self.data_vocab.word2idx:
+ if len(word) == 0:
+ continue
+ for ch in word:
+ if ch not in char_collection:
+ char_collection.add(ch)
+ self.character_vocab.update(list(char_collection))
+
def build_reverse_dict(self):
self.data_vocab.build_reverse_vocab()
self.label_vocab.build_reverse_vocab()
@@ -277,11 +308,3 @@ class ClassPreprocess(Preprocessor):
print("[FastNLP warning] ClassPreprocess is about to deprecate. Please use Preprocess directly.")
super(ClassPreprocess, self).__init__()
-
-if __name__ == "__main__":
- p = Preprocessor()
- train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"],
- [["You", "are", "pretty", "."], "1"]
- ]
- training_set = p.run(train_dev_data)
- print(training_set)
diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py
index 0a75f46a..0e74145b 100644
--- a/fastNLP/core/tester.py
+++ b/fastNLP/core/tester.py
@@ -1,7 +1,7 @@
-import numpy as np
import torch
from fastNLP.core.batch import Batch
+from fastNLP.core.metrics import Evaluator
from fastNLP.core.sampler import RandomSampler
from fastNLP.saver.logger import create_logger
@@ -22,28 +22,23 @@ class Tester(object):
"kwargs" must have the same type as "default_args" on corresponding keys.
Otherwise, error will raise.
"""
- default_args = {"save_output": True, # collect outputs of validation set
- "save_loss": True, # collect losses in validation
- "save_best_dev": False, # save best model during validation
- "batch_size": 8,
+ default_args = {"batch_size": 8,
"use_cuda": False,
"pickle_path": "./save/",
"model_name": "dev_best_model.pkl",
- "print_every_step": 1,
+ "evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
- required_args = {"task" # one of ("seq_label", "text_classify")
- }
+ required_args = {}
for req_key in required_args:
if req_key not in kwargs:
logger.error("Tester lacks argument {}".format(req_key))
raise ValueError("Tester lacks argument {}".format(req_key))
- self._task = kwargs["task"]
for key in default_args:
if key in kwargs:
@@ -59,17 +54,13 @@ class Tester(object):
pass
print(default_args)
- self.save_output = default_args["save_output"]
- self.save_best_dev = default_args["save_best_dev"]
- self.save_loss = default_args["save_loss"]
self.batch_size = default_args["batch_size"]
self.pickle_path = default_args["pickle_path"]
self.use_cuda = default_args["use_cuda"]
- self.print_every_step = default_args["print_every_step"]
+ self._evaluator = default_args["evaluator"]
self._model = None
self.eval_history = [] # evaluation results of all batches
- self.batch_output = [] # outputs of all batches
def test(self, network, dev_data):
if torch.cuda.is_available() and self.use_cuda:
@@ -80,26 +71,18 @@ class Tester(object):
# turn on the testing mode; clean up the history
self.mode(network, is_test=True)
self.eval_history.clear()
- self.batch_output.clear()
+ output_list = []
+ truth_list = []
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda)
- step = 0
for batch_x, batch_y in data_iterator:
with torch.no_grad():
prediction = self.data_forward(network, batch_x)
- eval_results = self.evaluate(prediction, batch_y)
-
- if self.save_output:
- self.batch_output.append(prediction)
- if self.save_loss:
- self.eval_history.append(eval_results)
-
- print_output = "[test step {}] {}".format(step, eval_results)
- logger.info(print_output)
- if self.print_every_step > 0 and step % self.print_every_step == 0:
- print(self.make_eval_output(prediction, eval_results))
- step += 1
+ output_list.append(prediction)
+ truth_list.append(batch_y)
+ eval_results = self.evaluate(output_list, truth_list)
+ print("[tester] {}".format(self.print_eval_results(eval_results)))
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
@@ -121,104 +104,30 @@ class Tester(object):
def evaluate(self, predict, truth):
"""Compute evaluation metrics.
- :param predict: Tensor
- :param truth: Tensor
+ :param predict: list of Tensor
+ :param truth: list of dict
:return eval_results: can be anything. It will be stored in self.eval_history
"""
- if "label_seq" in truth:
- truth = truth["label_seq"]
- elif "label" in truth:
- truth = truth["label"]
- else:
- raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
+ return self._evaluator(predict, truth)
- if self._task == "seq_label":
- return self._seq_label_evaluate(predict, truth)
- elif self._task == "text_classify":
- return self._text_classify_evaluate(predict, truth)
- else:
- raise NotImplementedError("Unknown task type {}.".format(self._task))
-
- def _seq_label_evaluate(self, predict, truth):
- batch_size, max_len = predict.size(0), predict.size(1)
- loss = self._model.loss(predict, truth) / batch_size
- prediction = self._model.prediction(predict)
- # pad prediction to equal length
- for pred in prediction:
- if len(pred) < max_len:
- pred += [0] * (max_len - len(pred))
- results = torch.Tensor(prediction).view(-1, )
-
- # make sure "results" is in the same device as "truth"
- results = results.to(truth)
- accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
- return [float(loss), float(accuracy)]
-
- def _text_classify_evaluate(self, y_logit, y_true):
- y_prob = torch.nn.functional.softmax(y_logit, dim=-1)
- return [y_prob, y_true]
-
- @property
- def metrics(self):
- """Compute and return metrics.
- Use self.eval_history to compute metrics over the whole dev set.
- Please refer to metrics.py for common metric functions.
-
- :return : variable number of outputs
- """
- if self._task == "seq_label":
- return self._seq_label_metrics
- elif self._task == "text_classify":
- return self._text_classify_metrics
- else:
- raise NotImplementedError("Unknown task type {}.".format(self._task))
-
- @property
- def _seq_label_metrics(self):
- batch_loss = np.mean([x[0] for x in self.eval_history])
- batch_accuracy = np.mean([x[1] for x in self.eval_history])
- return batch_loss, batch_accuracy
-
- @property
- def _text_classify_metrics(self):
- y_prob, y_true = zip(*self.eval_history)
- y_prob = torch.cat(y_prob, dim=0)
- y_pred = torch.argmax(y_prob, dim=-1)
- y_true = torch.cat(y_true, dim=0)
- acc = float(torch.sum(y_pred == y_true)) / len(y_true)
- return y_true.cpu().numpy(), y_prob.cpu().numpy(), acc
-
- def show_metrics(self):
- """Customize evaluation outputs in Trainer.
- Called by Trainer to print evaluation results on dev set during training.
- Use self.metrics to fetch available metrics.
-
- :return print_str: str
- """
- loss, accuracy = self.metrics
- return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy)
+ def print_eval_results(self, results):
+ """Override this method to support more print formats.
- def make_eval_output(self, predictions, eval_results):
- """Customize Tester outputs.
+ :param results: dict, (str: float) is (metrics name: value)
- :param predictions: Tensor
- :param eval_results: Tensor
- :return: str, to be printed.
"""
- return self.show_metrics()
+ return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
class SeqLabelTester(Tester):
def __init__(self, **test_args):
- test_args.update({"task": "seq_label"})
print(
- "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.")
+ "[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.")
super(SeqLabelTester, self).__init__(**test_args)
class ClassificationTester(Tester):
def __init__(self, **test_args):
- test_args.update({"task": "text_classify"})
print(
- "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester with argument 'task'='text_classify'.")
+ "[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.")
super(ClassificationTester, self).__init__(**test_args)
diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py
index a73229b2..957a4757 100644
--- a/fastNLP/core/trainer.py
+++ b/fastNLP/core/trainer.py
@@ -1,4 +1,3 @@
-import copy
import os
import time
from datetime import timedelta
@@ -8,6 +7,7 @@ from tensorboardX import SummaryWriter
from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss
+from fastNLP.core.metrics import Evaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.tester import SeqLabelTester, ClassificationTester
@@ -43,21 +43,20 @@ class Trainer(object):
default_args = {"epochs": 1, "batch_size": 2, "validate": False, "use_cuda": False, "pickle_path": "./save/",
"save_best_dev": False, "model_name": "default_model_name.pkl", "print_every_step": 1,
"loss": Loss(None), # used to pass type check
- "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0)
+ "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0),
+ "evaluator": Evaluator()
}
"""
"required_args" is the collection of arguments that users must pass to Trainer explicitly.
This is used to warn users of essential settings in the training.
Specially, "required_args" does not have default value, so they have nothing to do with "default_args".
"""
- required_args = {"task" # one of ("seq_label", "text_classify")
- }
+ required_args = {}
for req_key in required_args:
if req_key not in kwargs:
logger.error("Trainer lacks argument {}".format(req_key))
raise ValueError("Trainer lacks argument {}".format(req_key))
- self._task = kwargs["task"]
for key in default_args:
if key in kwargs:
@@ -86,6 +85,7 @@ class Trainer(object):
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
self._optimizer = None
self._optimizer_proto = default_args["optimizer"]
+ self._evaluator = default_args["evaluator"]
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs')
self._graph_summaried = False
self._best_accuracy = 0.0
@@ -106,9 +106,8 @@ class Trainer(object):
# define Tester over dev data
if self.validate:
- default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
- "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
- "use_cuda": self.use_cuda, "print_every_step": 0}
+ default_valid_args = {"batch_size": self.batch_size, "pickle_path": self.pickle_path,
+ "use_cuda": self.use_cuda, "evaluator": self._evaluator}
validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(validator)))
@@ -142,15 +141,6 @@ class Trainer(object):
logger.info("validation started")
validator.test(network, dev_data)
- if self.save_best_dev and self.best_eval_result(validator):
- self.save_model(network, self.model_name)
- print("Saved better model selected by validation.")
- logger.info("Saved better model selected by validation.")
-
- valid_results = validator.show_metrics()
- print("[epoch {}] {}".format(epoch, valid_results))
- logger.info("[epoch {}] {}".format(epoch, valid_results))
-
def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch.
@@ -178,31 +168,6 @@ class Trainer(object):
logger.info(print_output)
step += 1
- def cross_validate(self, network, train_data_cv, dev_data_cv):
- """Training with cross validation.
-
- :param network: the model
- :param train_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?]
- :param dev_data_cv: four-level list, of shape [num_folds, num_examples, 2, ?]
-
- """
- if len(train_data_cv) != len(dev_data_cv):
- logger.error("the number of folds in train and dev data unequals {}!={}".format(len(train_data_cv),
- len(dev_data_cv)))
- raise RuntimeError("the number of folds in train and dev data unequals")
- if self.validate is False:
- logger.warn("Cross validation requires self.validate to be True. Please turn it on. ")
- print("[warning] Cross validation requires self.validate to be True. Please turn it on. ")
- self.validate = True
-
- n_fold = len(train_data_cv)
- logger.info("perform {} folds cross validation.".format(n_fold))
- for i in range(n_fold):
- print("CV:", i)
- logger.info("running the {} of {} folds cross validation".format(i + 1, n_fold))
- network_copy = copy.deepcopy(network)
- self.train(network_copy, train_data_cv[i], dev_data_cv[i])
-
def mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
@@ -229,18 +194,9 @@ class Trainer(object):
self._optimizer.step()
def data_forward(self, network, x):
- if self._task == "seq_label":
- y = network(x["word_seq"], x["word_seq_origin_len"])
- elif self._task == "text_classify":
- y = network(x["word_seq"])
- else:
- raise NotImplementedError("Unknown task type {}.".format(self._task))
-
+ y = network(**x)
if not self._graph_summaried:
- if self._task == "seq_label":
- self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False)
- elif self._task == "text_classify":
- self._summary_writer.add_graph(network, x["word_seq"], verbose=False)
+ # self._summary_writer.add_graph(network, x, verbose=False)
self._graph_summaried = True
return y
@@ -261,13 +217,9 @@ class Trainer(object):
:param truth: ground truth label vector
:return: a scalar
"""
- if "label_seq" in truth:
- truth = truth["label_seq"]
- elif "label" in truth:
- truth = truth["label"]
- truth = truth.view((-1,))
- else:
- raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys()))
+ if len(truth) > 1:
+ raise NotImplementedError("Not ready to handle multi-labels.")
+ truth = list(truth.values())[0] if len(truth) > 0 else None
return self._loss_func(predict, truth)
def define_loss(self):
@@ -278,8 +230,8 @@ class Trainer(object):
These two losses cannot be defined at the same time.
Trainer does not handle loss definition or choose default losses.
"""
- if hasattr(self._model, "loss") and self._loss_func is not None:
- raise ValueError("Both the model and Trainer define loss. Please take out your loss.")
+ # if hasattr(self._model, "loss") and self._loss_func is not None:
+ # raise ValueError("Both the model and Trainer define loss. Please take out your loss.")
if hasattr(self._model, "loss"):
self._loss_func = self._model.loss
@@ -322,9 +274,8 @@ class SeqLabelTrainer(Trainer):
"""
def __init__(self, **kwargs):
- kwargs.update({"task": "seq_label"})
print(
- "[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.")
+ "[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.")
super(SeqLabelTrainer, self).__init__(**kwargs)
def _create_validator(self, valid_args):
@@ -335,9 +286,8 @@ class ClassificationTrainer(Trainer):
"""Trainer for text classification."""
def __init__(self, **train_args):
- train_args.update({"task": "text_classify"})
print(
- "[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer with argument 'task'='text_classify'.")
+ "[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.")
super(ClassificationTrainer, self).__init__(**train_args)
def _create_validator(self, valid_args):
diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py
index ad618ff9..08c00644 100644
--- a/fastNLP/core/vocabulary.py
+++ b/fastNLP/core/vocabulary.py
@@ -10,13 +10,15 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}
+
def isiterable(p_object):
try:
it = iter(p_object)
- except TypeError:
+ except TypeError:
return False
return True
+
class Vocabulary(object):
"""Use for word and index one to one mapping
@@ -28,9 +30,11 @@ class Vocabulary(object):
vocab["word"]
vocab.to_word(5)
"""
+
def __init__(self, need_default=True):
"""
- :param bool need_default: set if the Vocabulary has default labels reserved.
+ :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True.
+
"""
if need_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
@@ -50,20 +54,19 @@ class Vocabulary(object):
def update(self, word):
"""add word or list of words into Vocabulary
- :param word: a list of str or str
+ :param word: a list of string or a single string
"""
if not isinstance(word, str) and isiterable(word):
- # it's a nested list
+ # it's a nested list
for w in word:
self.update(w)
else:
- # it's a word to be added
+ # it's a word to be added
if word not in self.word2idx:
self.word2idx[word] = len(self)
if self.idx2word is not None:
self.idx2word = None
-
def __getitem__(self, w):
"""To support usage like::
@@ -81,12 +84,12 @@ class Vocabulary(object):
:param str w:
"""
return self[w]
-
+
def unknown_idx(self):
- if self.unknown_label is None:
+ if self.unknown_label is None:
return None
return self.word2idx[self.unknown_label]
-
+
def padding_idx(self):
if self.padding_label is None:
return None
@@ -95,8 +98,8 @@ class Vocabulary(object):
def build_reverse_vocab(self):
"""build 'index to word' dict based on 'word to index' dict
"""
- self.idx2word = {self.word2idx[w] : w for w in self.word2idx}
-
+ self.idx2word = {self.word2idx[w]: w for w in self.word2idx}
+
def to_word(self, idx):
"""given a word's index, return the word itself
@@ -105,7 +108,7 @@ class Vocabulary(object):
if self.idx2word is None:
self.build_reverse_vocab()
return self.idx2word[idx]
-
+
def __getstate__(self):
"""use to prepare data for pickle
"""
@@ -113,12 +116,9 @@ class Vocabulary(object):
# no need to pickle idx2word as it can be constructed from word2idx
del state['idx2word']
return state
-
+
def __setstate__(self, state):
"""use to restore state from pickle
"""
self.__dict__.update(state)
self.idx2word = None
-
-
-
\ No newline at end of file
diff --git a/fastNLP/fastnlp.py b/fastNLP/fastnlp.py
index 4643c247..0bd56d18 100644
--- a/fastNLP/fastnlp.py
+++ b/fastNLP/fastnlp.py
@@ -1,5 +1,6 @@
import os
+from fastNLP.core.dataset import SeqLabelDataSet, TextClassifyDataSet
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
@@ -71,11 +72,13 @@ class FastNLP(object):
:param model_dir: this directory should contain the following files:
1. a trained model
2. a config file, which is a fastNLP's configuration.
- 3. a Vocab file, which is a pickle object of a Vocab instance.
+ 3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs.
"""
self.model_dir = model_dir
self.model = None
self.infer_type = None # "seq_label"/"text_class"
+ self.word_vocab = None
+ self.label_vocab = None
def load(self, model_name, config_file="config", section_name="model"):
"""
@@ -100,10 +103,10 @@ class FastNLP(object):
print("Restore model hyper-parameters {}".format(str(model_args.data)))
# fetch dictionary size and number of labels from pickle files
- word_vocab = load_pickle(self.model_dir, "word2id.pkl")
- model_args["vocab_size"] = len(word_vocab)
- label_vocab = load_pickle(self.model_dir, "class2id.pkl")
- model_args["num_classes"] = len(label_vocab)
+ self.word_vocab = load_pickle(self.model_dir, "word2id.pkl")
+ model_args["vocab_size"] = len(self.word_vocab)
+ self.label_vocab = load_pickle(self.model_dir, "label2id.pkl")
+ model_args["num_classes"] = len(self.label_vocab)
# Construct the model
model = model_class(model_args)
@@ -130,8 +133,11 @@ class FastNLP(object):
# tokenize: list of string ---> 2-D list of string
infer_input = self.tokenize(raw_input, language="zh")
- # 2-D list of string ---> 2-D list of tags
- results = infer.predict(self.model, infer_input)
+ # create DataSet: 2-D list of strings ----> DataSet
+ infer_data = self._create_data_set(infer_input)
+
+ # DataSet ---> 2-D list of tags
+ results = infer.predict(self.model, infer_data)
# 2-D list of tags ---> list of final answers
outputs = self._make_output(results, infer_input)
@@ -154,6 +160,11 @@ class FastNLP(object):
return module
def _create_inference(self, model_dir):
+ """Specify which task to perform.
+
+ :param model_dir:
+ :return:
+ """
if self.infer_type == "seq_label":
return SeqLabelInfer(model_dir)
elif self.infer_type == "text_class":
@@ -161,8 +172,26 @@ class FastNLP(object):
else:
raise ValueError("fail to create inference instance")
+ def _create_data_set(self, infer_input):
+ """Create a DataSet object given the raw inputs.
+
+ :param infer_input: 2-D lists of strings
+ :return data_set: a DataSet object
+ """
+ if self.infer_type == "seq_label":
+ data_set = SeqLabelDataSet()
+ data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
+ return data_set
+ elif self.infer_type == "text_class":
+ data_set = TextClassifyDataSet()
+ data_set.load_raw(infer_input, {"word_vocab": self.word_vocab})
+ return data_set
+ else:
+ raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
+
+
def _load(self, model_dir, model_name):
- # To do
+
return 0
def _download(self, model_name, url):
@@ -172,7 +201,7 @@ class FastNLP(object):
:param url:
"""
print("Downloading {} from {}".format(model_name, url))
- # To do
+ # TODO: download model via url
def model_exist(self, model_dir):
"""
diff --git a/fastNLP/loader/base_loader.py b/fastNLP/loader/base_loader.py
index 808567fb..fc2814c8 100644
--- a/fastNLP/loader/base_loader.py
+++ b/fastNLP/loader/base_loader.py
@@ -1,27 +1,24 @@
class BaseLoader(object):
- """docstring for BaseLoader"""
- def __init__(self, data_path):
+ def __init__(self):
super(BaseLoader, self).__init__()
- self.data_path = data_path
-
- def load(self):
- """
- :return: string
- """
- with open(self.data_path, "r", encoding="utf-8") as f:
- text = f.read()
- return text
- def load_lines(self):
- with open(self.data_path, "r", encoding="utf=8") as f:
+ @staticmethod
+ def load_lines(data_path):
+ with open(data_path, "r", encoding="utf=8") as f:
text = f.readlines()
return [line.strip() for line in text]
+ @staticmethod
+ def load(data_path):
+ with open(data_path, "r", encoding="utf-8") as f:
+ text = f.readlines()
+ return [[word for word in sent.strip()] for sent in text]
+
class ToyLoader0(BaseLoader):
"""
- For charLM
+ For CharLM
"""
def __init__(self, data_path):
diff --git a/fastNLP/loader/config_loader.py b/fastNLP/loader/config_loader.py
index 94871222..9818d411 100644
--- a/fastNLP/loader/config_loader.py
+++ b/fastNLP/loader/config_loader.py
@@ -8,9 +8,9 @@ from fastNLP.loader.base_loader import BaseLoader
class ConfigLoader(BaseLoader):
"""loader for configuration files"""
- def __int__(self, data_name, data_path):
- super(ConfigLoader, self).__init__(data_path)
- self.config = self.parse(super(ConfigLoader, self).load())
+ def __int__(self, data_path):
+ super(ConfigLoader, self).__init__()
+ self.config = self.parse(super(ConfigLoader, self).load(data_path))
@staticmethod
def parse(string):
diff --git a/fastNLP/loader/dataset_loader.py b/fastNLP/loader/dataset_loader.py
index 72da209c..a6a0fb77 100644
--- a/fastNLP/loader/dataset_loader.py
+++ b/fastNLP/loader/dataset_loader.py
@@ -3,14 +3,17 @@ import os
from fastNLP.loader.base_loader import BaseLoader
-class DatasetLoader(BaseLoader):
+class DataSetLoader(BaseLoader):
""""loader for data sets"""
- def __init__(self, data_path):
- super(DatasetLoader, self).__init__(data_path)
+ def __init__(self):
+ super(DataSetLoader, self).__init__()
+ def load(self, path):
+ raise NotImplementedError
-class POSDatasetLoader(DatasetLoader):
+
+class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets.
In these datasets, each line are divided by '\t'
@@ -31,16 +34,10 @@ class POSDatasetLoader(DatasetLoader):
to label5.
"""
- def __init__(self, data_path):
- super(POSDatasetLoader, self).__init__(data_path)
-
- def load(self):
- assert os.path.exists(self.data_path)
- with open(self.data_path, "r", encoding="utf-8") as f:
- line = f.read()
- return line
+ def __init__(self):
+ super(POSDataSetLoader, self).__init__()
- def load_lines(self):
+ def load(self, data_path):
"""
:return data: three-level list
[
@@ -49,7 +46,7 @@ class POSDatasetLoader(DatasetLoader):
...
]
"""
- with open(self.data_path, "r", encoding="utf-8") as f:
+ with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
return self.parse(lines)
@@ -79,15 +76,16 @@ class POSDatasetLoader(DatasetLoader):
return data
-class TokenizeDatasetLoader(DatasetLoader):
+class TokenizeDataSetLoader(DataSetLoader):
"""
Data set loader for tokenization data sets
"""
- def __init__(self, data_path):
- super(TokenizeDatasetLoader, self).__init__(data_path)
+ def __init__(self):
+ super(TokenizeDataSetLoader, self).__init__()
- def load_pku(self, max_seq_len=32):
+ @staticmethod
+ def load(data_path, max_seq_len=32):
"""
load pku dataset for Chinese word segmentation
CWS (Chinese Word Segmentation) pku training dataset format:
@@ -104,7 +102,7 @@ class TokenizeDatasetLoader(DatasetLoader):
:return: three-level lists
"""
assert isinstance(max_seq_len, int) and max_seq_len > 0
- with open(self.data_path, "r", encoding="utf-8") as f:
+ with open(data_path, "r", encoding="utf-8") as f:
sentences = f.readlines()
data = []
for sent in sentences:
@@ -135,15 +133,15 @@ class TokenizeDatasetLoader(DatasetLoader):
return data
-class ClassDatasetLoader(DatasetLoader):
+class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets"""
- def __init__(self, data_path):
- super(ClassDatasetLoader, self).__init__(data_path)
+ def __init__(self):
+ super(ClassDataSetLoader, self).__init__()
- def load(self):
- assert os.path.exists(self.data_path)
- with open(self.data_path, "r", encoding="utf-8") as f:
+ def load(self, data_path):
+ assert os.path.exists(data_path)
+ with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
return self.parse(lines)
@@ -169,21 +167,21 @@ class ClassDatasetLoader(DatasetLoader):
return dataset
-class ConllLoader(DatasetLoader):
+class ConllLoader(DataSetLoader):
"""loader for conll format files"""
def __int__(self, data_path):
"""
:param str data_path: the path to the conll data set
"""
- super(ConllLoader, self).__init__(data_path)
- self.data_set = self.parse(self.load())
+ super(ConllLoader, self).__init__()
+ self.data_set = self.parse(self.load(data_path))
- def load(self):
+ def load(self, data_path):
"""
:return: list lines: all lines in a conll file
"""
- with open(self.data_path, "r", encoding="utf-8") as f:
+ with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
return lines
@@ -207,28 +205,48 @@ class ConllLoader(DatasetLoader):
return sentences
-class LMDatasetLoader(DatasetLoader):
- def __init__(self, data_path):
- super(LMDatasetLoader, self).__init__(data_path)
+class LMDataSetLoader(DataSetLoader):
+ """Language Model Dataset Loader
- def load(self):
- if not os.path.exists(self.data_path):
- raise FileNotFoundError("file {} not found.".format(self.data_path))
- with open(self.data_path, "r", encoding="utf=8") as f:
- text = " ".join(f.readlines())
- return text.strip().split()
+ This loader produces data for language model training in a supervised way.
+ That means it has X and Y.
+ """
+
+ def __init__(self):
+ super(LMDataSetLoader, self).__init__()
-class PeopleDailyCorpusLoader(DatasetLoader):
+ def load(self, data_path):
+ if not os.path.exists(data_path):
+ raise FileNotFoundError("file {} not found.".format(data_path))
+ with open(data_path, "r", encoding="utf=8") as f:
+ text = " ".join(f.readlines())
+ tokens = text.strip().split()
+ return self.sentence_cut(tokens)
+
+ def sentence_cut(self, tokens, sentence_length=15):
+ start_idx = 0
+ data_set = []
+ for idx in range(len(tokens) // sentence_length):
+ x = tokens[start_idx * idx: start_idx * idx + sentence_length]
+ y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1]
+ if start_idx * idx + sentence_length + 1 >= len(tokens):
+ # ad hoc
+ y.extend([""])
+ data_set.append([x, y])
+ return data_set
+
+
+class PeopleDailyCorpusLoader(DataSetLoader):
"""
People Daily Corpus: Chinese word segmentation, POS tag, NER
"""
- def __init__(self, data_path):
- super(PeopleDailyCorpusLoader, self).__init__(data_path)
+ def __init__(self):
+ super(PeopleDailyCorpusLoader, self).__init__()
- def load(self):
- with open(self.data_path, "r", encoding="utf-8") as f:
+ def load(self, data_path):
+ with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines()
pos_tag_examples = []
diff --git a/fastNLP/models/char_language_model.py b/fastNLP/models/char_language_model.py
index f678070e..2ad49abe 100644
--- a/fastNLP/models/char_language_model.py
+++ b/fastNLP/models/char_language_model.py
@@ -1,215 +1,8 @@
-import os
-
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.optim as optim
-from torch.autograd import Variable
-
-from fastNLP.models.base_model import BaseModel
-
-USE_GPU = True
-
-"""
- To be deprecated.
-"""
-
-
-class CharLM(BaseModel):
- """
- Controller of the Character-level Neural Language Model
- """
- def __init__(self, lstm_batch_size, lstm_seq_len):
- super(CharLM, self).__init__()
- """
- Settings: should come from config loader or pre-processing
- """
- self.word_embed_dim = 300
- self.char_embedding_dim = 15
- self.cnn_batch_size = lstm_batch_size * lstm_seq_len
- self.lstm_seq_len = lstm_seq_len
- self.lstm_batch_size = lstm_batch_size
- self.num_epoch = 10
- self.old_PPL = 100000
- self.best_PPL = 100000
-
- """
- These parameters are set by pre-processing.
- """
- self.max_word_len = None
- self.num_char = None
- self.vocab_size = None
- self.preprocess("./data_for_tests/charlm.txt")
-
- self.data = None # named tuple to store all data set
- self.data_ready = False
- self.criterion = nn.CrossEntropyLoss()
- self._loss = None
- self.use_gpu = USE_GPU
-
- # word_emb_dim == hidden_size / num of hidden units
- self.hidden = (to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim)),
- to_var(torch.zeros(2, self.lstm_batch_size, self.word_embed_dim)))
-
- self.model = charLM(self.char_embedding_dim,
- self.word_embed_dim,
- self.vocab_size,
- self.num_char,
- use_gpu=self.use_gpu)
- for param in self.model.parameters():
- nn.init.uniform(param.data, -0.05, 0.05)
-
- self.learning_rate = 0.1
- self.optimizer = None
-
- def prepare_input(self, raw_text):
- """
- :param raw_text: raw input text consisting of words
- :return: torch.Tensor, torch.Tensor
- feature matrix, label vector
- This function is only called once in Trainer.train, but may called multiple times in Tester.test
- So Tester will save test input for frequent calls.
- """
- if os.path.exists("cache/prep.pt") is False:
- self.preprocess("./data_for_tests/charlm.txt") # To do: This is not good. Need to fix..
- objects = torch.load("cache/prep.pt")
- word_dict = objects["word_dict"]
- char_dict = objects["char_dict"]
- max_word_len = self.max_word_len
- print("word/char dictionary built. Start making inputs.")
-
- words = raw_text
- input_vec = np.array(text2vec(words, char_dict, max_word_len))
- # Labels are next-word index in word_dict with the same length as inputs
- input_label = np.array([word_dict[w] for w in words[1:]] + [word_dict[words[-1]]])
- feature_input = torch.from_numpy(input_vec)
- label_input = torch.from_numpy(input_label)
- return feature_input, label_input
-
- def mode(self, test=False):
- if test:
- self.model.eval()
- else:
- self.model.train()
-
- def data_forward(self, x):
- """
- :param x: Tensor of size [lstm_batch_size, lstm_seq_len, max_word_len+2]
- :return: Tensor of size [num_words, ?]
- """
- # additional processing of inputs after batching
- num_seq = x.size()[0] // self.lstm_seq_len
- x = x[:num_seq * self.lstm_seq_len, :]
- x = x.view(-1, self.lstm_seq_len, self.max_word_len + 2)
-
- # detach hidden state of LSTM from last batch
- hidden = [state.detach() for state in self.hidden]
- output, self.hidden = self.model(to_var(x), hidden)
- return output
-
- def grad_backward(self):
- self.model.zero_grad()
- self._loss.backward()
- torch.nn.utils.clip_grad_norm(self.model.parameters(), 5, norm_type=2)
- self.optimizer.step()
-
- def get_loss(self, predict, truth):
- self._loss = self.criterion(predict, to_var(truth))
- return self._loss.data # No pytorch data structure exposed outsides
-
- def define_optimizer(self):
- # redefine optimizer for every new epoch
- self.optimizer = optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.85)
- def save(self):
- print("network saved")
- # torch.save(self.models, "cache/models.pkl")
-
- def preprocess(self, all_text_files):
- word_dict, char_dict = create_word_char_dict(all_text_files)
- num_char = len(char_dict)
- self.vocab_size = len(word_dict)
- char_dict["BOW"] = num_char + 1
- char_dict["EOW"] = num_char + 2
- char_dict["PAD"] = 0
- self.num_char = num_char + 3
- # char_dict is a dict of (int, string), int counting from 0 to 47
- reverse_word_dict = {value: key for key, value in word_dict.items()}
- self.max_word_len = max([len(word) for word in word_dict])
- objects = {
- "word_dict": word_dict,
- "char_dict": char_dict,
- "reverse_word_dict": reverse_word_dict,
- }
- if not os.path.exists("cache"):
- os.mkdir("cache")
- torch.save(objects, "cache/prep.pt")
- print("Preprocess done.")
-
-
-"""
- Global Functions
-"""
-
-
-def batch_generator(x, batch_size):
- # x: [num_words, in_channel, height, width]
- # partitions x into batches
- num_step = x.size()[0] // batch_size
- for t in range(num_step):
- yield x[t * batch_size:(t + 1) * batch_size]
-
-
-def text2vec(words, char_dict, max_word_len):
- """ Return list of list of int """
- word_vec = []
- for word in words:
- vec = [char_dict[ch] for ch in word]
- if len(vec) < max_word_len:
- vec += [char_dict["PAD"] for _ in range(max_word_len - len(vec))]
- vec = [char_dict["BOW"]] + vec + [char_dict["EOW"]]
- word_vec.append(vec)
- return word_vec
-
-
-def read_data(file_name):
- with open(file_name, 'r') as f:
- corpus = f.read().lower()
- import re
- corpus = re.sub(r"", "unk", corpus)
- return corpus.split()
-
-
-def get_char_dict(vocabulary):
- char_dict = dict()
- count = 1
- for word in vocabulary:
- for ch in word:
- if ch not in char_dict:
- char_dict[ch] = count
- count += 1
- return char_dict
-
-
-def create_word_char_dict(*file_name):
- text = []
- for file in file_name:
- text += read_data(file)
- word_dict = {word: ix for ix, word in enumerate(set(text))}
- char_dict = get_char_dict(word_dict)
- return word_dict, char_dict
-
-
-def to_var(x):
- if torch.cuda.is_available() and USE_GPU:
- x = x.cuda()
- return Variable(x)
-
-
-"""
- Neural Network
-"""
+from fastNLP.modules.encoder.lstm import LSTM
class Highway(nn.Module):
@@ -225,9 +18,8 @@ class Highway(nn.Module):
return torch.mul(t, F.relu(self.fc2(x))) + torch.mul(1 - t, x)
-class charLM(nn.Module):
- """Character-level Neural Language Model
- CNN + highway network + LSTM
+class CharLM(nn.Module):
+ """CNN + highway network + LSTM
# Input:
4D tensor with shape [batch_size, in_channel, height, width]
# Output:
@@ -241,8 +33,8 @@ class charLM(nn.Module):
"""
def __init__(self, char_emb_dim, word_emb_dim,
- vocab_size, num_char, use_gpu):
- super(charLM, self).__init__()
+ vocab_size, num_char):
+ super(CharLM, self).__init__()
self.char_emb_dim = char_emb_dim
self.word_emb_dim = word_emb_dim
self.vocab_size = vocab_size
@@ -254,8 +46,7 @@ class charLM(nn.Module):
self.convolutions = []
# list of tuples: (the number of filter, width)
- # self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)]
- self.filter_num_width = [(25, 1), (50, 2), (75, 3)]
+ self.filter_num_width = [(25, 1), (50, 2), (75, 3), (100, 4), (125, 5), (150, 6)]
for out_channel, filter_width in self.filter_num_width:
self.convolutions.append(
@@ -278,29 +69,13 @@ class charLM(nn.Module):
# LSTM
self.lstm_num_layers = 2
- self.lstm = nn.LSTM(input_size=self.highway_input_dim,
- hidden_size=self.word_emb_dim,
- num_layers=self.lstm_num_layers,
- bias=True,
- dropout=0.5,
- batch_first=True)
-
+ self.lstm = LSTM(self.highway_input_dim, hidden_size=self.word_emb_dim, num_layers=self.lstm_num_layers,
+ dropout=0.5)
# output layer
self.dropout = nn.Dropout(p=0.5)
self.linear = nn.Linear(self.word_emb_dim, self.vocab_size)
- if use_gpu is True:
- for x in range(len(self.convolutions)):
- self.convolutions[x] = self.convolutions[x].cuda()
- self.highway1 = self.highway1.cuda()
- self.highway2 = self.highway2.cuda()
- self.lstm = self.lstm.cuda()
- self.dropout = self.dropout.cuda()
- self.char_embed = self.char_embed.cuda()
- self.linear = self.linear.cuda()
- self.batch_norm = self.batch_norm.cuda()
-
- def forward(self, x, hidden):
+ def forward(self, x):
# Input: Variable of Tensor with shape [num_seq, seq_len, max_word_len+2]
# Return: Variable of Tensor with shape [num_words, len(word_dict)]
lstm_batch_size = x.size()[0]
@@ -313,7 +88,7 @@ class charLM(nn.Module):
# [num_seq*seq_len, max_word_len+2, char_emb_dim]
x = torch.transpose(x.view(x.size()[0], 1, x.size()[1], -1), 2, 3)
- # [num_seq*seq_len, 1, char_emb_dim, max_word_len+2]
+ # [num_seq*seq_len, 1, max_word_len+2, char_emb_dim]
x = self.conv_layers(x)
# [num_seq*seq_len, total_num_filters]
@@ -328,7 +103,7 @@ class charLM(nn.Module):
x = x.contiguous().view(lstm_batch_size, lstm_seq_len, -1)
# [num_seq, seq_len, total_num_filters]
- x, hidden = self.lstm(x, hidden)
+ x, hidden = self.lstm(x)
# [seq_len, num_seq, hidden_size]
x = self.dropout(x)
@@ -339,7 +114,7 @@ class charLM(nn.Module):
x = self.linear(x)
# [num_seq*seq_len, vocab_size]
- return x, hidden
+ return x
def conv_layers(self, x):
chosen_list = list()
diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py
index c2bcc693..464f99be 100644
--- a/fastNLP/models/sequence_modeling.py
+++ b/fastNLP/models/sequence_modeling.py
@@ -31,16 +31,18 @@ class SeqLabeling(BaseModel):
num_classes = args["num_classes"]
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim)
- self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim)
+ self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim)
self.Linear = encoder.linear.Linear(hidden_dim, num_classes)
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
self.mask = None
- def forward(self, word_seq, word_seq_origin_len):
+ def forward(self, word_seq, word_seq_origin_len, truth=None):
"""
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences.
- :return y: [batch_size, mex_len, tag_size]
+ :param truth: LongTensor, [batch_size, max_len]
+ :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
+ If truth is not None, return loss, a scalar. Used in training.
"""
self.mask = self.make_mask(word_seq, word_seq_origin_len)
@@ -50,9 +52,16 @@ class SeqLabeling(BaseModel):
# [batch_size, max_len, hidden_size * direction]
x = self.Linear(x)
# [batch_size, max_len, num_classes]
- return x
+ if truth is not None:
+ return self._internal_loss(x, truth)
+ else:
+ return self.decode(x)
def loss(self, x, y):
+ """ Since the loss has been computed in forward(), this function simply returns x."""
+ return x
+
+ def _internal_loss(self, x, y):
"""
Negative log likelihood loss.
:param x: Tensor, [batch_size, max_len, tag_size]
@@ -74,12 +83,19 @@ class SeqLabeling(BaseModel):
mask = mask.to(x)
return mask
- def prediction(self, x):
+ def decode(self, x, pad=True):
"""
:param x: FloatTensor, [batch_size, max_len, tag_size]
+ :param pad: pad the output sequence to equal lengths
:return prediction: list of [decode path(list)]
"""
+ max_len = x.shape[1]
tag_seq = self.Crf.viterbi_decode(x, self.mask)
+ # pad prediction to equal length
+ if pad is True:
+ for pred in tag_seq:
+ if len(pred) < max_len:
+ pred += [0] * (max_len - len(pred))
return tag_seq
@@ -97,7 +113,7 @@ class AdvSeqLabel(SeqLabeling):
num_classes = args["num_classes"]
self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb)
- self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True)
+ self.Rnn = encoder.lstm.LSTM(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True)
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3)
self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3)
self.relu = torch.nn.ReLU()
@@ -106,11 +122,12 @@ class AdvSeqLabel(SeqLabeling):
self.Crf = decoder.CRF.ConditionalRandomField(num_classes)
- def forward(self, word_seq, word_seq_origin_len):
+ def forward(self, word_seq, word_seq_origin_len, truth=None):
"""
:param word_seq: LongTensor, [batch_size, mex_len]
:param word_seq_origin_len: list of int.
- :return y: [batch_size, mex_len, tag_size]
+ :param truth: LongTensor, [batch_size, max_len]
+ :return y:
"""
self.mask = self.make_mask(word_seq, word_seq_origin_len)
@@ -129,4 +146,7 @@ class AdvSeqLabel(SeqLabeling):
x = self.Linear2(x)
x = x.view(batch_size, max_len, -1)
# [batch_size, max_len, num_classes]
- return x
+ if truth is not None:
+ return self._internal_loss(x, truth)
+ else:
+ return self.decode(x)
diff --git a/fastNLP/modules/aggregator/self_attention.py b/fastNLP/modules/aggregator/self_attention.py
index b56e869b..981f34c6 100644
--- a/fastNLP/modules/aggregator/self_attention.py
+++ b/fastNLP/modules/aggregator/self_attention.py
@@ -55,14 +55,13 @@ class SelfAttention(nn.Module):
input = input.contiguous()
size = input.size() # [bsz, len, nhid]
-
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len]
- input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]
+ input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]
- y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
- attention = self.ws2(y1).transpose(1,2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
+ y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
+ attention = self.ws2(y1).transpose(1,
+ 2).contiguous() # [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
- attention = F.softmax(attention,2) # [baz ,hop, len]
- return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid]
-
+ attention = F.softmax(attention, 2) # [baz ,hop, len]
+ return torch.bmm(attention, input), self.penalization(attention) # output1 --> [baz ,hop ,nhid]
diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py
index 71b786b9..b00a0ae9 100644
--- a/fastNLP/modules/encoder/__init__.py
+++ b/fastNLP/modules/encoder/__init__.py
@@ -1,10 +1,10 @@
-from .embedding import Embedding
-from .linear import Linear
-from .lstm import Lstm
from .conv import Conv
from .conv_maxpool import ConvMaxpool
+from .embedding import Embedding
+from .linear import Linear
+from .lstm import LSTM
-__all__ = ["Lstm",
+__all__ = ["LSTM",
"Embedding",
"Linear",
"Conv",
diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py
index 5af09f29..e48960a8 100644
--- a/fastNLP/modules/encoder/lstm.py
+++ b/fastNLP/modules/encoder/lstm.py
@@ -1,9 +1,10 @@
import torch.nn as nn
from fastNLP.modules.utils import initial_parameter
-class Lstm(nn.Module):
- """
- LSTM module
+
+
+class LSTM(nn.Module):
+ """Long Short Term Memory
Args:
input_size : input size
@@ -13,13 +14,17 @@ class Lstm(nn.Module):
bidirectional : If True, becomes a bidirectional RNN. Default: False.
"""
- def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0, bidirectional=False , initial_method = None):
- super(Lstm, self).__init__()
+ def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, bidirectional=False,
+ initial_method=None):
+ super(LSTM, self).__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=True,
dropout=dropout, bidirectional=bidirectional)
initial_parameter(self, initial_method)
+
def forward(self, x):
x, _ = self.lstm(x)
return x
+
+
if __name__ == "__main__":
- lstm = Lstm(10)
+ lstm = LSTM(10)
diff --git a/fastNLP/modules/other_modules.py b/fastNLP/modules/other_modules.py
index 0cd32d3b..ea1423be 100644
--- a/fastNLP/modules/other_modules.py
+++ b/fastNLP/modules/other_modules.py
@@ -196,30 +196,3 @@ class BiAffine(nn.Module):
output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2)
return output
-
-
-class Transpose(nn.Module):
- def __init__(self, x, y):
- super(Transpose, self).__init__()
- self.x = x
- self.y = y
-
- def forward(self, x):
- return x.transpose(self.x, self.y)
-
-
-class WordDropout(nn.Module):
- def __init__(self, dropout_rate, drop_to_token):
- super(WordDropout, self).__init__()
- self.dropout_rate = dropout_rate
- self.drop_to_token = drop_to_token
-
- def forward(self, word_idx):
- if not self.training:
- return word_idx
- drop_mask = torch.rand(word_idx.shape) < self.dropout_rate
- if word_idx.device.type == 'cuda':
- drop_mask = drop_mask.cuda()
- drop_mask = drop_mask.long()
- output = drop_mask * self.drop_to_token + (1 - drop_mask) * word_idx
- return output
diff --git a/fastNLP/saver/config_saver.py b/fastNLP/saver/config_saver.py
index e05e865e..83ef0e4b 100644
--- a/fastNLP/saver/config_saver.py
+++ b/fastNLP/saver/config_saver.py
@@ -18,7 +18,7 @@ class ConfigSaver(object):
:return: The section.
"""
sect = ConfigSection()
- ConfigLoader(self.file_path).load_config(self.file_path, {sect_name: sect})
+ ConfigLoader().load_config(self.file_path, {sect_name: sect})
return sect
def _read_section(self):
@@ -104,7 +104,8 @@ class ConfigSaver(object):
:return:
"""
section_file = self._get_section(section_name)
- if len(section_file.__dict__.keys()) == 0:#the section not in file before
+ if len(section_file.__dict__.keys()) == 0: # the section not in the file before
+ # append this section to config file
with open(self.file_path, 'a') as f:
f.write('[' + section_name + ']\n')
for k in section.__dict__.keys():
@@ -114,9 +115,11 @@ class ConfigSaver(object):
else:
f.write(str(section[k]) + '\n\n')
else:
+ # the section exists
change_file = False
for k in section.__dict__.keys():
if k not in section_file:
+ # find a new key in this section
change_file = True
break
if section_file[k] != section[k]:
diff --git a/reproduction/Char-aware_NLM/main.py b/reproduction/Char-aware_NLM/main.py
new file mode 100644
index 00000000..03810650
--- /dev/null
+++ b/reproduction/Char-aware_NLM/main.py
@@ -0,0 +1,25 @@
+from fastNLP.core.loss import Loss
+from fastNLP.core.preprocess import Preprocessor
+from fastNLP.core.trainer import Trainer
+from fastNLP.loader.dataset_loader import LMDataSetLoader
+from fastNLP.models.char_language_model import CharLM
+
+PICKLE = "./save/"
+
+
+def train():
+ loader = LMDataSetLoader()
+ train_data = loader.load()
+
+ pre = Preprocessor(label_is_seq=True, share_vocab=True)
+ train_set = pre.run(train_data, pickle_path=PICKLE)
+
+ model = CharLM(50, 50, pre.vocab_size, pre.char_vocab_size)
+
+ trainer = Trainer(task="language_model", loss=Loss("cross_entropy"))
+
+ trainer.train(model, train_set)
+
+
+if __name__ == "__main__":
+ train()
diff --git a/reproduction/Char-aware_NLM/test.txt b/reproduction/Char-aware_NLM/test.txt
index 8079cdf3..92aaec44 100644
--- a/reproduction/Char-aware_NLM/test.txt
+++ b/reproduction/Char-aware_NLM/test.txt
@@ -1,3761 +1,320 @@
- no it was n't black monday
- but while the new york stock exchange did n't fall apart friday as the dow jones industrial average plunged N points most of it in the final hour it barely managed to stay this side of chaos
- some circuit breakers installed after the october N crash failed their first test traders say unable to cool the selling panic in both stocks and futures
- the N stock specialist firms on the big board floor the buyers and sellers of last resort who were criticized after the N crash once again could n't handle the selling pressure
- big investment banks refused to step up to the plate to support the beleaguered floor traders by buying big blocks of stock traders say
- heavy selling of standard & poor 's 500-stock index futures in chicago beat stocks downward
- seven big board stocks ual amr bankamerica walt disney capital cities\/abc philip morris and pacific telesis group stopped trading and never resumed
- the has already begun
- the equity market was
- once again the specialists were not able to handle the imbalances on the floor of the new york stock exchange said christopher senior vice president at securities corp
- james chairman of specialists henderson brothers inc. it is easy to say the specialist is n't doing his job
- when the dollar is in a even central banks ca n't stop it
- speculators are calling for a degree of liquidity that is not there in the market
- many money managers and some traders had already left their offices early friday afternoon on a warm autumn day because the stock market was so quiet
- then in a plunge the dow jones industrials in barely an hour surrendered about a third of their gains this year up a 190.58-point or N N loss on the day in trading volume
- trading accelerated to N million shares a record for the big board
- at the end of the day N million shares were traded
- the dow jones industrials closed at N
- the dow 's decline was second in point terms only to the black monday crash that occurred oct. N N
- in percentage terms however the dow 's dive was the ever and the sharpest since the market fell N or N N a week after black monday
- the dow fell N N on black monday
- shares of ual the parent of united airlines were extremely active all day friday reacting to news and rumors about the proposed $ N billion buy-out of the airline by an group
- wall street 's takeover-stock speculators or risk arbitragers had placed unusually large bets that a takeover would succeed and ual stock would rise
- at N p.m. edt came the news the big board was trading in ual pending news
- on the exchange floor as soon as ual stopped trading we for a panic said one top floor trader
- several traders could be seen shaking their heads when the news
- for weeks the market had been nervous about takeovers after campeau corp. 's cash crunch spurred concern about the prospects for future highly leveraged takeovers
- and N minutes after the ual trading halt came news that the ual group could n't get financing for its bid
- at this point the dow was down about N points
- the market
- arbitragers could n't dump their ual stock but they rid themselves of nearly every rumor stock they had
- for example their selling caused trading halts to be declared in usair group which closed down N N to N N delta air lines which fell N N to N N and industries which sank N to N N
- these stocks eventually reopened
- but as panic spread speculators began to sell blue-chip stocks such as philip morris and international business machines to offset their losses
- when trading was halted in philip morris the stock was trading at N down N N while ibm closed N N lower at N
- selling because of waves of automatic stop-loss orders which are triggered by computer when prices fall to certain levels
- most of the stock selling pressure came from wall street professionals including computer-guided program traders
- traders said most of their major institutional investors on the other hand sat tight
- now at N one of the market 's post-crash reforms took hold as the s&p N futures contract had plunged N points equivalent to around a drop in the dow industrials
- under an agreement signed by the big board and the chicago mercantile exchange trading was temporarily halted in chicago
- after the trading halt in the s&p N pit in chicago waves of selling continued to hit stocks themselves on the big board and specialists continued to prices down
- as a result the link between the futures and stock markets apart
- without the of stock-index futures the barometer of where traders think the overall stock market is headed many traders were afraid to trust stock prices quoted on the big board
- the futures halt was even by big board floor traders
- it things up said one major specialist
- this confusion effectively halted one form of program trading stock index arbitrage that closely links the futures and stock markets and has been blamed by some for the market 's big swings
- in a stock-index arbitrage sell program traders buy or sell big baskets of stocks and offset the trade in futures to lock in a price difference
- when the airline information came through it every model we had for the marketplace said a managing director at one of the largest program-trading firms
- we did n't even get a chance to do the programs we wanted to do
- but stocks kept falling
- the dow industrials were down N points at N p.m. before the halt
- at N p.m. at the end of the cooling off period the average was down N points
- meanwhile during the the s&p trading halt s&p futures sell orders began up while stocks in new york kept falling sharply
- big board chairman john j. phelan said yesterday the circuit breaker worked well
- i just think it 's at this point to get into a debate if index arbitrage would have helped or hurt things
- under another post-crash system big board president richard mr. phelan was flying to as the market was falling was talking on an hot line to the other exchanges the securities and exchange commission and the federal reserve board
- he out at a high-tech center on the floor of the big board where he could watch on prices and pending stock orders
- at about N p.m. edt s&p futures resumed trading and for a brief time the futures and stock markets started to come back in line
- buyers stepped in to the futures pit
- but the of s&p futures sell orders weighed on the market and the link with stocks began to fray again
- at about N the s&p market to still another limit of N points down and trading was locked again
- futures traders say the s&p was that the dow could fall as much as N points
- during this time small investors began ringing their brokers wondering whether another crash had begun
- at prudential-bache securities inc. which is trying to cater to small investors some brokers thought this would be the final
- that 's when george l. ball chairman of the prudential insurance co. of america unit took to the internal system to declare that the plunge was only mechanical
- i have a that this particular decline today is something more about less
- it would be my to advise clients not to sell to look for an opportunity to buy mr. ball told the brokers
- at merrill lynch & co. the nation 's biggest brokerage firm a news release was prepared merrill lynch comments on market drop
- the release cautioned that there are significant differences between the current environment and that of october N and that there are still attractive investment opportunities in the stock market
- however jeffrey b. lane president of shearson lehman hutton inc. said that friday 's plunge is going to set back relations with customers because it the concern of volatility
- and i think a lot of people will on program trading
- it 's going to bring the debate right back to the
- as the dow average ground to its final N loss friday the s&p pit stayed locked at its trading limit
- jeffrey of program trader investment group said N s&p contracts were for sale on the close the equivalent of $ N million in stock
- but there were no buyers
- while friday 's debacle involved mainly professional traders rather than investors it left the market vulnerable to continued selling this morning traders said
- stock-index futures contracts settled at much lower prices than indexes of the stock market itself
- at those levels stocks are set up to be by index arbitragers who lock in profits by buying futures when futures prices fall and simultaneously sell off stocks
- but nobody knows at what level the futures and stocks will open today
- the between the stock and futures markets friday will undoubtedly cause renewed debate about whether wall street is properly prepared for another crash situation
- the big board 's mr. said our performance was good
- but the exchange will look at the performance of all specialists in all stocks
- obviously we 'll take a close look at any situation in which we think the obligations were n't met he said
- see related story fed ready to big funds wsj oct. N N
- but specialists complain privately that just as in the N crash the firms big investment banks that support the market by trading big blocks of stock stayed on the sidelines during friday 's
- mr. phelan said it will take another day or two to analyze who was buying and selling friday
- concerning your sept. N page-one article on prince charles and the it 's a few hundred years since england has been a kingdom
- it 's now the united kingdom of great britain and northern ireland northern ireland scotland and oh yes england too
- just thought you 'd like to know
- george
- ports of call inc. reached agreements to sell its remaining seven aircraft to buyers that were n't disclosed
- the agreements bring to a total of nine the number of planes the travel company has sold this year as part of a restructuring
- the company said a portion of the $ N million realized from the sales will be used to repay its bank debt and other obligations resulting from the currently suspended operations
- earlier the company announced it would sell its aging fleet of boeing co. because of increasing maintenance costs
- a consortium of private investors operating as funding co. said it has made a $ N million cash bid for most of l.j. hooker corp. 's real-estate and holdings
- the $ N million bid includes the assumption of an estimated $ N million in secured liabilities on those properties according to those making the bid
- the group is led by jay chief executive officer of investment corp. in and a. boyd simpson chief executive of the atlanta-based simpson organization inc
- mr. 's company specializes in commercial real-estate investment and claims to have $ N billion in assets mr. simpson is a developer and a former senior executive of l.j. hooker
- the assets are good but they require more money and management than can be provided in l.j. hooker 's current situation said mr. simpson in an interview
- hooker 's philosophy was to build and sell
- we want to build and hold
- l.j. hooker based in atlanta is operating with protection from its creditors under chapter N of the u.s. bankruptcy code
- its parent company hooker corp. of sydney australia is currently being managed by a court-appointed provisional
- sanford chief executive of l.j. hooker said yesterday in a statement that he has not yet seen the bid but that he would review it and bring it to the attention of the creditors committee
- the $ N million bid is estimated by mr. simpson as representing N N of the value of all hooker real-estate holdings in the u.s.
- not included in the bid are teller or b. altman & co. l.j. hooker 's department-store chains
- the offer covers the massive N forest fair mall in cincinnati the N fashion mall in columbia s.c. and the N town center mall in
- the mall opened sept. N with a 's as its the columbia mall is expected to open nov. N
- other hooker properties included are a office tower in atlanta expected to be completed next february vacant land sites in florida and ohio l.j. hooker international the commercial real-estate brokerage company that once did business as merrill lynch commercial real estate plus other shopping centers
- the consortium was put together by the london-based investment banking company that is a subsidiary of security pacific corp
- we do n't anticipate any problems in raising the funding for the bid said campbell the head of mergers and acquisitions at in an interview
- is acting as the consortium 's investment bankers
- according to people familiar with the consortium the bid was project a reference to the film in which a played by actress is saved from a businessman by a police officer named john
- l.j. hooker was a small company based in atlanta in N when mr. simpson was hired to push it into commercial development
- the company grew modestly until N when a majority position in hooker corp. was acquired by australian developer george currently hooker 's chairman
- mr. to launch an ambitious but $ N billion acquisition binge that included teller and b. altman & co. as well as majority positions in merksamer jewelers a sacramento chain inc. the retailer and inc. the southeast department-store chain
- eventually mr. simpson and mr. had a falling out over the direction of the company and mr. simpson said he resigned in N
- since then hooker corp. has sold its interest in the chain back to 's management and is currently attempting to sell the b. altman & co. chain
- in addition robert chief executive of the chain is seeking funds to buy out the hooker interest in his company
- the merksamer chain is currently being offered for sale by first boston corp
- reached in mr. said that he believes the various hooker can become profitable with new management
- these are n't mature assets but they have the potential to be so said mr.
- managed properly and with a long-term outlook these can become investment-grade quality properties
- canadian production totaled N metric tons in the week ended oct. N up N N from the preceding week 's total of N tons statistics canada a federal agency said
- the week 's total was up N N from N tons a year earlier
- the total was N tons up N N from N tons a year earlier
- the treasury plans to raise $ N million in new cash thursday by selling about $ N billion of 52-week bills and $ N billion of maturing bills
- the bills will be dated oct. N and will mature oct. N N
- they will be available in minimum denominations of $ N
- bids must be received by N p.m. edt thursday at the treasury or at federal reserve banks or branches
- as small investors their mutual funds with phone calls over the weekend big fund managers said they have a strong defense against any wave of withdrawals cash
- unlike the weekend before black monday the funds were n't with heavy withdrawal requests
- and many fund managers have built up cash levels and say they will be buying stock this week
- at fidelity investments the nation 's largest fund company telephone volume was up sharply but it was still at just half the level of the weekend preceding black monday in N
- the boston firm said redemptions were running at less than one-third the level two years ago
- as of yesterday afternoon the redemptions represented less than N N of the total cash position of about $ N billion of fidelity 's stock funds
- two years ago there were massive redemption levels over the weekend and a lot of fear around said c. bruce who runs fidelity investments ' $ N billion fund
- this feels more like a deal
- people are n't
- the test may come today
- friday 's stock market sell-off came too late for many investors to act
- some shareholders have held off until today because any fund exchanges made after friday 's close would take place at today 's closing prices
- stock fund redemptions during the N debacle did n't begin to until after the market opened on black monday
- but fund managers say they 're ready
- many have raised cash levels which act as a buffer against steep market declines
- mario for instance holds cash positions well above N N in several of his funds
- windsor fund 's john and mutual series ' michael price said they had raised their cash levels to more than N N and N N respectively this year
- even peter lynch manager of fidelity 's $ N billion fund the nation 's largest stock fund built up cash to N N or $ N million
- one reason is that after two years of monthly net redemptions the fund posted net inflows of money from investors in august and september
- i 've let the money build up mr. lynch said who added that he has had trouble finding stocks he likes
- not all funds have raised cash levels of course
- as a group stock funds held N N of assets in cash as of august the latest figures available from the investment company institute
- that was modestly higher than the N N and N N levels in august and september of N
- also persistent redemptions would force some fund managers to dump stocks to raise cash
- but a strong level of investor withdrawals is much more unlikely this time around fund managers said
- a major reason is that investors already have sharply scaled back their purchases of stock funds since black monday
- sales have rebounded in recent months but monthly net purchases are still running at less than half N levels
- there 's not nearly as much said john chairman of vanguard group inc. a big valley forge pa. fund company
- many fund managers argue that now 's the time to buy
- vincent manager of the $ N billion wellington fund added to his positions in bristol-myers squibb woolworth and dun & bradstreet friday
- and today he 'll be looking to buy drug stocks like eli lilly pfizer and american home products whose dividend yields have been bolstered by stock declines
- fidelity 's mr. lynch for his part snapped up southern co. shares friday after the stock got
- if the market drops further today he said he 'll be buying blue chips such as bristol-myers and kellogg
- if they stocks like that he said it presents an opportunity that is the kind of thing you dream about
- major mutual-fund groups said phone calls were at twice the normal weekend pace yesterday
- but most investors were seeking share prices and other information
- trading volume was only modestly higher than normal
- still fund groups are n't taking any chances
- they hope to avoid the phone lines and other that some fund investors in october N
- fidelity on saturday opened its N investor centers across the country
- the centers normally are closed through the weekend
- in addition east coast centers will open at N edt this morning instead of the normal N
- t. rowe price associates inc. increased its staff of phone representatives to handle investor requests
- the group noted that some investors moved money from stock funds to money-market funds
- but most investors seemed to be in an information mode rather than in a transaction mode said steven a vice president
- and vanguard among other groups said it was adding more phone representatives today to help investors get through
- in an unusual move several funds moved to calm investors with on their phone lines
- we view friday 's market decline as offering us a buying opportunity as long-term investors a recording at & co. funds said over the weekend
- the group had a similar recording for investors
- several fund managers expect a rough market this morning before prices stabilize
- some early selling is likely to stem from investors and portfolio managers who want to lock in this year 's fat profits
- stock funds have averaged a staggering gain of N N through september according to lipper analytical services inc
- who runs shearson lehman hutton inc. 's $ N million sector analysis portfolio predicts the market will open down at least N points on technical factors and some panic selling
- but she expects prices to rebound soon and is telling investors she expects the stock market wo n't decline more than N N to N N from recent highs
- this is not a major crash she said
- nevertheless ms. said she was with phone calls over the weekend from nervous shareholders
- half of them are really scared and want to sell she said but i 'm trying to talk them out of it
- she added if they all were bullish i 'd really be upset
- the backdrop to friday 's slide was different from that of the october N crash fund managers argue
- two years ago unlike today the dollar was weak interest rates were rising and the market was very they say
- from the investors ' standpoint institutions and individuals learned a painful lesson by selling at the lows on black monday said stephen boesel manager of the $ N million t. rowe price growth and income fund
- this time i do n't think we 'll get a panic reaction
- newport corp. said it expects to report earnings of between N cents and N cents a share somewhat below analysts ' estimates of N cents to N cents
- the maker of scientific instruments and laser parts said orders fell below expectations in recent months
- a spokesman added that sales in the current quarter will about equal the quarter 's figure when newport reported net income of $ N million or N cents a share on $ N million in sales
- from the strike by N machinists union members against boeing co. reached air carriers friday as america west airlines announced it will postpone its new service out of houston because of delays in receiving aircraft from the seattle jet maker
- peter vice president for planning at the phoenix ariz. carrier said in an interview that the work at boeing now entering its 13th day has caused some turmoil in our scheduling and that more than N passengers who were booked to fly out of houston on america west would now be put on other airlines
- mr. said boeing told america west that the N it was supposed to get this thursday would n't be delivered until nov. N the day after the airline had been planning to service at houston with four daily flights including three to phoenix and one to las vegas
- now those routes are n't expected to begin until jan
- boeing is also supposed to send to america west another N aircraft as well as a N by year 's end
- those too are almost certain to arrive late
- at this point no other america west flights including its new service at san antonio texas newark n.j. and calif. have been affected by the delays in boeing deliveries
- nevertheless the company 's reaction the effect that a huge manufacturer such as boeing can have on other parts of the economy
- it also is sure to help the machinists put added pressure on the company
- i just do n't feel that the company can really stand or would want a prolonged tom baker president of machinists ' district N said in an interview yesterday
- i do n't think their customers would like it very much
- america west though is a smaller airline and therefore more affected by the delayed delivery of a single plane than many of its competitors would be
- i figure that american and united probably have such a hard time counting all the planes in their fleets they might not miss one at all mr. said
- indeed a random check friday did n't seem to indicate that the strike was having much of an effect on other airline operations
- southwest airlines has a boeing N set for delivery at the end of this month and expects to have the plane on time
- it 's so close to completion boeing 's told us there wo n't be a problem said a southwest spokesman
- a spokesman for amr corp. said boeing has assured american airlines it will deliver a N on time later this month
- american is preparing to take delivery of another N in early december and N more next year and is n't anticipating any changes in that timetable
- in seattle a boeing spokesman explained that the company has been in constant communication with all of its customers and that it was impossible to predict what further disruptions might be triggered by the strike
- meanwhile supervisors and employees have been trying to finish some N aircraft mostly N and N jumbo jets at the company 's wash. plant that were all but completed before the
- as of friday four had been delivered and a fifth plane a N was supposed to be out over the weekend to air china
- no date has yet been set to get back to the bargaining table
- we want to make sure they know what they want before they come back said doug hammond the federal mediator who has been in contact with both sides since the strike began
- the investment community for one has been anticipating a resolution
- though boeing 's stock price was battered along with the rest of the market friday it actually has risen over the last two weeks on the strength of new orders
- the market has taken two views that the labor situation will get settled in the short term and that things look very for boeing in the long term said howard an analyst at j. lawrence inc
- boeing 's shares fell $ N friday to close at $ N in composite trading on the new york stock exchange
- but mr. baker said he thinks the earliest a pact could be struck would be the end of this month that the company and union may resume negotiations as early as this week
- still he said it 's possible that the strike could last considerably longer
- i would n't expect an immediate resolution to anything
- last week boeing chairman frank sent striking workers a letter saying that to my knowledge boeing 's offer represents the best overall three-year contract of any major u.s. industrial firm in recent history
- but mr. baker called the letter and the company 's offer of a N N wage increase over the life of the pact plus bonuses very weak
- he added that the company the union 's resolve and the workers ' with being forced to work many hours overtime
- in separate developments talks have broken off between machinists representatives at lockheed corp. and the calif. aerospace company
- the union is continuing to work through its expired contract however
- it had planned a strike vote for next sunday but that has been pushed back indefinitely
- united auto workers local N which represents N workers at boeing 's helicopter unit in delaware county pa. said it agreed to extend its contract on a basis with a notification to cancel while it continues bargaining
- the accord expired yesterday
- and boeing on friday said it received an order from for four model N valued at a total of about $ N million
- the planes long range versions of the will be delivered with & engines
- & is a unit of united technologies inc
- is based in amsterdam
- a boeing spokeswoman said a delivery date for the planes is still being worked out for a variety of reasons but not because of the strike
- contributed to this article
- ltd. said its utilities arm is considering building new electric power plants some valued at more than one billion canadian dollars us$ N million in great britain and elsewhere
- 's senior vice president finance said its canadian utilities ltd. unit is reviewing projects in eastern canada and conventional electric power generating plants elsewhere including britain where the british government plans to allow limited competition in electrical generation from private-sector suppliers as part of its privatization program
- the projects are big
- they can be c$ N billion plus mr. said
- but we would n't go into them alone and canadian utilities ' equity stake would be small he said
- we 'd like to be the operator of the project and a modest equity investor
- our long suit is our proven ability to operate power plants he said
- mr. would n't offer regarding 's proposed british project but he said it would compete for customers with two huge british power generating companies that would be formed under the country 's plan to its massive water and electric utilities
- britain 's government plans to raise about # N billion $ N billion from the sale of most of its giant water and electric utilities beginning next month
- the planned electric utility sale scheduled for next year is alone expected to raise # N billion making it the world 's largest public offering
- under terms of the plan independent would be able to compete for N N of customers until N and for another N N between N and N
- canadian utilities had N revenue of c$ N billion mainly from its natural gas and electric utility businesses in alberta where the company serves about N customers
- there seems to be a move around the world to the generation of electricity mr. said and canadian utilities hopes to capitalize on it
- this is a real thrust on our utility side he said adding that canadian utilities is also projects in countries though he would be specific
- canadian utilities is n't alone in exploring power generation opportunities in britain in anticipation of the privatization program
- we 're certainly looking at some power generating projects in england said bruce vice president corporate strategy and corporate planning with enron corp. houston a big natural gas producer and pipeline operator
- mr. said enron is considering building power plants in the u.k. capable of producing about N of power at a cost of about $ N million to $ N million
- pse inc. said it expects to report third earnings of $ N million to $ N million or N cents to N cents a share
- in the year-ago quarter the designer and operator of and waste heat recovery plants had net income of $ N or four cents a share on revenue of about $ N million
- the company said the improvement is related to additional facilities that have been put into operation
- flights are $ N to paris and $ N to london
- in a centennial journal article oct. N the fares were reversed
- diamond offshore partners said it had discovered gas offshore louisiana
- the well at a rate of N million cubic feet of gas a day through a N opening at between N and N feet
- diamond is the operator with a N N interest in the well
- diamond offshore 's stock rose N cents friday to close at $ N in new york stock exchange composite trading
- & broad home corp. said it formed a $ N million limited partnership subsidiary to buy land in california suitable for residential development
- the partnership & broad land development venture limited partnership is a N joint venture with a trust created by institutional clients of advisory corp. a unit of financial corp. a real estate advisory management and development company with offices in chicago and beverly hills calif
- & broad a home building company declined to identify the institutional investors
- the land to be purchased by the joint venture has n't yet received and other approvals required for development and part of & broad 's job will be to obtain such approvals
- the partnership runs the risk that it may not get the approvals for development but in return it can buy land at wholesale rather than retail prices which can result in sizable savings said bruce president and chief executive officer of & broad
- there are really very few companies that have adequate capital to buy properties in a raw state for cash
- typically developers option property and then once they get the administrative approvals they buy it said mr. adding that he believes the joint venture is the first of its kind
- we usually operate in that conservative manner
- by setting up the joint venture & broad can take the more aggressive approach of buying raw land while avoiding the negative to its own balance sheet mr. said
- the company is putting up only N N of the capital although it is responsible for providing management planning and processing services to the joint venture
- this is one of the best ways to assure a pipeline of land to fuel our growth at a minimum risk to our company mr. said
- when the price of plastics took off in N quantum chemical corp. went along for the ride
- the timing of quantum 's chief executive officer john