diff --git a/xls/ir/node_util.cc b/xls/ir/node_util.cc index cd4cc02558..289885ddde 100644 --- a/xls/ir/node_util.cc +++ b/xls/ir/node_util.cc @@ -1587,4 +1587,28 @@ absl::StatusOr GenericSelect::MakePredicateForDefault() const { absl::StrFormat("%s is not a select like operation.", n->ToString())); } } + +absl::StatusOr GenericSelect::MakeSelectLikeWithNewArms( + absl::Span new_cases, std::optional new_default_value, + const SourceInfo& loc) const { + XLS_RET_CHECK(valid()); + XLS_RET_CHECK_EQ(new_cases.size(), cases().size()); + FunctionBase* fb = AsNode()->function_base(); + return std::visit( + Visitor{[&](Select* /*unused*/) -> absl::StatusOr { + return fb->MakeNode() && diff --git a/xls/passes/select_simplification_pass_test.cc b/xls/passes/select_simplification_pass_test.cc index cb6f24329b..048ef66fa5 100644 --- a/xls/passes/select_simplification_pass_test.cc +++ b/xls/passes/select_simplification_pass_test.cc @@ -58,13 +58,14 @@ enum class AnalysisType { kTernary, kRange, }; -std::ostream& operator<<(std::ostream& os, AnalysisType a) { +[[maybe_unused]] std::ostream& operator<<(std::ostream& os, AnalysisType a) { switch (a) { case AnalysisType::kTernary: return os << "Ternary"; case AnalysisType::kRange: return os << "Range"; } + return os; } class SelectSimplificationPassTest : public IrTestBase, @@ -163,6 +164,261 @@ TEST_P(SelectSimplificationPassTest, BinaryTuplePrioritySelect) { /*default_value=*/m::TupleIndex(m::Tuple(), 1)))); } +TEST_P(SelectSimplificationPassTest, + HoistCompareThroughSelectLikeWithMostlyLiteralArms) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue x = fb.Param("x", p->GetBitsType(8)); + BValue param_p = fb.Param("p", p->GetBitsType(1)); + + BValue lit0 = fb.Literal(Value(UBits(0, 8))); + BValue lit1 = fb.Literal(Value(UBits(1, 8))); + BValue lit2 = fb.Literal(Value(UBits(2, 8))); + + BValue pr = fb.PrioritySelect(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + BValue eq_pr = fb.Eq(pr, lit0); + + BValue se = fb.Select(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + BValue eq_se = fb.Eq(se, lit0); + + // Selector is provably exactly-one-hot: concat(not(p), p). + BValue not_p = fb.Not(param_p); + BValue oh = fb.Concat({not_p, param_p}); + BValue ohs = fb.OneHotSelect(oh, /*cases=*/{lit1, x}); + BValue eq_ohs = fb.Eq(ohs, lit0); + + XLS_ASSERT_OK_AND_ASSIGN( + Function * f, fb.BuildWithReturnValue(fb.Tuple({eq_pr, eq_se, eq_ohs}))); + EXPECT_TRUE(f->return_value()->Is()); + + solvers::z3::ScopedVerifyEquivalence stays_equivalent{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::Tuple(m::PrioritySelect(m::Param("s"), + {m::Eq(m::Literal(1), m::Literal(0)), + m::Eq(m::Param("x"), m::Literal(0))}, + m::Eq(m::Literal(2), m::Literal(0))), + m::Select(m::Param("s"), + {m::Eq(m::Literal(1), m::Literal(0)), + m::Eq(m::Param("x"), m::Literal(0))}, + m::Eq(m::Literal(2), m::Literal(0))), + m::OneHotSelect(m::Concat(m::Not(m::Param("p")), m::Param("p")), + {m::Eq(m::Literal(1), m::Literal(0)), + m::Eq(m::Param("x"), m::Literal(0))}))); +} + +TEST_P(SelectSimplificationPassTest, + HoistOrReduceThroughSelectLikeWithMostlyLiteralArms) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue x = fb.Param("x", p->GetBitsType(8)); + BValue param_p = fb.Param("p", p->GetBitsType(1)); + + BValue lit1 = fb.Literal(Value(UBits(1, 8))); + BValue lit2 = fb.Literal(Value(UBits(2, 8))); + + BValue pr = fb.PrioritySelect(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + BValue or_pr = fb.OrReduce(pr); + + BValue se = fb.Select(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + BValue or_se = fb.OrReduce(se); + + // Selector is provably exactly-one-hot: concat(not(p), p). + BValue not_p = fb.Not(param_p); + BValue oh = fb.Concat({not_p, param_p}); + BValue ohs = fb.OneHotSelect(oh, /*cases=*/{lit1, x}); + BValue or_ohs = fb.OrReduce(ohs); + + XLS_ASSERT_OK_AND_ASSIGN( + Function * f, fb.BuildWithReturnValue(fb.Tuple({or_pr, or_se, or_ohs}))); + EXPECT_TRUE(f->return_value()->Is()); + + solvers::z3::ScopedVerifyEquivalence stays_equivalent{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::Tuple( + m::PrioritySelect( + m::Param("s"), + {m::OrReduce(m::Literal(1)), m::OrReduce(m::Param("x"))}, + m::OrReduce(m::Literal(2))), + m::Select(m::Param("s"), + {m::OrReduce(m::Literal(1)), m::OrReduce(m::Param("x"))}, + m::OrReduce(m::Literal(2))), + m::OneHotSelect( + m::Concat(m::Not(m::Param("p")), m::Param("p")), + {m::OrReduce(m::Literal(1)), m::OrReduce(m::Param("x"))}))); +} + +TEST_P(SelectSimplificationPassTest, + HoistAndXorReduceThroughSelectLikeWithMostlyLiteralArms) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue x = fb.Param("x", p->GetBitsType(8)); + + BValue lit1 = fb.Literal(Value(UBits(1, 8))); + BValue lit2 = fb.Literal(Value(UBits(2, 8))); + + // Keep the single-use profitability guard enabled by giving each select-like + // node exactly one use. + BValue pr_and = + fb.PrioritySelect(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + BValue and_pr = fb.AndReduce(pr_and); + BValue pr_xor = + fb.PrioritySelect(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + BValue xor_pr = fb.XorReduce(pr_xor); + + BValue se_and = fb.Select(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + BValue and_se = fb.AndReduce(se_and); + BValue se_xor = fb.Select(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + BValue xor_se = fb.XorReduce(se_xor); + + XLS_ASSERT_OK_AND_ASSIGN( + Function * f, + fb.BuildWithReturnValue(fb.Tuple({and_pr, xor_pr, and_se, xor_se}))); + EXPECT_TRUE(f->return_value()->Is()); + + solvers::z3::ScopedVerifyEquivalence stays_equivalent{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT( + f->return_value(), + m::Tuple( + m::PrioritySelect( + m::Param("s"), + {m::AndReduce(m::Literal(1)), m::AndReduce(m::Param("x"))}, + m::AndReduce(m::Literal(2))), + m::PrioritySelect( + m::Param("s"), + {m::XorReduce(m::Literal(1)), m::XorReduce(m::Param("x"))}, + m::XorReduce(m::Literal(2))), + m::Select(m::Param("s"), + {m::AndReduce(m::Literal(1)), m::AndReduce(m::Param("x"))}, + m::AndReduce(m::Literal(2))), + m::Select(m::Param("s"), + {m::XorReduce(m::Literal(1)), m::XorReduce(m::Param("x"))}, + m::XorReduce(m::Literal(2))))); +} + +TEST_P(SelectSimplificationPassTest, + HoistNonCommutativeCompareThroughSelectPreservesOrder) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue x = fb.Param("x", p->GetBitsType(8)); + BValue lit0 = fb.Literal(Value(UBits(0, 8))); + BValue lit1 = fb.Literal(Value(UBits(1, 8))); + BValue lit2 = fb.Literal(Value(UBits(2, 8))); + + BValue se = fb.Select(s, /*cases=*/{lit1, x}, /*default_value=*/lit2); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, + fb.BuildWithReturnValue(fb.ULt(lit0, se))); + EXPECT_TRUE(f->return_value()->Is()); + + solvers::z3::ScopedVerifyEquivalence stays_equivalent{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + m::Select(m::Param("s"), + {m::ULt(m::Literal(0), m::Literal(1)), + m::ULt(m::Literal(0), m::Param("x"))}, + m::ULt(m::Literal(0), m::Literal(2)))); +} + +TEST_P(SelectSimplificationPassTest, + HoistCompareThroughSelectLikeWithAllLiteralArms) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue param_p = fb.Param("p", p->GetBitsType(1)); + + BValue lit0 = fb.Literal(Value(UBits(0, 8))); + BValue lit1 = fb.Literal(Value(UBits(1, 8))); + BValue lit2 = fb.Literal(Value(UBits(2, 8))); + + BValue pr = + fb.PrioritySelect(s, /*cases=*/{lit0, lit1}, /*default_value=*/lit2); + BValue eq_pr = fb.Eq(pr, lit0); + + BValue se = fb.Select(s, /*cases=*/{lit0, lit1}, /*default_value=*/lit2); + BValue eq_se = fb.Eq(se, lit0); + + BValue not_p = fb.Not(param_p); + BValue oh = fb.Concat({not_p, param_p}); + BValue ohs = fb.OneHotSelect(oh, /*cases=*/{lit0, lit1}); + BValue eq_ohs = fb.Eq(ohs, lit0); + + XLS_ASSERT_OK_AND_ASSIGN( + Function * f, fb.BuildWithReturnValue(fb.Tuple({eq_pr, eq_se, eq_ohs}))); + EXPECT_TRUE(f->return_value()->Is()); + + solvers::z3::ScopedVerifyEquivalence stays_equivalent{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(true)); + + // Key invariant: after hoisting, no compare should directly compare a + // select-like node (sel/priority_sel/one_hot_sel) against the literal. + for (Node* n : f->nodes()) { + if (!n->Is()) { + continue; + } + if (n->operand(0)->OpIn({Op::kPrioritySel, Op::kSel, Op::kOneHotSel}) || + n->operand(1)->OpIn({Op::kPrioritySel, Op::kSel, Op::kOneHotSel})) { + // The hoisted compares can still exist (e.g. eq(lit0, lit0)), but the + // compare itself should no longer be comparing the select-like result to + // the literal. + FAIL() << "Unexpected compare of select-like value: " << n->ToString(); + } + } +} + +TEST_P(SelectSimplificationPassTest, + CompareThroughSelectLikeWithTwoNonLiteralArmsDoesNotHoist) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue x = fb.Param("x", p->GetBitsType(8)); + BValue y = fb.Param("y", p->GetBitsType(8)); + BValue lit0 = fb.Literal(Value(UBits(0, 8))); + BValue lit2 = fb.Literal(Value(UBits(2, 8))); + + // Two non-literal arms => profitability guard should block the rewrite. + BValue se = fb.Select(s, /*cases=*/{x, y}, /*default_value=*/lit2); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, + fb.BuildWithReturnValue(fb.Eq(se, lit0))); + + solvers::z3::ScopedVerifyEquivalence stays_equivalent{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(false)); + EXPECT_THAT(f->return_value(), + m::Eq(m::Select(m::Param("s"), + /*cases=*/{m::Param("x"), m::Param("y")}, + /*default_value=*/m::Literal(2)), + m::Literal(0))); +} + +TEST_P(SelectSimplificationPassTest, + OrReduceThroughSelectLikeWithTwoNonLiteralArmsDoesNotHoist) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue s = fb.Param("s", p->GetBitsType(2)); + BValue x = fb.Param("x", p->GetBitsType(8)); + BValue y = fb.Param("y", p->GetBitsType(8)); + BValue lit2 = fb.Literal(Value(UBits(2, 8))); + + // Two non-literal arms => profitability guard should block the rewrite. + BValue se = fb.Select(s, /*cases=*/{x, y}, /*default_value=*/lit2); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, + fb.BuildWithReturnValue(fb.OrReduce(se))); + + solvers::z3::ScopedVerifyEquivalence stays_equivalent{f}; + EXPECT_THAT(Run(f), IsOkAndHolds(false)); + EXPECT_THAT(f->return_value(), + m::OrReduce(m::Select(m::Param("s"), + /*cases=*/{m::Param("x"), m::Param("y")}, + /*default_value=*/m::Literal(2)))); +} + TEST_P(SelectSimplificationPassTest, FourWayTupleSelect) { auto p = CreatePackage(); XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"( diff --git a/xls/passes/stateless_query_engine.cc b/xls/passes/stateless_query_engine.cc index 2c41d40957..214185512c 100644 --- a/xls/passes/stateless_query_engine.cc +++ b/xls/passes/stateless_query_engine.cc @@ -262,6 +262,10 @@ bool StatelessQueryEngine::AtMostOneBitTrue(Node* node) const { if (node->Is()) { return true; } + // A <=1-bit value can never have more than one bit set. + if (node->GetType()->IsBits() && node->BitCountOrDie() <= 1) { + return true; + } return QueryEngine::AtMostOneBitTrue(node); } @@ -277,6 +281,49 @@ bool StatelessQueryEngine::ExactlyOneBitTrue(Node* node) const { return true; } + // Fast pattern matches for selectors which are exactly-one-hot by + // construction. + // + // Note: This is intentionally local/structural reasoning; stateless query + // engine does not propagate facts through the graph. + if (node->op() == Op::kConcat && node->operand_count() == 2) { + Node* msb = node->operand(0); + Node* lsb = node->operand(1); + auto is_single_bit = [](Node* n) { + return n->GetType()->IsBits() && n->BitCountOrDie() == 1; + }; + auto is_not_of = [](Node* n, Node* x) { + return n->op() == Op::kNot && n->operand_count() == 1 && + n->operand(0) == x; + }; + + // concat(not(x), x) or concat(x, not(x)) is exactly-one-hot when x is a + // single-bit value. + if (is_single_bit(lsb) && is_not_of(msb, lsb)) { + return true; + } + if (is_single_bit(msb) && is_not_of(lsb, msb)) { + return true; + } + + // OneHot(x) may be rewritten into concat(eq(x, 0), x) (e.g. by select + // simplification). This concat is exactly-one-hot when x is mutually + // exclusive (at most one bit set). + Node* maybe_eq = msb; + Node* x = lsb; + if (maybe_eq->op() == Op::kEq && maybe_eq->operand_count() == 2) { + Node* eq_lhs = maybe_eq->operand(0); + Node* eq_rhs = maybe_eq->operand(1); + // We only handle the literal-zero comparison in stateless mode. + if ((eq_lhs == x && IsAllZeros(eq_rhs)) || + (eq_rhs == x && IsAllZeros(eq_lhs))) { + if (AtMostOneBitTrue(x)) { + return true; + } + } + } + } + return QueryEngine::ExactlyOneBitTrue(node); } diff --git a/xls/passes/stateless_query_engine_test.cc b/xls/passes/stateless_query_engine_test.cc index 87259a0031..bc9e267790 100644 --- a/xls/passes/stateless_query_engine_test.cc +++ b/xls/passes/stateless_query_engine_test.cc @@ -117,6 +117,31 @@ TEST_F(StatelessQueryEngineTest, OneHotMsb) { EXPECT_TRUE(query_engine.ExactlyOneBitTrue(f->return_value())); } +TEST_F(StatelessQueryEngineTest, ExactlyOneBitTrueConcatNotXAndX) { + Package p("test_package"); + FunctionBuilder fb("f", &p); + BValue x = fb.Param("x", p.GetBitsType(1)); + BValue selector = fb.Concat({fb.Not(x), x}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(selector)); + VLOG(3) << f->DumpIr(); + + StatelessQueryEngine query_engine; + EXPECT_TRUE(query_engine.ExactlyOneBitTrue(f->return_value())); +} + +TEST_F(StatelessQueryEngineTest, ExactlyOneBitTrueConcatEqXZeroAndX) { + Package p("test_package"); + FunctionBuilder fb("f", &p); + BValue x = fb.Param("x", p.GetBitsType(1)); + BValue x_eq_zero = fb.Eq(x, fb.Literal(UBits(0, 1))); + BValue selector = fb.Concat({x_eq_zero, x}); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(selector)); + VLOG(3) << f->DumpIr(); + + StatelessQueryEngine query_engine; + EXPECT_TRUE(query_engine.ExactlyOneBitTrue(f->return_value())); +} + TEST_F(StatelessQueryEngineTest, ZeroExtend) { Package p("test_package"); FunctionBuilder fb("f", &p); diff --git a/xls/passes/union_query_engine.cc b/xls/passes/union_query_engine.cc index 3df0501d10..4fb9a4e273 100644 --- a/xls/passes/union_query_engine.cc +++ b/xls/passes/union_query_engine.cc @@ -14,6 +14,7 @@ #include "xls/passes/union_query_engine.h" +#include #include #include #include @@ -237,6 +238,27 @@ bool UnownedUnionQueryEngine::AtLeastOneTrue( return false; } +bool UnownedUnionQueryEngine::AtMostOneBitTrue(Node* node) const { + return std::any_of(engines_.begin(), engines_.end(), + [&](const QueryEngine* engine) { + return engine->AtMostOneBitTrue(node); + }); +} + +bool UnownedUnionQueryEngine::AtLeastOneBitTrue(Node* node) const { + return std::any_of(engines_.begin(), engines_.end(), + [&](const QueryEngine* engine) { + return engine->AtLeastOneBitTrue(node); + }); +} + +bool UnownedUnionQueryEngine::ExactlyOneBitTrue(Node* node) const { + return std::any_of(engines_.begin(), engines_.end(), + [&](const QueryEngine* engine) { + return engine->ExactlyOneBitTrue(node); + }); +} + bool UnownedUnionQueryEngine::KnownEquals(const TreeBitLocation& a, const TreeBitLocation& b) const { for (const auto& engine : engines_) { diff --git a/xls/passes/union_query_engine.h b/xls/passes/union_query_engine.h index d76f3704fc..104d21cf6c 100644 --- a/xls/passes/union_query_engine.h +++ b/xls/passes/union_query_engine.h @@ -101,6 +101,12 @@ class UnownedUnionQueryEngine : public QueryEngine { bool IsAllZeros(Node* n) const override; bool IsAllOnes(Node* n) const override; + // Returns true if at most/at least/exactly one of the bits in 'node' is true. + // 'node' must be bits-typed. + bool AtMostOneBitTrue(Node* node) const override; + bool AtLeastOneBitTrue(Node* node) const override; + bool ExactlyOneBitTrue(Node* node) const override; + Bits MaxUnsignedValue(Node* node) const override; Bits MinUnsignedValue(Node* node) const override; std::optional KnownLeadingZeros(Node* node) const override;