Skip to content

Commit

Permalink
Choose between conditional or biconditional in testing rules
Browse files Browse the repository at this point in the history
  • Loading branch information
pietrobarbiero committed Jun 28, 2022
1 parent 265568d commit dfcff89
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
8 changes: 6 additions & 2 deletions tests/test_logic_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def test_entropy_multi_target(self):
# train_mask = test_mask = torch.arange(len(y))
explanations = entropy.explain_classes(model, x, y, train_mask, test_mask,
c_threshold=0.5, y_threshold=0.5, verbose=True,
concept_names=concept_names, class_names=class_names)
concept_names=concept_names, class_names=class_names,
material=True)

return

Expand Down Expand Up @@ -214,7 +215,10 @@ def forward(self, x, edge_index):
# extract logic formulas
explanations = entropy.explain_classes(model, x, y, train_mask, test_mask,
edge_index=edge_index, c_threshold=0,
topk_explanations=3, verbose=True)
topk_explanations=3, verbose=True, material=False)
explanations = entropy.explain_classes(model, x, y, train_mask, test_mask,
edge_index=edge_index, c_threshold=0,
topk_explanations=3, verbose=True, material=True)

return

Expand Down
14 changes: 10 additions & 4 deletions torch_explain/logic/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

import torch
import numpy as np
from sklearn.metrics import f1_score
from sklearn.metrics import f1_score, accuracy_score
from sympy import to_dnf, lambdify


def test_explanation(formula: str, x: torch.Tensor, y: torch.Tensor, target_class: int,
mask: torch.Tensor = None, threshold: float = 0.5) -> Tuple[float, torch.Tensor]:
mask: torch.Tensor = None, threshold: float = 0.5,
material: bool = False) -> Tuple[float, torch.Tensor]:
"""
Tests a logic formula.
Expand All @@ -34,8 +35,13 @@ def test_explanation(formula: str, x: torch.Tensor, y: torch.Tensor, target_clas
x = x.cpu().detach().numpy()
predictions = fun(*[x[:, i] > threshold for i in range(x.shape[1])])
predictions = torch.LongTensor(predictions)
# material implication: (p=>q) <=> (not p or q)
accuracy = torch.sum(torch.logical_or(torch.logical_not(predictions[mask]), y2[mask])) / len(y2[mask])
if material:
# material implication: (p=>q) <=> (not p or q)
accuracy = torch.sum(torch.logical_or(torch.logical_not(predictions[mask]), y2[mask])) / len(y2[mask])
accuracy = accuracy.item()
else:
# material biconditional: (p<=>q) <=> (p and q) or (not p and not q)
accuracy = accuracy_score(predictions[mask], y2[mask])
return accuracy, predictions


Expand Down
5 changes: 3 additions & 2 deletions torch_explain/logic/nn/entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def explain_classes(model: torch.nn.Module, c: torch.Tensor, y: torch.Tensor,
edge_index: torch.Tensor = None, max_minterm_complexity: int = 1000,
topk_explanations: int = 1000, try_all: bool = False,
c_threshold: float = 0.5, y_threshold: float = 0.,
concept_names: List[str] = None, class_names: List[str] = None,
concept_names: List[str] = None, class_names: List[str] = None, material: bool = False,
verbose: bool = False) -> Dict:
"""
Explain LENs predictions with concept-based logic explanations.
Expand All @@ -38,6 +38,7 @@ def explain_classes(model: torch.nn.Module, c: torch.Tensor, y: torch.Tensor,
:param y_threshold: threshold to get truth values for class predictions (i.e. pred<threshold = false, pred>threshold = true)
:param concept_names: list of concept names
:param class_names: list of class names
:param material: if True, then the explanations performance is computed for the material implication
:param verbose: if True, then prints the explanations
:return: Global explanations
"""
Expand All @@ -54,7 +55,7 @@ def explain_classes(model: torch.nn.Module, c: torch.Tensor, y: torch.Tensor,
topk_explanations=topk_explanations, try_all=try_all,
c_threshold=c_threshold, y_threshold=y_threshold)

explanation_accuracy, _ = test_explanation(explanation, c, y, class_id, test_mask, c_threshold)
explanation_accuracy, _ = test_explanation(explanation, c, y, class_id, test_mask, c_threshold, material)
explanation_complexity = complexity(explanation)

explanations[str(class_id)] = {'explanation': explanation,
Expand Down

0 comments on commit dfcff89

Please sign in to comment.