|
@@ -78,11 +78,7 @@ |
|
|
" X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n", |
|
|
" X, y, categorical_indicator, attribute_names = dataset.get_data(target=dataset.default_target_attribute)\n", |
|
|
" self.attribute_names = attribute_names\n", |
|
|
" self.attribute_names = attribute_names\n", |
|
|
" self.target_names = y.cat.categories.tolist()\n", |
|
|
" self.target_names = y.cat.categories.tolist()\n", |
|
|
" print(\"Attribute names are: \", self.attribute_names)\n", |
|
|
|
|
|
" print(\"Target names are: \", self.target_names)\n", |
|
|
|
|
|
" # self.attribute_names = [\"hair\", \"feathers\", \"eggs\", \"milk\", \"airborne\", \"aquatic\", \"predator\", \"toothed\", \"backbone\", \"breathes\", \"venomous\", \"fins\", \"legs\", \"tail\", \"domestic\", \"catsize\"]\n", |
|
|
|
|
|
" # self.target_names = [\"mammal\", \"bird\", \"reptile\", \"fish\", \"amphibian\", \"insect\", \"invertebrate\"]\n", |
|
|
|
|
|
"\n", |
|
|
|
|
|
|
|
|
" \n", |
|
|
" # Define variables\n", |
|
|
" # Define variables\n", |
|
|
" for name in self.attribute_names+self.target_names:\n", |
|
|
" for name in self.attribute_names+self.target_names:\n", |
|
|
" exec(f\"globals()['{name}'] = Int('{name}')\") ## or use dict to create var and modify rules\n", |
|
|
" exec(f\"globals()['{name}'] = Int('{name}')\") ## or use dict to create var and modify rules\n", |
|
@@ -139,9 +135,6 @@ |
|
|
" if solver.check() == sat:\n", |
|
|
" if solver.check() == sat:\n", |
|
|
" model = solver.model()\n", |
|
|
" model = solver.model()\n", |
|
|
" total_weight = model.evaluate(total_violation_weight)\n", |
|
|
" total_weight = model.evaluate(total_violation_weight)\n", |
|
|
" # violated_rules = [str(rule) for rule in self.weights if model.evaluate(Not(rule))]\n", |
|
|
|
|
|
" # print(\"Total violation weight for the given data point:\", total_weight)\n", |
|
|
|
|
|
" # print(\"Violated rules:\", violated_rules)\n", |
|
|
|
|
|
" return total_weight.as_long()\n", |
|
|
" return total_weight.as_long()\n", |
|
|
" else:\n", |
|
|
" else:\n", |
|
|
" # No solution found\n", |
|
|
" # No solution found\n", |
|
|