Skip to content

Commit d28f1a5

Browse files
committed
Add softmax support for int8 in Cortex M (dim=-1)
- integrate CMSIS softmax into Cortex-M backend - add fusion pass/tests for quantized softmax - lint cleanup passes - Resolved merge conflicts Change-Id: I0ec19f011069fa1482e2de2ab62b9e7d7f56b2a8 Signed-off-by: Xingguo Li <xingguo.li@arm.com>
1 parent 288edb4 commit d28f1a5

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

backends/cadence/aot/program_builder.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,18 @@ def placeholder(
7777
return placeholder
7878

7979
def output(
80-
self, results: list[ProxyValue], output_kinds: Optional[list[OutputKind]] = None
80+
self,
81+
results: list[ProxyValue],
82+
output_kinds: Optional[list[OutputKind]] = None,
83+
output_targets: Optional[list[str | None]] = None,
8184
) -> ProxyValue:
8285
if output_kinds is None:
8386
output_kinds = [OutputKind.USER_OUTPUT] * len(results)
84-
for result, out_kind in zip(results, output_kinds):
87+
if output_targets is None:
88+
output_targets = [None] * len(results)
89+
for result, out_kind, target in zip(results, output_kinds, output_targets):
8590
self.output_specs.append(
86-
OutputSpec(out_kind, TensorArgument(result.node.name), target=None)
91+
OutputSpec(out_kind, TensorArgument(result.node.name), target=target)
8792
)
8893
return super().output(results)
8994

backends/cortex_m/ops/op_softmax.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,15 @@ Tensor& softmax_out(
7171
return out;
7272
}
7373

74-
const int32_t input_zp_val = static_cast<int32_t>(input_zero_point);
7574
const int32_t output_zp_val = static_cast<int32_t>(output_zero_point);
76-
(void)input_zp_val; // Zero-point difference cancels out during subtraction.
75+
const int32_t input_multiplier_val = static_cast<int32_t>(input_multiplier);
76+
const int32_t input_shift_val = static_cast<int32_t>(input_shift);
77+
const int32_t diff_min_val = static_cast<int32_t>(diff_min);
7778

7879
validate_single_quant_params(
79-
Scalar(input_zp_val),
80-
Scalar(input_multiplier),
81-
Scalar(input_shift),
80+
Scalar(static_cast<int32_t>(input_zero_point)),
81+
Scalar(input_multiplier_val),
82+
Scalar(input_shift_val),
8283
"softmax input");
8384

8485
const auto positive_dim = normalize_dim(input, dim);
@@ -118,10 +119,6 @@ Tensor& softmax_out(
118119
return out;
119120
}
120121

121-
const int32_t input_multiplier_val = static_cast<int32_t>(input_multiplier);
122-
const int32_t input_shift_val = static_cast<int32_t>(input_shift);
123-
const int32_t diff_min_val = static_cast<int32_t>(diff_min);
124-
125122
if (output_zp_val != kCmsisSoftmaxZeroPoint) {
126123
ET_LOG(
127124
Error,

backends/cortex_m/test/ops/test_softmax.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4747
CortexMSoftmax(dim=1),
4848
(ramp_tensor(-2, 2, (2, 3, 4)),),
4949
),
50+
"large_tensor": McuTestCase(
51+
CortexMSoftmax(dim=-1),
52+
(ramp_tensor(-10, 10, (8, 1024)),),
53+
),
5054
}
5155

5256

0 commit comments

Comments
 (0)