Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 95 additions & 25 deletions Fp/Basic.lean
Original file line number Diff line number Diff line change
@@ -1,6 +1,45 @@
import Fp.Utils
import Fp.ForLean.Dyadic

/-!
# Floating Point Parmeters
-/

@[bv_normalize]
def Nat.ceilLog2 (n : Nat) : Nat :=
if n.log2 * 2 = n then n.log2 else n.log2 + 1

@[bv_normalize]
def bias (e : Nat) : Nat :=
2 ^ (e - 1) - 1

@[bv_normalize]
def minNormalExp (e : Nat) : Int :=
-(bias e - 1)

@[bv_normalize]
def maxNormalExp (e : Nat) : Int := (bias e)

@[bv_normalize]
def subnormalExp (e : Nat) : Int :=
minNormalExp e - 1

@[simp]
theorem subnormalExp_eq_minus_bias (e : Nat) :
subnormalExp e = -(bias e) := by
simp [subnormalExp, minNormalExp, bias]
grind

-- This is a simpler (but less tight) bound than `exponentWidth`.
-- It's logarithmically larger.
@[bv_normalize]
def exponentWidth' (e s : Nat) : Nat :=
e + s.ceilLog2

@[bv_normalize]
def exponentWidth (e s : Nat) : Nat :=
(2 ^ (e - 1) + s - 2).log2 + 2

/-!
## Packed Floating Point Numbers

Expand Down Expand Up @@ -68,7 +107,10 @@ structure FixedPoint (width exOffset : Nat) where
val : BitVec width
-- | This should not be part of the structure, but a side invariant we keep in mind.
hExOffset : exOffset < width
deriving DecidableEq
deriving DecidableEq, Repr

-- TODO: We can create a theory of normalized fixed point numbers, that are either zero,
-- or have MSB = 1.

attribute [bv_normalize] FixedPoint.ext_iff

Expand Down Expand Up @@ -536,6 +578,7 @@ def toRat? (pf : PackedFloat e s) : Option Rat :=

end PackedFloat


/--
`UnpackedFloat e s` is the *working* (unpacked) representation of a floating-point
number with exponent width `e` and significand width `s`.
Expand Down Expand Up @@ -588,8 +631,9 @@ and in hardware floating-point pipelines.
structure UnpackedFloat (e s : Nat) where
sign : Bool
ex : BitVec e
sig : BitVec s
deriving DecidableEq, Inhabited, Repr
/-# 'sig' morally represents a 'FixedPoint s (s - 1)'. -/
sig : BitVec s
deriving DecidableEq, Repr

attribute [bv_normalize] UnpackedFloat.ext_iff

Expand Down Expand Up @@ -687,6 +731,54 @@ def toDyadic (uf : UnpackedFloat e s) : Dyadic :=
def toRat (uf : UnpackedFloat e s) : Rat :=
uf.toDyadic.toRat

def zero (e s : Nat) (sign : Bool) : UnpackedFloat e (s + 1) where
sign := sign
ex := BitVec.ofInt e 0
sig := BitVec.ofNat (s + 1) 0

/-- The minimum subnormal exponent, that shows up after normalizing a subnormal number -/
def normalizedMinSubnormalExp (e s : Nat) : Int :=
subnormalExp e - s

/-- Build the significand from the hidden bit and fraction bits. -/
def mkSigFromHiddenBitAndFrac (hiddenBit : Bool) (fraction : BitVec s) : BitVec s :=
((BitVec.ofBool hiddenBit).zeroExtend s <<< (s - 1)) ||| ((fraction <<< 1) >>> 1)

@[simp]
theorem getLsbD_mkSigFromHiddenBitAndFrac_of_lt {hidden : Bool} {frac : BitVec s} (i : Nat) (hi : i < s - 1):
(mkSigFromHiddenBitAndFrac hidden frac).getLsbD i = if i = s - 1 then hidden else frac.getLsbD i := by
grind [mkSigFromHiddenBitAndFrac]

@[simp]
theorem getElem_mkSigFromHiddenBitAndFrac {hidden : Bool} {frac : BitVec s} {i : Fin s} :
(mkSigFromHiddenBitAndFrac hidden frac)[i] = if i = s - 1 then hidden else frac[i] := by
grind [mkSigFromHiddenBitAndFrac]

/- Miniminum subnormal number that is unpacked-representable. -/
def minSubnormal (e s : Nat) : UnpackedFloat e s where -- s includes the hidden bit.
sign := false
sig := mkSigFromHiddenBitAndFrac false (BitVec.ofNat s 1) -- 0.000...01
ex := BitVec.ofInt e (subnormalExp e)


/-- Maximum subnormal number that is unpacked-representable. -/
def maxSubnormal (e s : Nat) : UnpackedFloat e s where
sign := false
sig := mkSigFromHiddenBitAndFrac false (BitVec.allOnes s) -- 0.111...11
ex := BitVec.ofInt e (subnormalExp e)

/-- Minimum positive normal that is unpacked-representable. -/
def minNormal (e s : Nat) : UnpackedFloat e s where
sign := false
sig := mkSigFromHiddenBitAndFrac true (BitVec.ofNat s 0) -- 1.000...00
ex := BitVec.ofInt e (minNormalExp e)

