import torch import numpy as np from torch import optim, nn from torch.utils.data import DataLoader, Dataset from learnware.utils import choose_device @torch.no_grad() def evaluate(model, evaluate_set: Dataset, device=None, distribution=True): device = choose_device(0) if device is None else device if isinstance(model, nn.Module): model.eval() mapping = lambda m, x: m(x) else: mapping = lambda m, x: m.predict(x) criterion = nn.CrossEntropyLoss(reduction="sum") total, correct, loss = 0, 0, torch.as_tensor(0.0, dtype=torch.float32, device=device) dataloader = DataLoader(evaluate_set, batch_size=1024, shuffle=True) for i, (X, y) in enumerate(dataloader): X, y = X.to(device), y.to(device) out = mapping(model, X) if not torch.is_tensor(out): out = torch.from_numpy(out).to(device) if distribution: loss += criterion(out, y) _, predicted = torch.max(out.data, 1) else: predicted = out total += y.size(0) correct += (predicted == y).sum().item() acc = correct / total * 100 loss = loss / total if isinstance(model, nn.Module): model.train() return loss.item(), acc def train_model( model: nn.Module, train_set: Dataset, valid_set: Dataset, save_path: str, epochs=35, batch_size=128, device=None, verbose=True, ): device = choose_device(0) if device is None else device model.train() optimizer = optim.SGD(model.parameters(), lr=1e-2, momentum=0.9) criterion = nn.CrossEntropyLoss() dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True) best_loss = 100000 for epoch in range(epochs): running_loss = [] model.train() for i, (X, y) in enumerate(dataloader): X, y = X.to(device=device), y.to(device=device) optimizer.zero_grad() out = model(X) loss = criterion(out, y) loss.backward() optimizer.step() running_loss.append(loss.item()) valid_loss, valid_acc = evaluate(model, valid_set, device=device) train_loss, train_acc = evaluate(model, train_set, device=device) if valid_loss < best_loss: best_loss = valid_loss torch.save(model.state_dict(), save_path) if verbose: print("Epoch: {}, Valid Best Accuracy: {:.3f}% ({:.3f})".format(epoch + 1, valid_acc, valid_loss)) if valid_acc > 99.0: if verbose: print("Early Stopping at 99% !") break if verbose and (epoch + 1) % 5 == 0: print( "Epoch: {}, Train Average Loss: {:.3f}, Accuracy {:.3f}%, Valid Average Loss: {:.3f}".format( epoch + 1, np.mean(running_loss), train_acc, valid_loss ) )