| @@ -0,0 +1,148 @@ | |||||
| #!/usr/bin/env python3 | |||||
| # -*- coding: utf-8 -*- | |||||
| """ | |||||
| Created on Tue Jul 7 11:42:48 2020 | |||||
| @author: ljia | |||||
| """ | |||||
| import numpy as np | |||||
| import cvxpy as cp | |||||
| import time | |||||
| from gklearn.ged.learning.costs_learner import CostsLearner | |||||
| from gklearn.ged.util import compute_geds_cml | |||||
| class CostMatricesLearner(CostsLearner): | |||||
| def __init__(self, edit_cost='CONSTANT', triangle_rule=False, allow_zeros=True, parallel=False, verbose=2): | |||||
| super().__init__(parallel, verbose) | |||||
| self._edit_cost = edit_cost | |||||
| self._triangle_rule = triangle_rule | |||||
| self._allow_zeros = allow_zeros | |||||
| def fit(self, X, y): | |||||
| if self._edit_cost == 'LETTER': | |||||
| raise Exception('Cannot compute for cost "LETTER".') | |||||
| elif self._edit_cost == 'LETTER2': | |||||
| raise Exception('Cannot compute for cost "LETTER2".') | |||||
| elif self._edit_cost == 'NON_SYMBOLIC': | |||||
| raise Exception('Cannot compute for cost "NON_SYMBOLIC".') | |||||
| elif self._edit_cost == 'CONSTANT': # @todo: node/edge may not labeled. | |||||
| if not self._triangle_rule and self._allow_zeros: | |||||
| w = cp.Variable(X.shape[1]) | |||||
| cost_fun = cp.sum_squares(X @ w - y) | |||||
| constraints = [w >= [0.0 for i in range(X.shape[1])]] | |||||
| prob = cp.Problem(cp.Minimize(cost_fun), constraints) | |||||
| self.execute_cvx(prob) | |||||
| edit_costs_new = w.value | |||||
| residual = np.sqrt(prob.value) | |||||
| elif self._triangle_rule and self._allow_zeros: # @todo | |||||
| x = cp.Variable(nb_cost_mat.shape[1]) | |||||
| cost_fun = cp.sum_squares(nb_cost_mat @ x - dis_k_vec) | |||||
| constraints = [x >= [0.0 for i in range(nb_cost_mat.shape[1])], | |||||
| np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0]).T@x >= 0.01, | |||||
| np.array([0.0, 1.0, 0.0, 0.0, 0.0, 0.0]).T@x >= 0.01, | |||||
| np.array([0.0, 0.0, 0.0, 1.0, 0.0, 0.0]).T@x >= 0.01, | |||||
| np.array([0.0, 0.0, 0.0, 0.0, 1.0, 0.0]).T@x >= 0.01, | |||||
| np.array([1.0, 1.0, -1.0, 0.0, 0.0, 0.0]).T@x >= 0.0, | |||||
| np.array([0.0, 0.0, 0.0, 1.0, 1.0, -1.0]).T@x >= 0.0] | |||||
| prob = cp.Problem(cp.Minimize(cost_fun), constraints) | |||||
| self.__execute_cvx(prob) | |||||
| edit_costs_new = x.value | |||||
| residual = np.sqrt(prob.value) | |||||
| elif not self._triangle_rule and not self._allow_zeros: # @todo | |||||
| x = cp.Variable(nb_cost_mat.shape[1]) | |||||
| cost_fun = cp.sum_squares(nb_cost_mat @ x - dis_k_vec) | |||||
| constraints = [x >= [0.01 for i in range(nb_cost_mat.shape[1])]] | |||||
| prob = cp.Problem(cp.Minimize(cost_fun), constraints) | |||||
| self.__execute_cvx(prob) | |||||
| edit_costs_new = x.value | |||||
| residual = np.sqrt(prob.value) | |||||
| elif self._triangle_rule and not self._allow_zeros: # @todo | |||||
| x = cp.Variable(nb_cost_mat.shape[1]) | |||||
| cost_fun = cp.sum_squares(nb_cost_mat @ x - dis_k_vec) | |||||
| constraints = [x >= [0.01 for i in range(nb_cost_mat.shape[1])], | |||||
| np.array([1.0, 1.0, -1.0, 0.0, 0.0, 0.0]).T@x >= 0.0, | |||||
| np.array([0.0, 0.0, 0.0, 1.0, 1.0, -1.0]).T@x >= 0.0] | |||||
| prob = cp.Problem(cp.Minimize(cost_fun), constraints) | |||||
| self.__execute_cvx(prob) | |||||
| edit_costs_new = x.value | |||||
| residual = np.sqrt(prob.value) | |||||
| else: | |||||
| raise Exception('The edit cost "', self._ged_options['edit_cost'], '" is not supported for update progress.') | |||||
| self._cost_list.append(edit_costs_new) | |||||
| def init_geds_and_nb_eo(self, y, graphs): | |||||
| time0 = time.time() | |||||
| self._cost_list.append(np.concatenate((self._ged_options['node_label_costs'], | |||||
| self._ged_options['edge_label_costs']))) | |||||
| ged_vec, self._nb_eo = self.compute_geds_and_nb_eo(graphs) | |||||
| self._residual_list.append(np.sqrt(np.sum(np.square(np.array(ged_vec) - y)))) | |||||
| self._runtime_list.append(time.time() - time0) | |||||
| if self._verbose >= 2: | |||||
| print('Current node label costs:', self._cost_list[-1][0:len(self._ged_options['node_label_costs'])]) | |||||
| print('Current edge label costs:', self._cost_list[-1][len(self._ged_options['node_label_costs']):]) | |||||
| print('Residual list:', self._residual_list) | |||||
| def update_geds_and_nb_eo(self, y, graphs, time0): | |||||
| self._ged_options['node_label_costs'] = self._cost_list[-1][0:len(self._ged_options['node_label_costs'])] | |||||
| self._ged_options['edge_label_costs'] = self._cost_list[-1][len(self._ged_options['node_label_costs']):] | |||||
| ged_vec, self._nb_eo = self.compute_geds_and_nb_eo(graphs) | |||||
| self._residual_list.append(np.sqrt(np.sum(np.square(np.array(ged_vec) - y)))) | |||||
| self._runtime_list.append(time.time() - time0) | |||||
| def compute_geds_and_nb_eo(self, graphs): | |||||
| ged_vec, ged_mat, n_edit_operations = compute_geds_cml(graphs, options=self._ged_options, parallel=self._parallel, verbose=(self._verbose > 1)) | |||||
| return ged_vec, np.array(n_edit_operations) | |||||
| def check_convergency(self): | |||||
| self._ec_changed = False | |||||
| for i, cost in enumerate(self._cost_list[-1]): | |||||
| if cost == 0: | |||||
| if self._cost_list[-2][i] > self._epsilon_ec: | |||||
| self._ec_changed = True | |||||
| break | |||||
| elif abs(cost - self._cost_list[-2][i]) / cost > self._epsilon_ec: | |||||
| self._ec_changed = True | |||||
| break | |||||
| # if abs(cost - edit_cost_list[-2][i]) > self.__epsilon_ec: | |||||
| # ec_changed = True | |||||
| # break | |||||
| self._residual_changed = False | |||||
| if self._residual_list[-1] == 0: | |||||
| if self._residual_list[-2] > self._epsilon_residual: | |||||
| self._residual_changed = True | |||||
| elif abs(self._residual_list[-1] - self._residual_list[-2]) / self._residual_list[-1] > self._epsilon_residual: | |||||
| self._residual_changed = True | |||||
| self._converged = not (self._ec_changed or self._residual_changed) | |||||
| if self._converged: | |||||
| self._itrs_without_update += 1 | |||||
| else: | |||||
| self._itrs_without_update = 0 | |||||
| self._num_updates_ecs += 1 | |||||
| def print_current_states(self): | |||||
| print() | |||||
| print('-------------------------------------------------------------------------') | |||||
| print('States of iteration', self._itrs + 1) | |||||
| print('-------------------------------------------------------------------------') | |||||
| # print('Time spend:', self.__runtime_optimize_ec) | |||||
| print('Total number of iterations for optimizing:', self._itrs + 1) | |||||
| print('Total number of updating edit costs:', self._num_updates_ecs) | |||||
| print('Was optimization of edit costs converged:', self._converged) | |||||
| print('Did edit costs change:', self._ec_changed) | |||||
| print('Did residual change:', self._residual_changed) | |||||
| print('Iterations without update:', self._itrs_without_update) | |||||
| print('Current node label costs:', self._cost_list[-1][0:len(self._ged_options['node_label_costs'])]) | |||||
| print('Current edge label costs:', self._cost_list[-1][len(self._ged_options['node_label_costs']):]) | |||||
| print('Residual list:', self._residual_list) | |||||
| print('-------------------------------------------------------------------------') | |||||