/-- Maximum normal number that is unpacked-representable. -/
def maxNormal (e s : Nat) : UnpackedFloat e s where
sign := false
sig := mkSigFromHiddenBitAndFrac true (BitVec.allOnes s) -- 1.111...11
ex := maxNormalExp e

end UnpackedFloat

namespace EUnpackedFloat
Expand Down Expand Up @@ -797,28 +889,6 @@ def toRat? (ef : EUnpackedFloat e s) : Option Rat :=

end EUnpackedFloat

@[bv_normalize]
def Nat.ceilLog2 (n : Nat) : Nat :=
if n.log2 * 2 = n then n.log2 else n.log2 + 1

@[bv_normalize]
def bias (e : Nat) : Nat :=
2 ^ (e - 1) - 1

@[bv_normalize]
def minNormalExp (e : Nat) : Int :=
-(bias e - 1)

-- This is a simpler (but less tight) bound than `exponentWidth`.
-- It's logarithmically larger.
@[bv_normalize]
def exponentWidth' (e s : Nat) : Nat :=
e + s.ceilLog2

@[bv_normalize]
def exponentWidth (e s : Nat) : Nat :=
(2 ^ (e - 1) + s - 2).log2 + 2

-- Constants

/-- E5M2 floating point representation of 1.0 -/
Expand Down
9 changes: 9 additions & 0 deletions Fp/Division.lean
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,15 @@ def BitVec.monus (a : BitVec w) (b : BitVec w) : BitVec w :=
@[bv_normalize]
def BitVec.sge (x y : BitVec w) : Bool := y.sle x

def roundUnpacked
(sig : BitVec sigWidth)
(ex : BitVec exWidth) (mode : RoundingMode) :
-- NOTE: we assume that 'sig' has sticky and guard bits.
-- Given these, we round to get a significand of size 's + 1'.
UnpackedFloat (exponentWidth enew snew) (snew + 1) :=
-- We will only do RNE for now.
sorry

@[bv_normalize]
def div_on_unpackedFloat (a b : UnpackedFloat (exponentWidth e s) (s + 1)) (mode : RoundingMode) : PackedFloat e s :=
-- a.toRat = (-1)^a.sign * a.sig.toNat * 2 ^ (a.ex.toInt) * 2 ^(-s)
Expand Down
67 changes: 67 additions & 0 deletions Fp/Rounding.lean
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,73 @@ theorem shouldRoundAway.match_eq_cond :
cases m <;> rfl


@[bv_normalize]
def roundNew (mode : RoundingMode) (sig : BitVec sigWidth) (exOffset : BitVec exWidth)
: EUnpackedFloat eout sout :=
if x.state = .NaN then
PackedFloat.getNaN _ _
else if x.state = .Infinity then
PackedFloat.getInfinity _ _ x.num.sign
else
let exOffset' := 2^(exWidth - 1) + sigWidth - 2
-- trim bitvector
let over := x.num.val >>> (exOffset + 2^(exWidth-1))
let a := (x.num.val >>> exOffset).truncate (2^(exWidth-1))
let b := truncateRight exOffset' (x.num.val.truncate exOffset)
let underWidth := exOffset - exOffset'
let under := x.num.val.truncate underWidth
let trimmed := a ++ b
if over != 0 then
-- Overflow to Infinity
-- Unless we're rounding RTN/RTP to the opposite sign, or RTZ
-- in which case we overflow to MAX
if (mode = .RTN ∧ ¬x.num.sign) ∨ (mode = .RTP ∧ x.num.sign) ∨ mode = .RTZ then
PackedFloat.getMax _ _ x.num.sign
else
PackedFloat.getInfinity _ _ x.num.sign
else
let index := fls trimmed
let sigWidthB := BitVec.ofNat _ sigWidth
let ex : BitVec exWidth :=
if index ≤ sigWidthB then
0
else
(index - sigWidthB).truncate _
let truncSig : BitVec sigWidth :=
if ex = 0 then
trimmed.truncate _
else
(trimmed >>> (ex - 1)).truncate _
let rem : BitVec (2^exWidth + underWidth) :=
if ex = 0 then
under.truncate _ <<< (1<<<exWidth)
else
let totalShift : BitVec (exWidth+1) := ex.truncate _ - 1
truncateRight _ (trimmed <<< ((1<<<exWidth) + sigWidth - 2 - totalShift)) |||
(under.truncate _ <<< ((1<<<exWidth) - totalShift))
if shouldRoundAway mode x.num.sign (truncSig.getLsbD 0) rem then
if truncSig = BitVec.allOnes _ then
-- overflow to next exponent
{
sign := x.num.sign
ex := ex+1
sig := 0
}
else
-- add 1 to significand
{
sign := x.num.sign
ex
sig := truncSig + 1
}
else
-- leave everything the same
{
sign := x.num.sign
ex
sig := truncSig
}

-- Round is less well-behaved when exWidth = 0. This shouldn't be an issue?
/--
Round an extended fixed-point number to its nearest floating point number of
Expand Down
Loading