Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xls/dslx/fmt/ast_fmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1270,6 +1270,10 @@ DocRef FmtMatchArm(const MatchArm& n, Comments& comments, DocArena& arena) {

DocRef Fmt(const Match& n, Comments& comments, DocArena& arena) {
std::vector<DocRef> pieces;
if (n.IsConst()) {
pieces.push_back(arena.Make(Keyword::kConst));
pieces.push_back(arena.space());
}
pieces.push_back(ConcatNGroup(
arena,
{arena.Make(Keyword::kMatch), arena.space(),
Expand Down
14 changes: 14 additions & 0 deletions xls/dslx/fmt/ast_fmt_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,20 @@ TEST_F(FunctionFmtTest, SimpleMatchOnBool) {
EXPECT_EQ(got, want);
}

TEST_F(FunctionFmtTest, SimpleConstMatchOnBool) {
const std::string_view original =
"fn f(b:bool)->u32{const match b{true=>u32:42,_=>u32:64}}";
XLS_ASSERT_OK_AND_ASSIGN(std::string got, DoFmt(original));
const std::string_view want =
R"(fn f(b: bool) -> u32 {
const match b {
true => u32:42,
_ => u32:64,
}
})";
EXPECT_EQ(got, want);
}

TEST_F(FunctionFmtTest, SimpleLetEqualsMatchOnBool) {
const std::string_view original =
"fn f(b:bool)->u32{let x=match b{true=>u32:42,_=>u32:64};x}";
Expand Down
6 changes: 3 additions & 3 deletions xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ std::vector<AstNode*> Match::GetChildren(bool want_types) const {
}

std::string Match::ToStringInternal() const {
std::string result = absl::StrFormat("match %s {\n", matched_->ToString());
std::string result = absl::StrFormat("%smatch %s {\n", IsConst() ? "const " : "", matched_->ToString());
for (MatchArm* arm : arms_) {
absl::StrAppend(&result, Indent(absl::StrCat(arm->ToString(), ",\n"),
kRustSpacesPerIndent));
Expand Down Expand Up @@ -2378,8 +2378,8 @@ Span MatchArm::GetPatternSpan() const {
}

Match::Match(Module* owner, Span span, Expr* matched,
std::vector<MatchArm*> arms, bool in_parens)
: Expr(owner, std::move(span), in_parens),
std::vector<MatchArm*> arms, bool in_parens, bool is_const)
: Expr(owner, std::move(span), in_parens, is_const),
matched_(matched),
arms_(std::move(arms)) {}

Expand Down
8 changes: 5 additions & 3 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,8 @@ inline bool WeakerThan(Precedence x, Precedence y) {
// (i.e. can produce runtime values).
class Expr : public AstNode {
public:
Expr(Module* owner, Span span, bool in_parens = false)
: AstNode(owner), span_(span), in_parens_(in_parens) {}
Expr(Module* owner, Span span, bool in_parens = false, bool is_const = false)
: AstNode(owner), span_(span), in_parens_(in_parens), is_const_(is_const) {}

~Expr() override;

Expand Down Expand Up @@ -1202,6 +1202,7 @@ class Expr : public AstNode {
// (x == y) == z
bool in_parens() const { return in_parens_; }
void set_in_parens(bool enabled) { in_parens_ = enabled; }
bool IsConst() const { return is_const_; }

protected:
virtual std::string ToStringInternal() const = 0;
Expand All @@ -1211,6 +1212,7 @@ class Expr : public AstNode {
private:
Span span_;
bool in_parens_ = false;
bool is_const_;
};

// ChannelTypeAnnotation has to be placed after the definition of Expr, so it
Expand Down Expand Up @@ -2633,7 +2635,7 @@ class MatchArm : public AstNode {
class Match : public Expr {
public:
Match(Module* owner, Span span, Expr* matched, std::vector<MatchArm*> arms,
bool in_parens = false);
bool in_parens = false, bool is_const = false);

~Match() override;

Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/frontend/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ class AstCloner : public AstNodeVisitor {

old_to_new_[n] = module(n)->Make<Match>(
n->span(), down_cast<Expr*>(old_to_new_.at(n->matched())), new_arms,
n->in_parens());
n->in_parens(), n->IsConst());
return absl::OkStatus();
}

Expand Down
35 changes: 30 additions & 5 deletions xls/dslx/frontend/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1989,7 +1989,10 @@ absl::StatusOr<NameDefTree*> Parser::ParsePattern(Bindings& bindings,
absl::StrFormat("Expected pattern; got %s", peek->ToErrorString()));
}

absl::StatusOr<Match*> Parser::ParseMatch(Bindings& bindings) {
absl::StatusOr<Match*> Parser::ParseMatch(Bindings& bindings, bool is_const) {
if (is_const) {
XLS_RETURN_IF_ERROR(DropKeywordOrError(Keyword::kConst));
}
XLS_ASSIGN_OR_RETURN(Token match, PopKeywordOrError(Keyword::kMatch));
XLS_ASSIGN_OR_RETURN(Expr * matched, ParseExpression(bindings));
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kOBrace));
Expand Down Expand Up @@ -2046,7 +2049,7 @@ absl::StatusOr<Match*> Parser::ParseMatch(Bindings& bindings) {
must_end = !dropped_comma;
}
Span span(match.span().start(), GetPos());
return module_->Make<Match>(span, matched, std::move(arms));
return module_->Make<Match>(span, matched, std::move(arms), false, is_const);
}

absl::StatusOr<UseTreeEntry*> Parser::ParseUseTreeEntry(Bindings& bindings) {
Expand Down Expand Up @@ -2414,12 +2417,22 @@ absl::StatusOr<Expr*> Parser::ParseTermLhs(Bindings& outer_bindings,
XLS_ASSIGN_OR_RETURN(
lhs, ParseParentheticalOrCastLhs(outer_bindings, start_pos));
} else if (peek->IsKeyword(Keyword::kMatch)) { // Match expression.
XLS_ASSIGN_OR_RETURN(lhs, ParseMatch(outer_bindings));
XLS_ASSIGN_OR_RETURN(lhs, ParseMatch(outer_bindings, false));
} else if (peek->kind() == TokenKind::kOBrack) { // Array expression.
XLS_ASSIGN_OR_RETURN(lhs, ParseArray(outer_bindings));
} else if (peek->IsKeyword(Keyword::kIf)) { // Conditional expression.
XLS_ASSIGN_OR_RETURN(lhs,
ParseRangeExpression(outer_bindings, kNoRestrictions));
} else if (peek->IsKeyword(Keyword::kConst)) {
XLS_ASSIGN_OR_RETURN(const Token* peek_1, PeekToken(1));
if (peek_1->IsKeyword(Keyword::kMatch)) { // constexpr match
XLS_ASSIGN_OR_RETURN(lhs, ParseMatch(outer_bindings, true));
} else {
return ParseErrorStatus(
peek_1->span(),
absl::StrFormat("Expected start of a const expression; got: %s",
peek_1->ToErrorString()));
}
} else {
return ParseErrorStatus(
peek->span(),
Expand Down Expand Up @@ -4131,8 +4144,7 @@ absl::StatusOr<StatementBlock*> Parser::ParseBlockExpression(
ParseTypeAlias(GetPos(), /*is_public=*/false, block_bindings));
stmts.push_back(module_->Make<Statement>(alias));
last_expr_had_trailing_semi = true;
} else if (peek->IsKeyword(Keyword::kLet) ||
peek->IsKeyword(Keyword::kConst)) {
} else if (peek->IsKeyword(Keyword::kLet)) {
XLS_ASSIGN_OR_RETURN(Let * let, ParseLet(block_bindings));
stmts.push_back(module_->Make<Statement>(let));
last_expr_had_trailing_semi = true;
Expand All @@ -4142,6 +4154,19 @@ absl::StatusOr<StatementBlock*> Parser::ParseBlockExpression(
stmts.push_back(module_->Make<Statement>(const_assert));
last_expr_had_trailing_semi = true;
} else {
// const can be a constant or a modifier
if (peek->IsKeyword(Keyword::kConst)) {
XLS_ASSIGN_OR_RETURN(const Token* peek_1, PeekToken(1));
// handle the case when const is a regular constant, otherwise
// it is a modifier and should be handled as expression with bindings
if (!peek_1->IsKeyword(Keyword::kMatch)) {
XLS_ASSIGN_OR_RETURN(Let * let, ParseLet(block_bindings));
stmts.push_back(module_->Make<Statement>(let));
last_expr_had_trailing_semi = true;
continue;
}
}

VLOG(5) << "ParseBlockExpression; parsing expression with bindings: ["
<< absl::StrJoin(block_bindings.GetLocalBindings(), ", ") << "]";
XLS_ASSIGN_OR_RETURN(Expr * e, ParseExpression(block_bindings));
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/frontend/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ class Parser : public TokenParser {
bool within_tuple_pattern);

// Parses a match expression.
absl::StatusOr<Match*> ParseMatch(Bindings& bindings);
absl::StatusOr<Match*> ParseMatch(Bindings& bindings, bool is_const);

// Parses a channel declaration.
absl::StatusOr<ChannelDecl*> ParseChannelDecl(
Expand Down
12 changes: 7 additions & 5 deletions xls/dslx/frontend/token_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,14 @@ class TokenParser {
// token is returned.
//
// Returns an error status in the case of things like scan errors.
absl::StatusOr<const Token*> PeekToken() {
if (index_ >= tokens_.size()) {
XLS_ASSIGN_OR_RETURN(Token token, scanner_->Pop());
tokens_.push_back(std::make_unique<Token>(std::move(token)));
absl::StatusOr<const Token*> PeekToken(int skip_count = 0) {
if (index_ + skip_count >= tokens_.size()) {
for (int i = 0; i < skip_count + 1; ++i) {
XLS_ASSIGN_OR_RETURN(Token token, scanner_->Pop());
tokens_.push_back(std::make_unique<Token>(std::move(token)));
}
}
return tokens_[index_].get();
return tokens_[index_ + skip_count].get();
}

// Returns a token that has been popped destructively from the token stream.
Expand Down
104 changes: 104 additions & 0 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,106 @@ absl::Status FunctionConverter::HandleBuiltinWideningCast(
return absl::OkStatus();
}

absl::StatusOr<uint64_t> FunctionConverter::ConstMatchWhichArm(const Match* node) {
ParametricEnv bindings(parametric_env_map_);
std::vector<MatchArm*> construct_match_arms;
construct_match_arms.reserve(node->arms().size());

// Construct a new Match object, which has the same `matched` and `patterns` as the original.
// Create a new expression for each arm
// so that the whole match can be evaluated to know which arm is selected.
for (int64_t i = 0; i < node->arms().size(); ++i) {
MatchArm* arm = node->arms()[i];
// create a new expression for this arm - a number with the index of the arm.
Number* expr = module_->Make<Number>(
Span::Fake(),
absl::StrFormat("%d", i),
NumberKind::kOther,
CreateU32Annotation(*module_, Span::Fake()));

current_type_info_->SetItem(expr, BitsType::MakeU32());
current_type_info_->NoteConstExpr(expr, InterpValue::MakeUBits(32, i));

construct_match_arms.push_back(
module_->Make<MatchArm>(arm->span(), arm->patterns(), expr));
}
Match *fake_match = module_->Make<Match>(node->span(), node->matched(), construct_match_arms);

XLS_ASSIGN_OR_RETURN(InterpValue interp_match,
ConstexprEvaluator::EvaluateToValue(
import_data_, current_type_info_,
kNoWarningCollector, bindings, fake_match));

XLS_ASSIGN_OR_RETURN(uint64_t arm_id, interp_match.GetBitValueUnsigned());

return arm_id;
}

bool pattern_has_namedef(const NameDefTree* pattern) {
if (pattern->is_leaf()) {
return absl::visit(
Visitor{
[&](NameDef* name_def) -> bool {
return true;
},
[&](AstNode* node) -> bool {
return false;
},
},
pattern->leaf());
} else {
return std::any_of(pattern->nodes().begin(), pattern->nodes().end(), pattern_has_namedef);
}
}

bool patterns_have_namedef(const std::vector<NameDefTree*>& patterns) {
return std::any_of(patterns.begin(), patterns.end(), pattern_has_namedef);
}

absl::Status FunctionConverter::HandleConstMatch(const Match* node) {
ParametricEnv bindings(parametric_env_map_);
std::optional<Value> matched_val;


XLS_RETURN_IF_ERROR(Visit(node->matched()));
XLS_ASSIGN_OR_RETURN(BValue matched, Use(node->matched()));
XLS_ASSIGN_OR_RETURN(InterpValue matched_const,
ConstexprEvaluator::EvaluateToValue(
import_data_, current_type_info_,
kNoWarningCollector, bindings, node->matched()));

XLS_ASSIGN_OR_RETURN(matched_val, InterpValueToValue(matched_const));
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> matched_type,
ResolveType(node->matched()));

XLS_ASSIGN_OR_RETURN(uint64_t arm_id, ConstMatchWhichArm(node));

const MatchArm* arm = node->arms()[arm_id];
bool has_namedef = patterns_have_namedef(arm->patterns());
if (!has_namedef) { // simple case when the arm's expression can be converted to IR
XLS_RETURN_IF_ERROR(Visit(node->arms()[arm_id]->expr()));
XLS_ASSIGN_OR_RETURN(BValue bval, Use(node->arms()[arm_id]->expr()));
SetNodeToIr(node, bval);
return absl::OkStatus();
} else { }

BValue final_val = function_builder_->Literal(matched_val.value());
std::vector<BValue> arm_selectors;

for (NameDefTree* pattern : arm->patterns()) {
XLS_ASSIGN_OR_RETURN(BValue selector,
HandleMatcher(pattern, final_val, *matched_type));
XLS_RET_CHECK(selector.valid());
arm_selectors.push_back(selector);
}

XLS_RETURN_IF_ERROR(Visit(arm->expr()));
XLS_ASSIGN_OR_RETURN(BValue arm_rhs_value, Use(arm->expr()));
SetNodeToIr(node, arm_rhs_value);

return absl::OkStatus();
}

absl::Status FunctionConverter::HandleMatch(const Match* node) {
if (node->arms().empty()) {
return IrConversionErrorStatus(
Expand All @@ -1291,6 +1391,10 @@ absl::Status FunctionConverter::HandleMatch(const Match* node) {
file_table());
}

if (node->IsConst()) {
return HandleConstMatch(node);
}

XLS_RETURN_IF_ERROR(Visit(node->matched()));
XLS_ASSIGN_OR_RETURN(BValue matched, Use(node->matched()));
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> matched_type,
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/ir_convert/function_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ class FunctionConverter {
->GetInvocationCalleeBindings(invocation, key);
}

// Helper to evaluate which arm of a const match should be used.
absl::StatusOr<uint64_t> ConstMatchWhichArm(const Match* node);

// Helpers for HandleBinop().
absl::Status HandleConcat(const Binop* node, BValue lhs, BValue rhs);
absl::Status HandleEq(const Binop* node, BValue lhs, BValue rhs);
Expand Down Expand Up @@ -423,6 +426,7 @@ class FunctionConverter {
absl::Status HandleLet(const Let* node);
absl::Status HandleLetChannelDecl(const Let* node);
absl::Status HandleMatch(const Match* node);
absl::Status HandleConstMatch(const Match* node);
absl::Status HandleRange(const Range* node);
absl::Status HandleSplatStructInstance(const SplatStructInstance* node);
absl::Status HandleStatement(const Statement* node);
Expand Down
14 changes: 14 additions & 0 deletions xls/dslx/type_system_v2/typecheck_module_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3339,6 +3339,20 @@ fn repro(x: u3) -> u2 {
TypecheckSucceeds(HasNodeWithType("upper", "uN[2]")));
}

TEST(TypecheckV2Test, ConstMatch) {
EXPECT_THAT(R"(
fn main(a: u32, b: u32) -> u32 {
const A = true;
let retval = const match A {
true => a,
false => b
};
retval
}
)",
TypecheckSucceeds(HasNodeWithType("retval", "uN[32]")));

}
TEST(TypecheckV2Test, MatchMismatch) {
EXPECT_THAT(R"(
const X = u32:1;
Expand Down
27 changes: 27 additions & 0 deletions xls/examples/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1412,3 +1412,30 @@ build_test(
name = "xls_pipeline_build_test",
targets = [":xls_pipeline"],
)

xls_dslx_library(
name = "const_match_dslx",
srcs = ["const_match.x"],
)

xls_dslx_test(
name = "const_match_test",
size = "small",
srcs = ["const_match.x"],
dslx_test_args = {"compare": "jit"},
)

xls_dslx_ir(
name = "const_match_ir",
dslx_top = "main",
ir_conv_args = {"lower_to_proc_scoped_channels": "true"},
ir_file = "const_match.ir",
library = ":const_match_dslx",
)

xls_dslx_opt_ir(
name = "const_match_opt_ir",
srcs = ["const_match.x"],
ir_conv_args = {"lower_to_proc_scoped_channels": "true"},
dslx_top = "main",
)
Loading