|
|
@@ -131,21 +131,19 @@ class BasicModel(): |
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
loss_value = 0 |
|
|
|
for _, data in enumerate(data_loader): |
|
|
|
X = data[0].to(device) |
|
|
|
Y = data[1].to(device) |
|
|
|
pred_Y = model(X) |
|
|
|
|
|
|
|
loss = criterion(pred_Y, Y) |
|
|
|
total_loss, total_num = 0.0, 0 |
|
|
|
for data, target in data_loader: |
|
|
|
data, target = data.to(device), target.to(device) |
|
|
|
out = model(data) |
|
|
|
loss = criterion(out, target) |
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
loss.backward() |
|
|
|
optimizer.step() |
|
|
|
|
|
|
|
loss_value += loss.item() |
|
|
|
total_loss += loss.item() * data.size(0) |
|
|
|
|
|
|
|
return loss_value |
|
|
|
return total_loss / total_num |
|
|
|
|
|
|
|
def _predict(self, data_loader): |
|
|
|
model = self.model |
|
|
@@ -155,9 +153,9 @@ class BasicModel(): |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
results = [] |
|
|
|
for _, data in enumerate(data_loader): |
|
|
|
X = data[0].to(device) |
|
|
|
pred_Y = model(X) |
|
|
|
for data, _ in data_loader: |
|
|
|
data = data.to(device) |
|
|
|
pred_Y = model(data) |
|
|
|
results.append(pred_Y) |
|
|
|
|
|
|
|
return torch.cat(results, axis=0) |
|
|
|