Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions xls/ir/node_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,61 @@ std::vector<Node*> RemoveRedundantNodes(

} // namespace

std::optional<ShiftedBitView> IsOneShiftedBit(Node* node) {
// Match: shll(zext(b), literal(k))
if (node->op() == Op::kShll) {
Node* shift_base = node->operand(0);
Node* shift_amount = node->operand(1);
if (shift_base->op() == Op::kZeroExt &&
IsSingleBitType(shift_base->operand(0)) &&
shift_amount->Is<Literal>()) {
absl::StatusOr<uint64_t> k_u64 =
shift_amount->As<Literal>()->value().bits().ToUint64();
if (!k_u64.ok()) {
return std::nullopt;
}
return ShiftedBitView{.b = shift_base->operand(0),
.k = static_cast<int64_t>(*k_u64)};
}
}

// Match: concat(0..., b, 0...)
if (node->Is<Concat>()) {
std::optional<int64_t> b_operand_index;
for (int64_t i = 0; i < node->operand_count(); ++i) {
Node* operand = node->operand(i);
if (!IsSingleBitType(operand)) {
continue;
}
if (b_operand_index.has_value()) {
// More than one 1-bit operand.
return std::nullopt;
}
b_operand_index = i;
}
if (!b_operand_index.has_value()) {
return std::nullopt;
}

for (int64_t i = 0; i < node->operand_count(); ++i) {
if (i == *b_operand_index) {
continue;
}
if (!IsLiteralZero(node->operand(i))) {
return std::nullopt;
}
}

int64_t k = 0;
for (int64_t i = *b_operand_index + 1; i < node->operand_count(); ++i) {
k += node->operand(i)->BitCountOrDie();
}
return ShiftedBitView{.b = node->operand(*b_operand_index), .k = k};
}

return std::nullopt;
}

bool IsLiteralWithRunOfSetBits(Node* node, int64_t* leading_zero_count,
int64_t* set_bit_count,
int64_t* trailing_zero_count) {
Expand Down
16 changes: 16 additions & 0 deletions xls/ir/node_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@

namespace xls {

struct ShiftedBitView {
Node* b;
int64_t k;
};

// Returns (b, k) if `node` is structurally equivalent to a value with a single
// potentially-set bit at position `k` (0 == LSb) controlled by the 1-bit value
// `b`. The recognized forms are:
//
// * shll(zext(b), literal(k))
// * concat(0..., b, 0...)
//
// This is a structural matcher and only recognizes literal zeros / literal shift
// amounts (it does not use any query engine).
std::optional<ShiftedBitView> IsOneShiftedBit(Node* node);

inline bool IsLiteralZero(Node* node) {
return node->Is<Literal>() && node->As<Literal>()->value().IsBits() &&
node->As<Literal>()->value().bits().IsZero();
Expand Down
46 changes: 46 additions & 0 deletions xls/passes/arith_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,52 @@ absl::StatusOr<bool> MatchArithPatterns(int64_t opt_level, Node* n,
}
}

// Pattern:
//
// umod(x, shll(zext(b), k)) -> sel(b, zext(bit_slice(x, 0, k), width(x)), 0)
//
// where b is a 1-bit value.
if (n->op() == Op::kUMod) {
Node* x = n->operand(0);
Node* divisor = n->operand(1);
const int64_t bit_count = x->BitCountOrDie();

auto replace_with_select = [&](Node* b, int64_t k) -> absl::StatusOr<bool> {
XLS_RET_CHECK_EQ(b->BitCountOrDie(), 1);
if (k <= 0 || k >= bit_count) {
XLS_RETURN_IF_ERROR(
n->ReplaceUsesWithNew<Literal>(ZeroOfType(n->GetType())).status());
return true;
}

XLS_ASSIGN_OR_RETURN(
Node * slice, n->function_base()->MakeNode<BitSlice>(
n->loc(), x, /*start=*/0, /*width=*/k));
XLS_ASSIGN_OR_RETURN(
Node * narrowed,
n->function_base()->MakeNode<ExtendOp>(n->loc(), slice, bit_count,
Op::kZeroExt));
XLS_ASSIGN_OR_RETURN(
Node * zero, n->function_base()->MakeNode<Literal>(
n->loc(), Value(UBits(0, bit_count))));
XLS_RETURN_IF_ERROR(
n->ReplaceUsesWithNew<Select>(
b, std::vector<Node*>{zero, narrowed},
/*default_value=*/std::nullopt)
.status());
return true;
};

std::optional<ShiftedBitView> shifted_bit = IsOneShiftedBit(divisor);
if (shifted_bit.has_value()) {
XLS_ASSIGN_OR_RETURN(bool changed,
replace_with_select(shifted_bit->b, shifted_bit->k));
if (changed) {
return true;
}
}
}

// Pattern: UMod/SMod by a literal.
if (n->OpIn({Op::kUMod, Op::kSMod}) &&
query_engine.IsFullyKnown(n->operand(1))) {
Expand Down
22 changes: 22 additions & 0 deletions xls/passes/arith_simplification_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,28 @@ TEST_F(ArithSimplificationPassTest, UModByVariable) {
EXPECT_THAT(f->return_value(), m::UMod());
}

TEST_F(ArithSimplificationPassTest, UModShiftedOneBitDivisor) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* bits32 = p->GetBitsType(32);
BValue x = fb.Param("x", bits32);
BValue b = fb.Param("b", p->GetBitsType(1));
const int64_t k = 2;
BValue divisor = fb.Shll(fb.ZeroExtend(b, 32), fb.Literal(UBits(k, 32)));
fb.UMod(x, divisor);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

ScopedVerifyEquivalence sve(f, kProverTimeout);
ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));

EXPECT_THAT(
f->return_value(),
m::Select(m::Param("b"),
/*cases=*/{m::Literal(UBits(0, 32)),
m::ZeroExt(m::BitSlice(m::Param("x"),
/*start=*/0, /*width=*/k))}));
}

