Browse Source

fixed softmax_cross_entropy return NaN value

tags/v1.0.0
liuluobin 5 years ago
parent
commit
fd3eb11b0e
1 changed files with 7 additions and 1 deletions
  1. +7
    -1
      mindarmour/diff_privacy/evaluation/membership_inference.py

+ 7
- 1
mindarmour/diff_privacy/evaluation/membership_inference.py View File

@@ -82,7 +82,12 @@ def _softmax_cross_entropy(logits, labels):
""" """
labels = np.eye(logits.shape[1])[labels].astype(np.int32) labels = np.eye(logits.shape[1])[labels].astype(np.int32)
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
return -1*np.sum(labels*np.log(logits), axis=1)
loss = -1*np.sum(labels*np.log(logits), axis=1)

nan_index = np.isnan(loss)
if np.any(nan_index):
loss[nan_index] = 0
return loss




class MembershipInference: class MembershipInference:
@@ -243,6 +248,7 @@ class MembershipInference:
np.random.shuffle(shuffle_index) np.random.shuffle(shuffle_index)
features = features[shuffle_index] features = features[shuffle_index]
labels = labels[shuffle_index] labels = labels[shuffle_index]

return features, labels return features, labels


def _generate(self, dataset_x, label): def _generate(self, dataset_x, label):


Loading…
Cancel
Save