|
- import torch
- import torch.nn.functional as F
-
- from pytorch.mutator import Mutator
- from pytorch.mutables import LayerChoice, InputChoice
-
- # TODO: This class is duplicate with SPOS.
- class RandomMutator(Mutator):
- """
- Random mutator that samples a random candidate in the search space each time ``reset()``.
- It uses random function in PyTorch, so users can set seed in PyTorch to ensure deterministic behavior.
- """
-
- def sample_search(self):
- """
- Sample a random candidate.
- """
- result = dict()
- for mutable in self.mutables:
- if isinstance(mutable, LayerChoice):
- gen_index = torch.randint(high=len(mutable), size=(1, ))
- result[mutable.key] = F.one_hot(gen_index, num_classes=len(mutable)).view(-1).bool()
- elif isinstance(mutable, InputChoice):
- if mutable.n_chosen is None:
- result[mutable.key] = torch.randint(high=2, size=(mutable.n_candidates,)).view(-1).bool()
- else:
- perm = torch.randperm(mutable.n_candidates)
- mask = [i in perm[:mutable.n_chosen] for i in range(mutable.n_candidates)]
- result[mutable.key] = torch.tensor(mask, dtype=torch.bool) # pylint: disable=not-callable
- return result
-
- def sample_final(self):
- """
- Same as :meth:`sample_search`.
- """
- return self.sample_search()
|