| @@ -64,6 +64,8 @@ train_Y = data_Y[:train_size] | |||||
| test_X = data_X[train_size:] | test_X = data_X[train_size:] | ||||
| test_Y = data_Y[train_size:] | test_Y = data_Y[train_size:] | ||||
| train_Y.shape | |||||
| # 最后,我们需要将数据改变一下形状,因为 RNN 读入的数据维度是 (seq, batch, feature),所以要重新改变一下数据的维度,这里只有一个序列,所以 batch 是 1,而输入的 feature 就是我们希望依据的几个月份,这里我们定的是两个月份,所以 feature 就是 2. | # 最后,我们需要将数据改变一下形状,因为 RNN 读入的数据维度是 (seq, batch, feature),所以要重新改变一下数据的维度,这里只有一个序列,所以 batch 是 1,而输入的 feature 就是我们希望依据的几个月份,这里我们定的是两个月份,所以 feature 就是 2. | ||||
| # + | # + | ||||
| @@ -0,0 +1,117 @@ | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| import matplotlib.pyplot as plt | |||||
| import torch | |||||
| from torch import nn | |||||
| from torch.autograd import Variable | |||||
| """ | |||||
| Using torch to do time series analysis by LSTM model | |||||
| """ | |||||
| # load data | |||||
| data_csv = pd.read_csv("./lstm_data.csv", usecols=[1]) | |||||
| #plt.plot(data_csv) | |||||
| #plt.show() | |||||
| # data pre-processing | |||||
| data_csv = data_csv.dropna() | |||||
| dataset = data_csv.values | |||||
| dataset = dataset.astype("float32") | |||||
| val_max = np.max(dataset) | |||||
| val_min = np.min(dataset) | |||||
| val_scale = val_max - val_min | |||||
| dataset = (dataset - val_min) / val_scale | |||||
| # generate dataset | |||||
| def create_dataset(dataset, look_back=6): | |||||
| dataX, dataY = [], [] | |||||
| dataset = dataset.tolist() | |||||
| for i in range(len(dataset) - look_back): | |||||
| a = dataset[i:(i+look_back)] | |||||
| dataX.append(a) | |||||
| dataY.append(dataset[i+look_back]) | |||||
| return np.array(dataX), np.array(dataY) | |||||
| look_back = 1 | |||||
| data_X, data_Y = create_dataset(dataset, look_back) | |||||
| # split train/test dataset | |||||
| train_size = int(len(data_X) * 0.7) | |||||
| test_size = len(data_X) - train_size | |||||
| train_X = data_X[:train_size] | |||||
| train_Y = data_Y[:train_size] | |||||
| test_X = data_X[train_size:] | |||||
| test_Y = data_Y[train_size:] | |||||
| # convert data for torch | |||||
| train_X = train_X.reshape(-1, 1, look_back) | |||||
| train_Y = train_Y.reshape(-1, 1, 1) | |||||
| test_X = test_X.reshape(-1, 1, look_back) | |||||
| train_x = torch.from_numpy(train_X).float() | |||||
| train_y = torch.from_numpy(train_Y).float() | |||||
| test_x = torch.from_numpy(test_X).float() | |||||
| # define LSTM model | |||||
| class LSTM_Reg(nn.Module): | |||||
| def __init__(self, input_size, hidden_size, output_size=1, num_layer=2): | |||||
| super(LSTM_Reg, self).__init__() | |||||
| self.rnn = nn.LSTM(input_size, hidden_size, num_layer) | |||||
| self.reg = nn.Linear(hidden_size, output_size) | |||||
| def forward(self, x): | |||||
| x, _ = self.rnn(x) | |||||
| s, b, h = x.shape | |||||
| x = x.view(s*b, h) | |||||
| x = self.reg(x) | |||||
| x = x.view(s, b, -1) | |||||
| return x | |||||
| net = LSTM_Reg(look_back, 4, num_layer=1) | |||||
| criterion = nn.MSELoss() | |||||
| optimizer = torch.optim.Adam(net.parameters(), lr=1e-2) | |||||
| for e in range(1000): | |||||
| var_x = Variable(train_x) | |||||
| var_y = Variable(train_y) | |||||
| # forward | |||||
| out = net(var_x) | |||||
| loss = criterion(out, var_y) | |||||
| # backward | |||||
| optimizer.zero_grad() | |||||
| loss.backward() | |||||
| optimizer.step() | |||||
| # print progress | |||||
| if e % 100 == 0: | |||||
| print("epoch: %5d, loss: %.5f" % (e, loss.data[0])) | |||||
| # do test | |||||
| net = net.eval() | |||||
| data_X = data_X.reshape(-1, 1, look_back) | |||||
| data_X = torch.from_numpy(data_X).float() | |||||
| var_data = Variable(data_X) | |||||
| pred_test = net(var_data) | |||||
| pred_test = pred_test.view(-1).data.numpy() | |||||
| # plot | |||||
| plt.plot(pred_test, 'r', label="Prediction") | |||||
| plt.plot(dataset, 'b', label="Real") | |||||
| plt.legend(loc="best") | |||||
| plt.show() | |||||
| @@ -0,0 +1,148 @@ | |||||
| "Month","International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60" | |||||
| "1949-01",112 | |||||
| "1949-02",118 | |||||
| "1949-03",132 | |||||
| "1949-04",129 | |||||
| "1949-05",121 | |||||
| "1949-06",135 | |||||
| "1949-07",148 | |||||
| "1949-08",148 | |||||
| "1949-09",136 | |||||
| "1949-10",119 | |||||
| "1949-11",104 | |||||
| "1949-12",118 | |||||
| "1950-01",115 | |||||
| "1950-02",126 | |||||
| "1950-03",141 | |||||
| "1950-04",135 | |||||
| "1950-05",125 | |||||
| "1950-06",149 | |||||
| "1950-07",170 | |||||
| "1950-08",170 | |||||
| "1950-09",158 | |||||
| "1950-10",133 | |||||
| "1950-11",114 | |||||
| "1950-12",140 | |||||
| "1951-01",145 | |||||
| "1951-02",150 | |||||
| "1951-03",178 | |||||
| "1951-04",163 | |||||
| "1951-05",172 | |||||
| "1951-06",178 | |||||
| "1951-07",199 | |||||
| "1951-08",199 | |||||
| "1951-09",184 | |||||
| "1951-10",162 | |||||
| "1951-11",146 | |||||
| "1951-12",166 | |||||
| "1952-01",171 | |||||
| "1952-02",180 | |||||
| "1952-03",193 | |||||
| "1952-04",181 | |||||
| "1952-05",183 | |||||
| "1952-06",218 | |||||
| "1952-07",230 | |||||
| "1952-08",242 | |||||
| "1952-09",209 | |||||
| "1952-10",191 | |||||
| "1952-11",172 | |||||
| "1952-12",194 | |||||
| "1953-01",196 | |||||
| "1953-02",196 | |||||
| "1953-03",236 | |||||
| "1953-04",235 | |||||
| "1953-05",229 | |||||
| "1953-06",243 | |||||
| "1953-07",264 | |||||
| "1953-08",272 | |||||
| "1953-09",237 | |||||
| "1953-10",211 | |||||
| "1953-11",180 | |||||
| "1953-12",201 | |||||
| "1954-01",204 | |||||
| "1954-02",188 | |||||
| "1954-03",235 | |||||
| "1954-04",227 | |||||
| "1954-05",234 | |||||
| "1954-06",264 | |||||
| "1954-07",302 | |||||
| "1954-08",293 | |||||
| "1954-09",259 | |||||
| "1954-10",229 | |||||
| "1954-11",203 | |||||
| "1954-12",229 | |||||
| "1955-01",242 | |||||
| "1955-02",233 | |||||
| "1955-03",267 | |||||
| "1955-04",269 | |||||
| "1955-05",270 | |||||
| "1955-06",315 | |||||
| "1955-07",364 | |||||
| "1955-08",347 | |||||
| "1955-09",312 | |||||
| "1955-10",274 | |||||
| "1955-11",237 | |||||
| "1955-12",278 | |||||
| "1956-01",284 | |||||
| "1956-02",277 | |||||
| "1956-03",317 | |||||
| "1956-04",313 | |||||
| "1956-05",318 | |||||
| "1956-06",374 | |||||
| "1956-07",413 | |||||
| "1956-08",405 | |||||
| "1956-09",355 | |||||
| "1956-10",306 | |||||
| "1956-11",271 | |||||
| "1956-12",306 | |||||
| "1957-01",315 | |||||
| "1957-02",301 | |||||
| "1957-03",356 | |||||
| "1957-04",348 | |||||
| "1957-05",355 | |||||
| "1957-06",422 | |||||
| "1957-07",465 | |||||
| "1957-08",467 | |||||
| "1957-09",404 | |||||
| "1957-10",347 | |||||
| "1957-11",305 | |||||
| "1957-12",336 | |||||
| "1958-01",340 | |||||
| "1958-02",318 | |||||
| "1958-03",362 | |||||
| "1958-04",348 | |||||
| "1958-05",363 | |||||
| "1958-06",435 | |||||
| "1958-07",491 | |||||
| "1958-08",505 | |||||
| "1958-09",404 | |||||
| "1958-10",359 | |||||
| "1958-11",310 | |||||
| "1958-12",337 | |||||
| "1959-01",360 | |||||
| "1959-02",342 | |||||
| "1959-03",406 | |||||
| "1959-04",396 | |||||
| "1959-05",420 | |||||
| "1959-06",472 | |||||
| "1959-07",548 | |||||
| "1959-08",559 | |||||
| "1959-09",463 | |||||
| "1959-10",407 | |||||
| "1959-11",362 | |||||
| "1959-12",405 | |||||
| "1960-01",417 | |||||
| "1960-02",391 | |||||
| "1960-03",419 | |||||
| "1960-04",461 | |||||
| "1960-05",472 | |||||
| "1960-06",535 | |||||
| "1960-07",622 | |||||
| "1960-08",606 | |||||
| "1960-09",508 | |||||
| "1960-10",461 | |||||
| "1960-11",390 | |||||
| "1960-12",432 | |||||
| International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60 | |||||