Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions xls/dslx/fmt/ast_fmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3125,6 +3125,11 @@ absl::StatusOr<DocRef> Formatter::Format(const Module& n) {
pieces.push_back(arena_.MakeText("#![feature(channel_attributes)]"));
pieces.push_back(arena_.hard_line());
break;
case ModuleAttribute::kExplicitStateAccess:
pieces.push_back(
arena_.MakeText("#![feature(explicit_state_access)]"));
pieces.push_back(arena_.hard_line());
break;
case ModuleAttribute::kGenerics:
pieces.push_back(arena_.MakeText("#![feature(generics)]"));
pieces.push_back(arena_.hard_line());
Expand Down
8 changes: 8 additions & 0 deletions xls/dslx/frontend/builtin_stubs.x
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ fn gate!<T: type>(x: u1, y: T) -> T;
// `join` can take zero or more `token` arguments; the varargs are handled by the type checker.
fn join(t: token) -> token;

fn labeled_read<T: type, N: u32, M: u32>(source: T, labels: u8[N]) -> (T);

fn labeled_write<T: type, N: u32>(dest: T, value: T, label: u8[N]) -> ();

fn one_hot<N: u32, M: u32={N + 1}>(x: uN[N], y: u1) -> uN[M];

fn one_hot_sel<N: u32, M: u32, S: bool>(x: uN[N], y: xN[S][M][N]) -> xN[S][M];
Expand All @@ -71,6 +75,8 @@ fn or_reduce<N: u32>(x: uN[N]) -> u1;

fn priority_sel<N: u32, M: u32, S: bool>(x: uN[N], y: xN[S][M][N], z: xN[S][M]) -> xN[S][M];

fn read<T: type>(source: T) -> T;

fn recv_if_non_blocking<T: type>(tok: token, channel: chan<T> in, predicate: bool, value: T) -> (token, T, bool);

fn recv_if<T: type>(tok: token, channel: chan<T> in, predicate: bool, value: T) -> (token, T);
Expand Down Expand Up @@ -104,6 +110,8 @@ fn umulp<N: u32>(x: uN[N], y: uN[N]) -> (uN[N], uN[N]);

fn widening_cast<DEST: type, SRC: type>(x: SRC) -> DEST;

fn write<T: type>(dest: T, value: T) -> ();

fn xor_reduce<N: u32>(x: uN[N]) -> u1;

fn zip<LHS_TYPE: type, N: u32, RHS_TYPE: type>(lhs: LHS_TYPE[N], rhs: RHS_TYPE[N]) ->
Expand Down
5 changes: 5 additions & 0 deletions xls/dslx/frontend/builtins_metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ const absl::flat_hash_map<std::string, BuiltinsData>& GetParametricBuiltins() {

// -- Proc-oriented built-ins.

{"read", {.signature = "(T) -> T"}},
{"write", {.signature = "(T, T) -> ()"}},
{"labeled_read", {.signature = "(T, u8[N]) -> (T)"}},
{"labeled_write", {.signature = "(T, T, u8[N]) -> ()"}},

// send/recv (communication) builtins that can only be used within
// proc scope.
{"send", {.signature = "(token, send_chan<T>, T) -> token"}},
Expand Down
3 changes: 3 additions & 0 deletions xls/dslx/frontend/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ std::string Module::ToString() const {
case ModuleAttribute::kChannelAttributes:
absl::StrAppend(out, "#![feature(channel_attributes)]");
break;
case ModuleAttribute::kExplicitStateAccess:
absl::StrAppend(out, "#![feature(explicit_state_access)]");
break;
case ModuleAttribute::kGenerics:
absl::StrAppend(out, "#![feature(generics)]");
break;
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/frontend/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ enum class ModuleAttribute : uint8_t {
// Enable #channel() attributes for this module.
kChannelAttributes,

// Enable read and write with optional labels for state access for this
// module.
kExplicitStateAccess,

kGenerics,

// Enable `trait` declarations.
Expand Down
37 changes: 36 additions & 1 deletion xls/dslx/frontend/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ absl::Status Parser::ParseModuleAttribute() {
} else if (feature == "channel_attributes") {
module_->AddAttribute(ModuleAttribute::kChannelAttributes,
attribute_span);
} else if (feature == "explicit_state_access") {
module_->AddAttribute(ModuleAttribute::kExplicitStateAccess,
attribute_span);
} else if (feature == "generics") {
module_->AddAttribute(ModuleAttribute::kGenerics, attribute_span);
} else if (feature == "traits") {
Expand Down Expand Up @@ -3096,6 +3099,31 @@ absl::StatusOr<Function*> Parser::ParseProcConfig(
return config;
}

absl::StatusOr<Function*> Parser::ParseProcNextExplicitStateAccess(
std::vector<Param*> next_params, std::string_view proc_name,
std::vector<ParametricBinding*> parametric_bindings, Token oparen,
Bindings& inner_bindings, bool is_public) {
for (Param* p : next_params) {
if (HasChannelElement(p->type_annotation())) {
return ParseErrorStatus(p->span(),
"Channels cannot be Proc next params.");
}
}
TypeAnnotation* return_type = module_->Make<TupleTypeAnnotation>(
Span(GetPos(), GetPos()), std::vector<TypeAnnotation*>{});
XLS_ASSIGN_OR_RETURN(StatementBlock * body,
ParseBlockExpression(inner_bindings));
Span span(oparen.span().start(), GetPos());
NameDef* name_def =
module_->Make<NameDef>(span, absl::StrCat(proc_name, ".next"), nullptr);
Function* next = module_->Make<Function>(
span, name_def, std::move(parametric_bindings), next_params, return_type,
body, FunctionTag::kProcNext, is_public,
/*is_stub=*/false);
name_def->set_definer(next);
return next;
}

absl::StatusOr<Function*> Parser::ParseProcNext(
Bindings& bindings, std::vector<ParametricBinding*> parametric_bindings,
std::string_view proc_name, bool is_public) {
Expand All @@ -3116,6 +3144,11 @@ absl::StatusOr<Function*> Parser::ParseProcNext(

XLS_ASSIGN_OR_RETURN(std::vector<Param*> next_params,
ParseCommaSeq<Param*>(parse_param, TokenKind::kCParen));
if (module_->attributes().contains(ModuleAttribute::kExplicitStateAccess)) {
return ParseProcNextExplicitStateAccess(next_params, proc_name,
parametric_bindings, oparen,
inner_bindings, is_public);
}

if (next_params.size() != 1) {
std::string next_params_str =
Expand All @@ -3135,9 +3168,11 @@ absl::StatusOr<Function*> Parser::ParseProcNext(
"Channels cannot be Proc next params.");
}

XLS_ASSIGN_OR_RETURN(TypeAnnotation * return_type,
TypeAnnotation* return_type;
XLS_ASSIGN_OR_RETURN(return_type,
CloneNode(state_param->type_annotation(),
&PreserveTypeDefinitionsReplacer));

XLS_ASSIGN_OR_RETURN(StatementBlock * body,
ParseBlockExpression(inner_bindings));
Span span(oparen.span().start(), GetPos());
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/frontend/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,10 @@ class Parser : public TokenParser {
Bindings& bindings, std::vector<ParametricBinding*> parametric_bindings,
const std::vector<ProcMember*>& proc_members, std::string_view proc_name,
bool is_public);
absl::StatusOr<Function*> ParseProcNextExplicitStateAccess(
std::vector<Param*> next_params, std::string_view proc_name,
std::vector<ParametricBinding*> parametric_bindings, Token oparen,
Bindings& inner_bindings, bool is_public);

absl::StatusOr<Function*> ParseProcNext(
Bindings& bindings, std::vector<ParametricBinding*> parametric_bindings,
Expand Down
73 changes: 73 additions & 0 deletions xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,50 @@ TEST_F(ParserTest, ParseRecvIfNb) {
RoundTrip(std::string(kModule));
}

TEST_F(ParserTest, ParseMultipleNextArgsExplicitStateAccess) {
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]

struct Point {
x: u32,
y: u32,
}
proc p {
config() {}
init {
()
}
next(state0: Point, state1: Point) {
let x = labeled_read(state0.x, ["label_x"]);
let y = labeled_read(state1.y, ["label_y"]);
}
})";

RoundTrip(std::string(kModule));
}

TEST_F(ParserTest, ParseReadWriteExplicitStateAccess) {
constexpr std::string_view kModule = R"(#![feature(explicit_state_access)]

struct Point {
x: u32,
y: u32,
}
proc p {
config() {}
init {
()
}
next(state0: Point) {
let x = labeled_read(state0, ["label_x"]);
labeled_write(state0.x, x, "label_x");
let y = read(state0.y);
write(state0.y, y);
}
})";

RoundTrip(std::string(kModule));
}

TEST_F(ParserTest, ParseJoin) {
constexpr std::string_view kModule = R"(proc foo {
c0: chan<u32> out;
Expand Down Expand Up @@ -3451,6 +3495,35 @@ TEST_F(ParserTest, ParseTypeInferenceV1AndV2Attributes) {
"and `type_inference_v2` attributes")));
}

TEST_F(ParserTest, ParseExplicitStateAccessAttribute) {
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Module> module, Parse(R"(
#![feature(explicit_state_access)]
)"));
EXPECT_THAT(module->attributes(),
testing::ElementsAre(ModuleAttribute::kExplicitStateAccess));
}

TEST_F(ParserTest, ExplicitStateAccessProcNextReturnsEmptyTuple) {
constexpr std::string_view kProgram = R"(#![feature(explicit_state_access)]
proc simple {
config() {
()
}
init {
u32:0
}
next(state: u32) {
state
}
})";
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Module> module, Parse(kProgram));
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc,
module->GetMemberOrError<Proc>("simple"));
const Function& next = proc->next();
ASSERT_NE(next.return_type(), nullptr);
EXPECT_EQ(next.return_type()->ToString(), "()");
}

// Verifies that we can walk backwards through a tree. In this case, from the
// terminal node to the defining expr.
TEST(ParserBackrefTest, CanFindDefiner) {
Expand Down