Browse Source

fixed softmax_cross_entropy return NaN value

tags/v1.0.0
liuluobin 5 years ago
parent
commit
fd3eb11b0e
1 changed files with 7 additions and 1 deletions
  1. +7
    -1
      mindarmour/diff_privacy/evaluation/membership_inference.py

+ 7
- 1
mindarmour/diff_privacy/evaluation/membership_inference.py View File

@@ -82,7 +82,12 @@ def _softmax_cross_entropy(logits, labels):
"""
labels = np.eye(logits.shape[1])[labels].astype(np.int32)
logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
return -1*np.sum(labels*np.log(logits), axis=1)
loss = -1*np.sum(labels*np.log(logits), axis=1)

nan_index = np.isnan(loss)
if np.any(nan_index):
loss[nan_index] = 0
return loss


class MembershipInference:
@@ -243,6 +248,7 @@ class MembershipInference:
np.random.shuffle(shuffle_index)
features = features[shuffle_index]
labels = labels[shuffle_index]

return features, labels

def _generate(self, dataset_x, label):


Loading…
Cancel
Save