diff --git a/Fp/Basic.lean b/Fp/Basic.lean index 3469fc0..0b92107 100644 --- a/Fp/Basic.lean +++ b/Fp/Basic.lean @@ -805,10 +805,20 @@ def Nat.ceilLog2 (n : Nat) : Nat := def bias (e : Nat) : Nat := 2 ^ (e - 1) - 1 +/-- The minimum value the exponent can take when unbiased. -/ @[bv_normalize] def minNormalExp (e : Nat) : Int := -(bias e - 1) +/-- The max value the exponent can take when unbiased. -/ +@[bv_normalize] +def maxNormalExp (e : Nat) : Int := (bias e) + +/-- The value the subnormal exponent can take. -/ +@[bv_normalize] +def subnormalExp (e : Nat) : Int := + minNormalExp e - 1 + -- This is a simpler (but less tight) bound than `exponentWidth`. -- It's logarithmically larger. @[bv_normalize] diff --git a/Fp/Division.lean b/Fp/Division.lean index ecdcf7d..4ac247d 100644 --- a/Fp/Division.lean +++ b/Fp/Division.lean @@ -255,10 +255,244 @@ def BitVec.monus (a : BitVec w) (b : BitVec w) : BitVec w := if a ≤ b then 0#w else a - b +@[bv_normalize] +def BitVec.addExtending (a : BitVec v) (b : BitVec w) : BitVec (max v w + 1) := + let a' := a.signExtend (max v w + 1) + let b' := b.signExtend (max v w + 1) + a' + b' + +@[bv_normalize] +def BitVec.subExtending (a : BitVec v) (b : BitVec w) : BitVec (max v w + 1) := + let a' := a.signExtend (max v w + 1) + let b' := b.signExtend (max v w + 1) + a' - b' + +@[bv_normalize] +def BitVec.eqExtending (a : BitVec v) (b : BitVec w) : Prop := + let a' := a.signExtend (max v w) + let b' := b.signExtend (max v w) + a' = b' + +@[bv_normalize] +def BitVec.sltExtending (a : BitVec v) (b : BitVec w) : Prop := + let a' := a.signExtend (max v w) + let b' := b.signExtend (max v w) + a'.slt b' + + +instance : Decidable (BitVec.eqExtending a b) := by + unfold BitVec.eqExtending + simp + infer_instance + +instance : Decidable (BitVec.sltExtending a b) := by + unfold BitVec.sltExtending + simp + infer_instance + /-- x ≥ y ↔ y ≤ x-/ @[bv_normalize] def BitVec.sge (x y : BitVec w) : Bool := y.sle x +#check round + + +@[bv_normalize] +def BitVec.dropLsbs (bv : BitVec w) (n : Nat) : BitVec (w - n) := + (bv >>> n).setWidth _ + +@[bv_normalize] +theorem getMsbD_dropLsbs {w n : Nat} (h : n < w) (bv : BitVec w) : + bv.getMsbD i = (bv.dropLsbs n).getMsbD i := by + simp [BitVec.dropLsbs, BitVec.getMsbD] + by_cases hi : i < w + · simp [hi] + sorry + · simp [hi] + sorry + +@[bv_normalize] +def BitVec.dropMsb (bv : BitVec w) (n : Nat) : BitVec (w - n) := + bv.setWidth _ + +@[bv_normalize] +def BitVec.takeMsb (bv : BitVec w) (n : Nat) : BitVec n := + (bv >>> (w - n)).setWidth _ + +@[bv_normalize] +def BitVec.splitAtMsbs (bv : BitVec w) (n : Nat) : BitVec n × BitVec (w - n) := + (bv.takeMsb n, bv.dropMsb n) + +/-- Shift left 'bv' by 'shAmt', extending 'bv' to width 'v' first. -/ +@[bv_normalize] +def BitVec.shlExtending (bv : BitVec w) (shAmt : BitVec v) : BitVec v := + (bv.zeroExtend v) <<< shAmt + +@[bv_normalize] +def BitVec.shrExtending (bv : BitVec w) (shAmt : BitVec v) : BitVec v := + (bv.zeroExtend v) >>> shAmt + + +@[bv_normalize] +def EUnpackedFloat.incrSignificand {e s : Nat} (eu : EUnpackedFloat e s) : EUnpackedFloat e s := + match eu.state with + | .NaN => eu + | .Infinity => eu + | .Number => + if eu.num.sig = BitVec.allOnes _ then + -- we overflowed the significand, so we need to adjust exponent. + let newExp := eu.num.ex + 1 + EUnpackedFloat.mkNumber { + sign := eu.num.sign + ex := newExp + sig := 1#s -- since we overflowed, the new significand is 1.000... (is that right?) + } + else + EUnpackedFloat.mkNumber { + sign := eu.num.sign + ex := eu.num.ex + sig := eu.num.sig + 1 + } + +/- + +The correct way to think of an UnpackedFloat is as a fixed point number, +written in scientific notation. + +So, consider fixed point numbers with three digits, dot right after the first digit (as in normal floating point). +This gives us: + +→ 0.00 0.01 0.10 0.11 +→ 1.00 1.01 1.10 1.11 + +→ 0.00 = _ | 0.01 = 1.00 * 10^(-2) | 0.10 = 1.00 * 10^(-1) | 0.11 = 1.10 * 10^(-1) +→ 1.00 = 1.00 * 10^(0) | 1.01 = 1.01 * 10^(0) | 1.10 = 1.10 * 10^(0) | 1.11 = 1.11 * 10^(0) + +Next, to encode the `1.00`, we write this as: + +→ 0.00 = _ | 0.01 = (100 * 10^(-2)) * 10^(-2) | 0.10 = (100 * 10^(-2)) * 10^(-1) | 0.11 = (110 * 10^(-2)) * 10^(-1) +→ 1.00 = (100 * 10^(-2)) * 10^(0) | 1.01 = (101 * 10^(-2)) * 10^(0) | 1.10 = (110*10^(-2)) * 10^(0) | 1.11 = (111 * 10^(-2)) * 10^(0) + +In this way, we can see our number as a composition of scientific notation with fixed-width exponent. +In the UnpackedFloat, we only store the exponent, not the constant fixed-point shift (10^(-2)), +since this is common to all numbers, and also, morally, it is absorbed into the significand as a fixed-point. +-/ + +/- +#### Why guard and sticky bits: + +Suppose we want to round to 0 bits of precision, using round to nearest even. Then, if we have: + +- 10.1 -> 10 +- 10.5???, we need to know if the ??? is zero or not. + + 10.5000 -> 10 + * If it is zero, then we round to the nearest even, which is 10. + + 10.5001 -> 11 + * If it is non-zero, we *always* round up to 11, since it's strictly closer to 11 than 10. +- 10.9 -> 11 +- 11.5????? + + 11.5000 -> 12 + * If it is zero, then we round to the nearest even, which is 12. + + 11.5001 -> 12 + * If it is non-zero, we *always* round up to 12, since it's strictly closer to 12 than 11. + +Thus, see that just having one sticky bit is enough, since it tells us whether we are at the exact halfway point or not, and that's all we need to know. + + +We will return an EUnpackedFloat, +which can signal NaN/Infinity/Zero, in addition to normal numbers. + + +Rounding normalized numbers +--------------------------- + +Suppose we have 3 digits of precision, and the exponent can go from -9 to 7. +Then, suppose we have an unnormalized number: + +##### Smallest Exponent +- 0.001 * 10^7 + This upon normalization becomes: 1.000 * 10^(7-3) = 1.000 * 10^4 + +- 0.001 * 10^(-9) + This upon normalization becomes: 1.000 * 10^(-9-3) = 1.000 * 10^(-12) + + +So, even though our scale can go from -9 to 7, after normalization, we can have exponents from -12 to 7. +If our exponent is smaller than -12, then we cannot fit it. + +##### Largest Exponent + +- 1.111 * 10^7 | This upon normalization stays the same: 1.111 * 10^7. +-/ + +-- the number is sig.toNat * 2 ^(-sigPrec) * 2 ^ (ex.toInt) +-- choose e = 0, s = 1 then see that after normalization, we e. +-- If our exponent is +@[bv_normalize] +def roundRNEFastUF (inUf : UnpackedFloat e s) : EUnpackedFloat e' s' := + let inUf := inUf.normalize + -- Great, we have a number of the form <1.xxxxx> * 2^(ex) or, 0 * 2^(ex). + if inUf.sig = 0 + -- we are in the <0> case. + then EUnpackedFloat.mkNumber { + sign := inUf.sign + ex := 0#e' + sig := 0#s' + } + else + -- we are in the <1.-----> case. + if hExTooBig : inUf.ex > BitVec.ofInt e (maxNormalExp e') then -- + BitVec.ofInt e s' then + EUnpackedFloat.mkInfinity inUf.sign + -- | TODO: why does this computation fit? In particular, why does adding `s'` not overflow? + -- We may need an extension here. + else if hExTooSmall : (inUf.ex).slt (BitVec.ofInt e (minNormalExp e') + BitVec.ofInt e s') then + -- See example, where `0.001 * 10^-5 -> 1.00 * 10^-7`, which is the smallest exponent we can use. + EUnpackedFloat.mkZero inUf.sign + else + -- Adjust the exponent to be in range. + let adjustedExp : BitVec e' := + if inUf.ex > BitVec.ofInt e (maxNormalExp e') then + BitVec.ofInt e' (maxNormalExp e') + else if inUf.ex < BitVec.ofInt e (minNormalExp e') then + BitVec.ofInt e' (minNormalExp e') + else + inUf.ex.signExtend e' + -- Adjust the significand. + let (truncatedSig, remainder) : BitVec s' × BitVec (s - s') := + if BitVec.eqExtending adjustedExp inUf.ex then + (inUf.sig.splitAtMsbs s') + -- e' < e + else if BitVec.sltExtending adjustedExp inUf.ex then + -- we decreased exponent, so increase significand + -- | TODO: can this overflow? + let sig' := (inUf.sig <<< (BitVec.subExtending inUf.ex adjustedExp)) + sig'.splitAtMsbs s' + else + -- e' > e + -- we increased exponent, so decrease significand + -- | TODO: can this overflow? + let shift := BitVec.subExtending adjustedExp inUf.ex + let sig' := (inUf.sig >>> shift).truncate _ + sig'.splitAtMsbs s' + let guardBit := remainder.getMsbD 0 + let stickyBit : Bool := (remainder.dropMsb 1) ≠ 0 + let isOdd := truncatedSig.getLsbD 0 + let shouldRoundAway : Bool := guardBit && (stickyBit || isOdd) + let ufOut : UnpackedFloat e' s' := + { + sign := inUf.sign + ex := adjustedExp + sig := truncatedSig + } + let eufOut : EUnpackedFloat e' s' := + EUnpackedFloat.mkNumber ufOut + if shouldRoundAway then + -- we should round up. + eufOut.incrSignificand + else + eufOut + + @[bv_normalize] def div_on_unpackedFloat (a b : UnpackedFloat (exponentWidth e s) (s + 1)) (mode : RoundingMode) : PackedFloat e s := -- a.toRat = (-1)^a.sign * a.sig.toNat * 2 ^ (a.ex.toInt) * 2 ^(-s) @@ -280,6 +514,14 @@ def div_on_unpackedFloat (a b : UnpackedFloat (exponentWidth e s) (s + 1)) (mode let expNumerator := a.ex -- if a.ex > 0 then a.ex - 1 else 0 let expDenominator := b.ex -- if b.ex > 0 then b.ex - 1 else 0 -- Shift and round + let ufIn : UnpackedFloat _ _ := { + sign + ex := BitVec.subExtending expNumerator expDenominator + sig := quot_with_sticky + } + roundRNEFastUF ufIn |>.pack + + /- -- | TODO: For the rounding, we still expand out into fixed point. -- We should instead use the cleverer rounder. if expNumerator.sge expDenominator then @@ -317,6 +559,7 @@ def div_on_unpackedFloat (a b : UnpackedFloat (exponentWidth e s) (s + 1)) (mode } } round _ _ mode quot_rshift + -/ /-- Division of two floating-point numbers, rounded to a floating point number @@ -373,4 +616,6 @@ theorem div_one_is_id_on_numerator_ex_smaller (a : PackedFloat 5 2) (h : ¬ a.un theorem div_self_is_one' (a : PackedFloat 5 2) (h : ¬a.isNaN ∧ ¬a.isInfinite ∧ ¬a.isZero) : (div a a .RTZ) = oneE5M2 := by + bv_normalize + simp at *; bv_decide diff --git a/Fp/Tactics/Simps.lean b/Fp/Tactics/Simps.lean index e3bc9c3..66d8467 100644 --- a/Fp/Tactics/Simps.lean +++ b/Fp/Tactics/Simps.lean @@ -24,6 +24,12 @@ open Lean Meta Simp in dsimproc [seval, simp, bv_normalize] reduceLog2 (Nat.log2 _) := Nat.reduceUnary ``Nat.log2 1 Nat.log2 +dsimproc [seval, simp, bv_normalize] reduceMax (Nat.max _ _) := + Nat.reduceBin ``Nat.max 2 Nat.max + +dsimproc [seval, simp, bv_normalize] reduceAdd (Nat.add _ _) := + Nat.reduceBin ``Nat.add 2 Nat.add + open Lean Meta Simp in simproc ↓ [bv_normalize] ite_eq_cond_proc (@ite _ _ _ _ _) := fun e => do let mkApp5 (.const ``ite [u]) α c hc t e := e | return .continue