Skip to content

Commit b97d9e1

Browse files
Auto merge of #149114 - BoxyUwU:mgca_adt_exprs, r=<try>
MGCA: Support struct expressions without intermediary anon consts
2 parents f794a08 + 8990d98 commit b97d9e1

File tree

62 files changed

+991
-329
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+991
-329
lines changed

compiler/rustc_ast_lowering/src/index.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,13 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
281281
});
282282
}
283283

284+
fn visit_const_arg_expr_field(&mut self, field: &'hir ConstArgExprField<'hir>) {
285+
self.insert(field.span, field.hir_id, Node::ConstArgExprField(field));
286+
self.with_parent(field.hir_id, |this| {
287+
intravisit::walk_const_arg_expr_field(this, field);
288+
})
289+
}
290+
284291
fn visit_stmt(&mut self, stmt: &'hir Stmt<'hir>) {
285292
self.insert(stmt.span, stmt.hir_id, Node::Stmt(stmt));
286293

compiler/rustc_ast_lowering/src/lib.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2408,6 +2408,37 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
24082408

24092409
ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Path(qpath) }
24102410
}
2411+
ExprKind::Struct(se) => {
2412+
let path = self.lower_qpath(
2413+
expr.id,
2414+
&se.qself,
2415+
&se.path,
2416+
ParamMode::Explicit,
2417+
AllowReturnTypeNotation::No,
2418+
ImplTraitContext::Disallowed(ImplTraitPosition::Path),
2419+
None,
2420+
);
2421+
2422+
let fields = self.arena.alloc_from_iter(se.fields.iter().map(|f| {
2423+
let hir_id = self.lower_node_id(f.id);
2424+
self.lower_attrs(hir_id, &f.attrs, f.span, Target::ExprField);
2425+
2426+
let expr = if let ExprKind::ConstBlock(anon_const) = &f.expr.kind {
2427+
self.lower_anon_const_to_const_arg_direct(anon_const)
2428+
} else {
2429+
self.lower_expr_to_const_arg_direct(&f.expr)
2430+
};
2431+
2432+
&*self.arena.alloc(hir::ConstArgExprField {
2433+
hir_id,
2434+
field: self.lower_ident(f.ident),
2435+
expr: self.arena.alloc(expr),
2436+
span: self.lower_span(f.span),
2437+
})
2438+
}));
2439+
2440+
ConstArg { hir_id: self.next_id(), kind: hir::ConstArgKind::Struct(path, fields) }
2441+
}
24112442
ExprKind::Underscore => ConstArg {
24122443
hir_id: self.lower_node_id(expr.id),
24132444
kind: hir::ConstArgKind::Infer(expr.span, ()),

compiler/rustc_codegen_cranelift/src/intrinsics/simd.rs

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
130130
return;
131131
}
132132

133-
let idx = generic_args[2].expect_const().to_value().valtree.unwrap_branch();
133+
let idx = generic_args[2].expect_const().to_branch();
134134

135135
assert_eq!(x.layout(), y.layout());
136136
let layout = x.layout();
@@ -143,7 +143,7 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
143143

144144
let total_len = lane_count * 2;
145145

146-
let indexes = idx.iter().map(|idx| idx.unwrap_leaf().to_u32()).collect::<Vec<u32>>();
146+
let indexes = idx.iter().map(|idx| idx.to_leaf().to_u32()).collect::<Vec<u32>>();
147147

