You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import numpy as np
  2. import torch
  3. from torch import nn, optim
  4. from torch.utils.data import DataLoader, Dataset
  5. from learnware.utils import choose_device
  6. @torch.no_grad()
  7. def evaluate(model, evaluate_set: Dataset, device=None, distribution=True):
  8. device = choose_device(0) if device is None else device
  9. if isinstance(model, nn.Module):
  10. model.eval()
  11. criterion = nn.CrossEntropyLoss(reduction="sum")
  12. total, correct, loss = 0, 0, torch.as_tensor(0.0, dtype=torch.float32, device=device)
  13. dataloader = DataLoader(evaluate_set, batch_size=1024, shuffle=True)
  14. for i, (X, y) in enumerate(dataloader):
  15. X, y = X.to(device), y.to(device)
  16. out = model(X) if isinstance(model, nn.Module) else model.predict(X)
  17. if not torch.is_tensor(out):
  18. out = torch.from_numpy(out).to(device)
  19. if distribution:
  20. loss += criterion(out, y)
  21. _, predicted = torch.max(out.data, 1)
  22. else:
  23. predicted = out
  24. total += y.size(0)
  25. correct += (predicted == y).sum().item()
  26. acc = correct / total * 100
  27. loss = loss / total
  28. if isinstance(model, nn.Module):
  29. model.train()
  30. return loss.item(), acc
  31. def train_model(
  32. model: nn.Module,
  33. train_set: Dataset,
  34. valid_set: Dataset,
  35. save_path: str,
  36. epochs=35,
  37. batch_size=128,
  38. device=None,
  39. verbose=True,
  40. ):
  41. device = choose_device(0) if device is None else device
  42. model.train()
  43. optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
  44. criterion = nn.CrossEntropyLoss()
  45. dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
  46. best_loss = 100000
  47. for epoch in range(epochs):
  48. running_loss = []
  49. model.train()
  50. for i, (X, y) in enumerate(dataloader):
  51. X, y = X.to(device=device), y.to(device=device)
  52. optimizer.zero_grad()
  53. out = model(X)
  54. loss = criterion(out, y)
  55. loss.backward()
  56. optimizer.step()
  57. running_loss.append(loss.item())
  58. valid_loss, valid_acc = evaluate(model, valid_set, device=device)
  59. train_loss, train_acc = evaluate(model, train_set, device=device)
  60. if valid_loss < best_loss:
  61. best_loss = valid_loss
  62. torch.save(model.state_dict(), save_path)
  63. if verbose:
  64. print("Epoch: {}, Valid Best Accuracy: {:.3f}% ({:.3f})".format(epoch + 1, valid_acc, valid_loss))
  65. if valid_acc > 99.0:
  66. if verbose:
  67. print("Early Stopping at 99% !")
  68. break
  69. if verbose and (epoch + 1) % 5 == 0:
  70. print(
  71. "Epoch: {}, Train Average Loss: {:.3f}, Accuracy {:.3f}%, Valid Average Loss: {:.3f}".format(
  72. epoch + 1, np.mean(running_loss), train_acc, valid_loss
  73. )
  74. )