@@ -0,0 +1,74 @@ | |||
#! /usr/bin/python | |||
# -*- coding: utf-8 -*- | |||
# The same set of code can switch the backend with one line | |||
import os | |||
os.environ['TL_BACKEND'] = 'tensorflow' | |||
# os.environ['TL_BACKEND'] = 'mindspore' | |||
# os.environ['TL_BACKEND'] = 'paddle' | |||
import tensorlayer as tl | |||
from tensorlayer.layers import Module | |||
from tensorlayer.layers import Dense, LSTM, Embedding | |||
from tensorlayer.dataflow import Dataset | |||
import numpy as np | |||
X_train, y_train, X_test, y_test = tl.files.load_imdb_dataset('data', nb_words=20000, test_split=0.2) | |||
Seq_Len = 200 | |||
vocab_size = len(X_train) + 1 | |||
class imdbdataset(Dataset): | |||
def __init__(self, X, y): | |||
self.X = X | |||
self.y = y | |||
def __getitem__(self, index): | |||
data = self.X[index] | |||
data = np.concatenate([data[:Seq_Len], [0] * (Seq_Len - len(data))]).astype('int64') # set | |||
label = self.y[index].astype('int64') | |||
return data, label | |||
def __len__(self): | |||
return len(self.y) | |||
class ImdbNet(Module): | |||
def __init__(self): | |||
super(ImdbNet, self).__init__() | |||
self.embedding = Embedding(vocabulary_size=vocab_size, embedding_size=64) | |||
self.lstm = LSTM(input_size=64, hidden_size=64) | |||
self.dense1 = Dense(in_channels=64, n_units=64, act=tl.ReLU) | |||
self.dense2 = Dense(in_channels=64, n_units=2) | |||
def forward(self, x): | |||
x = self.embedding(x) | |||
x, _ = self.lstm(x) | |||
x = tl.ops.reduce_mean(x, axis=1) | |||
x = self.dense1(x) | |||
x = self.dense2(x) | |||
return x | |||
n_epoch = 5 | |||
batch_size = 64 | |||
print_freq = 2 | |||
train_dataset = imdbdataset(X=X_train, y=y_train) | |||
train_dataset = tl.dataflow.FromGenerator( | |||
train_dataset, output_types=[tl.int64, tl.int64], column_names=['data', 'label'] | |||
) | |||
train_loader = tl.dataflow.Dataloader(train_dataset, batch_size=batch_size, shuffle=True) | |||
net = ImdbNet() | |||
train_weights = net.trainable_weights | |||
optimizer = tl.optimizers.Adam(1e-3) | |||
metric = tl.metric.Accuracy() | |||
loss_fn = tl.cost.softmax_cross_entropy_with_logits | |||
model = tl.models.Model(network=net, loss_fn=loss_fn, optimizer=optimizer, metrics=metric) | |||
model.train(n_epoch=n_epoch, train_dataset=train_loader, print_freq=print_freq, print_train_batch=True) |
@@ -141,3 +141,7 @@ from .load_backend import Maximum | |||
from .load_backend import Meshgrid | |||
from .load_backend import BatchToSpace | |||
from .load_backend import DepthToSpace | |||
from .load_backend import rnncell | |||
from .load_backend import lstmcell | |||
from .load_backend import grucell | |||
from .load_backend import rnnbase |
@@ -720,15 +720,26 @@ def reduce_min(input_tensor, axis=None): | |||
class Pad(Cell): | |||
def __init__(self, paddings, mode="REFLECT"): | |||
def __init__(self, paddings, mode="REFLECT", constant_values=0): | |||
super(Pad, self).__init__() | |||
if mode not in ["REFLECT", "SYMMETRIC"]: | |||
if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']: | |||
raise Exception("Unsupported mode: {}".format(mode)) | |||
self.pad = P.MirrorPad(mode=mode) | |||
self.paddings = Tensor(paddings) | |||
if mode == 'CONSTANT': | |||
self.pad = P.Pad(paddings) | |||
if constant_values-0 == 0: | |||
pass | |||
else: | |||
raise NotImplementedError("constant_values can only be equal to 0.") | |||
else: | |||
self.pad = P.MirrorPad(mode=mode) | |||
self.paddings = Tensor(np.array(self.paddings)) | |||
self.mode = mode | |||
def construct(self, x): | |||
return self.pad(x, self.paddings) | |||
if self.mode == 'CONSTANT': | |||
return self.pad(x) | |||
else: | |||
return self.pad(x, self.paddings) | |||
def pad(tensor, paddings, mode='CONSTANT', constant_values=0): | |||
@@ -1800,3 +1800,152 @@ class DorefaConv2D(Cell): | |||
outputs = nchw_to_nhwc(outputs) | |||
return outputs | |||
class rnncell(Cell): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act): | |||
super(rnncell, self).__init__() | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.act_fn = P.ReLU() if act == 'relu' else P.Tanh() | |||
self.transpose = P.Transpose() | |||
def construct(self, input, h): | |||
self.weight_ih = self.transpose(self.weight_ih, (1, 0)) | |||
i2h = P.matmul(input, self.weight_ih) | |||
if self.bias_ih is not None: | |||
i2h += self.bias_ih | |||
self.weight_hh = self.transpose(self.weight_hh, (1, 0)) | |||
h2h = P.matmul(h, self.weight_hh) | |||
if self.bias_hh is not None: | |||
h2h += self.bias_hh | |||
h = self.act_fn(i2h + h2h) | |||
return h, h | |||
class lstmcell(Cell): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh): | |||
super(lstmcell, self).__init__() | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.gate_act_fn = P.Sigmoid() | |||
self.act_fn = P.Tanh() | |||
self.transpose = P.Transpose() | |||
self.split = P.Split(axis=-1, output_num=4) | |||
def construct(self, input, h, c): | |||
self.weight_ih = self.transpose(self.weight_ih, (1, 0)) | |||
gates = P.matmul(input, self.weight_ih) | |||
if self.bias_ih is not None: | |||
gates += self.bias_ih | |||
self.weight_hh = self.transpose(self.weight_hh, (1, 0)) | |||
gates += P.matmul(h, self.weight_hh) | |||
if self.bias_hh is not None: | |||
gates += self.bias_hh | |||
gate_slices = self.split(gates) | |||
i = self.gate_act_fn(gate_slices[0]) | |||
f = self.gate_act_fn(gate_slices[1]) | |||
o = self.gate_act_fn(gate_slices[3]) | |||
c = f * c + i * self.act_fn(gate_slices[2]) | |||
h = o * self.act_fn(c) | |||
return h, h, c | |||
class grucell(Cell): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh): | |||
super(grucell, self).__init__() | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.gate_act_fn = P.Sigmoid() | |||
self.act_fn = P.Tanh() | |||
self.transpose = P.Transpose() | |||
self.split = P.Split(axis=-1, output_num=3) | |||
def construct(self, input, h): | |||
self.weight_ih = self.transpose(self.weight_ih, (1, 0)) | |||
x_gates = P.matmul(input, self.weight_ih) | |||
if self.bias_ih is not None: | |||
x_gates += self.bias_ih | |||
self.weight_hh = self.transpose(self.weight_hh, (1, 0)) | |||
h_gates = P.matmul(h, self.weight_hh) | |||
if self.bias_hh is not None: | |||
h_gates += self.bias_hh | |||
x_r, x_z, x_c = self.split(x_gates) | |||
h_r, h_z, h_c = self.split(h_gates) | |||
r = self.gate_act_fn(x_r + h_r) | |||
z = self.gate_act_fn(x_r + h_z) | |||
c = self.act_fn(x_c + r * h_c) | |||
h = (h - c) * z + c | |||
return h, h | |||
class rnnbase(Cell): | |||
def __init__( | |||
self, | |||
mode, | |||
input_size, | |||
hidden_size, | |||
num_layers, | |||
bias, | |||
batch_first, | |||
dropout, | |||
bidirectional, | |||
is_train, | |||
): | |||
super(rnnbase, self).__init__() | |||
self.mode = mode | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.bidirect = 2 if bidirectional else 1 | |||
self.batch_first = batch_first | |||
if mode == 'LSTM': | |||
self.lstm = ms.nn.LSTM( | |||
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, has_bias=bias, | |||
batch_first=batch_first, dropout=dropout, bidirectional=bidirectional | |||
) | |||
elif mode == 'GRU': | |||
raise NotImplementedError | |||
elif mode == 'RNN_TANH': | |||
raise NotImplementedError | |||
elif mode == 'RNN_RELU': | |||
raise NotImplementedError | |||
self.zeros = P.Zeros() | |||
def construct(self, input, states): | |||
input_shape = input.shape | |||
input_dtype = input.dtype | |||
if self.mode == 'LSTM': | |||
if self.batch_first: | |||
batch_size = input_shape[0] | |||
else: | |||
batch_size = input_shape[1] | |||
if states is None: | |||
h = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype) | |||
c = self.zeros((self.bidirect * self.num_layers, batch_size, self.hidden_size), input_dtype) | |||
states = (h, c) | |||
output, (h, c) = self.lstm(input, states) | |||
return output, (h, c) |
@@ -4,6 +4,7 @@ | |||
from __future__ import absolute_import, division, print_function | |||
import paddle as pd | |||
import paddle.nn as nn | |||
import numpy as np | |||
_dtypeDict = ["float16", "float32", "float64", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"] | |||
# TODO NotImplemented | |||
@@ -325,7 +326,7 @@ class Reshape(object): | |||
self.shape = shape | |||
def __call__(self, tensor): | |||
raise NotImplementedError | |||
return pd.reshape(tensor, shape=self.shape) | |||
def reshape(tensor, shape): | |||
@@ -352,7 +353,7 @@ class Concat(object): | |||
self.axis = axis | |||
def __call__(self, values): | |||
raise NotImplementedError | |||
return pd.concat(values, axis=self.axis) | |||
def concat(values, axis): | |||
@@ -369,7 +370,7 @@ def concat(values, axis): | |||
------- | |||
A Tensor resulting from concatenation of the input tensors. | |||
""" | |||
raise NotImplementedError | |||
return pd.concat(values, axis) | |||
def convert_to_tensor(value, dtype=float32): | |||
@@ -407,16 +408,16 @@ def sqrt(x): | |||
------- | |||
A Tensor. Has the same type as x. | |||
""" | |||
raise NotImplementedError | |||
return pd.sqrt(x) | |||
class ReduceSum(object): | |||
def __init__(self, axis): | |||
pass | |||
self.axis = axis | |||
def construct(self, input): | |||
pass | |||
return pd.sum(input, axis=self.axis) | |||
class ReduceMean(object): | |||
@@ -447,7 +448,7 @@ def reduce_mean(input_tensor, axis=None): | |||
The reduced tensor. | |||
""" | |||
raise NotImplementedError | |||
return pd.mean(input_tensor, axis) | |||
class ReduceMax(object): | |||
@@ -478,7 +479,7 @@ def reduce_max(input_tensor, axis=None): | |||
The reduced tensor. | |||
""" | |||
raise NotImplementedError | |||
return pd.max(input_tensor, axis) | |||
def reduce_min(input_tensor, axis=None): | |||
@@ -499,21 +500,47 @@ def reduce_min(input_tensor, axis=None): | |||
------- | |||
The reduced tensor. | |||
""" | |||
raise NotImplementedError | |||
return pd.min(input_tensor, axis) | |||
class Pad(object): | |||
def __init__(self, paddings, mode="REFLECT"): | |||
def __init__(self, paddings, mode="REFLECT", constant_values=0): | |||
if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']: | |||
raise Exception("Unsupported mode: {}".format(mode)) | |||
if mode == 'SYMMETRIC': | |||
mode = 'EDGE' | |||
raise NotImplementedError | |||
self.paddings = paddings | |||
self.mode = mode | |||
self.mode = mode.lower() | |||
self.constant_values = constant_values | |||
def __call__(self, x): | |||
raise NotImplementedError | |||
if len(x.shape) == 3: | |||
data_format = 'NLC' | |||
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format) | |||
elif len(x.shape) == 4: | |||
data_format = 'NHWC' | |||
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format) | |||
elif len(x.shape) == 5: | |||
data_format = 'NDHWC' | |||
self.paddings = self.correct_paddings(len(x.shape), self.paddings, data_format) | |||
else: | |||
raise NotImplementedError('Please check the input shape.') | |||
return pd.nn.functional.pad(x, self.paddings, self.mode, value=self.constant_values, data_format=data_format) | |||
def correct_paddings(self, in_shape, paddings, data_format): | |||
if in_shape == 3 and data_format == 'NLC': | |||
correct_output = [paddings[1][0], paddings[1][1]] | |||
elif in_shape == 4 and data_format == 'NHWC': | |||
correct_output = [paddings[2][0], paddings[2][1], | |||
paddings[1][0], paddings[1][1]] | |||
elif in_shape == 5 and data_format == 'NDHWC': | |||
correct_output = [paddings[3][0], paddings[3][1], | |||
paddings[2][0], paddings[2][1], | |||
paddings[1][0], paddings[1][1]] | |||
else: | |||
raise NotImplementedError('Does not support channels first') | |||
return correct_output | |||
def pad(tensor, paddings, mode='CONSTANT', constant_values=0): | |||
@@ -535,7 +562,7 @@ def pad(tensor, paddings, mode='CONSTANT', constant_values=0): | |||
------- | |||
A Tensor. Has the same type as tensor. | |||
""" | |||
raise NotImplementedError | |||
return Pad(paddings, mode, constant_values)(tensor) | |||
class Unstack(object): | |||
@@ -545,7 +572,7 @@ class Unstack(object): | |||
self.num = num | |||
def __call__(self, values): | |||
raise NotImplementedError | |||
return pd.unstack(values, self.axis, self.num) | |||
class Stack(object): | |||
@@ -554,7 +581,7 @@ class Stack(object): | |||
self.axis = axis | |||
def __call__(self, values): | |||
raise NotImplementedError | |||
return pd.stack(values, self.axis) | |||
def stack(values, axis=0): | |||
@@ -563,7 +590,7 @@ def stack(values, axis=0): | |||
Parameters | |||
---------- | |||
values : list | |||
values : list or tuple | |||
A list of Tensor objects with the same shape and type. | |||
axis : int | |||
An int. The axis to stack along. Defaults to the first dimension. | |||
@@ -573,7 +600,7 @@ def stack(values, axis=0): | |||
------- | |||
A stacked Tensor with the same type as values. | |||
""" | |||
raise NotImplementedError | |||
return pd.stack(values, axis=axis) | |||
class Meshgrid(object): | |||
@@ -583,10 +610,10 @@ class Meshgrid(object): | |||
self.index = indexing | |||
def __call__(self, inputs): | |||
pass | |||
return pd.meshgrid(inputs) | |||
def meshgrid(x, y): | |||
def meshgrid(*args, **kwargs): | |||
""" | |||
Broadcasts parameters for evaluation on an N-D grid. | |||
@@ -602,7 +629,7 @@ def meshgrid(x, y): | |||
A list of N Tensors with rank N. | |||
""" | |||
pass | |||
return pd.meshgrid(*args, **kwargs) | |||
def range(start, limit=None, delta=1, dtype=None): | |||
@@ -626,16 +653,19 @@ def range(start, limit=None, delta=1, dtype=None): | |||
------- | |||
An 1-D Tensor of type dtype. | |||
""" | |||
raise NotImplementedError | |||
return pd.arange(start, step=delta) | |||
class ExpandDims(object): | |||
def __init__(self, axis): | |||
pass | |||
self.axis = axis | |||
def construct(self, input): | |||
pass | |||
input = convert_to_numpy(input) | |||
output = np.expand_dims(input, axis=self.axis) | |||
output = convert_to_tensor(output) | |||
return output | |||
def expand_dims(input, axis): | |||
@@ -655,7 +685,10 @@ def expand_dims(input, axis): | |||
A Tensor with the same data as input, but its shape has an additional dimension of size 1 added. | |||
""" | |||
raise NotImplementedError | |||
input = convert_to_numpy(input) | |||
output = np.expand_dims(input, axis=axis) | |||
output = convert_to_tensor(output) | |||
return output | |||
class Tile(object): | |||
@@ -664,7 +697,7 @@ class Tile(object): | |||
pass | |||
def __call__(self, input, multiples): | |||
raise NotImplementedError | |||
return pd.tile(input, multiples) | |||
def tile(input, multiples): | |||
@@ -683,16 +716,16 @@ def tile(input, multiples): | |||
------- | |||
A Tensor. Has the same type as input. | |||
""" | |||
raise NotImplementedError | |||
return pd.tile(input, multiples) | |||
class Cast(object): | |||
def __init__(self, dtype): | |||
pass | |||
self.dtype = dtype | |||
def __call__(self, input): | |||
pass | |||
return pd.cast(input, self.dtype) | |||
def cast(x, dtype): | |||
@@ -711,7 +744,7 @@ def cast(x, dtype): | |||
------- | |||
A Tensor or SparseTensor or IndexedSlices with same shape as x and same type as dtype. | |||
""" | |||
raise NotImplementedError | |||
return pd.cast(x, dtype) | |||
class Transpose(object): | |||
@@ -722,7 +755,7 @@ class Transpose(object): | |||
raise ("The conjugate Parameters not supported") | |||
def __call__(self, a): | |||
raise NotImplementedError | |||
return pd.transpose(a, self.perm) | |||
def transpose(a, perm=None, conjugate=False): | |||
@@ -743,7 +776,7 @@ def transpose(a, perm=None, conjugate=False): | |||
A transposed Tensor. | |||
""" | |||
raise NotImplementedError | |||
return pd.transpose(a, perm) | |||
def gather_nd(params, indices, batch_dims=0): | |||
@@ -764,7 +797,7 @@ def gather_nd(params, indices, batch_dims=0): | |||
A Tensor. Has the same type as params. | |||
""" | |||
pass | |||
return pd.gather_nd(params, indices) | |||
def clip_by_value(t, clip_value_min, clip_value_max): | |||
@@ -785,7 +818,7 @@ def clip_by_value(t, clip_value_min, clip_value_max): | |||
A clipped Tensor or IndexedSlices. | |||
""" | |||
pass | |||
return pd.clip(t, clip_value_min, clip_value_max) | |||
def split(value, num_or_size_splits, axis=0, num=None): | |||
@@ -796,7 +829,7 @@ def split(value, num_or_size_splits, axis=0, num=None): | |||
---------- | |||
value : tensor | |||
The Tensor to split. | |||
num_or_size_splits : list | |||
num_or_size_splits : list or tuple | |||
Either an integer indicating the number of splits along split_dim or a 1-D integer Tensor or | |||
Python list containing the sizes of each output tensor along split_dim. | |||
axis : int | |||
@@ -808,33 +841,33 @@ def split(value, num_or_size_splits, axis=0, num=None): | |||
------- | |||
Tensor objects resulting from splitting value. | |||
""" | |||
pass | |||
pd.split(value, num_or_size_splits, axis) | |||
class Floor(object): | |||
def __call__(self, *args, **kwargs): | |||
raise NotImplementedError | |||
def __call__(self, x): | |||
return pd.floor(x) | |||
def floor(x): | |||
raise NotImplementedError | |||
return pd.floor(x) | |||
def gather(params, indices): | |||
raise NotImplementedError | |||
return pd.gather(params, indices) | |||
def linspace(start, stop, num): | |||
raise NotImplementedError | |||
return pd.linspace(start, stop, num) | |||
def slice(inputs, starts, sizes): | |||
raise NotImplementedError | |||
return pd.slice(inputs, starts=starts, ends=sizes) | |||
def add_n(inputs): | |||
raise NotImplementedError | |||
return pd.add_n(inputs) | |||
class OneHot(object): | |||
@@ -844,17 +877,19 @@ class OneHot(object): | |||
self.dtype = dtype | |||
def __call__(self, indices): | |||
raise NotImplementedError | |||
output = pd.nn.functional.one_hot(indices, self.depth) | |||
return output | |||
class L2Normalize(object): | |||
def __init__(self, axis=None, epsilon=1e-12): | |||
super(L2Normalize, self).__init__() | |||
pass | |||
self.axis = axis | |||
self.epsilon = epsilon | |||
def __call__(self, input, *args, **kwargs): | |||
pass | |||
def __call__(self, input): | |||
return pd.nn.functional.normalize(x=input, p=2, axis=self.axis, epsilon=self.epsilon) | |||
class EmbeddingLookup(object): | |||
@@ -862,7 +897,7 @@ class EmbeddingLookup(object): | |||
def __init__(self, max_norm=None): | |||
self.max_norm = max_norm | |||
def __call__(self, params, ids, *args, **kwargs): | |||
def __call__(self, params, ids): | |||
pass | |||
@@ -3,6 +3,12 @@ | |||
import paddle as pd | |||
import paddle.nn.functional as F | |||
import numpy as np | |||
import paddle.fluid as fluid | |||
from paddle.nn import initializer as I | |||
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as | |||
from paddle.fluid.data_feeder import convert_dtype | |||
from paddle.fluid.dygraph import Layer | |||
def padding_format(padding): | |||
@@ -1308,3 +1314,386 @@ class DorefaConv2D(object): | |||
def __call__(self, inputs, filters): | |||
raise NotImplementedError | |||
class rnncell(object): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias, act): | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.bias = bias | |||
self.act_fn = F.relu if act == 'relu' else F.tanh | |||
def __call__(self, input, h): | |||
i2h = pd.matmul(input, self.weight_ih, transpose_y=True) | |||
if self.bias_ih is not None: | |||
i2h += self.bias_ih | |||
h2h = pd.matmul(h, self.weight_hh, transpose_y=True) | |||
if self.bias_hh is not None: | |||
h2h += self.bias_hh | |||
h = self.act_fn(i2h + h2h) | |||
return h, h | |||
class lstmcell(object): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias): | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.bias = bias | |||
self.gate_act_fn = F.sigmoid | |||
self.act_fn = F.tanh | |||
def __call__(self, inputs, h, c): | |||
gates = pd.matmul(inputs, self.weight_ih, transpose_y=True) | |||
if self.bias_ih is not None: | |||
gates += self.bias_ih | |||
gates += pd.matmul(h, self.weight_hh, transpose_y=True) | |||
if self.bias_hh is not None: | |||
gates += self.bias_hh | |||
gates_slices = pd.split(gates, num_or_sections=4, axis=-1) | |||
i = self.gate_act_fn(gates_slices[0]) | |||
f = self.gate_act_fn(gates_slices[1]) | |||
o = self.gate_act_fn(gates_slices[3]) | |||
c = f * c + i * self.act_fn(gates_slices[2]) | |||
h = o * self.act_fn(c) | |||
return h, h, c | |||
class grucell(object): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, bias): | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.bias = bias | |||
self.gate_act_fn = F.sigmoid | |||
self.act_fn = F.tanh | |||
def __call__(self, input, h): | |||
x_gates = pd.matmul(input, self.weight_ih, transpose_y=True) | |||
if self.bias_ih is not None: | |||
x_gates = x_gates + self.bias_ih | |||
h_gates = pd.matmul(h, self.weight_hh, transpose_y=True) | |||
if self.bias_hh is not None: | |||
h_gates = h_gates + self.bias_hh | |||
x_r, x_z, x_c = pd.split(x_gates, num_or_sections=3, axis=-1) | |||
h_r, h_z, h_c = pd.split(h_gates, num_or_sections=3, axis=-1) | |||
r = self.gate_act_fn(x_r + h_r) | |||
z = self.gate_act_fn(x_z + h_z) | |||
c = self.act_fn(x_c + r * h_c) # apply reset gate after mm | |||
h = (h - c) * z + c | |||
return h, h | |||
def split_states(states, bidirectional=False, state_components=1): | |||
r""" | |||
Split states of RNN network into possibly nested list or tuple of | |||
states of each RNN cells of the RNN network. | |||
Parameters: | |||
states (Tensor|tuple|list): the concatenated states for RNN network. | |||
When `state_components` is 1, states in a Tensor with shape | |||
`(L*D, N, C)` where `L` is the number of layers of the RNN | |||
network, `D` is the number of directions of the RNN network(1 | |||
for unidirectional RNNs and 2 for bidirectional RNNs), `N` is | |||
the batch size of the input to the RNN network, `C` is the | |||
hidden size of the RNN network. | |||
When `state_components` is larger than 1, `states` is a tuple of | |||
`state_components` Tensors that meet the requirements described | |||
above. | |||
For SimpleRNNs and GRUs, `state_components` is 1, and for LSTMs, | |||
`state_components` is 2. | |||
bidirectional (bool): whether the state is of a bidirectional RNN | |||
network. Defaults to False. | |||
state_components (int): the number of the components of the states. see | |||
`states` above. Defaults to 1. | |||
Returns: | |||
A nested list or tuple of RNN cell states. | |||
If `bidirectional` is True, it can be indexed twice to get an RNN | |||
cell state. The first index indicates the layer, the second index | |||
indicates the direction. | |||
If `bidirectional` is False, it can be indexed once to get an RNN | |||
cell state. The index indicates the layer. | |||
Note that if `state_components` is larger than 1, an RNN cell state | |||
can be indexed one more time to get a tensor of shape(N, C), where | |||
`N` is the batch size of the input to the RNN cell, and `C` is the | |||
hidden size of the RNN cell. | |||
""" | |||
if state_components == 1: | |||
states = pd.unstack(states) | |||
if not bidirectional: | |||
return states | |||
else: | |||
return list(zip(states[::2], states[1::2])) | |||
else: | |||
assert len(states) == state_components | |||
states = tuple([pd.unstack(item) for item in states]) | |||
if not bidirectional: | |||
return list(zip(*states)) | |||
else: | |||
states = list(zip(*states)) | |||
return list(zip(states[::2], states[1::2])) | |||
def concat_states(states, bidirectional=False, state_components=1): | |||
r""" | |||
Concatenate a possibly nested list or tuple of RNN cell states into a | |||
compact form. | |||
Parameters: | |||
states (list|tuple): a possibly nested list or tuple of RNN cell | |||
states. | |||
If `bidirectional` is True, it can be indexed twice to get an | |||
RNN cell state. The first index indicates the layer, the second | |||
index indicates the direction. | |||
If `bidirectional` is False, it can be indexed once to get an RNN | |||
cell state. The index indicates the layer. | |||
Note that if `state_components` is larger than 1, an RNN cell | |||
state can be indexed one more time to get a tensor of shape(N, C), | |||
where `N` is the batch size of the input to the RNN cell, and | |||
`C` is the hidden size of the RNN cell. | |||
bidirectional (bool): whether the state is of a bidirectional RNN | |||
network. Defaults to False. | |||
state_components (int): the number of the components of the states. see | |||
`states` above. Defaults to 1. | |||
Returns: | |||
Concatenated states for RNN network. | |||
When `state_components` is 1, states in a Tensor with shape | |||
`(L\*D, N, C)` where `L` is the number of layers of the RNN | |||
network, `D` is the number of directions of the RNN network(1 for | |||
unidirectional RNNs and 2 for bidirectional RNNs), `N` is the batch | |||
size of the input to the RNN network, `C` is the hidden size of the | |||
RNN network. | |||
""" | |||
if state_components == 1: | |||
return pd.stack(flatten(states)) | |||
else: | |||
states = flatten(states) | |||
componnets = [] | |||
for i in range(state_components): | |||
componnets.append(states[i::state_components]) | |||
return tuple([pd.stack(item) for item in componnets]) | |||
class rnnbase(Layer): | |||
def __init__( | |||
self, | |||
mode, | |||
input_size, | |||
hidden_size, | |||
num_layers, | |||
bias, | |||
batch_first, | |||
dropout, | |||
bidirectional, | |||
is_train, | |||
): | |||
super(rnnbase, self).__init__() | |||
self.mode = mode | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.time_major = False if batch_first else True | |||
self.dropout = dropout | |||
self.bidirect = 2 if bidirectional else 1 | |||
self.state_components = 2 if mode == 'LSTM' else 1 | |||
self.rnn = pd.nn.LayerList() | |||
RNN = pd.nn.RNN | |||
BiRNN = pd.nn.BiRNN | |||
weight_ih_attr = None | |||
weight_hh_attr = None | |||
if bias: | |||
bias_ih_attr = None | |||
bias_hh_attr = None | |||
else: | |||
bias_ih_attr = False | |||
bias_hh_attr = False | |||
kwargs = { | |||
"weight_ih_attr": weight_ih_attr, | |||
"weight_hh_attr": weight_hh_attr, | |||
"bias_ih_attr": bias_ih_attr, | |||
"bias_hh_attr": bias_hh_attr | |||
} | |||
if mode == "LSTM": | |||
rnn_cls = pd.nn.LSTMCell | |||
elif mode == "GRU": | |||
rnn_cls = pd.nn.GRUCell | |||
elif mode == 'RNN_TANH': | |||
rnn_cls = pd.nn.SimpleRNNCell | |||
kwargs["activation"] = 'tanh' | |||
elif mode == 'RNN_RELU': | |||
rnn_cls = pd.nn.SimpleRNNCell | |||
kwargs["activation"] = 'relu' | |||
if not bidirectional: | |||
is_reverse = False | |||
cell = rnn_cls(input_size, hidden_size, **kwargs) | |||
self.rnn.append(RNN(cell, is_reverse, self.time_major)) | |||
for i in range(1, num_layers): | |||
cell = rnn_cls(hidden_size, hidden_size, **kwargs) | |||
self.rnn.append(RNN(cell, is_reverse, self.time_major)) | |||
else: | |||
cell_fw = rnn_cls(input_size, hidden_size, **kwargs) | |||
cell_bw = rnn_cls(input_size, hidden_size, **kwargs) | |||
self.rnn.append(BiRNN(cell_fw, cell_bw, self.time_major)) | |||
for i in range(1, num_layers): | |||
cell_fw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) | |||
cell_bw = rnn_cls(2 * hidden_size, hidden_size, **kwargs) | |||
self.rnn.append(BiRNN(cell_fw, cell_bw, self.time_major)) | |||
self.could_use_cudnn = True | |||
self.could_use_cudnn &= len(self.rnn.parameters()) == num_layers * 4 * self.bidirect | |||
param_names = [] | |||
for layer in range(self.num_layers): | |||
for direction in range(self.bidirect): | |||
suffix = '_reverse' if direction == 1 else '' | |||
param_names.extend(['weight_ih_l{}{}', 'weight_hh_l{}{}']) | |||
if bias_ih_attr != False: param_names.append('bias_ih_l{}{}') | |||
if bias_hh_attr != False: param_names.append('bias_hh_l{}{}') | |||
param_names = [x.format(layer, suffix) for x in param_names] | |||
for name, param in zip(param_names, self.rnn.parameters()): | |||
setattr(self.rnn, name, param) | |||
self.flatten_parameters() | |||
def flatten_parameters(self): | |||
""" | |||
Resets parameter data pointer to address in continuous memory block for | |||
cudnn usage. | |||
""" | |||
if self.could_use_cudnn: | |||
# layer.parameters() is depth first and ordered | |||
# for i in layer: for j in direct: w_ih, w_hh, b_ih, b_hh | |||
# need to reorganize to cudnn param layout: | |||
# all bias following all weights | |||
params = self.rnn.parameters(include_sublayers=False) | |||
shape = [np.prod(param.shape) for param in params] | |||
self._all_weights = [None] * len(params) | |||
for i, param in enumerate(params): | |||
offset = 0 if i % 4 < 2 else (2 * self.num_layers * self.bidirect) | |||
layer_idx = i // 4 | |||
self._all_weights[offset + layer_idx * 2 + i % 2] = param | |||
# Wrap using a list to avoid registed into params and saving, maybe | |||
# need a better way to handle this later. Use `create_parameter` to | |||
# add both to main_program and startup_program for static-graph. | |||
# Use Constant initializer to avoid make effect on random generator. | |||
self._flat_weight = [ | |||
self.rnn.create_parameter( | |||
shape=[np.sum(shape)], dtype=params[0].dtype, default_initializer=I.Constant(0.0) | |||
) | |||
] | |||
# dropout state may also can be hided and avoid saving | |||
# should dropout state be persistable for static-graph | |||
self._dropout_state = self.rnn.create_variable(dtype=fluid.core.VarDesc.VarType.UINT8) | |||
# for static-graph, append coalesce_tensor into startup program | |||
with fluid.program_guard(fluid.default_startup_program(), fluid.default_startup_program()): | |||
with pd.framework.no_grad(): | |||
self.rnn._helper.append_op( | |||
type="coalesce_tensor", inputs={"Input": self._all_weights}, outputs={ | |||
"Output": self._all_weights, | |||
"FusedOutput": self._flat_weight | |||
}, attrs={ | |||
"copy_data": True, | |||
"use_align": False, | |||
"dtype": params[0].dtype | |||
} | |||
) | |||
def _cudnn_impl(self, inputs, initial_states, sequence_length): | |||
if not self.time_major: | |||
inputs = pd.tensor.transpose(inputs, [1, 0, 2]) | |||
out = self.rnn._helper.create_variable_for_type_inference(inputs.dtype) | |||
state = [ | |||
self.rnn._helper.create_variable_for_type_inference(inputs.dtype) for i in range(self.state_components) | |||
] | |||
reserve = self.rnn._helper.create_variable_for_type_inference( | |||
dtype=fluid.core.VarDesc.VarType.UINT8, stop_gradient=True | |||
) | |||
inputs = { | |||
'Input': inputs, | |||
'WeightList': self._all_weights, | |||
'PreState': initial_states, | |||
'SequenceLength': sequence_length | |||
} | |||
attrs = { | |||
'dropout_prob': self.dropout, | |||
'is_bidirec': self.bidirect == 2, | |||
'input_size': self.input_size, | |||
'hidden_size': self.hidden_size, | |||
'num_layers': self.num_layers, | |||
'mode': self.mode, | |||
'is_test': not self.rnn.training | |||
} | |||
outputs = { | |||
'Out': out, | |||
'State': state, | |||
'Reserve': reserve, | |||
'DropoutState': self._dropout_state, | |||
} | |||
self.rnn._helper.append_op(type="rnn", inputs=inputs, outputs=outputs, attrs=attrs) | |||
out = pd.tensor.transpose(out, [1, 0, 2]) if not self.time_major else out | |||
return out, tuple(state) if len(state) > 1 else state[0] | |||
def forward(self, inputs, initial_states=None, sequence_length=None): | |||
batch_index = 1 if self.time_major else 0 | |||
dtype = inputs.dtype | |||
if initial_states is None: | |||
state_shape = [self.num_layers * self.bidirect, -1, self.hidden_size] | |||
if self.state_components == 1: | |||
initial_states = fluid.layers.fill_constant_batch_size_like( | |||
inputs, state_shape, dtype, 0, batch_index, 1 | |||
) | |||
else: | |||
initial_states = tuple( | |||
[ | |||
fluid.layers.fill_constant_batch_size_like(inputs, state_shape, dtype, 0, batch_index, 1) | |||
for _ in range(self.state_components) | |||
] | |||
) | |||
if self.could_use_cudnn: | |||
# Add CPU kernel and dispatch in backend later | |||
return self._cudnn_impl(inputs, initial_states, sequence_length) | |||
states = split_states(initial_states, self.bidirect == 2, self.state_components) | |||
final_states = [] | |||
for i, rnn_layer in enumerate(self.rnn): | |||
if i > 0: | |||
inputs = F.dropout(inputs, self.dropout, training=self.rnn.training, mode="upscale_in_train") | |||
outputs, final_state = rnn_layer(inputs, states[i], sequence_length) | |||
final_states.append(final_state) | |||
inputs = outputs | |||
final_states = concat_states(final_states, self.bidirect == 2, self.state_components) | |||
return outputs, final_states |
@@ -531,14 +531,15 @@ def reduce_min(input_tensor, axis=None): | |||
class Pad(object): | |||
def __init__(self, paddings, mode="REFLECT"): | |||
def __init__(self, paddings, mode="REFLECT", constant_values=0): | |||
if mode not in ['CONSTANT', 'REFLECT', 'SYMMETRIC']: | |||
raise Exception("Unsupported mode: {}".format(mode)) | |||
self.paddings = paddings | |||
self.mode = mode | |||
self.constant_values = constant_values | |||
def __call__(self, x): | |||
outputs = tf.pad(x, self.paddings, mode=self.mode, constant_values=0) | |||
outputs = tf.pad(x, self.paddings, mode=self.mode, constant_values=self.constant_values) | |||
return outputs | |||
@@ -884,7 +885,7 @@ class OneHot(object): | |||
self.axis = axis | |||
self.dtype = dtype | |||
def __call__(self, inputs, *args, **kwargs): | |||
def __call__(self, inputs): | |||
outputs = tf.one_hot( | |||
inputs, self.depth, on_value=self.on_value, off_value=self.off_value, axis=self.axis, dtype=self.dtype | |||
) | |||
@@ -907,7 +908,7 @@ class EmbeddingLookup(object): | |||
def __init__(self, max_norm=None): | |||
self.max_norm = max_norm | |||
def __call__(self, params, ids, *args, **kwargs): | |||
def __call__(self, params, ids): | |||
outputs = tf.nn.embedding_lookup(params=params, ids=ids, max_norm=self.max_norm) | |||
return outputs | |||
@@ -6,6 +6,7 @@ from tensorflow.python.framework import ops | |||
from tensorflow.python.ops import math_ops | |||
from tensorflow.python.training import moving_averages | |||
from math import floor, ceil | |||
import numpy as np | |||
# loss function | |||
sparse_softmax_cross_entropy_with_logits = tf.nn.sparse_softmax_cross_entropy_with_logits | |||
sigmoid_cross_entropy_with_logits = tf.nn.sigmoid_cross_entropy_with_logits | |||
@@ -1913,3 +1914,342 @@ class DorefaConv2D(object): | |||
) | |||
return outputs | |||
class rnncell(object): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act): | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.act_fn = tf.nn.relu if act == 'relu' else tf.nn.tanh | |||
def __call__(self, input, h, c=None): | |||
i2h = tf.matmul(input, self.weight_ih, transpose_b=True) | |||
if self.bias_ih is not None: | |||
i2h += self.bias_ih | |||
h2h = tf.matmul(h, self.weight_hh, transpose_b=True) | |||
if self.bias_hh is not None: | |||
h2h += self.bias_hh | |||
h = self.act_fn(i2h + h2h) | |||
return h, h | |||
class lstmcell(object): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act=None): | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.gate_act_fn = tf.sigmoid | |||
self.act_fn = tf.tanh | |||
def __call__(self, input, h, c): | |||
gates = tf.matmul(input, self.weight_ih, transpose_b=True) | |||
if self.bias_ih is not None: | |||
gates = gates + self.bias_ih | |||
gates += tf.matmul(h, self.weight_hh, transpose_b=True) | |||
if self.bias_hh is not None: | |||
gates += self.bias_hh | |||
gate_slices = tf.split(gates, num_or_size_splits=4, axis=-1) | |||
i = self.gate_act_fn(gate_slices[0]) | |||
f = self.gate_act_fn(gate_slices[1]) | |||
o = self.gate_act_fn(gate_slices[3]) | |||
c = f * c + i * self.act_fn(gate_slices[2]) | |||
h = o * self.act_fn(c) | |||
return h, h, c | |||
class grucell(object): | |||
def __init__(self, weight_ih, weight_hh, bias_ih, bias_hh, act=None): | |||
self.weight_ih = weight_ih | |||
self.weight_hh = weight_hh | |||
self.bias_ih = bias_ih | |||
self.bias_hh = bias_hh | |||
self.gate_act_fn = tf.sigmoid | |||
self.act_fn = tf.tanh | |||
def __call__(self, input, h, c=None): | |||
x_gates = tf.matmul(input, self.weight_ih, transpose_b=True) | |||
if self.bias_ih is not None: | |||
x_gates = x_gates + self.bias_ih | |||
h_gates = tf.matmul(h, self.weight_hh, transpose_b=True) | |||
if self.bias_hh is not None: | |||
h_gates = h_gates + self.bias_hh | |||
x_r, x_z, x_c = tf.split(x_gates, num_or_size_splits=3, axis=-1) | |||
h_r, h_z, h_c = tf.split(h_gates, num_or_size_splits=3, axis=-1) | |||
r = self.gate_act_fn(x_r + h_r) | |||
z = self.gate_act_fn(x_r + h_z) | |||
c = self.act_fn(x_c + r * h_c) | |||
h = (h - c) * z + c | |||
return h, h | |||
class rnnbase(object): | |||
def __init__( | |||
self, | |||
mode, | |||
input_size, | |||
hidden_size, | |||
num_layers, | |||
bias, | |||
batch_first, | |||
dropout, | |||
bidirectional, | |||
is_train, | |||
weights_fw, | |||
weights_bw, | |||
bias_fw, | |||
bias_bw, | |||
): | |||
self.mode = mode | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.bias = bias | |||
self.batch_first = batch_first | |||
self.dropout = float(dropout) | |||
self.train = is_train | |||
if not 0 <= dropout < 1: | |||
raise ValueError("dropout should be a number in range [0, 1).") | |||
if dropout > 0 and num_layers == 1: | |||
raise ValueError( | |||
"dropout option adds dropout after all but last " | |||
"recurrent layer, so non-zero dropout expects " | |||
"num_layers greater than 1, but got dropout={} and " | |||
"num_layers={}".format(dropout, num_layers) | |||
) | |||
self.bidirect = 2 if bidirectional else 1 | |||
self.weights_fw = weights_fw | |||
self.bias_fw = bias_fw | |||
self.weights_bw = weights_bw | |||
self.bias_bw = bias_bw | |||
# stdv = 1.0 / np.sqrt(self.hidden_size) | |||
# _init = tf.random_uniform_initializer(minval=-stdv, maxval=stdv) | |||
self.act_fn = None | |||
if mode == 'LSTM': | |||
# gate_size = 4 * hidden_size | |||
self.rnn_cell = lstmcell | |||
elif mode == 'GRU': | |||
# gate_size = 3 * hidden_size | |||
self.rnn_cell = grucell | |||
elif mode == 'RNN_TANH': | |||
# gate_size = hidden_size | |||
self.rnn_cell = rnncell | |||
self.act_fn = 'tanh' | |||
elif mode == 'RNN_RELU': | |||
# gate_size = hidden_size | |||
self.rnn_cell = rnncell | |||
self.act_fn = 'relu' | |||
# for layer in range(num_layers): | |||
# for direction in range(self.bidirect): | |||
# layer_input_size = input_size if layer==0 else hidden_size*self.bidirect | |||
# if direction == 0: | |||
# self.w_ih = tf.Variable(initial_value= _init(shape=(gate_size, layer_input_size)),name = 'weight_ih_l'+str(layer), trainable=True) | |||
# self.w_hh = tf.Variable(initial_value=_init(shape=(gate_size, hidden_size)), | |||
# name='weight_hh_l'+str(layer), trainable=True) | |||
# # self.w_ih = self.weights_init('weight_ih_l'+str(layer), shape = (gate_size, layer_input_size), init = _init) | |||
# # self.w_hh = self.weights_init('weight_ih_l' + str(layer), shape=(gate_size, hidden_size), | |||
# # init=_init) | |||
# self.weights_fw.append(self.w_ih) | |||
# self.weights_fw.append(self.w_hh) | |||
# if bias: | |||
# self.b_ih = tf.Variable(initial_value=_init(shape=(gate_size,)), | |||
# name='bias_ih_l'+str(layer), trainable=True) | |||
# self.b_hh = tf.Variable(initial_value=_init(shape=(gate_size,)), | |||
# name='bias_hh_l'+str(layer), trainable=True) | |||
# # self.b_ih = self.weights_init('bias_ih_l'+str(layer), shape=(gate_size,), init=_init) | |||
# # self.b_hh = self.weights_init('bias_hh_l'+str(layer), shape=(gate_size,), init=_init) | |||
# self.bias_fw.append(self.b_ih) | |||
# self.bias_fw.append(self.b_hh) | |||
# else: | |||
# self.w_ih = tf.Variable(initial_value= _init(shape=(gate_size, layer_input_size)),name = 'weight_ih_l'+str(layer)+'_reverse', trainable=True) | |||
# self.w_hh = tf.Variable(initial_value=_init(shape=(gate_size, hidden_size)), | |||
# name='weight_hh_l'+str(layer)+'_reverse', trainable=True) | |||
# # self.w_ih = self.weights_init('weight_ih_l'+str(layer)+'_reverse', shape = (gate_size, layer_input_size), init = _init) | |||
# # self.w_hh = self.weights_init('weight_hh_l'+str(layer)+'_reverse', shape=(gate_size, hidden_size), | |||
# # init=_init) | |||
# self.weights_bw.append(self.w_ih) | |||
# self.weights_bw.append(self.w_hh) | |||
# if bias: | |||
# self.b_ih = tf.Variable(initial_value=_init(shape=(gate_size,)), | |||
# name='bias_ih_l'+str(layer)+'_reverse', trainable=True) | |||
# self.b_hh = tf.Variable(initial_value=_init(shape=(gate_size,)), | |||
# name='bias_hh_l'+str(layer)+'_reverse', trainable=True) | |||
# # self.b_ih = self.weights_init('bias_ih_l'+str(layer)+'_reverse', shape=(gate_size,), init=_init) | |||
# # self.b_hh = self.weights_init('bias_hh_l'+str(layer)+'_reverse', shape=(gate_size,), init=_init) | |||
# self.bias_bw.append(self.b_ih) | |||
# self.bias_bw.append(self.b_hh) | |||
def _bi_rnn_forward(self, x, h, c=None): | |||
time_step, batch_size, input_size = x.shape | |||
h_out = [] | |||
c_out = [] | |||
y = [] | |||
pre_layer = x | |||
for i in range(self.num_layers): | |||
weight_ih_fw = self.weights_fw[2 * i] | |||
weight_hh_fw = self.weights_fw[2 * i + 1] | |||
weight_ih_bw = self.weights_bw[2 * i] | |||
weight_hh_bw = self.weights_bw[2 * i + 1] | |||
if self.bias: | |||
bias_ih_fw = self.bias_fw[2 * i] | |||
bias_hh_fw = self.bias_fw[2 * i + 1] | |||
bias_ih_bw = self.bias_bw[2 * i] | |||
bias_hh_bw = self.bias_bw[2 * i + 1] | |||
else: | |||
bias_ih_fw = None | |||
bias_hh_fw = None | |||
bias_ih_bw = None | |||
bias_hh_bw = None | |||
h_i_fw = h[i, :, :] | |||
h_i_bw = h[i + 1, :, :] | |||
if i != 0 and self.train: | |||
pre_layer = tf.nn.dropout(pre_layer, rate=self.dropout) | |||
if c is not None: | |||
c_i_fw = c[i, :, :] | |||
c_i_bw = c[i + 1, :, :] | |||
for j in range(time_step): | |||
input = pre_layer[j, :, :] | |||
cell_fw = self.rnn_cell(weight_ih_fw, weight_hh_fw, bias_ih_fw, bias_hh_fw, self.act_fn) | |||
cell_bw = self.rnn_cell(weight_ih_bw, weight_hh_bw, bias_ih_bw, bias_hh_bw, self.act_fn) | |||
bw_input = tf.reverse(input, axis=[0]) | |||
step_out_fw, h_i_fw, c_i_fw = cell_fw(input, h_i_fw, c_i_fw) | |||
step_out_bw, h_i_bw, c_i_bw = cell_bw(bw_input, h_i_bw, c_i_bw) | |||
step_out_bw = tf.reverse(step_out_bw, axis=[0]) | |||
step_out = tf.concat([step_out_fw, step_out_bw], axis=-1) | |||
y.append(step_out) | |||
h_out.append(h_i_fw) | |||
h_out.append(h_i_bw) | |||
c_out.append(c_i_fw) | |||
c_out.append(c_i_bw) | |||
pre_layer = tf.stack(y) | |||
y = [] | |||
else: | |||
for j in range(time_step): | |||
input = pre_layer[j, :, :] | |||
cell_fw = self.rnn_cell(weight_ih_fw, weight_hh_fw, bias_ih_fw, bias_hh_fw, self.act_fn) | |||
cell_bw = self.rnn_cell(weight_ih_bw, weight_hh_bw, bias_ih_bw, bias_hh_bw, self.act_fn) | |||
bw_input = tf.reverse(input, axis=[0]) | |||
step_out_fw, h_i_fw = cell_fw(input, h_i_fw) | |||
step_out_bw, h_i_bw = cell_bw(bw_input, h_i_bw) | |||
step_out_bw = tf.reverse(step_out_bw, axis=[0]) | |||
step_out = tf.concat([step_out_fw, step_out_bw], axis=-1) | |||
y.append(step_out) | |||
h_out.append(h_i_fw) | |||
h_out.append(h_i_bw) | |||
pre_layer = tf.stack(y) | |||
y = [] | |||
h_out = tf.stack(h_out) | |||
c_out = tf.stack(c_out) if c is not None else None | |||
return pre_layer, h_out, c_out | |||
def _rnn_forward(self, x, h, c=None): | |||
pre_layer = x | |||
h_out = [] | |||
c_out = [] | |||
y = [] | |||
time_step, batch_size, input_size = x.shape | |||
for i in range(self.num_layers): | |||
weight_ih = self.weights_fw[2 * i] | |||
weight_hh = self.weights_fw[2 * i + 1] | |||
if self.bias: | |||
bias_ih = self.bias_fw[2 * i] | |||
bias_hh = self.bias_fw[2 * i + 1] | |||
else: | |||
bias_ih = None | |||
bias_hh = None | |||
h_i = h[i, :, :] | |||
if i != 0 and self.train: | |||
pre_layer = tf.nn.dropout(pre_layer, rate=self.dropout) | |||
if c is not None: | |||
c_i = c[i, :, :] | |||
for j in range(time_step): | |||
input = pre_layer[j, :, :] | |||
cell = self.rnn_cell(weight_ih, weight_hh, bias_ih, bias_hh, self.act_fn) | |||
step_out, h_i, c_i = cell(input, h_i, c_i) | |||
y.append(step_out) | |||
h_out.append(h_i) | |||
c_out.append(c_i) | |||
pre_layer = tf.stack(y) | |||
y = [] | |||
else: | |||
for j in range(time_step): | |||
input = pre_layer[j, :, :] | |||
cell = self.rnn_cell(weight_hh, weight_ih, bias_ih, bias_hh, self.act_fn) | |||
step_out, h_i = cell(input, h_i) | |||
y.append(step_out) | |||
h_out.append(h_i) | |||
pre_layer = tf.stack(y) | |||
y = [] | |||
h_out = tf.stack(h_out) | |||
c_out = tf.stack(c_out) if c is not None else None | |||
return pre_layer, h_out, c_out | |||
def check_input(self, input_shape): | |||
if len(input_shape) != 3: | |||
raise ValueError("input must have 3 dimensions. But got {}.".format(len(input_shape))) | |||
if self.input_size != input_shape[-1]: | |||
raise ValueError( | |||
"The last dimension of input should be equal to input_size {}.But got {}".format( | |||
self.input_size, input_shape[-1] | |||
) | |||
) | |||
def check_hidden(self, h, batch_size): | |||
expected_hidden_size = (self.num_layers * self.bidirect, batch_size, self.hidden_size) | |||
if h.shape != expected_hidden_size: | |||
raise ValueError('Expected hidden size {}, got {}.'.format(expected_hidden_size, h.shape)) | |||
def __call__(self, input, states): | |||
if self.batch_first: | |||
input = tf.transpose(input, perm=(1, 0, 2)) | |||
input_dtype = input.dtype | |||
input_shape = input.shape | |||
time_step, batch_size, input_size = input_shape | |||
self.check_input(input_shape) | |||
if self.mode == "LSTM": | |||
if states is not None: | |||
h, c = states | |||
self.check_hidden(h, batch_size) | |||
self.check_hidden(c, batch_size) | |||
else: | |||
h = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype) | |||
c = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype) | |||
if self.bidirect == 1: | |||
y, new_h, new_c = self._rnn_forward(input, h, c) | |||
else: | |||
y, new_h, new_c = self._bi_rnn_forward(input, h, c) | |||
new_states = (new_h, new_c) | |||
else: | |||
if states is not None: | |||
h = states | |||
self.check_hidden(h, batch_size) | |||
else: | |||
h = tf.zeros(shape=(self.num_layers * self.bidirect, batch_size, self.hidden_size), dtype=input_dtype) | |||
if self.bidirect == 1: | |||
y, new_h, _ = self._rnn_forward(input, h) | |||
else: | |||
y, new_h, _ = self._bi_rnn_forward(input, h) | |||
new_states = new_h | |||
if self.batch_first: | |||
y = tf.transpose(y, perm=(1, 0, 2)) | |||
return y, new_states |
@@ -1963,6 +1963,8 @@ def save_npz(save_list=None, name='model.npz'): | |||
save_list_var = tf_variables_to_numpy(save_list) | |||
elif tl.BACKEND == 'mindspore': | |||
save_list_var = ms_variables_to_numpy(save_list) | |||
elif tl.BACKEND == 'paddle': | |||
save_list_var = pd_variables_to_numpy(save_list) | |||
else: | |||
raise NotImplementedError("This backend is not supported") | |||
# print(name, save_list_var) | |||
@@ -2050,6 +2052,11 @@ def assign_weights(weights, network): | |||
# net = Assign_net(network.all_weights[idx]) | |||
# net(assign_param) | |||
Assign()(network.all_weights[idx], assign_param) | |||
elif tl.BACKEND == 'paddle': | |||
for idx, param in enumerate(weights): | |||
assign_pd_variable(network.all_weights[idx], param) | |||
else: | |||
raise NotImplementedError ("This backend is not supported") | |||
return ops | |||
@@ -41,11 +41,13 @@ class PadLayer(Module): | |||
self, | |||
padding=None, | |||
mode='CONSTANT', | |||
constant_values=0, | |||
name=None, # 'pad_layer', | |||
): | |||
super().__init__(name) | |||
self.padding = padding | |||
self.mode = mode | |||
self.constant_values = constant_values | |||
logging.info("PadLayer %s: padding: %s mode: %s" % (self.name, self.padding, self.mode)) | |||
@@ -65,7 +67,7 @@ class PadLayer(Module): | |||
return s.format(classname=self.__class__.__name__, **self.__dict__) | |||
def build(self, inputs_shape=None): | |||
self.pad = tl.ops.Pad(paddings=self.padding, mode=self.mode) | |||
self.pad = tl.ops.Pad(paddings=self.padding, mode=self.mode, constant_values=self.constant_values) | |||
def forward(self, inputs): | |||
outputs = self.pad(inputs) | |||