diff --git a/xls/dslx/fmt/ast_fmt.cc b/xls/dslx/fmt/ast_fmt.cc index 07f8d6bafd..2cdf6eaa27 100644 --- a/xls/dslx/fmt/ast_fmt.cc +++ b/xls/dslx/fmt/ast_fmt.cc @@ -3125,6 +3125,11 @@ absl::StatusOr Formatter::Format(const Module& n) { pieces.push_back(arena_.MakeText("#![feature(channel_attributes)]")); pieces.push_back(arena_.hard_line()); break; + case ModuleAttribute::kExplicitStateAccess: + pieces.push_back( + arena_.MakeText("#![feature(explicit_state_access)]")); + pieces.push_back(arena_.hard_line()); + break; case ModuleAttribute::kGenerics: pieces.push_back(arena_.MakeText("#![feature(generics)]")); pieces.push_back(arena_.hard_line()); diff --git a/xls/dslx/frontend/builtin_stubs.x b/xls/dslx/frontend/builtin_stubs.x index 1c1938928c..f3cb96a051 100644 --- a/xls/dslx/frontend/builtin_stubs.x +++ b/xls/dslx/frontend/builtin_stubs.x @@ -63,6 +63,10 @@ fn gate!(x: u1, y: T) -> T; // `join` can take zero or more `token` arguments; the varargs are handled by the type checker. fn join(t: token) -> token; +fn labeled_read(source: T, labels: u8[N]) -> (T); + +fn labeled_write(dest: T, value: T, label: u8[N]) -> (); + fn one_hot(x: uN[N], y: u1) -> uN[M]; fn one_hot_sel(x: uN[N], y: xN[S][M][N]) -> xN[S][M]; @@ -71,6 +75,8 @@ fn or_reduce(x: uN[N]) -> u1; fn priority_sel(x: uN[N], y: xN[S][M][N], z: xN[S][M]) -> xN[S][M]; +fn read(source: T) -> T; + fn recv_if_non_blocking(tok: token, channel: chan in, predicate: bool, value: T) -> (token, T, bool); fn recv_if(tok: token, channel: chan in, predicate: bool, value: T) -> (token, T); @@ -104,6 +110,8 @@ fn umulp(x: uN[N], y: uN[N]) -> (uN[N], uN[N]); fn widening_cast(x: SRC) -> DEST; +fn write(dest: T, value: T) -> (); + fn xor_reduce(x: uN[N]) -> u1; fn zip(lhs: LHS_TYPE[N], rhs: RHS_TYPE[N]) -> diff --git a/xls/dslx/frontend/builtins_metadata.cc b/xls/dslx/frontend/builtins_metadata.cc index 27dc1f04b8..31602d0f69 100644 --- a/xls/dslx/frontend/builtins_metadata.cc +++ b/xls/dslx/frontend/builtins_metadata.cc @@ -143,6 +143,11 @@ const absl::flat_hash_map& GetParametricBuiltins() { // -- Proc-oriented built-ins. + {"read", {.signature = "(T) -> T"}}, + {"write", {.signature = "(T, T) -> ()"}}, + {"labeled_read", {.signature = "(T, u8[N]) -> (T)"}}, + {"labeled_write", {.signature = "(T, T, u8[N]) -> ()"}}, + // send/recv (communication) builtins that can only be used within // proc scope. {"send", {.signature = "(token, send_chan, T) -> token"}}, diff --git a/xls/dslx/frontend/module.cc b/xls/dslx/frontend/module.cc index b6fdefa5c0..d9a620ad72 100644 --- a/xls/dslx/frontend/module.cc +++ b/xls/dslx/frontend/module.cc @@ -94,6 +94,9 @@ std::string Module::ToString() const { case ModuleAttribute::kChannelAttributes: absl::StrAppend(out, "#![feature(channel_attributes)]"); break; + case ModuleAttribute::kExplicitStateAccess: + absl::StrAppend(out, "#![feature(explicit_state_access)]"); + break; case ModuleAttribute::kGenerics: absl::StrAppend(out, "#![feature(generics)]"); break; diff --git a/xls/dslx/frontend/module.h b/xls/dslx/frontend/module.h index 185a769899..818ccf4af8 100644 --- a/xls/dslx/frontend/module.h +++ b/xls/dslx/frontend/module.h @@ -82,6 +82,10 @@ enum class ModuleAttribute : uint8_t { // Enable #channel() attributes for this module. kChannelAttributes, + // Enable read and write with optional labels for state access for this + // module. + kExplicitStateAccess, + kGenerics, // Enable `trait` declarations. diff --git a/xls/dslx/frontend/parser.cc b/xls/dslx/frontend/parser.cc index 11a42e538d..f9cfe1c717 100644 --- a/xls/dslx/frontend/parser.cc +++ b/xls/dslx/frontend/parser.cc @@ -311,6 +311,9 @@ absl::Status Parser::ParseModuleAttribute() { } else if (feature == "channel_attributes") { module_->AddAttribute(ModuleAttribute::kChannelAttributes, attribute_span); + } else if (feature == "explicit_state_access") { + module_->AddAttribute(ModuleAttribute::kExplicitStateAccess, + attribute_span); } else if (feature == "generics") { module_->AddAttribute(ModuleAttribute::kGenerics, attribute_span); } else if (feature == "traits") { @@ -3096,6 +3099,31 @@ absl::StatusOr Parser::ParseProcConfig( return config; } +absl::StatusOr Parser::ParseProcNextExplicitStateAccess( + std::vector next_params, std::string_view proc_name, + std::vector parametric_bindings, Token oparen, + Bindings& inner_bindings, bool is_public) { + for (Param* p : next_params) { + if (HasChannelElement(p->type_annotation())) { + return ParseErrorStatus(p->span(), + "Channels cannot be Proc next params."); + } + } + TypeAnnotation* return_type = module_->Make( + Span(GetPos(), GetPos()), std::vector{}); + XLS_ASSIGN_OR_RETURN(StatementBlock * body, + ParseBlockExpression(inner_bindings)); + Span span(oparen.span().start(), GetPos()); + NameDef* name_def = + module_->Make(span, absl::StrCat(proc_name, ".next"), nullptr); + Function* next = module_->Make( + span, name_def, std::move(parametric_bindings), next_params, return_type, + body, FunctionTag::kProcNext, is_public, + /*is_stub=*/false); + name_def->set_definer(next); + return next; +} + absl::StatusOr Parser::ParseProcNext( Bindings& bindings, std::vector parametric_bindings, std::string_view proc_name, bool is_public) { @@ -3116,6 +3144,11 @@ absl::StatusOr Parser::ParseProcNext( XLS_ASSIGN_OR_RETURN(std::vector next_params, ParseCommaSeq(parse_param, TokenKind::kCParen)); + if (module_->attributes().contains(ModuleAttribute::kExplicitStateAccess)) { + return ParseProcNextExplicitStateAccess(next_params, proc_name, + parametric_bindings, oparen, + inner_bindings, is_public); + } if (next_params.size() != 1) { std::string next_params_str = @@ -3135,9 +3168,11 @@ absl::StatusOr Parser::ParseProcNext( "Channels cannot be Proc next params."); } - XLS_ASSIGN_OR_RETURN(TypeAnnotation * return_type, + TypeAnnotation* return_type; + XLS_ASSIGN_OR_RETURN(return_type, CloneNode(state_param->type_annotation(), &PreserveTypeDefinitionsReplacer)); + XLS_ASSIGN_OR_RETURN(StatementBlock * body, ParseBlockExpression(inner_bindings)); Span span(oparen.span().start(), GetPos()); diff --git a/xls/dslx/frontend/parser.h b/xls/dslx/frontend/parser.h index 645dc66e89..1fb7fbc2ae 100644 --- a/xls/dslx/frontend/parser.h +++ b/xls/dslx/frontend/parser.h @@ -692,6 +692,10 @@ class Parser : public TokenParser { Bindings& bindings, std::vector parametric_bindings, const std::vector& proc_members, std::string_view proc_name, bool is_public); + absl::StatusOr ParseProcNextExplicitStateAccess( + std::vector next_params, std::string_view proc_name, + std::vector parametric_bindings, Token oparen, + Bindings& inner_bindings, bool is_public); absl::StatusOr ParseProcNext( Bindings& bindings, std::vector parametric_bindings, diff --git a/xls/dslx/frontend/parser_test.cc b/xls/dslx/frontend/parser_test.cc index f8ea1fe828..8348a50234 100644 --- a/xls/dslx/frontend/parser_test.cc +++ b/xls/dslx/frontend/parser_test.cc @@ -1539,6 +1539,50 @@ TEST_F(ParserTest, ParseRecvIfNb) { RoundTrip(std::string(kModule)); } +TEST_F(ParserTest, ParseMultipleNextArgsExplicitStateAccess) { + constexpr std::string_view kModule = R"(#![feature(explicit_state_access)] + +struct Point { + x: u32, + y: u32, +} +proc p { + config() {} + init { + () + } + next(state0: Point, state1: Point) { + let x = labeled_read(state0.x, ["label_x"]); + let y = labeled_read(state1.y, ["label_y"]); + } +})"; + + RoundTrip(std::string(kModule)); +} + +TEST_F(ParserTest, ParseReadWriteExplicitStateAccess) { + constexpr std::string_view kModule = R"(#![feature(explicit_state_access)] + +struct Point { + x: u32, + y: u32, +} +proc p { + config() {} + init { + () + } + next(state0: Point) { + let x = labeled_read(state0, ["label_x"]); + labeled_write(state0.x, x, "label_x"); + let y = read(state0.y); + write(state0.y, y); + } +})"; + + RoundTrip(std::string(kModule)); +} + TEST_F(ParserTest, ParseJoin) { constexpr std::string_view kModule = R"(proc foo { c0: chan out; @@ -3451,6 +3495,35 @@ TEST_F(ParserTest, ParseTypeInferenceV1AndV2Attributes) { "and `type_inference_v2` attributes"))); } +TEST_F(ParserTest, ParseExplicitStateAccessAttribute) { + XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Parse(R"( +#![feature(explicit_state_access)] +)")); + EXPECT_THAT(module->attributes(), + testing::ElementsAre(ModuleAttribute::kExplicitStateAccess)); +} + +TEST_F(ParserTest, ExplicitStateAccessProcNextReturnsEmptyTuple) { + constexpr std::string_view kProgram = R"(#![feature(explicit_state_access)] +proc simple { + config() { + () + } + init { + u32:0 + } + next(state: u32) { + state + } +})"; + XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, Parse(kProgram)); + XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, + module->GetMemberOrError("simple")); + const Function& next = proc->next(); + ASSERT_NE(next.return_type(), nullptr); + EXPECT_EQ(next.return_type()->ToString(), "()"); +} + // Verifies that we can walk backwards through a tree. In this case, from the // terminal node to the defining expr. TEST(ParserBackrefTest, CanFindDefiner) {