From c3106431e3668ea288202f9ded80c6add9f58e99 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Sun, 11 Jan 2026 14:56:09 -0800 Subject: [PATCH] [opt] strength reduce umod-by-shifted-bit --- xls/ir/node_util.cc | 55 +++++++++++++++++++ xls/ir/node_util.h | 16 ++++++ xls/passes/arith_simplification_pass.cc | 46 ++++++++++++++++ xls/passes/arith_simplification_pass_test.cc | 22 ++++++++ xls/passes/optimization_pass_pipeline_test.cc | 45 +++++++++++++++ 5 files changed, 184 insertions(+) diff --git a/xls/ir/node_util.cc b/xls/ir/node_util.cc index cd4cc02558..352bcc35da 100644 --- a/xls/ir/node_util.cc +++ b/xls/ir/node_util.cc @@ -87,6 +87,61 @@ std::vector RemoveRedundantNodes( } // namespace +std::optional IsOneShiftedBit(Node* node) { + // Match: shll(zext(b), literal(k)) + if (node->op() == Op::kShll) { + Node* shift_base = node->operand(0); + Node* shift_amount = node->operand(1); + if (shift_base->op() == Op::kZeroExt && + IsSingleBitType(shift_base->operand(0)) && + shift_amount->Is()) { + absl::StatusOr k_u64 = + shift_amount->As()->value().bits().ToUint64(); + if (!k_u64.ok()) { + return std::nullopt; + } + return ShiftedBitView{.b = shift_base->operand(0), + .k = static_cast(*k_u64)}; + } + } + + // Match: concat(0..., b, 0...) + if (node->Is()) { + std::optional b_operand_index; + for (int64_t i = 0; i < node->operand_count(); ++i) { + Node* operand = node->operand(i); + if (!IsSingleBitType(operand)) { + continue; + } + if (b_operand_index.has_value()) { + // More than one 1-bit operand. + return std::nullopt; + } + b_operand_index = i; + } + if (!b_operand_index.has_value()) { + return std::nullopt; + } + + for (int64_t i = 0; i < node->operand_count(); ++i) { + if (i == *b_operand_index) { + continue; + } + if (!IsLiteralZero(node->operand(i))) { + return std::nullopt; + } + } + + int64_t k = 0; + for (int64_t i = *b_operand_index + 1; i < node->operand_count(); ++i) { + k += node->operand(i)->BitCountOrDie(); + } + return ShiftedBitView{.b = node->operand(*b_operand_index), .k = k}; + } + + return std::nullopt; +} + bool IsLiteralWithRunOfSetBits(Node* node, int64_t* leading_zero_count, int64_t* set_bit_count, int64_t* trailing_zero_count) { diff --git a/xls/ir/node_util.h b/xls/ir/node_util.h index af2885430e..1fe1bbf2ba 100644 --- a/xls/ir/node_util.h +++ b/xls/ir/node_util.h @@ -48,6 +48,22 @@ namespace xls { +struct ShiftedBitView { + Node* b; + int64_t k; +}; + +// Returns (b, k) if `node` is structurally equivalent to a value with a single +// potentially-set bit at position `k` (0 == LSb) controlled by the 1-bit value +// `b`. The recognized forms are: +// +// * shll(zext(b), literal(k)) +// * concat(0..., b, 0...) +// +// This is a structural matcher and only recognizes literal zeros / literal shift +// amounts (it does not use any query engine). +std::optional IsOneShiftedBit(Node* node); + inline bool IsLiteralZero(Node* node) { return node->Is() && node->As()->value().IsBits() && node->As()->value().bits().IsZero(); diff --git a/xls/passes/arith_simplification_pass.cc b/xls/passes/arith_simplification_pass.cc index f3532d51f0..0db69c05db 100644 --- a/xls/passes/arith_simplification_pass.cc +++ b/xls/passes/arith_simplification_pass.cc @@ -1009,6 +1009,52 @@ absl::StatusOr MatchArithPatterns(int64_t opt_level, Node* n, } } + // Pattern: + // + // umod(x, shll(zext(b), k)) -> sel(b, zext(bit_slice(x, 0, k), width(x)), 0) + // + // where b is a 1-bit value. + if (n->op() == Op::kUMod) { + Node* x = n->operand(0); + Node* divisor = n->operand(1); + const int64_t bit_count = x->BitCountOrDie(); + + auto replace_with_select = [&](Node* b, int64_t k) -> absl::StatusOr { + XLS_RET_CHECK_EQ(b->BitCountOrDie(), 1); + if (k <= 0 || k >= bit_count) { + XLS_RETURN_IF_ERROR( + n->ReplaceUsesWithNew(ZeroOfType(n->GetType())).status()); + return true; + } + + XLS_ASSIGN_OR_RETURN( + Node * slice, n->function_base()->MakeNode( + n->loc(), x, /*start=*/0, /*width=*/k)); + XLS_ASSIGN_OR_RETURN( + Node * narrowed, + n->function_base()->MakeNode(n->loc(), slice, bit_count, + Op::kZeroExt)); + XLS_ASSIGN_OR_RETURN( + Node * zero, n->function_base()->MakeNode( + n->loc(), Value(UBits(0, bit_count)))); + XLS_RETURN_IF_ERROR( + n->ReplaceUsesWithNew