Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions xls/ir/node_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1587,4 +1587,28 @@ absl::StatusOr<Node*> GenericSelect::MakePredicateForDefault() const {
absl::StrFormat("%s is not a select like operation.", n->ToString()));
}
}

absl::StatusOr<Node*> GenericSelect::MakeSelectLikeWithNewArms(
absl::Span<Node* const> new_cases, std::optional<Node*> new_default_value,
const SourceInfo& loc) const {
XLS_RET_CHECK(valid());
XLS_RET_CHECK_EQ(new_cases.size(), cases().size());
FunctionBase* fb = AsNode()->function_base();
return std::visit(
Visitor{[&](Select* /*unused*/) -> absl::StatusOr<Node*> {
return fb->MakeNode<Select>(loc, selector(), new_cases,
new_default_value);
},
[&](PrioritySelect* /*unused*/) -> absl::StatusOr<Node*> {
XLS_RET_CHECK(new_default_value.has_value());
return fb->MakeNode<PrioritySelect>(loc, selector(), new_cases,
*new_default_value);
},
[&](OneHotSelect* /*unused*/) -> absl::StatusOr<Node*> {
XLS_RET_CHECK(!new_default_value.has_value());
return fb->MakeNode<OneHotSelect>(loc, selector(), new_cases);
}},
sel_);
}

} // namespace xls
19 changes: 19 additions & 0 deletions xls/ir/node_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,8 @@ absl::StatusOr<Node*> RemoveFromTuple(Node* tuple,
// select nodes.
class GenericSelect {
public:
enum class Kind { kSel, kPrioritySel, kOneHotSel };

// noop constructor for lists and such
GenericSelect() : sel_(static_cast<Select*>(nullptr)) {}
explicit GenericSelect(Select* select) : sel_(select) {}
Expand Down Expand Up @@ -635,6 +637,13 @@ class GenericSelect {
}

bool valid() const { return AsNode() != nullptr; }
Kind kind() const {
return std::visit(
Visitor{[](Select*) { return Kind::kSel; },
[](PrioritySelect*) { return Kind::kPrioritySel; },
[](OneHotSelect*) { return Kind::kOneHotSel; }},
sel_);
}
absl::Span<Node* const> cases() const {
return std::visit(
Visitor{
Expand Down Expand Up @@ -669,6 +678,16 @@ class GenericSelect {
// Make and return a new node which is true if the default case is selected.
absl::StatusOr<Node*> MakePredicateForDefault() const;

// Creates a new select-like node of the same kind as this GenericSelect.
//
// `new_cases` must have the same size as `cases()`.
//
// - For `priority_sel`, `new_default_value` must be present.
// - For `one_hot_sel`, `new_default_value` must be empty.
absl::StatusOr<Node*> MakeSelectLikeWithNewArms(
absl::Span<Node* const> new_cases, std::optional<Node*> new_default_value,
const SourceInfo& loc) const;

friend bool operator==(const GenericSelect& lhs, const GenericSelect& rhs) {
return lhs.AsNode() == rhs.AsNode();
}
Expand Down
129 changes: 129 additions & 0 deletions xls/passes/select_simplification_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1202,9 +1202,138 @@ absl::StatusOr<bool> MaybeReorderSelect(Node* node,
return true;
}

absl::StatusOr<bool> TryHoistPredicateThroughSelectLike(
Node* node, const QueryEngine& query_engine) {
enum class HoistKind : uint8_t { kCompare, kBitwiseReduce };
HoistKind kind;
CompareOp* cmp = nullptr;
Op bitwise_reduce_op;

std::optional<GenericSelect> sel;
Node* other = nullptr;
bool select_like_is_lhs = false;
if (node->Is<CompareOp>()) {
kind = HoistKind::kCompare;
cmp = node->As<CompareOp>();

// Look for `cmp(sel_like, literal)` or `cmp(literal, sel_like)` where
// `sel_like` is any select-like node wrapped by `GenericSelect`. We keep
// track of which side the select-like node is on so we preserve operand
// ordering for non-commutative compares (e.g. ULT/UGT).
if (auto gs = GenericSelect::From(cmp->operand(0)); gs.ok()) {
sel = *gs;
other = cmp->operand(1);
select_like_is_lhs = true;
} else if (auto gs = GenericSelect::From(cmp->operand(1)); gs.ok()) {
sel = *gs;
other = cmp->operand(0);
select_like_is_lhs = false;
} else {
return false;
}
} else if (node->Is<BitwiseReductionOp>()) {
kind = HoistKind::kBitwiseReduce;
bitwise_reduce_op = node->op();
// Treat `reduce(sel_like(...))` similarly to applying the reduction to the
// selected value.
if (auto gs = GenericSelect::From(node->operand(0)); gs.ok()) {
sel = *gs;
} else {
return false;
}
} else {
return false;
}
XLS_RET_CHECK(sel.has_value() && sel->valid());
if (sel->kind() == GenericSelect::Kind::kOneHotSel &&
!query_engine.ExactlyOneBitTrue(sel->selector())) {
// For `one_hot_sel`, hoisting the compare into the arms is only semantics-
// preserving when the selector is exactly-one-hot. Otherwise, `one_hot_sel`
// has "multiple active arms" behavior which is not equivalent to selecting
// exactly one arm and then comparing.
return false;
}

// Guardrails:
//
// - Single use: avoids duplicating the select tree under multiple
// compares/users (this rewrite clones the compare into each arm). Use
// `HasSingleUse()` so we account for implicit uses.
// - Literal other operand: maximizes folding opportunities after hoisting the
// comparison through the select(-like) into the arms.
// This keeps the rewrite roughly area-preserving in the intended cases:
// all-but-(at most)-one arm becomes a constant after folding.
if (!HasSingleUse(sel->AsNode()) ||
(kind == HoistKind::kCompare && !other->Is<Literal>())) {
return false;
}

auto count_non_literal_arms =
[&](absl::Span<Node* const> cases,
std::optional<Node*> default_value) -> int64_t {
int64_t non_literal_arm_count = 0;
for (Node* c : cases) {
if (!c->Is<Literal>()) {
++non_literal_arm_count;
}
}
if (default_value.has_value() && !default_value.value()->Is<Literal>()) {
++non_literal_arm_count;
}
return non_literal_arm_count;
};

// Profitability: we only rewrite when there is at most one non-literal arm
// (including the default). In these cases, hoisting the predicate into the
// arms enables aggressive folding of constant arms and can reduce critical-
// path delay by moving the predicate "closer" to the data.
if (count_non_literal_arms(sel->cases(), sel->default_value()) > 1) {
return false;
}

FunctionBase* f = node->function_base();
auto make_predicate = [&](Node* arm) -> absl::StatusOr<Node*> {
if (kind == HoistKind::kCompare) {
// Preserve operand order for non-commutative compares by placing `arm` on
// the same side where the select-like result originally appeared.
Node* lhs = select_like_is_lhs ? arm : other;
Node* rhs = select_like_is_lhs ? other : arm;
return f->MakeNode<CompareOp>(node->loc(), lhs, rhs, cmp->op());
}
XLS_RET_CHECK(kind == HoistKind::kBitwiseReduce);
return f->MakeNode<BitwiseReductionOp>(node->loc(), arm, bitwise_reduce_op);
};

std::vector<Node*> new_cases;
new_cases.reserve(sel->cases().size());
for (Node* c : sel->cases()) {
XLS_ASSIGN_OR_RETURN(Node * new_case, make_predicate(c));
new_cases.push_back(new_case);
}
std::optional<Node*> new_default = std::nullopt;
if (sel->default_value().has_value()) {
XLS_ASSIGN_OR_RETURN(new_default, make_predicate(*sel->default_value()));
}

XLS_ASSIGN_OR_RETURN(
Node * replacement,
sel->MakeSelectLikeWithNewArms(new_cases, new_default, node->loc()));
VLOG(2) << "Hoisting "
<< (kind == HoistKind::kCompare ? "compare" : OpToString(node->op()))
<< " through " << sel->AsNode()->op() << ": " << node->ToString();
XLS_RETURN_IF_ERROR(node->ReplaceUsesWith(replacement));
return true;
}

absl::StatusOr<bool> SimplifyNode(Node* node, const QueryEngine& query_engine,
BitProvenanceAnalysis& provenance,
int64_t opt_level, bool range_analysis) {
XLS_ASSIGN_OR_RETURN(bool hoisted_cmp,
TryHoistPredicateThroughSelectLike(node, query_engine));
if (hoisted_cmp) {
return true;
}

// Select with a constant selector can be replaced with the respective
// case.
if (node->Is<Select>() &&
Expand Down
Loading