148148
for &idx in &indexes {
149149
assert!(u64::from(idx) < total_len, "idx {} out of range 0..{}", idx, total_len);
@@ -961,9 +961,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
961961
let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap();
962962
let ptr_val = ptr.load_scalar(fx);
963963

964-
let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
965-
.unwrap_leaf()
966-
.to_simd_alignment();
964+
let alignment =
965+
generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment();
967966

968967
let memflags = match alignment {
969968
SimdAlign::Unaligned => MemFlags::new().with_notrap(),
@@ -1006,9 +1005,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
10061005
let lane_clif_ty = fx.clif_type(val_lane_ty).unwrap();
10071006
let ret_lane_layout = fx.layout_of(ret_lane_ty);
10081007

1009-
let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1010-
.unwrap_leaf()
1011-
.to_simd_alignment();
1008+
let alignment =
1009+
generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment();
10121010

10131011
let memflags = match alignment {
10141012
SimdAlign::Unaligned => MemFlags::new().with_notrap(),
@@ -1059,9 +1057,8 @@ pub(super) fn codegen_simd_intrinsic_call<'tcx>(
10591057
let ret_lane_layout = fx.layout_of(ret_lane_ty);
10601058
let ptr_val = ptr.load_scalar(fx);
10611059

1062-
let alignment = generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1063-
.unwrap_leaf()
1064-
.to_simd_alignment();
1060+
let alignment =
1061+
generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment();
10651062

10661063
let memflags = match alignment {
10671064
SimdAlign::Unaligned => MemFlags::new().with_notrap(),

compiler/rustc_codegen_llvm/src/intrinsic.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
351351
_ => bug!(),
352352
};
353353
let ptr = args[0].immediate();
354-
let locality = fn_args.const_at(1).to_value().valtree.unwrap_leaf().to_i32();
354+
let locality = fn_args.const_at(1).to_leaf().to_i32();
355355
self.call_intrinsic(
356356
"llvm.prefetch",
357357
&[self.val_ty(ptr)],
@@ -1536,7 +1536,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
15361536
}
15371537

15381538
if name == sym::simd_shuffle_const_generic {
1539-
let idx = fn_args[2].expect_const().to_value().valtree.unwrap_branch();
1539+
let idx = fn_args[2].expect_const().to_branch();
15401540
let n = idx.len() as u64;
15411541

15421542
let (out_len, out_ty) = require_simd!(ret_ty, SimdReturn);
@@ -1555,7 +1555,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
15551555
.iter()
15561556
.enumerate()
15571557
.map(|(arg_idx, val)| {
1558-
let idx = val.unwrap_leaf().to_i32();
1558+
let idx = val.to_leaf().to_i32();
15591559
if idx >= i32::try_from(total_len).unwrap() {
15601560
bx.sess().dcx().emit_err(InvalidMonomorphization::SimdIndexOutOfBounds {
15611561
span,
@@ -1967,9 +1967,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
19671967
// those lanes whose `mask` bit is enabled.
19681968
// The memory addresses corresponding to the “off” lanes are not accessed.
19691969

1970-
let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
1971-
.unwrap_leaf()
1972-
.to_simd_alignment();
1970+
let alignment = fn_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment();
19731971

19741972
// The element type of the "mask" argument must be a signed integer type of any width
19751973
let mask_ty = in_ty;
@@ -2062,9 +2060,7 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
20622060
// those lanes whose `mask` bit is enabled.
20632061
// The memory addresses corresponding to the “off” lanes are not accessed.
20642062

2065-
let alignment = fn_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
2066-
.unwrap_leaf()
2067-
.to_simd_alignment();
2063+
let alignment = fn_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment();
20682064

20692065
// The element type of the "mask" argument must be a signed integer type of any width
20702066
let mask_ty = in_ty;

compiler/rustc_codegen_ssa/src/mir/constant.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,21 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
7777
.flatten()
7878
.map(|val| {
7979
// A SIMD type has a single field, which is an array.
80-
let fields = val.unwrap_branch();
80+
let fields = val.to_branch();
8181
assert_eq!(fields.len(), 1);
82-
let array = fields[0].unwrap_branch();
82+
let array = fields[0].to_branch();
8383
// Iterate over the array elements to obtain the values in the vector.
8484
let values: Vec<_> = array
8585
.iter()
8686
.map(|field| {
87-
if let Some(prim) = field.try_to_scalar() {
88-
let layout = bx.layout_of(field_ty);
89-
let BackendRepr::Scalar(scalar) = layout.backend_repr else {
90-
bug!("from_const: invalid ByVal layout: {:#?}", layout);
91-
};
92-
bx.scalar_to_backend(prim, scalar, bx.immediate_backend_type(layout))
93-
} else {
87+
let Some(prim) = field.try_to_scalar() else {
9488
bug!("field is not a scalar {:?}", field)
95-
}
89+
};
90+
let layout = bx.layout_of(field_ty);
91+
let BackendRepr::Scalar(scalar) = layout.backend_repr else {
92+
bug!("from_const: invalid ByVal layout: {:#?}", layout);
93+
};
94+
bx.scalar_to_backend(prim, scalar, bx.immediate_backend_type(layout))
9695
})
9796
.collect();
9897
bx.const_vector(&values)

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
102102
};
103103

104104
let parse_atomic_ordering = |ord: ty::Value<'tcx>| {
105-
let discr = ord.valtree.unwrap_branch()[0].unwrap_leaf();
105+
let discr = ord.to_branch()[0].to_leaf();
106106
discr.to_atomic_ordering()
107107
};
108108

compiler/rustc_const_eval/src/const_eval/valtrees.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,17 @@ fn branches<'tcx>(
3636
// For enums, we prepend their variant index before the variant's fields so we can figure out
3737
// the variant again when just seeing a valtree.
3838
if let Some(variant) = variant {
39-
branches.push(ty::ValTree::from_scalar_int(*ecx.tcx, variant.as_u32().into()));
39+
branches.push(ty::Const::new_value(
40+
*ecx.tcx,
41+
ty::ValTree::from_scalar_int(*ecx.tcx, variant.as_u32().into()),
42+
ecx.tcx.types.u32,
43+
));
4044
}
4145

4246
for i in 0..field_count {
4347
let field = ecx.project_field(&place, FieldIdx::from_usize(i)).unwrap();
4448
let valtree = const_to_valtree_inner(ecx, &field, num_nodes)?;
45-
branches.push(valtree);
49+
branches.push(ty::Const::new_value(*ecx.tcx, valtree, field.layout.ty));
4650
}
4751

4852
// Have to account for ZSTs here
@@ -65,7 +69,7 @@ fn slice_branches<'tcx>(
6569
for i in 0..n {
6670
let place_elem = ecx.project_index(place, i).unwrap();
6771
let valtree = const_to_valtree_inner(ecx, &place_elem, num_nodes)?;
68-
elems.push(valtree);
72+
elems.push(ty::Const::new_value(*ecx.tcx, valtree, place_elem.layout.ty));
6973
}
7074

7175
Ok(ty::ValTree::from_branches(*ecx.tcx, elems))
@@ -200,8 +204,8 @@ fn reconstruct_place_meta<'tcx>(
200204
&ObligationCause::dummy(),
201205
|ty| ty,
202206
|| {
203-
let branches = last_valtree.unwrap_branch();
204-
last_valtree = *branches.last().unwrap();
207+
let branches = last_valtree.to_branch();
208+
last_valtree = branches.last().unwrap().to_value().valtree;
205209
debug!(?branches, ?last_valtree);
206210
},
207211
);
@@ -212,7 +216,7 @@ fn reconstruct_place_meta<'tcx>(
212216
};
213217

214218
// Get the number of elements in the unsized field.
215-
let num_elems = last_valtree.unwrap_branch().len();
219+
let num_elems = last_valtree.to_branch().len();
216220
MemPlaceMeta::Meta(Scalar::from_target_usize(num_elems as u64, &tcx))
217221
}
218222

@@ -274,7 +278,7 @@ pub fn valtree_to_const_value<'tcx>(
274278
mir::ConstValue::ZeroSized
275279
}
276280
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char | ty::RawPtr(_, _) => {
277-
mir::ConstValue::Scalar(Scalar::Int(cv.valtree.unwrap_leaf()))
281+
mir::ConstValue::Scalar(Scalar::Int(cv.to_leaf()))
278282
}
279283
ty::Pat(ty, _) => {
280284
let cv = ty::Value { valtree: cv.valtree, ty };
@@ -301,12 +305,13 @@ pub fn valtree_to_const_value<'tcx>(
301305
|| matches!(cv.ty.kind(), ty::Adt(def, _) if def.is_struct()))
302306
{
303307
// A Scalar tuple/struct; we can avoid creating an allocation.
304-
let branches = cv.valtree.unwrap_branch();
308+
let branches = cv.to_branch();
305309
// Find the non-ZST field. (There can be aligned ZST!)
306310
for (i, &inner_valtree) in branches.iter().enumerate() {
307311
let field = layout.field(&LayoutCx::new(tcx, typing_env), i);
308312
if !field.is_zst() {
309-
let cv = ty::Value { valtree: inner_valtree, ty: field.ty };
313+
let cv =
314+
ty::Value { valtree: inner_valtree.to_value().valtree, ty: field.ty };
310315
return valtree_to_const_value(tcx, typing_env, cv);
311316
}
312317
}
@@ -381,7 +386,7 @@ fn valtree_into_mplace<'tcx>(
381386
// Zero-sized type, nothing to do.
382387
}
383388
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char | ty::RawPtr(..) => {
384-
let scalar_int = valtree.unwrap_leaf();
389+
let scalar_int = valtree.to_leaf();
385390
debug!("writing trivial valtree {:?} to place {:?}", scalar_int, place);
386391
ecx.write_immediate(Immediate::Scalar(scalar_int.into()), place).unwrap();
387392
}
@@ -391,13 +396,13 @@ fn valtree_into_mplace<'tcx>(
391396
ecx.write_immediate(imm, place).unwrap();
392397
}
393398
ty::Adt(_, _) | ty::Tuple(_) | ty::Array(_, _) | ty::Str | ty::Slice(_) => {
394-
let branches = valtree.unwrap_branch();
399+
let branches = valtree.to_branch();
395400

396401
// Need to downcast place for enums
397402
let (place_adjusted, branches, variant_idx) = match ty.kind() {
398403
ty::Adt(def, _) if def.is_enum() => {
399404
// First element of valtree corresponds to variant
400-
let scalar_int = branches[0].unwrap_leaf();
405+
let scalar_int = branches[0].to_leaf();
401406
let variant_idx = VariantIdx::from_u32(scalar_int.to_u32());
402407
let variant = def.variant(variant_idx);
403408
debug!(?variant);
@@ -425,7 +430,7 @@ fn valtree_into_mplace<'tcx>(
425430
};
426431

427432
debug!(?place_inner);
428-
valtree_into_mplace(ecx, &place_inner, *inner_valtree);
433+
valtree_into_mplace(ecx, &place_inner, inner_valtree.to_value().valtree);
429434
dump_place(ecx, &place_inner);
430435
}
431436

compiler/rustc_const_eval/src/interpret/intrinsics/simd.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -545,15 +545,15 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
545545
let (right, right_len) = self.project_to_simd(&args[1])?;
546546
let (dest, dest_len) = self.project_to_simd(&dest)?;
547547

548-
let index = generic_args[2].expect_const().to_value().valtree.unwrap_branch();
548+
let index = generic_args[2].expect_const().to_branch();
549549
let index_len = index.len();
550550

551551
assert_eq!(left_len, right_len);
552552
assert_eq!(u64::try_from(index_len).unwrap(), dest_len);
553553

554554
for i in 0..dest_len {
555555
let src_index: u64 =
556-
index[usize::try_from(i).unwrap()].unwrap_leaf().to_u32().into();
556+
index[usize::try_from(i).unwrap()].to_leaf().to_u32().into();
557557
let dest = self.project_index(&dest, i)?;
558558

559559
let val = if src_index < left_len {
@@ -657,9 +657,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
657657
self.check_simd_ptr_alignment(
658658
ptr,
659659
dest_layout,
660-
generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
661-
.unwrap_leaf()
662-
.to_simd_alignment(),
660+
generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(),
663661
)?;
664662

665663
for i in 0..dest_len {
@@ -689,9 +687,7 @@ impl<'tcx, M: Machine<'tcx>> InterpCx<'tcx, M> {
689687
self.check_simd_ptr_alignment(
690688
ptr,
691689
args[2].layout,
692-
generic_args[3].expect_const().to_value().valtree.unwrap_branch()[0]
693-
.unwrap_leaf()
694-
.to_simd_alignment(),
690+
generic_args[3].expect_const().to_branch()[0].to_leaf().to_simd_alignment(),
695691
)?;
696692

697693
for i in 0..vals_len {

compiler/rustc_hir/src/hir.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ impl<'hir, Unambig> ConstArg<'hir, Unambig> {
494494

495495
pub fn span(&self) -> Span {
496496
match self.kind {
497+
ConstArgKind::Struct(path, _) => path.span(),
497498
ConstArgKind::Path(path) => path.span(),
498499
ConstArgKind::Anon(anon) => anon.span,
499500
ConstArgKind::Error(span, _) => span,
@@ -513,13 +514,23 @@ pub enum ConstArgKind<'hir, Unambig = ()> {
513514
/// However, in the future, we'll be using it for all of those.
514515
Path(QPath<'hir>),
515516
Anon(&'hir AnonConst),
517+
/// Represents construction of struct/struct variants
518+
Struct(QPath<'hir>, &'hir [&'hir ConstArgExprField<'hir>]),
516519
/// Error const
517520
Error(Span, ErrorGuaranteed),
518521
/// This variant is not always used to represent inference consts, sometimes
519522
/// [`GenericArg::Infer`] is used instead.
520523
Infer(Span, Unambig),
521524
}
522525

526+
#[derive(Clone, Copy, Debug, HashStable_Generic)]
527+
pub struct ConstArgExprField<'hir> {
528+
pub hir_id: HirId,
529+
pub span: Span,
530+
pub field: Ident,
531+
pub expr: &'hir ConstArg<'hir>,
532+
}
533+
523534
#[derive(Clone, Copy, Debug, HashStable_Generic)]
524535
pub struct InferArg {
525536
#[stable_hasher(ignore)]
@@ -4714,6 +4725,7 @@ pub enum Node<'hir> {
47144725
ConstArg(&'hir ConstArg<'hir>),
47154726
Expr(&'hir Expr<'hir>),
47164727
ExprField(&'hir ExprField<'hir>),
4728+
ConstArgExprField(&'hir ConstArgExprField<'hir>),
47174729
Stmt(&'hir Stmt<'hir>),
47184730
PathSegment(&'hir PathSegment<'hir>),
47194731
Ty(&'hir Ty<'hir>),
@@ -4773,6 +4785,7 @@ impl<'hir> Node<'hir> {
47734785
Node::AssocItemConstraint(c) => Some(c.ident),
47744786
Node::PatField(f) => Some(f.ident),
47754787
Node::ExprField(f) => Some(f.ident),
4788+
Node::ConstArgExprField(f) => Some(f.field),
47764789
Node::PreciseCapturingNonLifetimeArg(a) => Some(a.ident),
47774790
Node::Param(..)
47784791
| Node::AnonConst(..)

0 commit comments

Comments
 (0)