diff --git a/xls/contrib/xlscc/BUILD b/xls/contrib/xlscc/BUILD index 56607eb297..607fa7fc9c 100644 --- a/xls/contrib/xlscc/BUILD +++ b/xls/contrib/xlscc/BUILD @@ -160,6 +160,7 @@ cc_library( "@com_google_absl//absl/container:btree", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", @@ -253,6 +254,7 @@ cc_library( "//xls/passes:inlining_pass", "//xls/passes:node_source_analysis", "//xls/passes:optimization_pass", + "//xls/passes:optimization_pass_pipeline", "//xls/passes:partial_info_query_engine", "//xls/passes:pass_base", "//xls/solvers:z3_ir_translator", diff --git a/xls/contrib/xlscc/continuations.cc b/xls/contrib/xlscc/continuations.cc index 62a253b0d3..b9f4e21c67 100644 --- a/xls/contrib/xlscc/continuations.cc +++ b/xls/contrib/xlscc/continuations.cc @@ -113,11 +113,59 @@ std::string GraphvizEscape(std::string_view s) { absl::StrReplaceAll(label, {{"\"", "\\\""}})); }; +absl::Status ValidateContinuations(GeneratedFunction& func, + const xls::SourceInfo& loc) { + absl::flat_hash_map + slice_index_by_continuation_out; + absl::flat_hash_map + slice_index_by_slice; + + for (GeneratedFunctionSlice& slice : func.slices) { + const int64_t slice_index = slice_index_by_slice.size(); + slice_index_by_slice[&slice] = slice_index; + for (const ContinuationValue& continuation_out : slice.continuations_out) { + slice_index_by_continuation_out[&continuation_out] = slice_index; + } + } + + for (GeneratedFunctionSlice& slice : func.slices) { + const int64_t slice_index = slice_index_by_slice.at(&slice); + + absl::flat_hash_map + num_upstream_inputs_by_param; + + for (const ContinuationInput& continuation_in : slice.continuations_in) { + const int64_t upstream_slice_index = + slice_index_by_continuation_out.at(continuation_in.continuation_out); + + const bool is_feedback = slice_index <= upstream_slice_index; + int64_t& num_upstream_inputs_for_param = + num_upstream_inputs_by_param[continuation_in.input_node]; + + if (!is_feedback) { + ++num_upstream_inputs_for_param; + } + } + + for (const auto& [param, num_upstream_inputs] : + num_upstream_inputs_by_param) { + if (num_upstream_inputs != 1) { + return absl::InvalidArgumentError(absl::StrFormat( + "Param %s to slice %s has %i upstream inputs, should have exactly " + "1", + param->name(), slice.function->name(), num_upstream_inputs)); + } + } + } + return absl::OkStatus(); +} + } // namespace SourcesSetNodeInfo::SourcesSetNodeInfo() : xls::DataFlowLazyNodeInfo( - /*compute_tree_for_source=*/false, /*default_to_leaf=*/false) {} + /*compute_tree_for_source=*/false, /*default_to_leaf=*/false, + /*include_selectors=*/true) {} ParamSet SourcesSetNodeInfo::ComputeInfoForBitsLiteral( const xls::Bits& literal) const { @@ -149,7 +197,8 @@ ParamSet SourcesSetNodeInfo::MergeInfos( SourcesSetTreeNodeInfo::SourcesSetTreeNodeInfo() : xls::DataFlowLazyNodeInfo( - /*compute_tree_for_source=*/true, /*default_to_leaf=*/true) {} + /*compute_tree_for_source=*/true, /*default_to_leaf=*/true, + /*include_selectors=*/false) {} NodeSourceSet SourcesSetTreeNodeInfo::ComputeInfoForBitsLiteral( const xls::Bits& literal) const { @@ -294,6 +343,7 @@ Translator::ConvertBValuesToContinuationOutputsForCurrentSlice( if (!cval.rvalue().valid()) { continue; } + decls_by_bval_top_context[&cval.rvalue()] = decl; } @@ -322,7 +372,8 @@ Translator::ConvertBValuesToContinuationOutputsForCurrentSlice( CHECK(!continuation_outputs_by_bval.contains(bval)); continuation_outputs_by_bval[bval] = &new_continuation; if (decls_by_bval_top_context.contains(bval)) { - new_continuation.decls.insert(decls_by_bval_top_context.at(bval)); + new_continuation.decls.insert( + DeclLeaf{.decl = decls_by_bval_top_context.at(bval)}); } } } @@ -787,19 +838,328 @@ absl::Status Translator::RemoveMaskedOpParams(GeneratedFunction& func, return absl::OkStatus(); } +absl::Status Translator::DecomposeContinuationValues( + GeneratedFunction& func, bool& changed, const xls::SourceInfo& loc) { + absl::flat_hash_map + original_output_decomposable; + + absl::flat_hash_map params_decomposable; + + absl::flat_hash_map> + params_by_output; + + { + for (GeneratedFunctionSlice& slice : func.slices) { + // Ensure params_by_output is initialized for all outputs + for (ContinuationValue& continuation_out : slice.continuations_out) { + params_by_output[&continuation_out] = {}; + } + } + + // Ensure params_decomposable is initialized + for (GeneratedFunctionSlice& slice : func.slices) { + for (ContinuationInput& continuation_in : slice.continuations_in) { + params_by_output[continuation_in.continuation_out].push_back( + continuation_in.input_node); + params_decomposable[continuation_in.input_node] = true; + } + } + + for (GeneratedFunctionSlice& slice : func.slices) { + for (ContinuationValue& continuation_out : slice.continuations_out) { + bool decomposable = + TypeIsDecomposable(continuation_out.output_node->GetType()); + + if (continuation_out.direct_in) { + decomposable = false; + } + + original_output_decomposable[&continuation_out] = decomposable; + } + } + + // Iteratively mark things not decomposable + for (bool iter_changed = true; iter_changed;) { + iter_changed = false; + + for (GeneratedFunctionSlice& slice : func.slices) { + for (ContinuationInput& continuation_in : slice.continuations_in) { + if (!params_decomposable.at(continuation_in.input_node)) { + continue; + } + if (!original_output_decomposable.at( + continuation_in.continuation_out)) { + params_decomposable[continuation_in.input_node] = false; + iter_changed = true; + } + } + } + + for (GeneratedFunctionSlice& slice : func.slices) { + for (ContinuationValue& continuation_out : slice.continuations_out) { + if (!original_output_decomposable.at(&continuation_out)) { + continue; + } + + // The output is not decomposable if it feeds any parameter fed by a + // direct-in + for (const xls::Param* feeds_param : + params_by_output.at(&continuation_out)) { + if (!params_decomposable.at(feeds_param)) { + original_output_decomposable[&continuation_out] = false; + iter_changed = true; + break; + } + } + } + } + } + } + + absl::flat_hash_map> + decomposed_cont_values_by_original_output; + + for (GeneratedFunctionSlice& slice : func.slices) { + absl::InlinedVector new_returns; + + // Last slice has no continuation outputs + if (&slice == &func.slices.back()) { + XLSCC_CHECK(slice.continuations_out.empty(), loc); + break; + } + + xls::Node* return_tuple = slice.function->return_value(); + XLSCC_CHECK(return_tuple->Is(), loc); + const int64_t extra_returns = + return_tuple->operand_count() - slice.continuations_out.size(); + XLSCC_CHECK_GE(extra_returns, 0, loc); + + // NOTE: This iterates over the output continuations while adding more at + // the end + std::list original_continuations_out; + for (ContinuationValue& continuation_out : slice.continuations_out) { + original_continuations_out.push_back(&continuation_out); + } + + for (ContinuationValue* continuation_out_ptr : original_continuations_out) { + ContinuationValue& continuation_out = *continuation_out_ptr; + + // Don't reference the output identity, so that it can be removed later. + XLSCC_CHECK_EQ(continuation_out.output_node->op(), xls::Op::kIdentity, + loc); + xls::Node* output_source = continuation_out.output_node->operand(0); + + if (!original_output_decomposable.at(&continuation_out)) { + continue; + } + + if (!TypeIsDecomposable(output_source->GetType())) { + continue; + } + + for (const xls::Param* param : params_by_output.at(&continuation_out)) { + XLSCC_CHECK(params_decomposable.at(param), loc); + } + + absl::InlinedVector decomposed_nodes; + XLS_ASSIGN_OR_RETURN(decomposed_nodes, DecomposeTuples(output_source)); + + absl::InlinedVector decomposed_literals; + if (continuation_out.literal.has_value()) { + XLS_ASSIGN_OR_RETURN(decomposed_literals, + DecomposeValue(output_source->GetType(), + *continuation_out.literal)); + } + + // Create decomposed outputs + absl::InlinedVector decomposed_cont_values; + + for (int64_t di = 0; di < decomposed_nodes.size(); ++di) { + xls::Node* node = decomposed_nodes.at(di); + ContinuationValue new_continuation_out = continuation_out; + + if (continuation_out.literal.has_value()) { + new_continuation_out.literal = decomposed_literals.at(di); + } + + absl::flat_hash_set new_decls; + for (const DeclLeaf& decl : new_continuation_out.decls) { + new_decls.insert(DeclLeaf{.decl = decl.decl, .leaf_index = di}); + } + new_continuation_out.decls = new_decls; + + XLS_ASSIGN_OR_RETURN( + xls::UnOp * output_identity, + slice.function->MakeNodeWithName( + loc, node, xls::Op::kIdentity, + /*name=*/absl::StrFormat("%s_ident", node->GetName()))); + new_continuation_out.output_node = output_identity; + new_returns.push_back(new_continuation_out.output_node); + slice.continuations_out.push_back(new_continuation_out); + decomposed_cont_values.push_back(&slice.continuations_out.back()); + changed = true; + } + + decomposed_cont_values_by_original_output[&continuation_out] = + decomposed_cont_values; + } + + // Update return value + if (!new_returns.empty()) { + std::vector all_returns; + all_returns.insert(all_returns.end(), return_tuple->operands().begin(), + return_tuple->operands().begin() + + return_tuple->operands().size() - extra_returns); + all_returns.insert(all_returns.end(), new_returns.begin(), + new_returns.end()); + + CHECK_EQ(all_returns.size(), slice.continuations_out.size()); + + for (int64_t i = return_tuple->operand_count() - extra_returns; + i < return_tuple->operand_count(); ++i) { + all_returns.push_back(return_tuple->operand(i)); + } + + XLS_ASSIGN_OR_RETURN( + xls::Node * new_return_node, + slice.function->MakeNode(loc, all_returns)); + XLS_RETURN_IF_ERROR(slice.function->set_return_value(new_return_node)); + XLS_RETURN_IF_ERROR(slice.function->RemoveNode(return_tuple)); + changed = true; + } + } + + // To accomodate phis, first decompose all slices' outputs, then all slices' + // inputs + for (GeneratedFunctionSlice& slice : func.slices) { + absl::flat_hash_map> + decomposed_params_by_original; + absl::flat_hash_set all_cont_params; + std::vector all_params_decomposed; + + for (ContinuationInput& continuation_in : slice.continuations_in) { + all_cont_params.insert(continuation_in.input_node->As()); + } + + // Add decomposed continuation inputs + // NOTE: This loop iterates over continuation inputs while adding more at + // the end. + std::list original_continuations_in; + for (ContinuationInput& continuation_in : slice.continuations_in) { + original_continuations_in.push_back(&continuation_in); + } + for (ContinuationInput* continuation_in_ptr : original_continuations_in) { + ContinuationInput& continuation_in = *continuation_in_ptr; + XLSCC_CHECK_EQ(continuation_in.input_node->op(), xls::Op::kParam, loc); + xls::Param* input_param = continuation_in.input_node->As(); + + if (!params_decomposable.at(input_param)) { + continue; + } + + CHECK(TypeIsDecomposable(input_param->GetType())); + CHECK(!continuation_in.continuation_out->direct_in); + + CHECK(original_output_decomposable.at(continuation_in.continuation_out)); + + const absl::InlinedVector& decomposed_cont_values = + decomposed_cont_values_by_original_output.at( + continuation_in.continuation_out); + + // Create params to replace original param, if not done already + if (!decomposed_params_by_original.contains(input_param)) { + decomposed_params_by_original[input_param] = {}; + + all_params_decomposed.push_back(input_param); + + for (int64_t di = 0; di < decomposed_cont_values.size(); ++di) { + const ContinuationValue& decomposed_value = + *decomposed_cont_values.at(di); + xls::Type* type = decomposed_value.output_node->GetType(); + XLSCC_CHECK(!TypeIsDecomposable(type), loc); + + std::string decomposed_name = + absl::StrFormat("%s_%d", input_param->GetName(), di); + + xls::Node* decomposed_param_node = + slice.function->AddNode(std::make_unique( + loc, decomposed_value.output_node->GetType(), + /*name=*/decomposed_name, slice.function)); + + xls::Param* decomposed_param = + decomposed_param_node->As(); + + XLS_RETURN_IF_ERROR(slice.function->MoveParamToIndex( + decomposed_param, all_cont_params.size())); + + decomposed_params_by_original[input_param].push_back( + decomposed_param); + all_cont_params.insert(decomposed_param); + + changed = true; + } + } + + // Create inputs to replace original input. + const absl::InlinedVector& decomposed_params = + decomposed_params_by_original.at(input_param); + + for (int64_t di = 0; di < decomposed_cont_values.size(); ++di) { + ContinuationValue* decomposed_cont_value = + decomposed_cont_values.at(di); + + std::string decomposed_name = + absl::StrFormat("%s_%d", continuation_in.name, di); + + ContinuationInput new_continuation_in = continuation_in; + new_continuation_in.continuation_out = decomposed_cont_value; + new_continuation_in.input_node = decomposed_params.at(di); + new_continuation_in.name = decomposed_name; + new_continuation_in.decls = decomposed_cont_value->decls; + + slice.continuations_in.push_back(new_continuation_in); + changed = true; + } + } + + // Replace uses of original param with a tuple of new params. + for (xls::Param* orig_param : all_params_decomposed) { + const absl::InlinedVector& decomposed_params = + decomposed_params_by_original.at(orig_param); + + absl::InlinedVector decomposed_param_nodes; + decomposed_param_nodes.reserve(decomposed_params.size()); + for (xls::Param* param : decomposed_params) { + decomposed_param_nodes.push_back(param); + } + + XLS_ASSIGN_OR_RETURN( + xls::Node * param_replace_tuple, + ComposeTuples(orig_param->GetName(), orig_param->GetType(), + slice.function, loc, decomposed_param_nodes)); + + XLS_RETURN_IF_ERROR(orig_param->ReplaceUsesWith(param_replace_tuple)); + changed = true; + } + } + + return absl::OkStatus(); +} + absl::Status Translator::FinishLastSlice(TrackedBValue return_bval, const xls::SourceInfo& loc) { XLS_RETURN_IF_ERROR(FinishSlice(return_bval, loc)); XLS_RETURN_IF_ERROR(RemoveMaskedOpParams(*context().sf, loc)); + // Direct-inness is used in optimization OptimizationContext optimization_context; + XLS_RETURN_IF_ERROR(MarkDirectIns(*context().sf, optimization_context, loc)); XLS_RETURN_IF_ERROR( OptimizeContinuations(*context().sf, optimization_context, loc)); - XLS_RETURN_IF_ERROR(MarkDirectIns(*context().sf, optimization_context, loc)); - if (debug_ir_trace_flags_ & DebugIrTraceFlags_FSMStates) { LogContinuations(*context().sf); } @@ -838,6 +1198,7 @@ absl::Status RemoveUnusedContinuationOutputs(GeneratedFunction& func, for (auto cont_out_it = slice.continuations_out.begin(); cont_out_it != slice.continuations_out.end();) { ContinuationValue& continuation_out = *cont_out_it; + if (outputs_used_by_inputs.contains(&continuation_out)) { ++cont_out_it; new_output_elems.push_back(continuation_out.output_node); @@ -880,6 +1241,7 @@ absl::Status RemoveUnusedContinuationOutputs(GeneratedFunction& func, } absl::Status RemoveUnusedContinuationInputs(GeneratedFunction& func, + OptimizationContext& context, bool& changed, const xls::SourceInfo& loc) { // Multiple inputs can share a parameter in the case of a phi / @@ -888,6 +1250,13 @@ absl::Status RemoveUnusedContinuationInputs(GeneratedFunction& func, absl::flat_hash_set deleted_params; for (GeneratedFunctionSlice& slice : func.slices) { + XLS_ASSIGN_OR_RETURN( + SourcesSetNodeInfo * node_info, + context.GetSourcesSetNodeInfoForFunction(slice.function)); + + ParamSet return_value_from_params = + node_info->GetSingleInfoForNode(slice.function->return_value()); + for (auto cont_in_it = slice.continuations_in.begin(); cont_in_it != slice.continuations_in.end();) { ContinuationInput& continuation_in = *cont_in_it; @@ -900,12 +1269,19 @@ absl::Status RemoveUnusedContinuationInputs(GeneratedFunction& func, CHECK_EQ(continuation_in.input_node->function_base(), slice.function); - if (!continuation_in.input_node->users().empty() || - slice.function->HasImplicitUse(continuation_in.input_node)) { + if (return_value_from_params.contains(continuation_in.input_node)) { ++cont_in_it; continue; } + // There may still be uses like forming a tuple, the element of which is + // never indexed + XLS_RETURN_IF_ERROR( + continuation_in.input_node + ->ReplaceUsesWithNew( + xls::ZeroOfType(continuation_in.input_node->GetType())) + .status()); + XLS_RETURN_IF_ERROR( slice.function->RemoveNode(continuation_in.input_node)); @@ -918,6 +1294,13 @@ absl::Status RemoveUnusedContinuationInputs(GeneratedFunction& func, return absl::OkStatus(); } +// For each continuation value, finds the whole parameters that feed it. +// For a simple pass through, there will just be one. However, for example, +// in the case of a select, there could be several. +// +// Whole parameter means that the entire value is passed through. For example, +// a tuple parameter's entire tuple value must be passed through, with all +// elements in the original positions. absl::StatusOr>> FindPassThroughs(GeneratedFunction& func, OptimizationContext& context) { @@ -961,6 +1344,16 @@ FindPassThroughs(GeneratedFunction& func, OptimizationContext& context) { disallowed = true; break; } + + // Check that param has the same number of elements as + // continuation_out. This avoids marking slices as pass-throughs. + xls::LeafTypeTree param_tree( + source_node->GetType()); + if (param_tree.elements().size() != sources.elements().size()) { + disallowed = true; + break; + } + xls::Param* from_param = source_node->As(); if (!continuation_params.contains(from_param)) { disallowed = true; @@ -1398,11 +1791,13 @@ absl::Status RemovePassthroughFeedbacks(GeneratedFunction& func, bool& changed, for (const xls::Param* param : params_per_slice.at(&slice)) { ParamInputs& param_inputs = all_continuation_params.at(param); + CHECK(!param_inputs.all_inputs.empty()); + CHECK_NE(param_inputs.upstream_input, nullptr); + if (param_inputs.multiple_upstream_inputs) { continue; } - CHECK(!param_inputs.all_inputs.empty()); if (param_inputs.all_inputs.size() == 1) { continue; } @@ -1506,6 +1901,7 @@ absl::Status RemoveDuplicateParams(GeneratedFunction& func, bool& changed, this_param_upstream_values = upstream_values_by_param.at(this_param); const absl::flat_hash_set& params_for_upstream_values = params_by_upstream_values.at(this_param_upstream_values); + CHECK(params_for_upstream_values.contains(this_param)); if (params_for_upstream_values.size() == 1) { continue; @@ -1596,20 +1992,31 @@ absl::Status Translator::OptimizeContinuations(GeneratedFunction& func, bool changed = true; xls::OptimizationContext xls_opt_context; + XLS_RETURN_IF_ERROR(ValidateContinuations(func, loc)); + do { - changed = false; - XLS_RETURN_IF_ERROR(RemoveUnusedContinuationInputs(func, changed, loc)); - XLS_RETURN_IF_ERROR(RemoveUnusedContinuationOutputs(func, changed, loc)); - XLS_RETURN_IF_ERROR(RemovePassThroughs(func, changed, context, loc)); - XLS_RETURN_IF_ERROR( - RemoveDeadCode(func, changed, package_, xls_opt_context, loc)); - XLS_RETURN_IF_ERROR(RemoveDuplicateInputs(func, changed, loc)); - XLS_RETURN_IF_ERROR(RemoveDuplicateParams(func, changed, loc)); - XLS_RETURN_IF_ERROR( - RemovePassthroughFeedbacks(func, changed, context, loc)); - XLS_RETURN_IF_ERROR(SubstituteLiterals(func, changed, loc)); + do { + changed = false; + XLS_RETURN_IF_ERROR( + RemoveUnusedContinuationInputs(func, context, changed, loc)); + XLS_RETURN_IF_ERROR(RemoveUnusedContinuationOutputs(func, changed, loc)); + XLS_RETURN_IF_ERROR(RemovePassThroughs(func, changed, context, loc)); + XLS_RETURN_IF_ERROR( + RemoveDeadCode(func, changed, package_, xls_opt_context, loc)); + XLS_RETURN_IF_ERROR(RemoveDuplicateInputs(func, changed, loc)); + XLS_RETURN_IF_ERROR(RemoveDuplicateParams(func, changed, loc)); + XLS_RETURN_IF_ERROR( + RemovePassthroughFeedbacks(func, changed, context, loc)); + XLS_RETURN_IF_ERROR(SubstituteLiterals(func, changed, loc)); + } while (changed); + + // For efficiency's sake, do a round of optimization before decomposing + // Decompose relies on other passes to clean up + XLS_RETURN_IF_ERROR(DecomposeContinuationValues(func, changed, loc)); } while (changed); + XLS_RETURN_IF_ERROR(ValidateContinuations(func, loc)); + return absl::OkStatus(); } @@ -1635,6 +2042,7 @@ absl::Status Translator::GetDirectInSourcesForSlice( for (int64_t p = 0; p < slice.function->params().size(); ++p) { const xls::Param* param = slice.function->params().at(p); + // Don't mark statics direct-in if (static_param_names.contains(param->name())) { continue; } diff --git a/xls/contrib/xlscc/generate_fsm.cc b/xls/contrib/xlscc/generate_fsm.cc index 2ec7c5da6a..1e2d9b7f57 100644 --- a/xls/contrib/xlscc/generate_fsm.cc +++ b/xls/contrib/xlscc/generate_fsm.cc @@ -81,7 +81,7 @@ absl::Status NewFSMGenerator::SetupNewFSMGenerationContext( absl::StatusOr NewFSMGenerator::LayoutNewFSM( const GeneratedFunction& func, - const absl::flat_hash_map& + const absl::flat_hash_map& state_element_for_static, const xls::SourceInfo& body_loc) { NewFSMLayout ret; @@ -189,25 +189,20 @@ absl::StatusOr NewFSMGenerator::LayoutNewFSM( // For optimization purposes, such as narrowing, it is better that the values // saved in a state element share semantics. Therefore, Clang NamedDecls are // used to identify values that may share a state element. - absl::flat_hash_map> + absl::flat_hash_map> state_element_indices_by_decl; // Inject statics: this enables state element sharing with statics - std::vector static_decls = - func.GetDeterministicallyOrderedStaticValues(); - - for (const clang::NamedDecl* decl : static_decls) { - xls::StateElement* existing_state_element = - state_element_for_static.at(decl); - + for (const auto& [decl_leaf, existing_state_element] : + state_element_for_static) { NewFSMStateElement state_element = { .name = existing_state_element->name(), .type = existing_state_element->type(), .existing_state_element = existing_state_element, }; ret.state_elements.push_back(state_element); - state_element_indices_by_decl[decl].push_back(ret.state_elements.size() - - 1); + state_element_indices_by_decl[decl_leaf].push_back( + ret.state_elements.size() - 1); } for (const NewFSMState& state : ret.states) { @@ -235,26 +230,30 @@ absl::StatusOr NewFSMGenerator::LayoutNewFSM( // This value has not already been assigned a state element // Try to find state elements to share by decl std::optional found_element_by_decl = std::nullopt; - std::vector decls; + std::vector decls; - for (const clang::NamedDecl* decl : value->decls) { + for (const DeclLeaf& decl : value->decls) { decls.push_back(decl); } func.SortNamesDeterministically(decls); - for (const clang::NamedDecl* decl : decls) { + for (const DeclLeaf& decl : decls) { if (!state_element_indices_by_decl.contains(decl)) { continue; } const std::vector& elements_for_this_decl = state_element_indices_by_decl.at(decl); + for (const int64_t element_for_decl_index : elements_for_this_decl) { if (used_state_element_indices.contains(element_for_decl_index)) { continue; } - XLSCC_CHECK(ret.state_elements.at(element_for_decl_index) - .type->IsEqualTo(value->output_node->GetType()), - body_loc); + + if (!ret.state_elements.at(element_for_decl_index) + .type->IsEqualTo(value->output_node->GetType())) { + continue; + } + found_element_by_decl = element_for_decl_index; break; } @@ -266,9 +265,12 @@ absl::StatusOr NewFSMGenerator::LayoutNewFSM( // Create a new state element if none were found to share if (!found_element_by_decl.has_value()) { + std::string elem_name = value->name; + if (element_index >= 0) { + elem_name = absl::StrFormat("%s_el%li", elem_name, element_index); + } NewFSMStateElement state_element = { - .name = - absl::StrFormat("%s_slc_%li", value->name, state.slice_index), + .name = absl::StrFormat("%s_slc%li", elem_name, state.slice_index), .type = value->output_node->GetType(), }; ret.state_elements.push_back(state_element); @@ -281,7 +283,7 @@ absl::StatusOr NewFSMGenerator::LayoutNewFSM( XLSCC_CHECK_LT(element_index, ret.state_elements.size(), body_loc); ret.state_element_by_continuation_value[value] = element_index; used_state_element_indices.insert(element_index); - for (const clang::NamedDecl* decl : decls) { + for (const DeclLeaf& decl : decls) { state_element_indices_by_decl[decl].push_back(element_index); } } @@ -311,8 +313,9 @@ absl::StatusOr NewFSMGenerator::LayoutNewFSM( absl::StrFormat("%s from slice %li", value->name, ret.output_slice_index_by_value.at(value))); } - LOG(INFO) << absl::StrFormat(" %s (%s), values: %s", elem.name, - elem.type->ToString(), + LOG(INFO) << absl::StrFormat(" %s type %s (%i bits), values: %s", + elem.name, elem.type->ToString(), + elem.type->GetFlatBitCount(), absl::StrJoin(value_names, ", ")); } } @@ -615,8 +618,10 @@ absl::StatusOr NewFSMGenerator::GenerateNewFSMInvocation( const GeneratedFunction* xls_func, const std::vector& direct_in_args, - const absl::flat_hash_map& + const absl::flat_hash_map& state_element_for_static, + const absl::flat_hash_map& + type_for_static, const absl::flat_hash_map& return_index_for_static, xls::ProcBuilder& pb, const xls::SourceInfo& body_loc) { @@ -847,9 +852,12 @@ NewFSMGenerator::GenerateNewFSMInvocation( // Add statics for (const clang::NamedDecl* decl : slice.static_values) { - xls::StateElement* state_element = state_element_for_static.at(decl); - xls::StateRead* state_read = pb.proc()->GetStateRead(state_element); - TrackedBValue prev_val(state_read, &pb); + TrackedBValue prev_val; + XLS_ASSIGN_OR_RETURN( + prev_val, ComposeStaticValueInput(decl, + /*generate_new_fsm=*/true, + state_element_for_static, + type_for_static, pb, body_loc)); invoke_params.push_back(prev_val); } diff --git a/xls/contrib/xlscc/generate_fsm.h b/xls/contrib/xlscc/generate_fsm.h index f2f1cfd6c8..14427b5b6e 100644 --- a/xls/contrib/xlscc/generate_fsm.h +++ b/xls/contrib/xlscc/generate_fsm.h @@ -115,7 +115,7 @@ class NewFSMGenerator : public GeneratorBase { // XLS IR. absl::StatusOr LayoutNewFSM( const GeneratedFunction& func, - const absl::flat_hash_map& + const absl::flat_hash_map& state_element_for_static, const xls::SourceInfo& body_loc); @@ -123,8 +123,10 @@ class NewFSMGenerator : public GeneratorBase { absl::StatusOr GenerateNewFSMInvocation( const GeneratedFunction* xls_func, const std::vector& direct_in_args, - const absl::flat_hash_map& + const absl::flat_hash_map& state_element_for_static, + const absl::flat_hash_map& + type_for_static, const absl::flat_hash_map& return_index_for_static, xls::ProcBuilder& pb, const xls::SourceInfo& body_loc); diff --git a/xls/contrib/xlscc/translate_block.cc b/xls/contrib/xlscc/translate_block.cc index beb56b4c5d..9f0c1d0de7 100644 --- a/xls/contrib/xlscc/translate_block.cc +++ b/xls/contrib/xlscc/translate_block.cc @@ -27,6 +27,7 @@ #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -246,6 +247,62 @@ absl::Status Translator::GenerateExternalChannels( return absl::OkStatus(); } +absl::StatusOr ComposeStaticValueInput( + const clang::NamedDecl* namedecl, bool generate_new_fsm, + const absl::flat_hash_map& + state_element_for_static, + const absl::flat_hash_map& + type_for_static, + xls::ProcBuilder& pb, const xls::SourceInfo& loc) { + xls::Type* xls_type = type_for_static.at(namedecl); + if (!generate_new_fsm || !TypeIsDecomposable(xls_type)) { + xls::StateElement* state_element = state_element_for_static.at( + DeclLeaf{.decl = namedecl, .leaf_index = -1}); + return TrackedBValue(pb.proc()->GetStateRead(state_element), &pb); + } + absl::InlinedVector decomposed_types = + DecomposeTupleTypes(xls_type); + absl::InlinedVector nodes; + + for (int64_t i = 0; i < decomposed_types.size(); ++i) { + xls::StateElement* decomposed_element = state_element_for_static.at( + DeclLeaf{.decl = namedecl, .leaf_index = i}); + nodes.push_back(pb.proc()->GetStateRead(decomposed_element)); + } + + XLS_ASSIGN_OR_RETURN(xls::Node * node, + ComposeTuples(namedecl->getNameAsString(), xls_type, + pb.proc(), loc, nodes)); + return TrackedBValue(node, &pb); +} + +absl::StatusOr> +Translator::DecomposeStaticValueInput(PreparedBlock& prepared, + const clang::NamedDecl* namedecl, + xls::ProcBuilder& pb, + const xls::SourceInfo& loc) { + xls::Type* xls_type = prepared.type_for_variable.at(namedecl); + if (!generate_new_fsm_ || !TypeIsDecomposable(xls_type)) { + xls::StateElement* state_element = prepared.state_element_for_variable.at( + DeclLeaf{.decl = namedecl, .leaf_index = -1}); + return absl::InlinedVector{state_element}; + } + absl::InlinedVector decomposed_types = + DecomposeTupleTypes(xls_type); + if (decomposed_types.empty()) { + decomposed_types.push_back(xls_type); + } + absl::InlinedVector ret; + ret.reserve(decomposed_types.size()); + + for (int64_t i = 0; i < decomposed_types.size(); ++i) { + ret.push_back(prepared.state_element_for_variable.at( + DeclLeaf{.decl = namedecl, .leaf_index = i})); + } + + return ret; +} + absl::StatusOr Translator::GenerateIR_Block( xls::Package* package, const HLSBlock& block, int top_level_init_interval, const ChannelOptions& channel_options) { @@ -530,7 +587,7 @@ absl::StatusOr Translator::GenerateIR_Block( generator.GenerateNewFSMInvocation( prepared.xls_func, /*direct_in_args=*/prepared.args, - /*state_element_for_static=*/prepared.state_element_for_variable, + prepared.state_element_for_variable, prepared.type_for_variable, prepared.return_index_for_static, pb, body_loc)); } else { XLS_ASSIGN_OR_RETURN( @@ -573,27 +630,49 @@ absl::StatusOr Translator::GenerateIR_Block( TrackedBValue next_val = GetFlexTupleField(fsm_ret.return_value, ret_idx, prepared.xls_func->return_value_count, body_loc); - xls::StateElement* state_elem = - prepared.state_element_for_variable.at(namedecl); - xls::StateRead* state_read = pb.proc()->GetStateRead(state_elem); - TrackedBValue prev_val(state_read, &pb); + + absl::InlinedVector decomposed_elems; + XLS_ASSIGN_OR_RETURN( + decomposed_elems, + DecomposeStaticValueInput(prepared, namedecl, pb, body_loc)); + absl::InlinedVector decomposed_next_vals; + XLS_ASSIGN_OR_RETURN(decomposed_next_vals, + DecomposeTuples(next_val.node())); + + // Empty tuple case + if (decomposed_next_vals.empty() || !generate_new_fsm_) { + decomposed_next_vals.clear(); + decomposed_next_vals.push_back(next_val.node()); + XLSCC_CHECK_EQ(decomposed_elems.size(), 1, body_loc); + } else { + XLSCC_CHECK_EQ(decomposed_elems.size(), decomposed_next_vals.size(), + body_loc); + } XLS_ASSIGN_OR_RETURN(bool is_on_reset, DeclIsOnReset(namedecl)); - NextStateValue next_state_value = {.priority = 0, - .extra_label = std::string{block_name}}; + for (int64_t i = 0; i < decomposed_elems.size(); ++i) { + xls::StateElement* decomposed_elem = decomposed_elems.at(i); + xls::Node* decomposed_next_val = decomposed_next_vals.at(i); - if (!is_on_reset) { - next_state_value.value = next_val; - } else { - next_state_value.value = - pb.And(prev_val, - pb.Not(fsm_ret.returns_this_activation, body_loc, - /*name=*/"does_not_return_this_activation"), - body_loc, /*name=*/"next_on_reset"); + NextStateValue next_state_value = { + .priority = 0, .extra_label = std::string{block_name}}; + + if (!is_on_reset) { + next_state_value.value = TrackedBValue(decomposed_next_val, &pb); + } else { + XLSCC_CHECK_EQ(decomposed_elems.size(), 1, body_loc); + xls::StateRead* state_read = pb.proc()->GetStateRead(decomposed_elem); + TrackedBValue prev_val(state_read, &pb); + next_state_value.value = + pb.And(prev_val, + pb.Not(fsm_ret.returns_this_activation, body_loc, + /*name=*/"does_not_return_this_activation"), + body_loc, /*name=*/"next_on_reset"); + } + next_state_value.condition = fsm_ret.returns_this_activation; + next_state_values.insert({decomposed_elem, next_state_value}); } - next_state_value.condition = fsm_ret.returns_this_activation; - next_state_values.insert({state_elem, next_state_value}); } for (const auto& [state_elem, bval] : fsm_ret.extra_next_state_values) { @@ -1662,15 +1741,18 @@ absl::StatusOr Translator::GenerateSubFSM( const clang::NamedDecl* caller_decl = *unique_caller_decls_for_param.begin(); - if (!outer_prepared.state_element_for_variable.contains(caller_decl)) { + if (!outer_prepared.state_element_for_variable.contains( + DeclLeaf{.decl = caller_decl})) { continue; } xls::StateElement* state_elem = - outer_prepared.state_element_for_variable.at(caller_decl); + outer_prepared.state_element_for_variable.at( + DeclLeaf{.decl = caller_decl}); // Avoid [] = .at() - outer_prepared.state_element_for_variable[callee_param] = state_elem; + outer_prepared.state_element_for_variable[DeclLeaf{.decl = callee_param}] = + state_elem; } // Generate inner FSM @@ -1970,6 +2052,45 @@ Translator::GenerateIRBlockPrepare( context().sf = temp_sf.get(); context().fb = dynamic_cast(&pb); + auto generate_decomposed_elements = + [this, &prepared, &pb, body_loc]( + const xls::Value& init_value, const std::string& name, + const clang::NamedDecl* decl) -> absl::Status { + absl::InlinedVector decomposed_values; + XLS_ASSIGN_OR_RETURN(xls::TypeProto type_proto, init_value.TypeAsProto()); + XLS_ASSIGN_OR_RETURN(xls::Type * xls_type, + package_->GetTypeFromProto(type_proto)); + prepared.type_for_variable[decl] = xls_type; + const bool do_decompose = generate_new_fsm_ && TypeIsDecomposable(xls_type); + if (do_decompose) { + XLS_ASSIGN_OR_RETURN(decomposed_values, + DecomposeValue(xls_type, init_value)); + + // Empty tuple case + if (decomposed_values.empty()) { + decomposed_values.push_back(init_value); + } + } else { + decomposed_values.push_back(init_value); + } + for (int64_t i = 0; i < decomposed_values.size(); ++i) { + const xls::Value& decomposed_value = decomposed_values.at(i); + DeclLeaf decl_leaf = {.decl = decl, .leaf_index = -1}; + std::string decomposed_name = name; + if (do_decompose) { + decomposed_name = absl::StrFormat("%s_%d", name, i); + decl_leaf.leaf_index = i; + } + TrackedBValue elem_bval = + pb.StateElement(decomposed_name, decomposed_value, body_loc); + xls::StateElement* state_elem = + elem_bval.node()->As()->state_element(); + prepared.state_element_for_variable[decl_leaf] = state_elem; + } + + return absl::OkStatus(); + }; + // This state and argument if (this_decl != nullptr) { XLS_ASSIGN_OR_RETURN(CValue this_cval, GenerateTopClassInitValue( @@ -1979,29 +2100,28 @@ Translator::GenerateIRBlockPrepare( XLS_ASSIGN_OR_RETURN(xls::Value this_init_val, EvaluateBVal(this_cval.rvalue(), body_loc)); - TrackedBValue elem_bval = pb.StateElement("this", this_init_val, body_loc); - xls::StateElement* state_elem = - elem_bval.node()->As()->state_element(); + XLS_RETURN_IF_ERROR( + generate_decomposed_elements(this_init_val, "this", this_decl)); - // Don't need to worry about sharing for this, as it's only used at the top - // level (block as class) - prepared.state_element_for_variable[this_decl] = state_elem; - prepared.args.push_back(elem_bval); + XLS_ASSIGN_OR_RETURN( + TrackedBValue bval, + ComposeStaticValueInput(this_decl, + /*generate_new_fsm=*/generate_new_fsm_, + prepared.state_element_for_variable, + prepared.type_for_variable, pb, body_loc)); + + prepared.args.push_back(bval); } for (const clang::NamedDecl* namedecl : prepared.xls_func->GetDeterministicallyOrderedStaticValues()) { const ConstValue& initval = prepared.xls_func->static_values.at(namedecl); - // Don't need to worry about sharing for this, as it's only used when a - // static variable is declared in the loop body - TrackedBValue elem_bval = pb.StateElement( - XLSNameMangle(clang::GlobalDecl(namedecl)), initval.rvalue(), body_loc); - xls::StateElement* state_elem = - elem_bval.node()->As()->state_element(); + XLS_RETURN_IF_ERROR(generate_decomposed_elements( + initval.rvalue(), XLSNameMangle(clang::GlobalDecl(namedecl)), + namedecl)); prepared.return_index_for_static[namedecl] = next_return_index++; - prepared.state_element_for_variable[namedecl] = state_elem; } // This return @@ -2092,10 +2212,11 @@ Translator::GenerateIRBlockPrepare( break; } case xlscc::SideEffectingParameterType::kStatic: { - TrackedBValue bval( - pb.proc()->GetStateRead( - prepared.state_element_for_variable.at(param.static_value)), - &pb); + XLS_ASSIGN_OR_RETURN( + TrackedBValue bval, + ComposeStaticValueInput(param.static_value, generate_new_fsm_, + prepared.state_element_for_variable, + prepared.type_for_variable, pb, body_loc)); prepared.args.push_back(bval); break; } @@ -2151,6 +2272,11 @@ absl::StatusOr Translator::BuildWithNextStateValueMap( next_state_values.find(elem)->second; TrackedBValue next_state_value_bval; if (next_state_value.condition.valid()) { + XLSCC_CHECK_EQ(next_state_value.condition.GetType()->GetFlatBitCount(), + 1, loc); + XLSCC_CHECK( + next_state_value.value.GetType()->IsEqualTo(read_bval.GetType()), + loc); next_state_value_bval = pb.Select(next_state_value.condition, /*on_true=*/next_state_value.value, diff --git a/xls/contrib/xlscc/translate_loops.cc b/xls/contrib/xlscc/translate_loops.cc index 5e84dc38c3..b5baf13f90 100644 --- a/xls/contrib/xlscc/translate_loops.cc +++ b/xls/contrib/xlscc/translate_loops.cc @@ -1274,7 +1274,7 @@ Translator::GenerateIR_PipelinedLoopContents( const PipelinedLoopSubProc& pipelined_loop_proc, xls::ProcBuilder& pb, TrackedBValue token_in, TrackedBValue received_context_tuple, TrackedBValue in_state_condition, bool in_fsm, - absl::flat_hash_map* + absl::flat_hash_map* state_element_for_variable, int nesting_level) { const std::shared_ptr& context_in_cvars_struct_ctype = @@ -1346,7 +1346,7 @@ Translator::GenerateIR_PipelinedLoopContents( } const bool do_create_state_element = - !prepared.state_element_for_variable.contains(decl); + !prepared.state_element_for_variable.contains(DeclLeaf{.decl = decl}); if (debug_ir_trace_flags_ & DebugIrTraceFlags_LoopContext) { LOG(INFO) << absl::StrFormat( @@ -1367,10 +1367,10 @@ Translator::GenerateIR_PipelinedLoopContents( state_read_bval.node()->As()->state_element(); state_reads_by_decl[decl] = state_read_bval; - prepared.state_element_for_variable[decl] = state_elem; + prepared.state_element_for_variable[DeclLeaf{.decl = decl}] = state_elem; } else { xls::StateElement* state_elem = - prepared.state_element_for_variable.at(decl); + prepared.state_element_for_variable.at(DeclLeaf{.decl = decl}); state_reads_by_decl[decl] = TrackedBValue(pb.proc()->GetStateRead(state_elem), &pb); } @@ -1602,7 +1602,8 @@ Translator::GenerateIR_PipelinedLoopContents( } next_state_values.insert( - {prepared.state_element_for_variable[decl], next_state_value}); + {prepared.state_element_for_variable[DeclLeaf{.decl = decl}], + next_state_value}); if (context_in_field_indices.contains(decl)) { out_tuple_values[context_in_field_indices.at(decl)] = out_bval; @@ -1624,7 +1625,7 @@ Translator::GenerateIR_PipelinedLoopContents( namedecl->getNameAsString())); next_state_values.insert( - {prepared.state_element_for_variable.at(namedecl), + {prepared.state_element_for_variable.at(DeclLeaf{.decl = namedecl}), NextStateValue{.priority = nesting_level, .extra_label = name_prefix, .value = ret_next, @@ -1641,7 +1642,8 @@ Translator::GenerateIR_PipelinedLoopContents( // Can't re-use state elements that are fed into context output, // as the context output must be kept steady outside of the state // containing the loop. - if (context_in_field_indices.contains(decl)) { + XLSCC_CHECK_EQ(decl.leaf_index, -1, loc); + if (context_in_field_indices.contains(decl.decl)) { continue; } (*state_element_for_variable)[decl] = param; diff --git a/xls/contrib/xlscc/translator.cc b/xls/contrib/xlscc/translator.cc index 7f13e5436c..c2df47d0fb 100644 --- a/xls/contrib/xlscc/translator.cc +++ b/xls/contrib/xlscc/translator.cc @@ -3294,7 +3294,7 @@ absl::StatusOr Translator::GenerateIR_Call(const clang::CallExpr* call, XLS_ASSIGN_OR_RETURN( CValue call_res, - GenerateIR_Call(funcdecl, args, pthisval, &this_lval, loc)); + GenerateIR_Call(funcdecl, args, pthisval, &this_lval, this_expr, loc)); if (add_this_return) { const int64_t reads_masked_before = @@ -3394,7 +3394,8 @@ absl::StatusOr Translator::ApplyArrayAssignHack( XLS_ASSIGN_OR_RETURN( CValue f_return, GenerateIR_Call(to_call, {ivalue, rvalue}, &this_inout, - /*this_lval=*/nullptr, loc)); + /*this_lval=*/nullptr, + /*this_expr_for_decl_translation=*/nullptr, loc)); XLS_RETURN_IF_ERROR( Assign(lvalue, CValue(this_inout, lvalue_initial.type()), loc)); *output = f_return; @@ -3485,7 +3486,9 @@ absl::Status Translator::GetChannelsForLValue( absl::StatusOr Translator::GenerateIR_Call( const clang::FunctionDecl* funcdecl, std::vector expr_args, TrackedBValue* this_inout, - std::shared_ptr* this_lval, const xls::SourceInfo& loc) { + std::shared_ptr* this_lval, + std::optional this_expr_for_decl_translation, + const xls::SourceInfo& loc) { // Ensure callee has been parsed XLS_RETURN_IF_ERROR(GetFunctionBody(funcdecl).status()); @@ -3947,6 +3950,30 @@ absl::StatusOr Translator::GenerateIR_Call( std::list phi_inputs_to_insert; + auto translate_decls = [this, funcdecl, this_expr_for_decl_translation]( + const absl::flat_hash_set& decls) + -> absl::flat_hash_set { + absl::flat_hash_set ret; + for (const DeclLeaf& decl : decls) { + if (decl.decl != funcdecl) { + ret.insert(decl); + continue; + } + if (this_expr_for_decl_translation.has_value()) { + const clang::Expr* this_expr_to_use = + RemoveParensAndCasts(this_expr_for_decl_translation.value()); + const clang::DeclRefExpr* var_ref = + clang::dyn_cast(this_expr_to_use); + if (var_ref != nullptr) { + ret.insert(DeclLeaf(var_ref->getDecl(), decl.leaf_index)); + } + continue; + } + ret.insert(decl); + } + return ret; + }; + // Add N-1 slices to the caller, where N is the number of slices in the // callee. A callee with only one slice does not need to add slices to the // caller. When adding caller slices, copy the (optimized) continuation @@ -4017,6 +4044,9 @@ absl::StatusOr Translator::GenerateIR_Call( last_callee_slice.continuations_out) { ContinuationValue caller_continuation_out = callee_continuation_out; + caller_continuation_out.decls = + translate_decls(caller_continuation_out.decls); + NATIVE_BVAL return_bval = returns_by_continuation_value.at(&callee_continuation_out); @@ -4198,6 +4228,9 @@ absl::StatusOr Translator::GenerateIR_Call( ContinuationInput caller_continuation_in = *phi_input.callee_continuation_in; + caller_continuation_in.decls = + translate_decls(caller_continuation_in.decls); + caller_continuation_in.input_node = caller_params_for_callee_params.at( phi_input.callee_continuation_in->input_node); caller_continuation_in.continuation_out = @@ -4692,9 +4725,12 @@ absl::StatusOr Translator::HandleConstructors( args.push_back(ctor->getArg(pi)); } std::shared_ptr this_lval; + XLS_ASSIGN_OR_RETURN( - CValue ret, GenerateIR_Call(ctor->getConstructor(), args, &single_element, - /*this_lval=*/&this_lval, loc)); + CValue ret, + GenerateIR_Call(ctor->getConstructor(), args, &single_element, + /*this_lval=*/&this_lval, + /*this_expr_for_decl_translation=*/nullptr, loc)); TrackedBValue result; if (ctype->Is()) { XLS_ASSIGN_OR_RETURN(result, @@ -6097,6 +6133,21 @@ void GeneratedFunction::SortNamesDeterministically( }); } +void GeneratedFunction::SortNamesDeterministically( + std::vector& decls) const { + std::sort(decls.begin(), decls.end(), + [this](const DeclLeaf& a, const DeclLeaf& b) { + CHECK(declaration_order_by_name_.contains(a.decl)); + CHECK(declaration_order_by_name_.contains(b.decl)); + const int64_t a_order = declaration_order_by_name_.at(a.decl); + const int64_t b_order = declaration_order_by_name_.at(b.decl); + if (a_order != b_order) { + return a_order < b_order; + } + return a.leaf_index < b.leaf_index; + }); +} + std::vector GeneratedFunction::GetDeterministicallyOrderedStaticValues() const { std::vector ret; diff --git a/xls/contrib/xlscc/translator.h b/xls/contrib/xlscc/translator.h index 809572a3c8..31f36f8fb2 100644 --- a/xls/contrib/xlscc/translator.h +++ b/xls/contrib/xlscc/translator.h @@ -757,7 +757,9 @@ class Translator final : public GeneratorBase, absl::StatusOr GenerateIR_Call( const clang::FunctionDecl* funcdecl, std::vector expr_args, TrackedBValue* this_inout, - std::shared_ptr* this_lval, const xls::SourceInfo& loc); + std::shared_ptr* this_lval, + std::optional this_expr_for_decl_translation, + const xls::SourceInfo& loc); absl::Status AddIOOpForSliceForCall( const GeneratedFunction& func, GeneratedFunctionSlice& slice, @@ -834,8 +836,10 @@ class Translator final : public GeneratorBase, absl::flat_hash_map return_index_for_op; absl::flat_hash_map return_index_for_static; - absl::flat_hash_map + + absl::flat_hash_map state_element_for_variable; + absl::flat_hash_map type_for_variable; TrackedBValue orig_token; TrackedBValue token; bool contains_fsm = false; @@ -903,6 +907,11 @@ class Translator final : public GeneratorBase, const GeneratedFunction* caller_sub_function, const xls::SourceInfo& body_loc); + absl::StatusOr> + DecomposeStaticValueInput(PreparedBlock& prepared, + const clang::NamedDecl* namedecl, + xls::ProcBuilder& pb, const xls::SourceInfo& loc); + // Generates a dummy no-op with condition 0 for channels in // unused_external_channels_ absl::Status GenerateDefaultIOOps(PreparedBlock& prepared, @@ -1086,6 +1095,9 @@ class Translator final : public GeneratorBase, const xls::SourceInfo& loc); absl::Status FinishLastSlice(TrackedBValue return_bval, const xls::SourceInfo& loc); + absl::Status DecomposeContinuationValues(GeneratedFunction& func, + bool& changed, + const xls::SourceInfo& loc); absl::Status OptimizeContinuations(GeneratedFunction& func, OptimizationContext& context, const xls::SourceInfo& loc); @@ -1240,7 +1252,7 @@ class Translator final : public GeneratorBase, const PipelinedLoopSubProc& pipelined_loop_proc, xls::ProcBuilder& pb, TrackedBValue token_in, TrackedBValue received_context_tuple, TrackedBValue in_state_condition, bool in_fsm, - absl::flat_hash_map* + absl::flat_hash_map* state_element_for_variable = nullptr, int nesting_level = -1); diff --git a/xls/contrib/xlscc/translator_types.cc b/xls/contrib/xlscc/translator_types.cc index 78e4c38265..c4f3ece419 100644 --- a/xls/contrib/xlscc/translator_types.cc +++ b/xls/contrib/xlscc/translator_types.cc @@ -29,6 +29,7 @@ #include "absl/container/btree_map.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -87,7 +88,8 @@ absl::StatusOr CType::ContainsLValues( CType::operator std::string() const { return "CType"; } xls::Type* CType::GetXLSType(xls::Package* /*package*/) const { - LOG(FATAL) << "GetXLSType() unsupported in CType base class"; + LOG(FATAL) << "GetXLSType() unsupported in CType base class: " + << this->debug_string(); return nullptr; } @@ -1235,14 +1237,14 @@ void LogContinuations(const xlscc::GeneratedFunction& func) { slices_by_continuation_out[&continuation_out] = &slice; } } - auto decl_names_string = - [](const absl::flat_hash_set& decls) { - std::vector decl_names; - for (const clang::NamedDecl* decl : decls) { - decl_names.push_back(decl->getNameAsString()); - } - return absl::StrJoin(decl_names, ","); - }; + auto decl_names_string = [](const absl::flat_hash_set& decls) { + std::vector decl_names; + for (const DeclLeaf& decl : decls) { + decl_names.push_back(absl::StrFormat( + "%s:%li", decl.decl->getNameAsString(), decl.leaf_index)); + } + return absl::StrJoin(decl_names, ","); + }; int64_t slice_index = 0; for (const GeneratedFunctionSlice& slice : func.slices) { LOG(INFO) << ""; @@ -1439,4 +1441,119 @@ absl::Status ShortCircuitBVal(TrackedBValue& bval, const xls::SourceInfo& loc) { return absl::OkStatus(); } +namespace { + +void DecomposeTupleTypes(xls::Type* type, + absl::InlinedVector& out) { + if (!type->IsTuple()) { + out.push_back(type); + return; + } + for (xls::Type* element_type : type->AsTupleOrDie()->element_types()) { + DecomposeTupleTypes(element_type, out); + } +} + +absl::Status DecomposeTuples(xls::Node* node, + absl::InlinedVector& out) { + if (!TypeIsDecomposable(node->GetType())) { + out.push_back(node); + return absl::OkStatus(); + } + xls::TupleType* tuple_type = node->GetType()->AsTupleOrDie(); + xls::FunctionBase* func = node->function_base(); + for (int64_t ti = 0; ti < tuple_type->size(); ++ti) { + XLS_ASSIGN_OR_RETURN( + xls::Node * new_index, + func->MakeNodeWithName( + node->loc(), node, ti, + /*name=*/absl::StrFormat("%s_%li", node->GetName(), ti))); + XLS_RETURN_IF_ERROR(DecomposeTuples(new_index, out)); + } + return absl::OkStatus(); +} + +absl::StatusOr ComposeTuples( + std::string_view name, xls::Type* to_type, xls::FunctionBase* in_func, + const xls::SourceInfo& loc, const absl::InlinedVector& nodes, + int64_t& index_offset) { + absl::InlinedVector operands; + + if (!(to_type->IsTuple() && to_type->AsTupleOrDie()->size() == 0)) { + if (!TypeIsDecomposable(to_type)) { + xls::Node* first_node = nodes.at(index_offset); + ++index_offset; + return first_node; + } + + xls::TupleType* tuple_type = to_type->AsTupleOrDie(); + + for (int64_t elem = 0; elem < tuple_type->size(); ++elem) { + XLS_ASSIGN_OR_RETURN( + xls::Node * new_operand, + ComposeTuples(name, tuple_type->element_types().at(elem), in_func, + loc, nodes, + + index_offset)); + operands.push_back(new_operand); + } + } + + XLS_ASSIGN_OR_RETURN( + xls::Node * new_tuple, + in_func->MakeNodeWithName( + loc, operands, + /*name=*/absl::StrFormat("%s_%li", name, index_offset))); + + CHECK(new_tuple->GetType()->IsEqualTo(to_type)); + return new_tuple; +} + +absl::Status DecomposeValue(xls::Type* type, const xls::Value& value, + absl::InlinedVector& ret) { + if (!TypeIsDecomposable(type)) { + ret.push_back(value); + return absl::OkStatus(); + } + + for (int64_t ei = 0; ei < type->AsTupleOrDie()->size(); ++ei) { + XLS_RETURN_IF_ERROR(DecomposeValue(type->AsTupleOrDie()->element_type(ei), + value.element(ei), ret)); + } + + return absl::OkStatus(); +} + +} // namespace + +bool TypeIsDecomposable(xls::Type* type) { return type->IsTuple(); } + +absl::InlinedVector DecomposeTupleTypes(xls::Type* type) { + absl::InlinedVector ret; + DecomposeTupleTypes(type, ret); + return ret; +} + +absl::StatusOr> DecomposeTuples( + xls::Node* node) { + absl::InlinedVector ret; + XLS_RETURN_IF_ERROR(DecomposeTuples(node, ret)); + return ret; +} + +absl::StatusOr ComposeTuples( + std::string_view name, xls::Type* to_type, xls::FunctionBase* in_func, + const xls::SourceInfo& loc, + const absl::InlinedVector& nodes) { + int64_t index_offset = 0; + return ComposeTuples(name, to_type, in_func, loc, nodes, index_offset); +} + +absl::StatusOr> DecomposeValue( + xls::Type* type, const xls::Value& value) { + absl::InlinedVector ret; + XLS_RETURN_IF_ERROR(DecomposeValue(type, value, ret)); + return ret; +} + } // namespace xlscc diff --git a/xls/contrib/xlscc/translator_types.h b/xls/contrib/xlscc/translator_types.h index 43873b4f68..ab7446e7fc 100644 --- a/xls/contrib/xlscc/translator_types.h +++ b/xls/contrib/xlscc/translator_types.h @@ -1006,6 +1006,25 @@ struct PipelinedLoopSubProc { std::vector vars_to_save_between_iters; }; +struct DeclLeaf { + friend bool operator==(const DeclLeaf& lhs, const DeclLeaf& rhs) { + return lhs.decl == rhs.decl && lhs.leaf_index == rhs.leaf_index; + } + + template + friend H AbslHashValue(H h, const DeclLeaf& c) { + return H::combine(std::move(h), c.decl, c.leaf_index); + } + + std::string ToString() const { + return absl::StrFormat("%s_%li", decl->getNameAsString().c_str(), + leaf_index); + } + + const clang::NamedDecl* decl = nullptr; + int64_t leaf_index = -1L; +}; + // A value outputted from a function slice for potential later use. // One of these is generated per node that is referred to by a TrackedBValue // at the time the function gets "sliced" during generation. @@ -1015,9 +1034,10 @@ struct PipelinedLoopSubProc { struct ContinuationValue { xls::Node* output_node = nullptr; - // name, decls are for test/debug only + // name is for test/debug only std::string name; - absl::flat_hash_set decls; + // decls are used for state element sharing + absl::flat_hash_set decls; // Precomputed literal, used for unrolling, IO pruning, etc std::optional literal = std::nullopt; @@ -1039,7 +1059,7 @@ struct ContinuationInput { // name, decls are for test/debug only std::string name; - absl::flat_hash_set decls; + absl::flat_hash_set decls; }; // One "slice" of a C++ function. When there are side-effecting operations @@ -1165,6 +1185,7 @@ struct GeneratedFunction { } void SortNamesDeterministically( std::vector& names) const; + void SortNamesDeterministically(std::vector& decls) const; std::vector GetDeterministicallyOrderedStaticValues() const; }; @@ -1217,6 +1238,8 @@ absl::Status ShortCircuitBVal(TrackedBValue& bval, const xls::SourceInfo& loc); typedef absl::flat_hash_set ParamSet; +// This class sees through all operation types, and includes dynamic selectors +// and indices in its output. class SourcesSetNodeInfo : public xls::DataFlowLazyNodeInfo { public: @@ -1234,12 +1257,13 @@ class SourcesSetNodeInfo const absl::Span& infos) const override final; }; -// typedef absl::flat_hash_set NodeSet; typedef absl::flat_hash_set NodeSourceSet; // This class sees through compound type operations, ie on tuples, // but it does not see through other operations, ie add. // +// It also does not include dynamic selectors or indices in its output. +// // It returns a set of nodes, as some sources will be of the unsupported types, // ie the aforementioned add. class SourcesSetTreeNodeInfo @@ -1287,6 +1311,36 @@ class OptimizationContext { query_engines_by_function_; }; +absl::InlinedVector DecomposeTupleTypes(xls::Type* type); + +bool TypeIsDecomposable(xls::Type* type); +absl::StatusOr> DecomposeTuples( + xls::Node* node); + +absl::StatusOr ComposeTuples( + std::string_view name, xls::Type* to_type, xls::FunctionBase* in_func, + const xls::SourceInfo& loc, + const absl::InlinedVector& nodes); + +absl::StatusOr> DecomposeValue( + xls::Type* type, const xls::Value& value); + +struct GeneratedFunctionSliceMaps { + absl::flat_hash_map slice_index_map; + absl::flat_hash_map + slice_by_continuation_input; + absl::flat_hash_map + slice_by_continuation_output; +}; + +absl::StatusOr ComposeStaticValueInput( + const clang::NamedDecl* namedecl, bool generate_new_fsm, + const absl::flat_hash_map& + state_element_for_static, + const absl::flat_hash_map& + type_for_static, + xls::ProcBuilder& pb, const xls::SourceInfo& loc); + } // namespace xlscc #endif // XLS_CONTRIB_XLSCC_TRANSLATOR_TYPES_H_ diff --git a/xls/contrib/xlscc/unit_tests/BUILD b/xls/contrib/xlscc/unit_tests/BUILD index ca52bc2d51..bdd21dc8f8 100644 --- a/xls/contrib/xlscc/unit_tests/BUILD +++ b/xls/contrib/xlscc/unit_tests/BUILD @@ -175,9 +175,12 @@ cc_test( "//xls/ir", "//xls/ir:function_builder", "//xls/ir:state_element", + "//xls/ir:value", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@googletest//:gtest", "@llvm-project//clang:ast", ], @@ -266,7 +269,9 @@ cc_test( "//xls/common/status:matchers", "//xls/contrib/xlscc:hls_block_cc_proto", "//xls/contrib/xlscc:translator_types", + "//xls/ir", "//xls/ir:bits", + "//xls/ir:state_element", "//xls/ir:value", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status", diff --git a/xls/contrib/xlscc/unit_tests/continuations_test.cc b/xls/contrib/xlscc/unit_tests/continuations_test.cc index b4edb85708..b2f64a83e1 100644 --- a/xls/contrib/xlscc/unit_tests/continuations_test.cc +++ b/xls/contrib/xlscc/unit_tests/continuations_test.cc @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include +#include #include #include #include @@ -81,8 +83,8 @@ class ContinuationsTest : public XlsccTestBase { std::optional direct_in = std::nullopt) { for (const xlscc::ContinuationValue& continuation_out : slice.continuations_out) { - for (const clang::NamedDecl* decl : continuation_out.decls) { - if (decl->getNameAsString() == name && + for (const DeclLeaf& decl : continuation_out.decls) { + if (decl.decl->getNameAsString() == name && (!direct_in.has_value() || direct_in.value() == continuation_out.direct_in)) { return true; @@ -97,8 +99,8 @@ class ContinuationsTest : public XlsccTestBase { int64_t count = 0; for (const xlscc::ContinuationValue& continuation_out : slice.continuations_out) { - for (const clang::NamedDecl* decl : continuation_out.decls) { - if (decl->getNameAsString() == name) { + for (const DeclLeaf& decl : continuation_out.decls) { + if (decl.decl->getNameAsString() == name) { ++count; } } @@ -106,16 +108,68 @@ class ContinuationsTest : public XlsccTestBase { return count; }; - bool SliceInputsDecl(const xlscc::GeneratedFunctionSlice& slice, - std::string_view name, - std::optional direct_in = std::nullopt) { + static absl::flat_hash_map + GetSliceIndicesByValue(const GeneratedFunction& func) { + absl::flat_hash_map slice_index_by_value; + { + int64_t slice_index = 0; + for (const GeneratedFunctionSlice& full_slice : func.slices) { + for (const ContinuationValue& cont_out : full_slice.continuations_out) { + slice_index_by_value[&cont_out] = slice_index; + } + ++slice_index; + } + } + + return slice_index_by_value; + } + + // func must be specified when is_feedback is specified + bool SliceInputsDecl( + const xlscc::GeneratedFunctionSlice& slice, std::string_view name, + std::optional direct_in = std::nullopt, + std::optional is_feedback = std::nullopt, + std::optional func = std::nullopt, + std::optional decl_index = std::nullopt) { + auto check_is_feedback = + [&func, + &slice](const xlscc::ContinuationInput& continuation_in) -> bool { + CHECK(func.has_value()); + const xlscc::GeneratedFunction& full_func = *func.value(); + std::optional input_slice_index = std::nullopt; + absl::flat_hash_map + slice_index_by_value = GetSliceIndicesByValue(full_func); + { + int64_t slice_index = 0; + for (const GeneratedFunctionSlice& full_slice : full_func.slices) { + if (&slice == &full_slice) { + input_slice_index = slice_index; + } + ++slice_index; + } + } + CHECK(input_slice_index.has_value()); + CHECK(slice_index_by_value.contains(continuation_in.continuation_out)); + return input_slice_index <= + slice_index_by_value.at(continuation_in.continuation_out); + }; + for (const xlscc::ContinuationInput& continuation_in : slice.continuations_in) { - for (const clang::NamedDecl* decl : continuation_in.decls) { - if (decl->getNameAsString() == name && + absl::flat_hash_set all_decls = continuation_in.decls; + + all_decls.insert(continuation_in.continuation_out->decls.begin(), + continuation_in.continuation_out->decls.end()); + + for (const DeclLeaf& decl : all_decls) { + if (decl.decl->getNameAsString() == name && (!direct_in.has_value() || direct_in.value() == - continuation_in.continuation_out->direct_in)) { + continuation_in.continuation_out->direct_in) && + (!is_feedback.has_value() || + is_feedback.value() == check_is_feedback(continuation_in)) && + (!decl_index.has_value() || + decl_index.value() == decl.leaf_index)) { return true; } } @@ -128,8 +182,13 @@ class ContinuationsTest : public XlsccTestBase { int64_t count = 0; for (const xlscc::ContinuationInput& continuation_in : slice.continuations_in) { - for (const clang::NamedDecl* decl : continuation_in.decls) { - if (decl->getNameAsString() == name) { + absl::flat_hash_set all_decls = continuation_in.decls; + + all_decls.insert(continuation_in.continuation_out->decls.begin(), + continuation_in.continuation_out->decls.end()); + + for (const DeclLeaf& decl : all_decls) { + if (decl.decl->getNameAsString() == name) { ++count; } } @@ -140,12 +199,17 @@ class ContinuationsTest : public XlsccTestBase { bool SliceInputDoesNotInputBothDecls( const xlscc::GeneratedFunctionSlice& slice, std::string_view name_a, std::string_view name_b) { - absl::flat_hash_set decls_found; + absl::flat_hash_set decls_found; bool found = false; for (const xlscc::ContinuationInput& continuation_in : slice.continuations_in) { - for (const clang::NamedDecl* decl : continuation_in.decls) { - if (decl->getNameAsString() == name_a) { + absl::flat_hash_set all_decls = continuation_in.decls; + + all_decls.insert(continuation_in.continuation_out->decls.begin(), + continuation_in.continuation_out->decls.end()); + + for (const DeclLeaf& decl : all_decls) { + if (decl.decl->getNameAsString() == name_a) { decls_found = continuation_in.decls; found = true; continue; @@ -155,8 +219,8 @@ class ContinuationsTest : public XlsccTestBase { if (!found) { return true; } - for (const clang::NamedDecl* decl : decls_found) { - if (decl->getNameAsString() == name_b) { + for (const DeclLeaf& decl : decls_found) { + if (decl.decl->getNameAsString() == name_b) { return false; } } @@ -367,8 +431,9 @@ TEST_F(ContinuationsTest, PassthroughsRemovedScoped) { EXPECT_FALSE(SliceOutputsDecl(first_slice, "x")); EXPECT_TRUE(SliceOutputsDecl(second_slice, "x")); - EXPECT_TRUE(SliceOutputsDecl(third_slice, "x")); + EXPECT_TRUE(SliceInputsDecl(third_slice, "x")); EXPECT_FALSE(SliceOutputsDecl(fourth_slice, "x")); + EXPECT_TRUE(SliceInputsDecl(fourth_slice, "x")); EXPECT_FALSE(SliceOutputsDecl(fifth_slice, "x")); } @@ -400,52 +465,12 @@ TEST_F(ContinuationsTest, SwizzleNotRemoved) { ++slice_it; const xlscc::GeneratedFunctionSlice& third_slice = *slice_it; ++slice_it; + const xlscc::GeneratedFunctionSlice& fourth_slice = *slice_it; ++slice_it; EXPECT_FALSE(SliceOutputsDecl(second_slice, "y")); EXPECT_TRUE(SliceInputsDecl(third_slice, "x")); - EXPECT_TRUE(SliceOutputsDecl(third_slice, "y")); -} - -TEST_F(ContinuationsTest, SmallPassThroughNotRemoved) { - const std::string content = R"( - struct Small { - int x; - int y; - }; - - struct Big { - int x; - int y; - int z; - }; - - #pragma hls_top - void my_package(__xls_channel& in, - __xls_channel& out) { - const Small x = in.read(); - out.write(Big()); - const Big y = {.x = x.y, .y = x.x, .z = x.x}; - out.write(y); - out.write(y); - })"; - - XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, - GenerateTopFunction(content)); - - ASSERT_EQ(func->slices.size(), 5); - - auto slice_it = func->slices.begin(); - ++slice_it; - const xlscc::GeneratedFunctionSlice& second_slice = *slice_it; - ++slice_it; - const xlscc::GeneratedFunctionSlice& third_slice = *slice_it; - ++slice_it; - ++slice_it; - - EXPECT_FALSE(SliceOutputsDecl(second_slice, "y")); - EXPECT_TRUE(SliceInputsDecl(third_slice, "x")); - EXPECT_TRUE(SliceOutputsDecl(third_slice, "y")); + EXPECT_TRUE(SliceInputsDecl(fourth_slice, "y")); } TEST_F(ContinuationsTest, UnusedContinuationOutputsRemoved) { @@ -1032,8 +1057,14 @@ TEST_F(ContinuationsTest, PipelinedLoopBackwardsPropagation) { EXPECT_EQ(SliceInputsDeclCount(third_slice, "i"), 2); EXPECT_EQ(SliceInputsDeclCount(third_slice, "a"), 2); - EXPECT_TRUE(SliceInputsDecl(third_slice, "a")); - EXPECT_TRUE(SliceInputsDecl(third_slice, "i")); + EXPECT_TRUE(SliceInputsDecl(third_slice, "a", /*direct_in=*/true, + /*is_feedback=*/false, /*func=*/func)); + EXPECT_TRUE(SliceInputsDecl(third_slice, "i", /*direct_in=*/true, + /*is_feedback=*/false, /*func=*/func)); + EXPECT_TRUE(SliceInputsDecl(third_slice, "a", /*direct_in=*/false, + /*is_feedback=*/true, /*func=*/func)); + EXPECT_TRUE(SliceInputsDecl(third_slice, "i", /*direct_in=*/false, + /*is_feedback=*/true, /*func=*/func)); EXPECT_TRUE(SliceInputsDecl(fourth_slice, "a")); EXPECT_FALSE(SliceInputsDecl(fourth_slice, "i")); @@ -1253,7 +1284,7 @@ TEST_F(ContinuationsTest, PipelinedLoopSameNodeOneBypass) { EXPECT_TRUE(SliceInputsDecl(fourth_slice, "r")); EXPECT_TRUE(SliceInputDoesNotInputBothDecls(fourth_slice, "a", "r")); EXPECT_EQ(SliceInputsDeclCount(fourth_slice, "i"), 2); - EXPECT_EQ(SliceInputsDeclCount(fourth_slice, "a"), 2); + EXPECT_EQ(SliceInputsDeclCount(fourth_slice, "a"), 3); EXPECT_TRUE(SliceOutputsDecl(fourth_slice, "a")); EXPECT_TRUE(SliceOutputsDecl(fourth_slice, "i")); @@ -1454,12 +1485,16 @@ TEST_F(ContinuationsTest, PipelinedLoopNested) { EXPECT_FALSE(SliceOutputsDecl(third_slice, "i")); EXPECT_EQ(SliceInputsDeclCount(fourth_slice, "ctrl"), 1); + EXPECT_TRUE(SliceInputsDecl(fourth_slice, "ctrl", /*direct_in=*/false, + /*is_feedback=*/false, /*func=*/func)); EXPECT_EQ(SliceInputsDeclCount(fourth_slice, "a"), 3); EXPECT_EQ(SliceInputsDeclCount(fourth_slice, "j"), 2); EXPECT_EQ(SliceOutputsDeclCount(fourth_slice, "a"), 1); EXPECT_EQ(SliceOutputsDeclCount(fourth_slice, "j"), 1); EXPECT_EQ(SliceInputsDeclCount(fifth_slice, "ctrl"), 1); + EXPECT_TRUE(SliceInputsDecl(fifth_slice, "ctrl", /*direct_in=*/false, + /*is_feedback=*/false, /*func=*/func)); EXPECT_EQ(SliceInputsDeclCount(fifth_slice, "i"), 1); EXPECT_EQ(SliceOutputsDeclCount(fifth_slice, "a"), 1); EXPECT_EQ(SliceInputsDeclCount(fifth_slice, "a"), 1); @@ -1566,12 +1601,50 @@ TEST_F(ContinuationsTest, DirectInMarked) { EXPECT_TRUE(SliceOutputsDecl(first_slice, "direct_in", /*direct_in=*/true)); EXPECT_TRUE(SliceOutputsDecl(second_slice, "x", /*direct_in=*/false)); - EXPECT_TRUE(SliceInputsDecl(second_slice, "direct_in", /*direct_in=*/true)); - EXPECT_TRUE(SliceOutputsDecl(second_slice, "dval", /*direct_in=*/true)); EXPECT_TRUE(SliceInputsDecl(third_slice, "dval", /*direct_in=*/true)); } +TEST_F(ContinuationsTest, DirectInNotDecomposed) { + const std::string content = R"( + struct DirectIn { + int x; + int y; + }; + + #pragma hls_top + void my_package(const DirectIn&direct_in, + __xls_channel& in, + __xls_channel& out) { + const int x = in.read(); + out.write(x); + out.write(x * direct_in.x); + })"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + + ASSERT_EQ(func->slices.size(), 4); + + auto slice_it = func->slices.begin(); + const xlscc::GeneratedFunctionSlice& first_slice = *slice_it; + ++slice_it; + ++slice_it; + const xlscc::GeneratedFunctionSlice& third_slice = *slice_it; + + EXPECT_TRUE(SliceOutputsDecl(first_slice, "direct_in", /*direct_in=*/true)); + + EXPECT_TRUE(SliceInputsDecl(third_slice, "direct_in", /*direct_in=*/true)); + + for (const ContinuationInput& continuation_in : + third_slice.continuations_in) { + if (!continuation_in.continuation_out->direct_in) { + continue; + } + EXPECT_TRUE(continuation_in.input_node->GetType()->IsTuple()); + } +} + TEST_F(ContinuationsTest, DirectInShouldNotFeedback) { const std::string content = R"( struct DirectIn { @@ -1734,8 +1807,6 @@ TEST_F(ContinuationsTest, DirectInMarkedInSubroutine) { EXPECT_TRUE(SliceOutputsDecl(first_slice, "direct_in", /*direct_in=*/true)); EXPECT_TRUE(SliceOutputsDecl(second_slice, "x", /*direct_in=*/false)); - EXPECT_TRUE(SliceInputsDecl(second_slice, "direct_in", /*direct_in=*/true)); - EXPECT_TRUE(SliceOutputsDecl(second_slice, "dval", /*direct_in=*/true)); EXPECT_TRUE(SliceInputsDecl(third_slice, "dval", /*direct_in=*/true)); } @@ -1790,8 +1861,6 @@ TEST_F(ContinuationsTest, DirectInMarkedInSubroutineMultiCall) { EXPECT_TRUE(SliceOutputsDecl(first_slice, "direct_in", /*direct_in=*/true)); EXPECT_TRUE(SliceOutputsDecl(second_slice, "x", /*direct_in=*/false)); - EXPECT_TRUE(SliceInputsDecl(second_slice, "direct_in", /*direct_in=*/true)); - EXPECT_TRUE(SliceOutputsDecl(second_slice, "dval", /*direct_in=*/true)); EXPECT_TRUE(SliceInputsDecl(third_slice, "dval", /*direct_in=*/true)); @@ -1837,8 +1906,6 @@ TEST_F(ContinuationsTest, DirectInMarkedInSubroutineNonConst) { EXPECT_TRUE(SliceOutputsDecl(first_slice, "direct_in", /*direct_in=*/true)); EXPECT_TRUE(SliceOutputsDecl(second_slice, "x", /*direct_in=*/false)); - EXPECT_TRUE(SliceInputsDecl(second_slice, "direct_in", /*direct_in=*/true)); - EXPECT_TRUE(SliceOutputsDecl(second_slice, "dval", /*direct_in=*/true)); EXPECT_TRUE(SliceInputsDecl(third_slice, "dval", /*direct_in=*/true)); } @@ -1887,8 +1954,6 @@ TEST_F(ContinuationsTest, DirectInMarkedInSubroutineNested) { EXPECT_TRUE(SliceOutputsDecl(first_slice, "direct_in", /*direct_in=*/true)); EXPECT_TRUE(SliceOutputsDecl(second_slice, "x", /*direct_in=*/false)); - EXPECT_TRUE(SliceInputsDecl(second_slice, "direct_in", /*direct_in=*/true)); - EXPECT_TRUE(SliceOutputsDecl(second_slice, "dval", /*direct_in=*/true)); EXPECT_TRUE(SliceInputsDecl(third_slice, "dval", /*direct_in=*/true)); } @@ -1943,12 +2008,46 @@ TEST_F(ContinuationsTest, DirectInMarkedInSubroutineNested2) { EXPECT_TRUE(SliceOutputsDecl(first_slice, "direct_in", /*direct_in=*/true)); EXPECT_TRUE(SliceOutputsDecl(second_slice, "x", /*direct_in=*/false)); - EXPECT_TRUE(SliceInputsDecl(second_slice, "direct_in", /*direct_in=*/true)); - EXPECT_TRUE(SliceOutputsDecl(second_slice, "dval", /*direct_in=*/true)); EXPECT_TRUE(SliceInputsDecl(third_slice, "dval", /*direct_in=*/true)); } +TEST_F(ContinuationsTest, DirectInSelectNotMarked) { + const std::string content = R"( + struct DirectIn { + int x; + int y; + }; + + #pragma hls_top + void my_package(const DirectIn&direct_in, + __xls_channel& in, + __xls_channel& out) { + const int x = in.read(); + const int dval = x == 5 ? direct_in.x : direct_in.y; + out.write(x); + out.write(x * dval); + })"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + + ASSERT_EQ(func->slices.size(), 4); + + auto slice_it = func->slices.begin(); + const xlscc::GeneratedFunctionSlice& first_slice = *slice_it; + ++slice_it; + const xlscc::GeneratedFunctionSlice& second_slice = *slice_it; + ++slice_it; + const xlscc::GeneratedFunctionSlice& third_slice = *slice_it; + + EXPECT_TRUE(SliceOutputsDecl(first_slice, "direct_in", /*direct_in=*/true)); + + EXPECT_TRUE(SliceOutputsDecl(second_slice, "x", /*direct_in=*/false)); + + EXPECT_TRUE(SliceInputsDecl(third_slice, "dval", /*direct_in=*/false)); +} + TEST_F(ContinuationsTest, SplitOnChannelOps) { const std::string content = R"( #pragma hls_top @@ -2043,5 +2142,335 @@ TEST_F(ContinuationsTest, PipelinedLoopBackwardsPropagationInSubroutine) { EXPECT_FALSE(SliceInputsDecl(fourth_slice, "i", /*direct_in=*/false)); } +TEST_F(ContinuationsTest, PassthroughSliceNotRemoved) { + const std::string content = R"( + struct Big { + int v[7]; + }; + struct Small { + int v[3]; + bool flag; + }; + + class Block { + public: + __xls_channel& in; + __xls_channel& in2; + __xls_channel& out; + + #pragma hls_top + void Run() { + Big big = in.read(); + Small small; + (void)in2.read(); + small.v[0] = big.v[0]; + small.v[1] = big.v[1]; + small.v[2] = big.v[2]; + [[hls_pipeline_init_interval(1)]] + for (int i = 0; i < 24; ++i) { + small.flag = false; + } + } + };)"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + (void)func; + // (Just check that it doesn't crash) +} + +TEST_F(ContinuationsTest, StructSingleElementContinued) { + const std::string content = R"( + struct Test { + int x = 0; + int y = 0; + }; + + #pragma hls_top + void my_package(__xls_channel& in, + __xls_channel& out) { + Test test; + test.x = in.read(); + test.y = test.x * 3; + out.write(test.y); + out.write(test.x); + out.write(test.x); + })"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + + ASSERT_EQ(func->slices.size(), 5); + + auto slice_it = func->slices.begin(); + ++slice_it; + const xlscc::GeneratedFunctionSlice& second_slice = *slice_it; + + EXPECT_EQ(SliceOutputsDeclCount(second_slice, "test"), 1); +} + +TEST_F(ContinuationsTest, PipelinedLoopSimpleFeedback) { + const std::string content = R"( + class Block { + public: + __xls_channel& in; + __xls_channel& out; + + #pragma hls_top + void Run() { + int sum = 0; + + #pragma hls_pipeline_init_interval 1 + while(true) { + sum += in.read(); + out.write(sum); + } + } + };)"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + ASSERT_EQ(func->slices.size(), 5); + + auto slice_it = func->slices.begin(); + const xlscc::GeneratedFunctionSlice& first_slice = *slice_it; + ++slice_it; + ++slice_it; + const xlscc::GeneratedFunctionSlice& third_slice = *slice_it; + ++slice_it; + const xlscc::GeneratedFunctionSlice& fourth_slice = *slice_it; + ++slice_it; + + EXPECT_TRUE(SliceOutputsDecl(first_slice, "sum")); + EXPECT_EQ(SliceInputsDeclCount(third_slice, "sum"), 2); + EXPECT_TRUE(SliceInputsDecl(third_slice, "sum", /*direct_in=*/false, + /*is_feedback=*/true, /*func=*/func)); + EXPECT_TRUE(SliceOutputsDecl(third_slice, "sum")); + EXPECT_EQ(SliceInputsDeclCount(fourth_slice, "sum"), 1); + EXPECT_TRUE(SliceOutputsDecl(fourth_slice, "sum")); +} + +TEST_F(ContinuationsTest, PipelinedLoopSimpleNoFeedback) { + const std::string content = R"( + class Block { + public: + __xls_channel& in; + __xls_channel& out; + + #pragma hls_top + void Run() { + const int v = in.read(); + + #pragma hls_pipeline_init_interval 1 + while(true) { + out.write(v); + } + } + };)"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + ASSERT_EQ(func->slices.size(), 5); + + auto slice_it = func->slices.begin(); + ++slice_it; + const xlscc::GeneratedFunctionSlice& second_slice = *slice_it; + ++slice_it; + const xlscc::GeneratedFunctionSlice& third_slice = *slice_it; + ++slice_it; + ++slice_it; + + EXPECT_TRUE(SliceOutputsDecl(second_slice, "v")); + EXPECT_TRUE(SliceInputsDecl(third_slice, "v", /*direct_in=*/false, + /*is_feedback=*/false, /*func=*/func)); +} + +TEST_F(ContinuationsTest, PassthroughFeedbackOrder) { + const std::string content = R"( + class Block { + public: + __xls_channel& in; + __xls_channel& out; + + #pragma hls_top + void Run() { + int value = in.read(); + int set_it = 10; + + #pragma hls_pipeline_init_interval 1 + for(int i=0;i<6;++i) { + #pragma hls_pipeline_init_interval 1 + for(int j=0;j<6;++j) { + out.write(value); + } + value = set_it; + } + } + };)"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + ASSERT_GE(func->slices.size(), 4); + + auto slice_it = func->slices.begin(); + ++slice_it; + ++slice_it; + ++slice_it; + const xlscc::GeneratedFunctionSlice& fourth_slice = *slice_it; + + ASSERT_EQ(fourth_slice.continuations_in.size(), 3); + + absl::flat_hash_map + slice_indices_by_value = GetSliceIndicesByValue(*func); + + std::list inputs_sorted_by_slice_index = + fourth_slice.continuations_in; + + inputs_sorted_by_slice_index.sort( + [&slice_indices_by_value](const ContinuationInput& a, + const ContinuationInput& b) -> bool { + return slice_indices_by_value.at(a.continuation_out) < + slice_indices_by_value.at(b.continuation_out); + }); + + const absl::flat_hash_set& first_decls = + inputs_sorted_by_slice_index.front().continuation_out->decls; + EXPECT_FALSE(std::any_of(first_decls.begin(), first_decls.end(), + [](const DeclLeaf& decl) { + return decl.decl->getNameAsString() == "set_it"; + })); + + const absl::flat_hash_set& last_decls = + inputs_sorted_by_slice_index.back().continuation_out->decls; + EXPECT_TRUE(std::any_of(last_decls.begin(), last_decls.end(), + [](const DeclLeaf& decl) { + return decl.decl->getNameAsString() == "set_it"; + })); +} + +TEST_F(ContinuationsTest, ContinuationDecomposed) { + const std::string content = R"( + struct Thing { + int x; + long y; + }; + + #pragma hls_top + void my_package(__xls_channel& in, + __xls_channel& out) { + const Thing x = in.read(); + out.write(x); + out.write(x); + })"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + + ASSERT_EQ(func->slices.size(), 4); + + auto slice_it = func->slices.begin(); + ++slice_it; + const xlscc::GeneratedFunctionSlice& second_slice = *slice_it; + ++slice_it; + const xlscc::GeneratedFunctionSlice& third_slice = *slice_it; + ++slice_it; + + EXPECT_TRUE(SliceOutputsDecl(second_slice, "x")); + EXPECT_TRUE(SliceInputsDecl(third_slice, "x", /*direct_in=*/std::nullopt, + /*is_feedback=*/std::nullopt, + /*func=*/std::nullopt, /*decl_index=*/0)); + EXPECT_TRUE(SliceInputsDecl(third_slice, "x", /*direct_in=*/std::nullopt, + /*is_feedback=*/std::nullopt, + /*func=*/std::nullopt, /*decl_index=*/1)); + EXPECT_FALSE(SliceInputsDecl(third_slice, "x", /*direct_in=*/std::nullopt, + /*is_feedback=*/std::nullopt, + /*func=*/std::nullopt, /*decl_index=*/-1)); + EXPECT_FALSE(SliceInputsDecl(third_slice, "x", /*direct_in=*/std::nullopt, + /*is_feedback=*/std::nullopt, + /*func=*/std::nullopt, /*decl_index=*/2)); +} + +TEST_F(ContinuationsTest, ContinuationLiteralDecomposed) { + const std::string content = R"( + struct Thing { + int x; + long y; + }; + + #pragma hls_top + void my_package(__xls_channel& in, + __xls_channel& out) { + Thing x = {.x = 5, .y = 10}; + [[hls_pipeline_init_interval(1)]] + for (int i=0;i<2;++i) { + out.write(x); + x = in.read(); + } + })"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + + ASSERT_GE(func->slices.size(), 2); + + auto slice_it = func->slices.begin(); + ++slice_it; + const xlscc::GeneratedFunctionSlice& second_slice = *slice_it; + + // TODO(seanhaskell): Calling literals direct-in is stopped them from being + // decomposed. This makes loops less efficient + + EXPECT_TRUE(SliceInputsDecl(second_slice, "x", /*direct_in=*/true, + /*is_feedback=*/false, /*func=*/func, + /*decl_index=*/-1)); + EXPECT_TRUE(SliceInputsDecl(second_slice, "x", /*direct_in=*/false, + /*is_feedback=*/true, /*func=*/func, + /*decl_index=*/-1)); +} + +TEST_F(ContinuationsTest, StaticThisDecomposed) { + const std::string content = R"( + struct Block { + int x; + long y; + + void Run(__xls_channel& in, + __xls_channel& out) { + [[hls_pipeline_init_interval(1)]] + for (int i=0;i<8;++i) { + [[hls_pipeline_init_interval(1)]] + for (int j=0;j<8;++j) { + y += x; + } + } + [[hls_pipeline_init_interval(1)]] + for (int x = 0; x < 4; ++x) { + } + } + }; + + #pragma hls_top + void my_package(__xls_channel& in, + __xls_channel& out) { + static Block block; + block.Run(in, out); + })"; + + XLS_ASSERT_OK_AND_ASSIGN(const xlscc::GeneratedFunction* func, + GenerateTopFunction(content)); + + (void)func; + ASSERT_EQ(func->slices.size(), 7); + + const xlscc::GeneratedFunctionSlice& last_slice = func->slices.back(); + + EXPECT_TRUE(SliceInputsDecl(last_slice, "block", /*direct_in=*/std::nullopt, + /*is_feedback=*/std::nullopt, + /*func=*/std::nullopt, /*decl_index=*/0)); + EXPECT_TRUE(SliceInputsDecl(last_slice, "block", /*direct_in=*/std::nullopt, + /*is_feedback=*/std::nullopt, + /*func=*/std::nullopt, /*decl_index=*/1)); +} + } // namespace } // namespace xlscc diff --git a/xls/contrib/xlscc/unit_tests/fsm_layout_test.cc b/xls/contrib/xlscc/unit_tests/fsm_layout_test.cc index 3f5bcf0e5c..9c6d1730a9 100644 --- a/xls/contrib/xlscc/unit_tests/fsm_layout_test.cc +++ b/xls/contrib/xlscc/unit_tests/fsm_layout_test.cc @@ -22,7 +22,9 @@ #include "gtest/gtest.h" #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/container/inlined_vector.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "clang/include/clang/AST/Decl.h" #include "xls/common/file/temp_file.h" #include "xls/common/status/matchers.h" @@ -36,6 +38,7 @@ #include "xls/ir/nodes.h" #include "xls/ir/package.h" #include "xls/ir/state_element.h" +#include "xls/ir/value.h" namespace xlscc { @@ -43,7 +46,7 @@ class FSMLayoutTest : public XlsccTestBase { public: absl::StatusOr GenerateTopFunction( std::string_view content, - absl::flat_hash_map* + absl::flat_hash_map* state_element_for_static_out = nullptr) { generate_new_fsm_ = true; @@ -69,16 +72,35 @@ class FSMLayoutTest : public XlsccTestBase { DebugIrTraceFlags_FSMStates, split_states_on_channel_ops_); - absl::flat_hash_map - state_element_for_static; + absl::flat_hash_map state_element_for_static; xls::ProcBuilder pb("test", package_.get()); + for (const auto& [decl, init_cvalue] : func->static_values) { + const xls::Value& init_value = init_cvalue.rvalue(); + absl::InlinedVector decomposed_values; - for (const auto& [decl, init_val] : func->static_values) { - NATIVE_BVAL state_read_bval = pb.StateElement( - decl->getNameAsString(), init_val.rvalue(), xls::SourceInfo()); - state_element_for_static[decl] = - state_read_bval.node()->As()->state_element(); + XLS_ASSIGN_OR_RETURN(xls::TypeProto type_proto, init_value.TypeAsProto()); + XLS_ASSIGN_OR_RETURN(xls::Type * xls_type, + package_->GetTypeFromProto(type_proto)); + + XLS_ASSIGN_OR_RETURN(decomposed_values, + DecomposeValue(xls_type, init_value)); + + // Empty tuple case + if (decomposed_values.empty()) { + decomposed_values.push_back(init_value); + } + + for (int64_t i = 0; i < decomposed_values.size(); ++i) { + const xls::Value& decomposed_value = decomposed_values.at(i); + DeclLeaf decl_leaf = {.decl = decl, .leaf_index = i}; + const std::string& name = decl->getNameAsString(); + std::string decomposed_name = absl::StrFormat("%s_%d", name, i); + NATIVE_BVAL state_read_bval = pb.StateElement( + decomposed_name, decomposed_value, xls::SourceInfo()); + state_element_for_static[decl_leaf] = + state_read_bval.node()->As()->state_element(); + } } if (state_element_for_static_out != nullptr) { @@ -89,6 +111,21 @@ class FSMLayoutTest : public XlsccTestBase { xls::SourceInfo()); } + bool TypeContainsArray(xls::Type* type) { + if (type->IsArray()) { + return true; + } + if (type->IsTuple()) { + for (xls::Type* elem_type : type->AsTupleOrDie()->element_types()) { + if (TypeContainsArray(elem_type)) { + return true; + } + } + return false; + } + return false; + } + std::vector FilterStates( const NewFSMLayout& layout, const int64_t find_slice_index, const absl::flat_hash_set& find_jumped_from_slice_indices) { @@ -124,7 +161,7 @@ class FSMLayoutTest : public XlsccTestBase { continue; } for (const auto& decl : continuation_value->decls) { - if (decl->getNameAsString() == name) { + if (decl.decl->getNameAsString() == name) { return true; } } @@ -149,7 +186,7 @@ class FSMLayoutTest : public XlsccTestBase { continue; } for (const auto& decl : value->decls) { - if (decl->getNameAsString() == name) { + if (decl.decl->getNameAsString() == name) { return value; } } @@ -162,7 +199,7 @@ class FSMLayoutTest : public XlsccTestBase { for (const NewFSMState& state : layout.states) { for (const ContinuationValue* value : state.values_to_save) { for (const auto& decl : value->decls) { - if (decl->getNameAsString() == name) { + if (decl.decl->getNameAsString() == name) { return true; } } @@ -177,7 +214,7 @@ class FSMLayoutTest : public XlsccTestBase { for (const ContinuationValue* value : state.values_to_save) { bool count_value = false; for (const auto& decl : value->decls) { - if (decl->getNameAsString() == name) { + if (decl.decl->getNameAsString() == name) { count_value = true; } } @@ -736,30 +773,119 @@ TEST_F(FSMLayoutTest, IOActivationTransitionsNoTransition) { EXPECT_TRUE(layout.transition_by_slice_from_index.empty()); } -TEST_F(FSMLayoutTest, StaticStateElementShared) { +TEST_F(FSMLayoutTest, StaticStateElementLeafShared) { const std::string content = R"( + struct Test { + int x = 0; + int y = 0; + }; + #pragma hls_top void my_package(__xls_channel& in, __xls_channel& out) { - static int count = 0; - const int x = in.read(); - count += x; - const int y = in.read(); - out.write(count + y); + static Test count; + int z; + + [[hls_pipeline_init_interval(1)]] + for (int i=0;i<1;++i) { + count.x = in.read(); + z = 3 * count.x; + count.y = in.read(); + } + out.write(count.y + z); + })"; + split_states_on_channel_ops_ = false; + + absl::flat_hash_map state_element_for_static; + + XLS_ASSERT_OK_AND_ASSIGN( + NewFSMLayout layout, + GenerateTopFunction(content, &state_element_for_static)); + + ASSERT_EQ(layout.state_elements.size(), 3); + + for (const NewFSMStateElement& elem : layout.state_elements) { + EXPECT_TRUE(elem.type->IsBits()); + EXPECT_EQ(elem.type->GetFlatBitCount(), 32L); + } +} + +TEST_F(FSMLayoutTest, StateElementSharedForAlias) { + const std::string content = R"( + struct TestBase { + long arr[16]; + + void WriteIt(__xls_channel& out) { + [[hls_pipeline_init_interval(1)]] + for (int i=0;i<16;++i) { + out.write(arr[i]); + arr[i] += 5; + } + } + }; + + struct Test : TestBase { + }; + + #pragma hls_top + void my_package(__xls_channel& in, + __xls_channel& out) { + static Test count = in.read(); + count.WriteIt(out); + [[hls_pipeline_init_interval(1)]] + for (int i=0;i<16;++i) { + out.write(count.arr[i]); + count.arr[i] -= 1; + } })"; split_states_on_channel_ops_ = true; - absl::flat_hash_map - state_element_for_static; + absl::flat_hash_map state_element_for_static; + + XLS_ASSERT_OK_AND_ASSIGN( + NewFSMLayout layout, + GenerateTopFunction(content, &state_element_for_static)); + + int64_t n_array_state_elements = 0; + for (const NewFSMStateElement& elem : layout.state_elements) { + if (TypeContainsArray(elem.type)) { + ++n_array_state_elements; + } + } + EXPECT_EQ(n_array_state_elements, 1); +} + +TEST_F(FSMLayoutTest, StaticStateElementLeafSharedByRef) { + const std::string content = R"( + struct Test { + int x = 0; + int y = 0; + }; + + #pragma hls_top + void my_package(__xls_channel& in, + __xls_channel& out) { + Test count; + int z; + + count.x = in.read(); + z = 3 * count.x; + + [[hls_pipeline_init_interval(1)]] + for (int i=0;i<1;++i) { + count.y = in.read(); + } + out.write(count.y + z); + })"; + split_states_on_channel_ops_ = false; + + absl::flat_hash_map state_element_for_static; XLS_ASSERT_OK_AND_ASSIGN( NewFSMLayout layout, GenerateTopFunction(content, &state_element_for_static)); - ASSERT_EQ(layout.state_elements.size(), 1); - ASSERT_EQ(state_element_for_static.size(), 1); - EXPECT_EQ(layout.state_elements.front().existing_state_element, - state_element_for_static.begin()->second); + ASSERT_EQ(layout.state_elements.size(), 2); } } // namespace diff --git a/xls/contrib/xlscc/unit_tests/testdata/translator_verilog_test_IOProcComboGenNToOneMux.svtxt b/xls/contrib/xlscc/unit_tests/testdata/translator_verilog_test_IOProcComboGenNToOneMux.svtxt index a8ce380ab3..596b1f92c5 100644 --- a/xls/contrib/xlscc/unit_tests/testdata/translator_verilog_test_IOProcComboGenNToOneMux.svtxt +++ b/xls/contrib/xlscc/unit_tests/testdata/translator_verilog_test_IOProcComboGenNToOneMux.svtxt @@ -16,16 +16,16 @@ module foo_proc( wire [31:0] in2_select; wire [31:0] in1_select; wire p0_all_active_inputs_valid; - wire [31:0] out_send_value; + wire [31:0] out_op0_0__2; assign ctx_2__x_literal = 32'h0000_0000; assign continuation_1_ctx_3__full_condi_output = dir == ctx_2__x_literal; assign ctx_3__full_condition_ctx_3__rel_output__1 = ~continuation_1_ctx_3__full_condi_output; assign in2_select = ctx_3__full_condition_ctx_3__rel_output__1 ? in2 : 32'h0000_0000; assign in1_select = continuation_1_ctx_3__full_condi_output ? in1 : 32'h0000_0000; assign p0_all_active_inputs_valid = (~continuation_1_ctx_3__full_condi_output | in1_vld) & (~ctx_3__full_condition_ctx_3__rel_output__1 | in2_vld); - assign out_send_value = continuation_1_ctx_3__full_condi_output ? in1_select : in2_select; + assign out_op0_0__2 = continuation_1_ctx_3__full_condi_output ? in1_select : in2_select; assign in1_rdy = out_rdy & continuation_1_ctx_3__full_condi_output; assign in2_rdy = out_rdy & ctx_3__full_condition_ctx_3__rel_output__1; - assign out = out_send_value; + assign out = out_op0_0__2; assign out_vld = p0_all_active_inputs_valid & 1'h1 & 1'h1; endmodule diff --git a/xls/contrib/xlscc/unit_tests/testdata/translator_verilog_test_IOProcComboGenNToOneMux.vtxt b/xls/contrib/xlscc/unit_tests/testdata/translator_verilog_test_IOProcComboGenNToOneMux.vtxt index a8ce380ab3..596b1f92c5 100644 --- a/xls/contrib/xlscc/unit_tests/testdata/translator_verilog_test_IOProcComboGenNToOneMux.vtxt +++ b/xls/contrib/xlscc/unit_tests/testdata/translator_verilog_test_IOProcComboGenNToOneMux.vtxt @@ -16,16 +16,16 @@ module foo_proc( wire [31:0] in2_select; wire [31:0] in1_select; wire p0_all_active_inputs_valid; - wire [31:0] out_send_value; + wire [31:0] out_op0_0__2; assign ctx_2__x_literal = 32'h0000_0000; assign continuation_1_ctx_3__full_condi_output = dir == ctx_2__x_literal; assign ctx_3__full_condition_ctx_3__rel_output__1 = ~continuation_1_ctx_3__full_condi_output; assign in2_select = ctx_3__full_condition_ctx_3__rel_output__1 ? in2 : 32'h0000_0000; assign in1_select = continuation_1_ctx_3__full_condi_output ? in1 : 32'h0000_0000; assign p0_all_active_inputs_valid = (~continuation_1_ctx_3__full_condi_output | in1_vld) & (~ctx_3__full_condition_ctx_3__rel_output__1 | in2_vld); - assign out_send_value = continuation_1_ctx_3__full_condi_output ? in1_select : in2_select; + assign out_op0_0__2 = continuation_1_ctx_3__full_condi_output ? in1_select : in2_select; assign in1_rdy = out_rdy & continuation_1_ctx_3__full_condi_output; assign in2_rdy = out_rdy & ctx_3__full_condition_ctx_3__rel_output__1; - assign out = out_send_value; + assign out = out_op0_0__2; assign out_vld = p0_all_active_inputs_valid & 1'h1 & 1'h1; endmodule diff --git a/xls/contrib/xlscc/unit_tests/translator_proc_test.cc b/xls/contrib/xlscc/unit_tests/translator_proc_test.cc index 0ca77a26f8..559de294ea 100644 --- a/xls/contrib/xlscc/unit_tests/translator_proc_test.cc +++ b/xls/contrib/xlscc/unit_tests/translator_proc_test.cc @@ -10274,6 +10274,7 @@ TEST_P(TranslatorProcTest, SubroutineNotDuplicated) { function_names.insert(func->name()); } } + TEST_P(TranslatorProcTest, PassthroughCrossingFeedback) { const std::string content = R"( class Block { @@ -10311,44 +10312,6 @@ TEST_P(TranslatorProcTest, PassthroughCrossingFeedback) { } } -TEST_P(TranslatorProcTest, PassthroughFeedbackOrder) { - const std::string content = R"( - class Block { - public: - __xls_channel& in; - __xls_channel& out; - - #pragma hls_top - void Run() { - int value = in.read(); - int set_it = 10; - - #pragma hls_pipeline_init_interval 1 - for(int i=0;i<2;++i) { - #pragma hls_pipeline_init_interval 1 - for(int j=0;j<3;++j) { - out.write(value); - } - value = set_it; - } - } - };)"; - - absl::flat_hash_map> inputs; - inputs["in"] = {xls::Value(xls::SBits(34, 32))}; - - { - absl::flat_hash_map> outputs; - outputs["out"] = { - xls::Value(xls::SBits(34, 32)), xls::Value(xls::SBits(34, 32)), - xls::Value(xls::SBits(34, 32)), xls::Value(xls::SBits(10, 32)), - xls::Value(xls::SBits(10, 32)), xls::Value(xls::SBits(10, 32))}; - ProcTest(content, /*block_spec=*/std::nullopt, inputs, outputs, - /* min_ticks = */ 1, - /* max_ticks = */ 100, - /* top_level_init_interval = */ 1); - } -} } // namespace diff --git a/xls/contrib/xlscc/unit_tests/translator_static_test.cc b/xls/contrib/xlscc/unit_tests/translator_static_test.cc index d2226c68cd..93d60fd899 100644 --- a/xls/contrib/xlscc/unit_tests/translator_static_test.cc +++ b/xls/contrib/xlscc/unit_tests/translator_static_test.cc @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -27,6 +28,8 @@ #include "xls/contrib/xlscc/translator_types.h" #include "xls/contrib/xlscc/unit_tests/unit_test.h" #include "xls/ir/bits.h" +#include "xls/ir/proc.h" +#include "xls/ir/state_element.h" #include "xls/ir/value.h" namespace xlscc { @@ -415,7 +418,9 @@ TEST_F(TranslatorStaticTest, IOProcStaticNoIO) { ch_out->set_type(CHANNEL_TYPE_FIFO); } - { ProcTest(content, block_spec, {}, {}); } + { + ProcTest(content, block_spec, {}, {}); + } } TEST_F(TranslatorStaticTest, IOProcStaticStruct) { @@ -1232,6 +1237,62 @@ TEST_F(TranslatorStaticTest, ReturnStaticLValueTemplated) { RunWithStatics(args, expected_vals, content); } +TEST_F(TranslatorStaticTest, StaticDecomposed) { + const std::string content = R"( + struct Thing { + int x; + long y; + }; + + class Block { + __xls_channel in; + __xls_channel out; + + #pragma hls_top + void foo() { + static Thing thing; + thing.x += in.read(); + thing.y = thing.x * 3; + out.write(thing.y); + } + }; + )"; + + generate_new_fsm_ = true; + + absl::flat_hash_map> inputs; + inputs["in"] = {xls::Value(xls::SBits(80, 32)), + xls::Value(xls::SBits(100, 32))}; + + { + absl::flat_hash_map> outputs; + outputs["out"] = {xls::Value(xls::SBits(3 * 80, 32)), + xls::Value(xls::SBits(3 * (80 + 100), 32))}; + + ProcTest(content, /*block_spec=*/std::nullopt, inputs, outputs); + } + + ASSERT_EQ(package_->procs().size(), 1); + const xls::Proc* proc = package_->procs().at(0).get(); + + int64_t total_32bit_elems = 0; + int64_t total_64bit_elems = 0; + + for (const xls::StateElement* state : proc->StateElements()) { + EXPECT_FALSE(state->type()->IsTuple() && + state->type()->GetFlatBitCount() > 0); + if (state->type()->IsBits() && state->type()->GetFlatBitCount() == 32) { + ++total_32bit_elems; + } + if (state->type()->IsBits() && state->type()->GetFlatBitCount() == 64) { + ++total_64bit_elems; + } + } + + EXPECT_EQ(total_32bit_elems, 1); + EXPECT_EQ(total_64bit_elems, 1); +} + } // namespace } // namespace xlscc diff --git a/xls/passes/data_flow_node_info.h b/xls/passes/data_flow_node_info.h index fc0628082f..51207f5a9e 100644 --- a/xls/passes/data_flow_node_info.h +++ b/xls/passes/data_flow_node_info.h @@ -47,6 +47,10 @@ namespace xls { // CRTP is used to instantiate sub-caches for each invoke, since their // parameters can vary. The CRTP type is called Derived. // +// The include_selectors flag controls whether or not dynamic selectors or +// indices are included in the output info. Even when this is turned on, +// selectors known at compile time (eg literals) will not be included. +// // The compute_tree_for_source flag controls which API is used to originate // Infos. // @@ -78,15 +82,18 @@ class DataFlowLazyNodeInfo : public LazyNodeInfo { virtual Info MergeInfos(const absl::Span& infos) const = 0; explicit DataFlowLazyNodeInfo(bool compute_tree_for_source, - bool default_info_source) + bool default_info_source, + bool include_selectors) : LazyNodeInfo(DagCacheInvalidateDirection::kInvalidatesUsers), compute_tree_for_source_(compute_tree_for_source), - default_info_source_(default_info_source) {} + default_info_source_(default_info_source), + include_selectors_(include_selectors) {} DataFlowLazyNodeInfo(const DataFlowLazyNodeInfo& o) : LazyNodeInfo(o), compute_tree_for_source_(o.compute_tree_for_source_), default_info_source_(o.default_info_source_), + include_selectors_(o.include_selectors_), query_engine_(o.query_engine_), parent_(o.parent_), parent_node_(o.parent_node_) { @@ -125,8 +132,8 @@ class DataFlowLazyNodeInfo : public LazyNodeInfo { LOG(FATAL) << "Unsupported value type"; } - void DuplicateInfo(xls::Type* type, const Info& info, - absl::InlinedVector& infos) const { + static void DuplicateInfo(xls::Type* type, const Info& info, + absl::InlinedVector& infos) { if (type->IsBits()) { infos.push_back(info); return; @@ -160,6 +167,9 @@ class DataFlowLazyNodeInfo : public LazyNodeInfo { bool is_info_source = false; + // A place to put a synthetic tree. + LeafTypeTree selector_tree; + switch (node->op()) { case xls::Op::kParam: { is_info_source = true; @@ -200,9 +210,15 @@ class DataFlowLazyNodeInfo : public LazyNodeInfo { return ret.AsShared().ToOwned(); } - // With dynamic indexing, merge all infos in the array (but not the - // index) + // With dynamic indexing, merge all infos in the array, and optionally + // the index operand_infos_out.clear(); + if (include_selectors_) { + for (int64_t i = 0; i < indices.size(); ++i) { + operand_infos_out.push_back( + operand_infos[xls::ArrayIndex::kIndexOperandStart + i]); + } + } operand_infos_out.push_back( operand_infos[xls::ArrayIndex::kArgOperand]); @@ -238,6 +254,12 @@ class DataFlowLazyNodeInfo : public LazyNodeInfo { operand_infos_out.clear(); operand_infos_out.push_back(to_update_info); operand_infos_out.push_back(replace_info); + if (include_selectors_) { + for (int64_t i = 0; i < array_update->indices().size(); ++i) { + operand_infos_out.push_back( + operand_infos[xls::ArrayUpdate::kIndexOperandStart + i]); + } + } // Fall through to default handling break; @@ -349,6 +371,16 @@ class DataFlowLazyNodeInfo : public LazyNodeInfo { } else { for (int64_t op = 0; op < operand_infos.size(); ++op) { if (op == xls::Select::kSelectorOperand) { + if (!include_selectors_) { + continue; + } + CHECK_EQ(operand_infos.at(op)->elements().size(), 1); + absl::InlinedVector infos; + DuplicateInfo(node->GetType(), + operand_infos.at(op)->elements().at(0), infos); + selector_tree = LeafTypeTree::CreateFromVector( + node->GetType(), std::move(infos)); + operand_infos_out.push_back(&selector_tree); continue; } operand_infos_out.push_back(operand_infos.at(op)); @@ -486,7 +518,6 @@ class DataFlowLazyNodeInfo : public LazyNodeInfo { if (operand_infos.empty()) { return LeafTypeTree(); } - const LeafTypeTree* const first_info = operand_infos.at(0); for (int64_t op = 0; op < operand_infos.size(); ++op) { @@ -510,6 +541,7 @@ class DataFlowLazyNodeInfo : public LazyNodeInfo { bool default_info_source_ = false; bool compute_tree_for_source_ = false; + bool include_selectors_ = false; DataFlowLazyNodeInfo* parent_ = nullptr; xls::Node* parent_node_ = nullptr; diff --git a/xls/passes/data_flow_node_info_test.cc b/xls/passes/data_flow_node_info_test.cc index b09363dbc4..d7e330ab7e 100644 --- a/xls/passes/data_flow_node_info_test.cc +++ b/xls/passes/data_flow_node_info_test.cc @@ -41,15 +41,16 @@ namespace xls { namespace { -template +template class TestParamCountInfo - : public DataFlowLazyNodeInfo, - int64_t> { + : public DataFlowLazyNodeInfo< + TestParamCountInfo, int64_t> { public: TestParamCountInfo() : DataFlowLazyNodeInfo( /*compute_tree_for_source=*/false, - /*default_info_source=*/default_info_source) {} + /*default_info_source=*/default_info_source, + /*include_selectors=*/include_selectors) {} int64_t ComputeInfoForNode(Node* node) const override final { if (node->op() == xls::Op::kAdd) { @@ -79,15 +80,17 @@ class TestParamCountInfo typedef absl::flat_hash_set NodeSourceSet; -template +template class TestNodeSourceInfo - : public DataFlowLazyNodeInfo, - NodeSourceSet> { + : public DataFlowLazyNodeInfo< + TestNodeSourceInfo, + NodeSourceSet> { public: TestNodeSourceInfo() : DataFlowLazyNodeInfo( /*compute_tree_for_source=*/true, - /*default_info_source=*/default_info_source) {} + /*default_info_source=*/default_info_source, + /*include_selectors=*/include_selectors) {} NodeSourceSet ComputeInfoForNode(Node* node) const override final { LOG(FATAL) << "ComputeInfoForNode should be unused for TestNodeSourceInfo"; @@ -141,7 +144,8 @@ TEST_F(DataFlowNodeInfoTest, Identity) { BValue id = fb.Identity(x, SourceInfo(), "id"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(id)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -159,7 +163,8 @@ TEST_F(DataFlowNodeInfoTest, Literal) { BValue l = fb.Literal(xls::Value(xls::UBits(5, 32))); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(l)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -178,7 +183,8 @@ TEST_F(DataFlowNodeInfoTest, LiteralTuple) { {xls::Value(xls::UBits(5, 32)), xls::Value(xls::UBits(7, 32))})); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(l)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -198,7 +204,8 @@ TEST_F(DataFlowNodeInfoTest, Add) { BValue add = fb.Add(x, y, SourceInfo(), "add"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(add)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -218,7 +225,8 @@ TEST_F(DataFlowNodeInfoTest, AddDefaultInfoSource) { BValue add = fb.Add(x, y, SourceInfo(), "add"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(add)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -240,7 +248,8 @@ TEST_F(DataFlowNodeInfoTest, ModifyNode) { BValue add = fb.Add(x, y, SourceInfo(), "add"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(add)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -272,7 +281,8 @@ TEST_F(DataFlowNodeInfoTest, AddLiteral) { BValue add = fb.Add(x, l, SourceInfo(), "add"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(add)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -296,7 +306,8 @@ TEST_F(DataFlowNodeInfoTest, Select) { BValue sel = fb.Select(eq, x, y, SourceInfo(), "sel"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -316,6 +327,39 @@ TEST_F(DataFlowNodeInfoTest, Select) { p->GetBitsType(1), 1)); } +TEST_F(DataFlowNodeInfoTest, SelectIncludeSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue c = fb.Param("c", p->GetBitsType(32)); + BValue l = fb.Literal(xls::Value(xls::UBits(5, 32))); + BValue eq = fb.Eq(c, l, SourceInfo(), "eq"); + + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue sel = fb.Select(eq, x, y, SourceInfo(), "sel"); + XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); + + TestParamCountInfo + node_info; + node_info.set_query_engine(query_engine()); + XLS_ASSERT_OK(node_info.Attach(f)); + XLS_ASSERT_OK(query_engine()->Populate(f).status()); + int64_t sel_count = node_info.GetSingleInfoForNode(sel.node()); + SharedLeafTypeTree sel_tree = node_info.GetInfo(sel.node()); + + // Should not include the selector + EXPECT_EQ(sel_count, 3); + EXPECT_EQ(sel_tree, LeafTypeTree::CreateSingleElementTree( + p->GetBitsType(32), 3)); + + int64_t eq_count = node_info.GetSingleInfoForNode(eq.node()); + SharedLeafTypeTree eq_tree = node_info.GetInfo(eq.node()); + + EXPECT_EQ(eq_count, 1); + EXPECT_EQ(eq_tree, LeafTypeTree::CreateSingleElementTree( + p->GetBitsType(1), 1)); +} + TEST_F(DataFlowNodeInfoTest, SelectDefault) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); @@ -326,7 +370,8 @@ TEST_F(DataFlowNodeInfoTest, SelectDefault) { BValue sel = fb.Select(l, {x, y}, /*default_value=*/l, SourceInfo(), "sel"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -366,7 +411,8 @@ TEST_F(DataFlowNodeInfoTest, Invoke) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sum_returned)); - TestParamCountInfo sub_node_info; + TestParamCountInfo + sub_node_info; sub_node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(sub_node_info.Attach(sub_fn)); XLS_ASSERT_OK(query_engine()->Populate(sub_fn).status()); @@ -381,7 +427,8 @@ TEST_F(DataFlowNodeInfoTest, Invoke) { p->GetBitsType(32)}), {1, 1, 2})); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -462,7 +509,8 @@ TEST_F(DataFlowNodeInfoTest, ModifyInvoke) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sum_returned)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -527,7 +575,8 @@ TEST_F(DataFlowNodeInfoTest, ModifyInvokeCallee) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sum_returned)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -591,7 +640,8 @@ TEST_F(DataFlowNodeInfoTest, ModifyInvokeParam) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sum_returned)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -658,7 +708,8 @@ TEST_F(DataFlowNodeInfoTest, DeleteInvoke) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sum_returned)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -710,7 +761,8 @@ TEST_F(DataFlowNodeInfoTest, Tuple) { BValue tuple = fb.Tuple({x, y}, SourceInfo(), "tuple"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(tuple)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -736,7 +788,8 @@ TEST_F(DataFlowNodeInfoTest, TupleOfTuples) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(tuple_outer)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -774,7 +827,8 @@ TEST_F(DataFlowNodeInfoTest, TupleParam) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(tuple_index0)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -807,7 +861,8 @@ TEST_F(DataFlowNodeInfoTest, TupleWithLiteral) { BValue tuple = fb.Tuple({x, l, y}, SourceInfo(), "tuple"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(tuple)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -832,7 +887,8 @@ TEST_F(DataFlowNodeInfoTest, TupleIdentity) { BValue tuple_index0 = fb.TupleIndex(id, 0, SourceInfo(), "tuple_index0"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(id)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -873,7 +929,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayParam) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_index0)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -906,7 +963,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayUpdateDynamic) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_update)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -942,7 +1000,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayUpdateDynamic2) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_update)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -967,6 +1026,36 @@ TEST_F(DataFlowNodeInfoTest, ArrayUpdateDynamic2) { {11, 11, 11, 11, 11, 11, 11, 11, 11, 11})); } +TEST_F(DataFlowNodeInfoTest, ArrayUpdateDynamicWithSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetArrayType(10, p->GetBitsType(32))); + BValue i = fb.Param("i", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue array_update = fb.ArrayUpdate(x, /*update_value=*/y, /*indices=*/{i}, + SourceInfo(), /*name=*/"array_update"); + XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, + fb.BuildWithReturnValue(array_update)); + + TestParamCountInfo + node_info; + node_info.set_query_engine(query_engine()); + XLS_ASSERT_OK(node_info.Attach(f)); + XLS_ASSERT_OK(query_engine()->Populate(f).status()); + + SharedLeafTypeTree array_update_tree = + node_info.GetInfo(array_update.node()); + int64_t array_update_count = + node_info.GetSingleInfoForNode(array_update.node()); + + // Each element in the array ends up with a count of 12 + EXPECT_EQ(array_update_count, 12 * 10); + + EXPECT_EQ(array_update_tree, LeafTypeTree::CreateFromVector( + p->GetArrayType(10, p->GetBitsType(32)), + {12, 12, 12, 12, 12, 12, 12, 12, 12, 12})); +} + TEST_F(DataFlowNodeInfoTest, ArrayUpdateLiteral) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); @@ -977,7 +1066,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayUpdateLiteral) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_update)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1010,7 +1100,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayUpdateLiteral2) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_update)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1042,7 +1133,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayUpdateLiteralOutOfBounds) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_update)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1084,7 +1176,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayUpdateLiteralOutOfBounds2) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_update)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1125,7 +1218,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayWithLiteral) { BValue array = fb.Array({x, l, y}, x.GetType(), SourceInfo(), "array"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1151,7 +1245,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayIndexDynamic) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_index0)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1174,6 +1269,36 @@ TEST_F(DataFlowNodeInfoTest, ArrayIndexDynamic) { EXPECT_EQ(array_index0_count, 2); } +TEST_F(DataFlowNodeInfoTest, ArrayIndexDynamicWithSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue i = fb.Param("i", p->GetBitsType(32)); + BValue l = fb.Literal(xls::Value(xls::UBits(55, 32))); + BValue array = fb.Array({x, l, y}, x.GetType(), SourceInfo(), "array"); + BValue array_index0 = fb.ArrayIndex(array, {i}, /*assumed_in_bounds=*/false, + SourceInfo(), "array_index0"); + XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, + fb.BuildWithReturnValue(array_index0)); + + TestParamCountInfo + node_info; + node_info.set_query_engine(query_engine()); + XLS_ASSERT_OK(node_info.Attach(f)); + XLS_ASSERT_OK(query_engine()->Populate(f).status()); + + SharedLeafTypeTree array_index0_tree = + node_info.GetInfo(array_index0.node()); + int64_t array_index0_count = + node_info.GetSingleInfoForNode(array_index0.node()); + EXPECT_EQ(array_index0_tree, LeafTypeTree::CreateSingleElementTree( + p->GetBitsType(32), 3)); + + // The index variable should not be included + EXPECT_EQ(array_index0_count, 3); +} + TEST_F(DataFlowNodeInfoTest, ArrayIndexWithLiteral) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); @@ -1189,7 +1314,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayIndexWithLiteral) { SourceInfo(), "array_index1"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1230,7 +1356,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayIndexWithLiteralOutOfBounds) { SourceInfo(), "array_index0"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1262,7 +1389,8 @@ TEST_F(DataFlowNodeInfoTest, TupleNested) { BValue tuple2 = fb.Tuple({tuple, z}, SourceInfo(), "tuple"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(tuple2)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1289,7 +1417,8 @@ TEST_F(DataFlowNodeInfoTest, TupleIndex) { BValue tuple_index1 = fb.TupleIndex(tuple, 1, SourceInfo(), "tuple_index1"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(tuple)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1329,7 +1458,8 @@ TEST_F(DataFlowNodeInfoTest, LiteralTupleIndex) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(tuple_index)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1357,7 +1487,8 @@ TEST_F(DataFlowNodeInfoTest, TupleOfArrays) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(tuple_outer)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1398,7 +1529,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayOfTuples) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_index)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1439,7 +1571,8 @@ TEST_F(DataFlowNodeInfoTest, ArrayOfTuplesDynamicIndex) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_index)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1477,7 +1610,8 @@ TEST_F(DataFlowNodeInfoTest, ArraySlice) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_slice)); - TestParamCountInfo node_info; + TestParamCountInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1512,7 +1646,8 @@ TEST_F(DataFlowNodeInfoTest, AddComputeTreeForLeaf) { BValue add = fb.Add(x, y, SourceInfo(), "add"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(add)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1536,7 +1671,8 @@ TEST_F(DataFlowNodeInfoTest, AddComputeTreeForLeafNoDefaultInfoSource) { BValue add = fb.Add(x, y, SourceInfo(), "add"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(add)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1566,7 +1702,8 @@ TEST_F(DataFlowNodeInfoTest, AddInSelectComputeTreeForLeaf) { BValue sel = fb.Select(eq, x, add, SourceInfo(), "sel"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1596,7 +1733,8 @@ TEST_F(DataFlowNodeInfoTest, SelectComputeTreeForLeaf) { BValue sel = fb.Select(eq, x, y, SourceInfo(), "sel"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1614,6 +1752,38 @@ TEST_F(DataFlowNodeInfoTest, SelectComputeTreeForLeaf) { p->GetBitsType(32), sel_ref_sources)); } +TEST_F(DataFlowNodeInfoTest, SelectComputeTreeForLeafWithSelector) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue c = fb.Param("c", p->GetBitsType(32)); + BValue l = fb.Literal(xls::Value(xls::UBits(5, 32))); + BValue eq = fb.Eq(c, l, SourceInfo(), "eq"); + + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue sel = fb.Select(eq, x, y, SourceInfo(), "sel"); + XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); + + TestNodeSourceInfo + node_info; + node_info.set_query_engine(query_engine()); + XLS_ASSERT_OK(node_info.Attach(f)); + XLS_ASSERT_OK(query_engine()->Populate(f).status()); + NodeSourceSet sel_sources = node_info.GetSingleInfoForNode(sel.node()); + SharedLeafTypeTree sel_sources_tree = + node_info.GetInfo(sel.node()); + + NodeSourceSet sel_ref_sources = {NodeSource(x.node(), /*tree_index=*/{}), + NodeSource(y.node(), /*tree_index=*/{}), + NodeSource(eq.node(), /*tree_index=*/{})}; + + // Should not include the selector + EXPECT_EQ(sel_sources, sel_ref_sources); + EXPECT_EQ(sel_sources_tree, + LeafTypeTree::CreateSingleElementTree( + p->GetBitsType(32), sel_ref_sources)); +} + TEST_F(DataFlowNodeInfoTest, SelectConstantSelectorComputeTreeForLeaf) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); @@ -1626,7 +1796,8 @@ TEST_F(DataFlowNodeInfoTest, SelectConstantSelectorComputeTreeForLeaf) { fb.Select(l1, /*on_true=*/x, /*on_false=*/y, SourceInfo(), "sel"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1676,7 +1847,8 @@ TEST_F(DataFlowNodeInfoTest, SelectConstExprComputeTreeForLeaf) { fb.Select(eq, /*on_true=*/x, /*on_false=*/y, SourceInfo(), "sel"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1694,6 +1866,111 @@ TEST_F(DataFlowNodeInfoTest, SelectConstExprComputeTreeForLeaf) { p->GetBitsType(32), sel_ref_sources)); } +TEST_F(DataFlowNodeInfoTest, ArrayIndexDynamicComputeTreeForLeaf) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue i = fb.Param("i", p->GetBitsType(32)); + BValue l = fb.Literal(xls::Value(xls::UBits(55, 32))); + BValue array = fb.Array({x, l, y}, x.GetType(), SourceInfo(), "array"); + BValue array_index = fb.ArrayIndex(array, {i}, /*assumed_in_bounds=*/false, + SourceInfo(), "array_index0"); + XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, + fb.BuildWithReturnValue(array_index)); + + TestNodeSourceInfo + node_info; + node_info.set_query_engine(query_engine()); + XLS_ASSERT_OK(node_info.Attach(f)); + XLS_ASSERT_OK(query_engine()->Populate(f).status()); + NodeSourceSet array_index_sources = + node_info.GetSingleInfoForNode(array_index.node()); + SharedLeafTypeTree sel_sources_tree = + node_info.GetInfo(array_index.node()); + + NodeSourceSet array_index_ref_sources = { + NodeSource(x.node(), /*tree_index=*/{}), + NodeSource(y.node(), /*tree_index=*/{}), + NodeSource(l.node(), /*tree_index=*/{})}; + + // Should not include the selector + EXPECT_EQ(array_index_sources, array_index_ref_sources); + EXPECT_EQ(sel_sources_tree, + LeafTypeTree::CreateSingleElementTree( + p->GetBitsType(32), array_index_ref_sources)); +} + +TEST_F(DataFlowNodeInfoTest, ArrayIndexDynamicWithSelectorComputeTreeForLeaf) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue i = fb.Param("i", p->GetBitsType(32)); + BValue l = fb.Literal(xls::Value(xls::UBits(55, 32))); + BValue array = fb.Array({x, l, y}, x.GetType(), SourceInfo(), "array"); + BValue array_index = fb.ArrayIndex(array, {i}, /*assumed_in_bounds=*/false, + SourceInfo(), "array_index"); + XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, + fb.BuildWithReturnValue(array_index)); + + TestNodeSourceInfo + node_info; + node_info.set_query_engine(query_engine()); + XLS_ASSERT_OK(node_info.Attach(f)); + XLS_ASSERT_OK(query_engine()->Populate(f).status()); + NodeSourceSet array_index_sources = + node_info.GetSingleInfoForNode(array_index.node()); + SharedLeafTypeTree sel_sources_tree = + node_info.GetInfo(array_index.node()); + + NodeSourceSet array_index_ref_sources = { + NodeSource(x.node(), /*tree_index=*/{}), + NodeSource(y.node(), /*tree_index=*/{}), + NodeSource(l.node(), /*tree_index=*/{}), + NodeSource(i.node(), /*tree_index=*/{})}; + + // Should not include the selector + EXPECT_EQ(array_index_sources, array_index_ref_sources); + EXPECT_EQ(sel_sources_tree, + LeafTypeTree::CreateSingleElementTree( + p->GetBitsType(32), array_index_ref_sources)); +} + +TEST_F(DataFlowNodeInfoTest, + ArrayIndexDynamicWithLiteralSelectorComputeTreeForLeaf) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(32)); + BValue y = fb.Param("y", p->GetBitsType(32)); + BValue i = fb.Literal(xls::Value(xls::UBits(0, 32))); + BValue l = fb.Literal(xls::Value(xls::UBits(55, 32))); + BValue array = fb.Array({x, l, y}, x.GetType(), SourceInfo(), "array"); + BValue array_index = fb.ArrayIndex(array, {i}, /*assumed_in_bounds=*/false, + SourceInfo(), "array_index"); + XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, + fb.BuildWithReturnValue(array_index)); + + TestNodeSourceInfo + node_info; + node_info.set_query_engine(query_engine()); + XLS_ASSERT_OK(node_info.Attach(f)); + XLS_ASSERT_OK(query_engine()->Populate(f).status()); + NodeSourceSet array_index_sources = + node_info.GetSingleInfoForNode(array_index.node()); + SharedLeafTypeTree sel_sources_tree = + node_info.GetInfo(array_index.node()); + + NodeSourceSet array_index_ref_sources = { + NodeSource(x.node(), /*tree_index=*/{})}; + + // Should not include the selector + EXPECT_EQ(array_index_sources, array_index_ref_sources); + EXPECT_EQ(sel_sources_tree, + LeafTypeTree::CreateSingleElementTree( + p->GetBitsType(32), array_index_ref_sources)); +} + TEST_F(DataFlowNodeInfoTest, SwizzleComputeTreeForLeaf) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); @@ -1704,7 +1981,8 @@ TEST_F(DataFlowNodeInfoTest, SwizzleComputeTreeForLeaf) { BValue swizzle = fb.Tuple({ti1, ti0}, SourceInfo(), "swizzle"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(swizzle)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1737,7 +2015,8 @@ TEST_F(DataFlowNodeInfoTest, SelectTupleComputeTreeForLeaf) { fb.Select(eq, /*on_true=*/x, /*on_false=*/y, SourceInfo(), "sel"); XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(sel)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status()); @@ -1769,7 +2048,8 @@ TEST_F(DataFlowNodeInfoTest, ArraySliceComputeTreeForLeaf) { XLS_ASSERT_OK_AND_ASSIGN(xls::Function * f, fb.BuildWithReturnValue(array_slice)); - TestNodeSourceInfo node_info; + TestNodeSourceInfo + node_info; node_info.set_query_engine(query_engine()); XLS_ASSERT_OK(node_info.Attach(f)); XLS_ASSERT_OK(query_engine()->Populate(f).status());