Browse Source

update basic model

pull/3/head
Tony-HYX 2 years ago
parent
commit
1523101266
1 changed files with 10 additions and 12 deletions
  1. +10
    -12
      models/basic_model.py

+ 10
- 12
models/basic_model.py View File

@@ -131,21 +131,19 @@ class BasicModel():
model.train()

loss_value = 0
for _, data in enumerate(data_loader):
X = data[0].to(device)
Y = data[1].to(device)
pred_Y = model(X)

loss = criterion(pred_Y, Y)
total_loss, total_num = 0.0, 0
for data, target in data_loader:
data, target = data.to(device), target.to(device)
out = model(data)
loss = criterion(out, target)

optimizer.zero_grad()
loss.backward()
optimizer.step()

loss_value += loss.item()
total_loss += loss.item() * data.size(0)

return loss_value
return total_loss / total_num

def _predict(self, data_loader):
model = self.model
@@ -155,9 +153,9 @@ class BasicModel():

with torch.no_grad():
results = []
for _, data in enumerate(data_loader):
X = data[0].to(device)
pred_Y = model(X)
for data, _ in data_loader:
data = data.to(device)
pred_Y = model(data)
results.append(pred_Y)
return torch.cat(results, axis=0)


Loading…
Cancel
Save