From fd3eb11b0e69d5fd9da4286a6c4af4b333a1279a Mon Sep 17 00:00:00 2001 From: liuluobin Date: Mon, 31 Aug 2020 17:24:51 +0800 Subject: [PATCH] fixed softmax_cross_entropy return NaN value --- .../diff_privacy/evaluation/membership_inference.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mindarmour/diff_privacy/evaluation/membership_inference.py b/mindarmour/diff_privacy/evaluation/membership_inference.py index c27882d..a91c5fb 100755 --- a/mindarmour/diff_privacy/evaluation/membership_inference.py +++ b/mindarmour/diff_privacy/evaluation/membership_inference.py @@ -82,7 +82,12 @@ def _softmax_cross_entropy(logits, labels): """ labels = np.eye(logits.shape[1])[labels].astype(np.int32) logits = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) - return -1*np.sum(labels*np.log(logits), axis=1) + loss = -1*np.sum(labels*np.log(logits), axis=1) + + nan_index = np.isnan(loss) + if np.any(nan_index): + loss[nan_index] = 0 + return loss class MembershipInference: @@ -243,6 +248,7 @@ class MembershipInference: np.random.shuffle(shuffle_index) features = features[shuffle_index] labels = labels[shuffle_index] + return features, labels def _generate(self, dataset_x, label):