|
- """
- Copyright 2020 Tianshu AI Platform. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
- =============================================================
- """
-
- import torch
- import torch.nn as nn
- import abc, math, weakref, typing, time
- from typing import Any, Callable, Optional, Sequence
- import numpy as np
-
- from kamal.core.engine.events import DefaultEvents, Event
- from kamal.core import tasks
- from kamal.utils import set_mode, move_to_device, get_logger
- from collections import defaultdict
-
- import numbers
- import contextlib
-
- class State(object):
- def __init__(self):
- self.iter = 0
- self.max_iter = None
- self.epoch_length = None
- self.dataloader = None
- self.seed = None
-
- self.metrics=dict()
- self.batch=None
-
- @property
- def current_epoch(self):
- if self.epoch_length is not None:
- return self.iter // self.epoch_length
- return None
-
- @property
- def max_epoch(self):
- if self.epoch_length is not None:
- return self.max_iter // self.epoch_length
- return None
-
- @property
- def current_batch_index(self):
- if self.epoch_length is not None:
- return self.iter % self.epoch_length
- return None
-
- @property
- def max_batch_index(self):
- return self.epoch_length
-
- def __repr__(self):
- rep = "State:\n"
- for attr, value in self.__dict__.items():
- if not isinstance(value, (numbers.Number, str, dict)):
- value = type(value)
- rep += "\t{}: {}\n".format(attr, value)
- return rep
-
- class Engine(abc.ABC):
- def __init__(self, logger=None, tb_writer=None):
- self._logger = logger if logger else get_logger(name='kamal', color=True)
- self._tb_writer = tb_writer
- self._callbacks = defaultdict(list)
- self._allowed_events = [ *DefaultEvents ]
- self._state = State()
-
- def reset(self):
- self._state = State()
-
- def run(self, step_fn: Callable, dataloader, max_iter, start_iter=0, epoch_length=None):
- self.state.iter = self._state.start_iter = start_iter
- self.state.max_iter = max_iter
- self.state.epoch_length = epoch_length if epoch_length else len(dataloader)
- self.state.dataloader = dataloader
- self.state.dataloader_iter = iter(dataloader)
- self.state.step_fn = step_fn
-
- self.trigger_events(DefaultEvents.BEFORE_RUN)
- for self.state.iter in range( start_iter, max_iter ):
- if self.state.epoch_length!=None and \
- self.state.iter%self.state.epoch_length==0: # Epoch Start
- self.trigger_events(DefaultEvents.BEFORE_EPOCH)
- self.trigger_events(DefaultEvents.BEFORE_STEP)
- self.state.batch = self._get_batch()
- step_output = step_fn(self, self.state.batch)
- if isinstance(step_output, dict):
- self.state.metrics.update(step_output)
- self.trigger_events(DefaultEvents.AFTER_STEP)
- if self.state.epoch_length!=None and \
- (self.state.iter+1)%self.state.epoch_length==0: # Epoch End
- self.trigger_events(DefaultEvents.AFTER_EPOCH)
- self.trigger_events(DefaultEvents.AFTER_RUN)
-
- def _get_batch(self):
- try:
- batch = next( self.state.dataloader_iter )
- except StopIteration:
- self.state.dataloader_iter = iter(self.state.dataloader) # reset iterator
- batch = next( self.state.dataloader_iter )
- if not isinstance(batch, (list, tuple)):
- batch = [ batch, ] # no targets
- return batch
-
- @property
- def state(self):
- return self._state
-
- @property
- def logger(self):
- return self._logger
-
- @property
- def tb_writer(self):
- return self._tb_writer
-
- def add_callback(self, event: Event, callbacks ):
- if not isinstance(callbacks, Sequence):
- callbacks = [callbacks]
- if event in self._allowed_events:
- for callback in callbacks:
- if callback not in self._callbacks[event]:
- if event.trigger!=event.default_trigger:
- callback = self._trigger_wrapper(self, event.trigger, callback )
- self._callbacks[event].append( callback )
- callbacks = [ RemovableCallback(self, event, c) for c in callbacks ]
- return ( callbacks[0] if len(callbacks)==1 else callbacks )
-
- def remove_callback(self, event, callback):
- for c in self._callbacks[event]:
- if c==callback:
- self._callbacks.remove( callback )
- return True
- return False
-
- @staticmethod
- def _trigger_wrapper(engine, trigger, callback):
- def wrapper(*args, **kwargs) -> Any:
- if trigger(engine):
- return callback(engine)
- return wrapper
-
- def trigger_events(self, *events):
- for e in events:
- if e in self._allowed_events:
- for callback in self._callbacks[e]:
- callback(self)
-
- def register_events(self, *events):
- for e in events:
- if e not in self._allowed_events:
- self._allowed_events.apped( e )
-
- @contextlib.contextmanager
- def save_current_callbacks(self):
- temp = self._callbacks
- self._callbacks = defaultdict(list)
- yield
- self._callbacks = temp
-
- class RemovableCallback:
- def __init__(self, engine, event, callback):
- self._engine = weakref.ref(engine)
- self._callback = weakref.ref(callback)
- self._event = weakref.ref(event)
-
- @property
- def callback(self):
- return self._callback()
-
- def remove(self):
- engine = self._engine()
- callback = self._callback()
- event = self._event()
- return engine.remove_callback(event, callback)
-
-
|