| @@ -0,0 +1,175 @@ | |||
| #!/usr/bin/env python3 | |||
| # -*- coding: utf-8 -*- | |||
| """ | |||
| Created on Tue Jul 7 11:30:31 2020 | |||
| @author: ljia | |||
| """ | |||
| import numpy as np | |||
| import cvxpy as cp | |||
| import time | |||
| from gklearn.utils import Timer | |||
| class CostsLearner(object): | |||
| def __init__(self, parallel, verbose): | |||
| ### To set. | |||
| self._parallel = parallel | |||
| self._verbose = verbose | |||
| # For update(). | |||
| self._time_limit_in_sec = 0 | |||
| self._max_itrs = 100 | |||
| self._max_itrs_without_update = 3 | |||
| self._epsilon_residual = 0.01 | |||
| self._epsilon_ec = 0.1 | |||
| ### To compute. | |||
| self._residual_list = [] | |||
| self._runtime_list = [] | |||
| self._cost_list = [] | |||
| self._nb_eo = None | |||
| # For update(). | |||
| self._itrs = 0 | |||
| self._converged = False | |||
| self._num_updates_ecs = 0 | |||
| self._ec_changed = None | |||
| self._residual_changed = None | |||
| self._itrs_without_update = 0 | |||
| ### Both set and get. | |||
| self._ged_options = None | |||
| def fit(self, X, y): | |||
| pass | |||
| def preprocess(self): | |||
| pass # @todo: remove the zero numbers of edit costs. | |||
| def postprocess(self): | |||
| for i in range(len(self._cost_list[-1])): | |||
| if -1e-9 <= self._cost_list[-1][i] <= 1e-9: | |||
| self._cost_list[-1][i] = 0 | |||
| if self._cost_list[-1][i] < 0: | |||
| raise ValueError('The edit cost is negative.') | |||
| def set_update_params(self, **kwargs): | |||
| self._time_limit_in_sec = kwargs.get('time_limit_in_sec', self._time_limit_in_sec) | |||
| self._max_itrs = kwargs.get('max_itrs', self._max_itrs) | |||
| self._max_itrs_without_update = kwargs.get('max_itrs_without_update', self._max_itrs_without_update) | |||
| self._epsilon_residual = kwargs.get('epsilon_residual', self._epsilon_residual) | |||
| self._epsilon_ec = kwargs.get('epsilon_ec', self._epsilon_ec) | |||
| def update(self, y, graphs, ged_options, **kwargs): | |||
| # Set parameters. | |||
| self._ged_options = ged_options | |||
| if kwargs != {}: | |||
| self.set_update_params(**kwargs) | |||
| # The initial iteration. | |||
| if self._verbose >= 2: | |||
| print('\ninitial:') | |||
| self.init_geds_and_nb_eo(y, graphs) | |||
| self._converged = False | |||
| self._itrs_without_update = 0 | |||
| self._itrs = 0 | |||
| self._num_updates_ecs = 0 | |||
| timer = Timer(self._time_limit_in_sec) | |||
| # Run iterations from initial edit costs. | |||
| while not self.termination_criterion_met(self._converged, timer, self._itrs, self._itrs_without_update): | |||
| if self._verbose >= 2: | |||
| print('\niteration', self._itrs + 1) | |||
| time0 = time.time() | |||
| # Fit GED space to the target space. | |||
| self.preprocess() | |||
| self.fit(self._nb_eo, y) | |||
| self.postprocess() | |||
| # Compute new GEDs and numbers of edit operations. | |||
| self.update_geds_and_nb_eo(y, graphs, time0) | |||
| # Check convergency. | |||
| self.check_convergency() | |||
| # Print current states. | |||
| if self._verbose >= 2: | |||
| self.print_current_states() | |||
| self._itrs += 1 | |||
| def init_geds_and_nb_eo(self, y, graphs): | |||
| pass | |||
| def update_geds_and_nb_eo(self, y, graphs, time0): | |||
| pass | |||
| def compute_geds_and_nb_eo(self, graphs): | |||
| pass | |||
| def check_convergency(self): | |||
| pass | |||
| def print_current_states(self): | |||
| pass | |||
| def termination_criterion_met(self, converged, timer, itr, itrs_without_update): | |||
| if timer.expired() or (itr >= self._max_itrs if self._max_itrs >= 0 else False): | |||
| # if self.__state == AlgorithmState.TERMINATED: | |||
| # self.__state = AlgorithmState.INITIALIZED | |||
| return True | |||
| return converged or (itrs_without_update > self._max_itrs_without_update if self._max_itrs_without_update >= 0 else False) | |||
| def execute_cvx(self, prob): | |||
| try: | |||
| prob.solve(verbose=(self._verbose>=2)) | |||
| except MemoryError as error0: | |||
| if self._verbose >= 2: | |||
| print('\nUsing solver "OSQP" caused a memory error.') | |||
| print('the original error message is\n', error0) | |||
| print('solver status: ', prob.status) | |||
| print('trying solver "CVXOPT" instead...\n') | |||
| try: | |||
| prob.solve(solver=cp.CVXOPT, verbose=(self._verbose>=2)) | |||
| except Exception as error1: | |||
| if self._verbose >= 2: | |||
| print('\nAn error occured when using solver "CVXOPT".') | |||
| print('the original error message is\n', error1) | |||
| print('solver status: ', prob.status) | |||
| print('trying solver "MOSEK" instead. Notice this solver is commercial and a lisence is required.\n') | |||
| prob.solve(solver=cp.MOSEK, verbose=(self._verbose>=2)) | |||
| else: | |||
| if self._verbose >= 2: | |||
| print('solver status: ', prob.status) | |||
| else: | |||
| if self._verbose >= 2: | |||
| print('solver status: ', prob.status) | |||
| if self._verbose >= 2: | |||
| print() | |||
| def get_results(self): | |||
| results = {} | |||
| results['residual_list'] = self._residual_list | |||
| results['runtime_list'] = self._runtime_list | |||
| results['cost_list'] = self._cost_list | |||
| results['nb_eo'] = self._nb_eo | |||
| results['itrs'] = self._itrs | |||
| results['converged'] = self._converged | |||
| results['num_updates_ecs'] = self._num_updates_ecs | |||
| results['ec_changed'] = self._ec_changed | |||
| results['residual_changed'] = self._residual_changed | |||
| results['itrs_without_update'] = self._itrs_without_update | |||
| return results | |||