You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import torch
  2. import numpy as np
  3. from torch import optim, nn
  4. from torch.utils.data import DataLoader, Dataset
  5. from learnware.utils import choose_device
  6. @torch.no_grad()
  7. def evaluate(model, evaluate_set: Dataset, device=None, distribution=True):
  8. device = choose_device(0) if device is None else device
  9. if isinstance(model, nn.Module):
  10. model.eval()
  11. mapping = lambda m, x: m(x)
  12. else:
  13. mapping = lambda m, x: m.predict(x)
  14. criterion = nn.CrossEntropyLoss(reduction="sum")
  15. total, correct, loss = 0, 0, torch.as_tensor(0.0, dtype=torch.float32, device=device)
  16. dataloader = DataLoader(evaluate_set, batch_size=1024, shuffle=True)
  17. for i, (X, y) in enumerate(dataloader):
  18. X, y = X.to(device), y.to(device)
  19. out = mapping(model, X)
  20. if not torch.is_tensor(out):
  21. out = torch.from_numpy(out).to(device)
  22. if distribution:
  23. loss += criterion(out, y)
  24. _, predicted = torch.max(out.data, 1)
  25. else:
  26. predicted = out
  27. total += y.size(0)
  28. correct += (predicted == y).sum().item()
  29. acc = correct / total * 100
  30. loss = loss / total
  31. if isinstance(model, nn.Module):
  32. model.train()
  33. return loss.item(), acc
  34. def train_model(
  35. model: nn.Module,
  36. train_set: Dataset,
  37. valid_set: Dataset,
  38. save_path: str,
  39. epochs=35,
  40. batch_size=128,
  41. device=None,
  42. verbose=True,
  43. ):
  44. device = choose_device(0) if device is None else device
  45. model.train()
  46. optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
  47. criterion = nn.CrossEntropyLoss()
  48. dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  49. best_loss = 100000
  50. for epoch in range(epochs):
  51. running_loss = []
  52. model.train()
  53. for i, (X, y) in enumerate(dataloader):
  54. X, y = X.to(device=device), y.to(device=device)
  55. optimizer.zero_grad()
  56. out = model(X)
  57. loss = criterion(out, y)
  58. loss.backward()
  59. optimizer.step()
  60. running_loss.append(loss.item())
  61. valid_loss, valid_acc = evaluate(model, valid_set, device=device)
  62. train_loss, train_acc = evaluate(model, train_set, device=device)
  63. if valid_loss < best_loss:
  64. best_loss = valid_loss
  65. torch.save(model.state_dict(), save_path)
  66. if verbose:
  67. print("Epoch: {}, Valid Best Accuracy: {:.3f}% ({:.3f})".format(epoch + 1, valid_acc, valid_loss))
  68. if valid_acc > 99.0:
  69. if verbose:
  70. print("Early Stopping at 99% !")
  71. break
  72. if verbose and (epoch + 1) % 5 == 0:
  73. print(
  74. "Epoch: {}, Train Average Loss: {:.3f}, Accuracy {:.3f}%, Valid Average Loss: {:.3f}".format(
  75. epoch + 1, np.mean(running_loss), train_acc, valid_loss
  76. )
  77. )