diff --git a/saliency/core/base.py b/saliency/core/base.py index bb886173..8d7bd1cb 100644 --- a/saliency/core/base.py +++ b/saliency/core/base.py @@ -110,18 +110,22 @@ def GetSmoothedMask(self, magnitude: If true, computes the sum of squares of gradients instead of just the sum. Defaults to true. """ - stdev = stdev_spread * (np.max(x_value) - np.min(x_value)) + stdev = stdev_spread * (np.max(x_value) - np.min(x_value)) total_gradients = np.zeros_like(x_value, dtype=np.float32) - for _ in range(nsamples): - noise = np.random.normal(0, stdev, x_value.shape) - x_plus_noise = x_value + noise - grad = self.GetMask(x_plus_noise, call_model_function, call_model_args, + shape = (nsamples,) + x_value.shape + noisy_samples = np.zeros(shape) + for i in range(nsamples): + noise = np.random.normal(0, stdev, x_value.shape) + x_plus_noise = x_value + noise + noisy_samples[i] = x_plus_noise + + grads = self.GetMask(noisy_samples, call_model_function, call_model_args, **kwargs) - if magnitude: - total_gradients += (grad * grad) - else: - total_gradients += grad + if magnitude: + total_gradients = np.sum((grads * grads), axis = 0) + else: + total_gradients = np.sum(grads, axis = 0) return total_gradients / nsamples diff --git a/saliency/core/gradients.py b/saliency/core/gradients.py index 8ac89947..224f1cd6 100644 --- a/saliency/core/gradients.py +++ b/saliency/core/gradients.py @@ -47,7 +47,10 @@ def GetMask(self, x_value, call_model_function, call_model_args=None): call_model_args: The arguments that will be passed to the call model function, for every call of the model. """ - x_value_batched = np.expand_dims(x_value, axis=0) + if(len(x_value.shape) == 3): + x_value_batched = np.expand_dims(x_value, axis=0) + else: + x_value_batched = x_value call_model_output = call_model_function( x_value_batched, call_model_args=call_model_args, @@ -57,4 +60,7 @@ def GetMask(self, x_value, call_model_function, call_model_args=None): x_value_batched.shape, self.expected_keys) - return call_model_output[INPUT_OUTPUT_GRADIENTS][0] + if(x_value_batched.shape[0] == 1): + return call_model_output[INPUT_OUTPUT_GRADIENTS][0] + else: + return call_model_output[INPUT_OUTPUT_GRADIENTS]