From c28dd0705b38bbc7df86a66da413549cce0ea59d Mon Sep 17 00:00:00 2001 From: Siddharth Bhat Date: Thu, 18 Dec 2025 18:54:51 +0000 Subject: [PATCH] chore: more defs around unpacked float --- Fp/Basic.lean | 120 +++++++++++++++++++++++++++++++++++++---------- Fp/Division.lean | 9 ++++ Fp/Rounding.lean | 67 ++++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 25 deletions(-) diff --git a/Fp/Basic.lean b/Fp/Basic.lean index 3469fc0..73a791f 100644 --- a/Fp/Basic.lean +++ b/Fp/Basic.lean @@ -1,6 +1,45 @@ import Fp.Utils import Fp.ForLean.Dyadic +/-! +# Floating Point Parmeters +-/ + +@[bv_normalize] +def Nat.ceilLog2 (n : Nat) : Nat := + if n.log2 * 2 = n then n.log2 else n.log2 + 1 + +@[bv_normalize] +def bias (e : Nat) : Nat := + 2 ^ (e - 1) - 1 + +@[bv_normalize] +def minNormalExp (e : Nat) : Int := + -(bias e - 1) + +@[bv_normalize] +def maxNormalExp (e : Nat) : Int := (bias e) + +@[bv_normalize] +def subnormalExp (e : Nat) : Int := + minNormalExp e - 1 + +@[simp] +theorem subnormalExp_eq_minus_bias (e : Nat) : + subnormalExp e = -(bias e) := by + simp [subnormalExp, minNormalExp, bias] + grind + +-- This is a simpler (but less tight) bound than `exponentWidth`. +-- It's logarithmically larger. +@[bv_normalize] +def exponentWidth' (e s : Nat) : Nat := + e + s.ceilLog2 + +@[bv_normalize] +def exponentWidth (e s : Nat) : Nat := + (2 ^ (e - 1) + s - 2).log2 + 2 + /-! ## Packed Floating Point Numbers @@ -68,7 +107,10 @@ structure FixedPoint (width exOffset : Nat) where val : BitVec width -- | This should not be part of the structure, but a side invariant we keep in mind. hExOffset : exOffset < width -deriving DecidableEq + deriving DecidableEq, Repr + +-- TODO: We can create a theory of normalized fixed point numbers, that are either zero, +-- or have MSB = 1. attribute [bv_normalize] FixedPoint.ext_iff @@ -536,6 +578,7 @@ def toRat? (pf : PackedFloat e s) : Option Rat := end PackedFloat + /-- `UnpackedFloat e s` is the *working* (unpacked) representation of a floating-point number with exponent width `e` and significand width `s`. @@ -588,8 +631,9 @@ and in hardware floating-point pipelines. structure UnpackedFloat (e s : Nat) where sign : Bool ex : BitVec e - sig : BitVec s - deriving DecidableEq, Inhabited, Repr + /-# 'sig' morally represents a 'FixedPoint s (s - 1)'. -/ + sig : BitVec s + deriving DecidableEq, Repr attribute [bv_normalize] UnpackedFloat.ext_iff @@ -687,6 +731,54 @@ def toDyadic (uf : UnpackedFloat e s) : Dyadic := def toRat (uf : UnpackedFloat e s) : Rat := uf.toDyadic.toRat +def zero (e s : Nat) (sign : Bool) : UnpackedFloat e (s + 1) where + sign := sign + ex := BitVec.ofInt e 0 + sig := BitVec.ofNat (s + 1) 0 + +/-- The minimum subnormal exponent, that shows up after normalizing a subnormal number -/ +def normalizedMinSubnormalExp (e s : Nat) : Int := + subnormalExp e - s + +/-- Build the significand from the hidden bit and fraction bits. -/ +def mkSigFromHiddenBitAndFrac (hiddenBit : Bool) (fraction : BitVec s) : BitVec s := + ((BitVec.ofBool hiddenBit).zeroExtend s <<< (s - 1)) ||| ((fraction <<< 1) >>> 1) + +@[simp] +theorem getLsbD_mkSigFromHiddenBitAndFrac_of_lt {hidden : Bool} {frac : BitVec s} (i : Nat) (hi : i < s - 1): + (mkSigFromHiddenBitAndFrac hidden frac).getLsbD i = if i = s - 1 then hidden else frac.getLsbD i := by + grind [mkSigFromHiddenBitAndFrac] + +@[simp] +theorem getElem_mkSigFromHiddenBitAndFrac {hidden : Bool} {frac : BitVec s} {i : Fin s} : + (mkSigFromHiddenBitAndFrac hidden frac)[i] = if i = s - 1 then hidden else frac[i] := by + grind [mkSigFromHiddenBitAndFrac] + +/- Miniminum subnormal number that is unpacked-representable. -/ +def minSubnormal (e s : Nat) : UnpackedFloat e s where -- s includes the hidden bit. + sign := false + sig := mkSigFromHiddenBitAndFrac false (BitVec.ofNat s 1) -- 0.000...01 + ex := BitVec.ofInt e (subnormalExp e) + + +/-- Maximum subnormal number that is unpacked-representable. -/ +def maxSubnormal (e s : Nat) : UnpackedFloat e s where + sign := false + sig := mkSigFromHiddenBitAndFrac false (BitVec.allOnes s) -- 0.111...11 + ex := BitVec.ofInt e (subnormalExp e) + +/-- Minimum positive normal that is unpacked-representable. -/ +def minNormal (e s : Nat) : UnpackedFloat e s where + sign := false + sig := mkSigFromHiddenBitAndFrac true (BitVec.ofNat s 0) -- 1.000...00 + ex := BitVec.ofInt e (minNormalExp e) + +/-- Maximum normal number that is unpacked-representable. -/ +def maxNormal (e s : Nat) : UnpackedFloat e s where + sign := false + sig := mkSigFromHiddenBitAndFrac true (BitVec.allOnes s) -- 1.111...11 + ex := maxNormalExp e + end UnpackedFloat namespace EUnpackedFloat @@ -797,28 +889,6 @@ def toRat? (ef : EUnpackedFloat e s) : Option Rat := end EUnpackedFloat -@[bv_normalize] -def Nat.ceilLog2 (n : Nat) : Nat := - if n.log2 * 2 = n then n.log2 else n.log2 + 1 - -@[bv_normalize] -def bias (e : Nat) : Nat := - 2 ^ (e - 1) - 1 - -@[bv_normalize] -def minNormalExp (e : Nat) : Int := - -(bias e - 1) - --- This is a simpler (but less tight) bound than `exponentWidth`. --- It's logarithmically larger. -@[bv_normalize] -def exponentWidth' (e s : Nat) : Nat := - e + s.ceilLog2 - -@[bv_normalize] -def exponentWidth (e s : Nat) : Nat := - (2 ^ (e - 1) + s - 2).log2 + 2 - -- Constants /-- E5M2 floating point representation of 1.0 -/ diff --git a/Fp/Division.lean b/Fp/Division.lean index ecdcf7d..a8b1f8a 100644 --- a/Fp/Division.lean +++ b/Fp/Division.lean @@ -259,6 +259,15 @@ def BitVec.monus (a : BitVec w) (b : BitVec w) : BitVec w := @[bv_normalize] def BitVec.sge (x y : BitVec w) : Bool := y.sle x +def roundUnpacked + (sig : BitVec sigWidth) + (ex : BitVec exWidth) (mode : RoundingMode) : + -- NOTE: we assume that 'sig' has sticky and guard bits. + -- Given these, we round to get a significand of size 's + 1'. + UnpackedFloat (exponentWidth enew snew) (snew + 1) := + -- We will only do RNE for now. + sorry + @[bv_normalize] def div_on_unpackedFloat (a b : UnpackedFloat (exponentWidth e s) (s + 1)) (mode : RoundingMode) : PackedFloat e s := -- a.toRat = (-1)^a.sign * a.sig.toNat * 2 ^ (a.ex.toInt) * 2 ^(-s) diff --git a/Fp/Rounding.lean b/Fp/Rounding.lean index 153b3fd..9f3a092 100644 --- a/Fp/Rounding.lean +++ b/Fp/Rounding.lean @@ -137,6 +137,73 @@ theorem shouldRoundAway.match_eq_cond : cases m <;> rfl +@[bv_normalize] +def roundNew (mode : RoundingMode) (sig : BitVec sigWidth) (exOffset : BitVec exWidth) + : EUnpackedFloat eout sout := + if x.state = .NaN then + PackedFloat.getNaN _ _ + else if x.state = .Infinity then + PackedFloat.getInfinity _ _ x.num.sign + else + let exOffset' := 2^(exWidth - 1) + sigWidth - 2 + -- trim bitvector + let over := x.num.val >>> (exOffset + 2^(exWidth-1)) + let a := (x.num.val >>> exOffset).truncate (2^(exWidth-1)) + let b := truncateRight exOffset' (x.num.val.truncate exOffset) + let underWidth := exOffset - exOffset' + let under := x.num.val.truncate underWidth + let trimmed := a ++ b + if over != 0 then + -- Overflow to Infinity + -- Unless we're rounding RTN/RTP to the opposite sign, or RTZ + -- in which case we overflow to MAX + if (mode = .RTN ∧ ¬x.num.sign) ∨ (mode = .RTP ∧ x.num.sign) ∨ mode = .RTZ then + PackedFloat.getMax _ _ x.num.sign + else + PackedFloat.getInfinity _ _ x.num.sign + else + let index := fls trimmed + let sigWidthB := BitVec.ofNat _ sigWidth + let ex : BitVec exWidth := + if index ≤ sigWidthB then + 0 + else + (index - sigWidthB).truncate _ + let truncSig : BitVec sigWidth := + if ex = 0 then + trimmed.truncate _ + else + (trimmed >>> (ex - 1)).truncate _ + let rem : BitVec (2^exWidth + underWidth) := + if ex = 0 then + under.truncate _ <<< (1<<