TEST_F(ArithSimplificationPassTest, UModOf13) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
Expand Down
45 changes: 45 additions & 0 deletions xls/passes/optimization_pass_pipeline_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,51 @@ TEST_F(OptimizationPipelineTest, MultiplyBy16StrengthReduction) {
EXPECT_THAT(f->return_value(), m::Concat());
}

TEST_F(OptimizationPipelineTest, UModShiftedOneBitDivisorSimplified) {
auto p = CreatePackage();
FunctionBuilder fb(TestName(), p.get());
Type* bits32 = p->GetBitsType(32);
BValue x = fb.Param("x", bits32);
BValue b = fb.Param("b", p->GetBitsType(1));

// Construct the pattern:
// umod(x, shll(zext(b), k))
const int64_t k = 2;
BValue divisor = fb.Shll(fb.ZeroExtend(b, 32), fb.Literal(UBits(k, 32)));
fb.UMod(x, divisor);
XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build());

ASSERT_THAT(Run(p.get()), IsOkAndHolds(true));

// In the pipeline we expect:
//
// umod(x, shll(zext(b), k)) -> sel(b, zext(bit_slice(x, 0, k), width(x)), 0)
//
// which will often then become:
//
// and(sign_ext(b), zext(bit_slice(x, 0, k), width(x))).
auto reduced =
m::ZeroExt(m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/k));
auto and_form = m::And(m::SignExt(m::Param("b")), reduced);
auto and_form_flipped = m::And(reduced, m::SignExt(m::Param("b")));
auto sel_form = m::Select(m::Param("b"),
/*cases=*/{m::Literal(UBits(0, 32)), reduced});
auto narrowed_and = m::And(m::SignExt(m::Param("b")),
m::BitSlice(m::Param("x"), /*start=*/0,
/*width=*/k));
auto narrowed_and_flipped =
m::And(m::BitSlice(m::Param("x"), /*start=*/0, /*width=*/k),
m::SignExt(m::Param("b")));
auto concat_form =
m::Concat(m::Literal(UBits(0, 32 - k)), narrowed_and);
auto concat_form_flipped =
m::Concat(m::Literal(UBits(0, 32 - k)), narrowed_and_flipped);
EXPECT_THAT(f->return_value(),
::testing::AnyOf(and_form, and_form_flipped, sel_form, concat_form,
concat_form_flipped))
<< f->DumpIr();
}

TEST_F(OptimizationPipelineTest, LogicAbsorption) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(
Expand Down