support multiplication

This commit is contained in:
numzero 2025-11-02 02:56:15 +03:00
parent 2ea94bfda4
commit 7ff3893aff
2 changed files with 32 additions and 7 deletions

View File

@ -41,21 +41,21 @@ impl F8 {
Self((e << M_BITS) | m)
}
fn merge_unbias(in_m: u8, in_e: i8) -> Self {
fn merge_unbias(in_m: u32, in_e: i32) -> Self {
if in_m == 0 {
return Self(0);
}
let base_e = in_m.ilog2() as u8;
let off = base_e as i8 - M_BITS as i8;
let m = if off >= 0 { in_m >> off } else { in_m << -off };
let e = (base_e as i8) + in_e + (E_BIAS as i8);
let e = (base_e as i32) + in_e + (E_BIAS as i32);
if e < 0 {
return Self(0);
}
if e > E_STORAGE_MAX as i8 {
if e > E_STORAGE_MAX as i32 {
return Self(0xff);
}
Self::merge(m & M_STORAGE_MAX, e as u8)
Self::merge(m as u8 & M_STORAGE_MAX, e as u8)
}
}
@ -163,7 +163,7 @@ mod tests {
assert_eq!(f32::from(F8::merge_unbias(3, 1)), 6.0);
assert_eq!(f32::from(F8::merge_unbias(3, -1)), 1.5);
assert_eq!(
f32::from(F8::merge_unbias(EXACT_INT_MAX, 0)),
f32::from(F8::merge_unbias(EXACT_INT_MAX.into(), 0)),
EXACT_INT_MAX as f32
);
}

View File

@ -14,7 +14,7 @@ impl std::ops::Add for F8 {
}
let e_diff = e1 - e2;
let m = m1 + (m2 >> e_diff);
Self::merge_unbias(m, e1)
Self::merge_unbias(m.into(), e1.into())
}
}
@ -34,7 +34,20 @@ impl std::ops::Sub for F8 {
let Some(m) = m1.checked_sub(m2 >> e_diff) else {
return Self(0);
};
Self::merge_unbias(m, e1)
Self::merge_unbias(m.into(), e1.into())
}
}
impl std::ops::Mul for F8 {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
let (m1, e1) = self.split_unbias();
let (m2, e2) = rhs.split_unbias();
if self.0 == 0 || rhs.0 == 0 {
return Self(0);
}
Self::merge_unbias(m1 as u32 * m2 as u32, e1 as i32 + e2 as i32)
}
}
@ -63,4 +76,16 @@ mod tests {
assert_eq!(F8::from(7) - F8::from(3), F8::from(4));
assert_eq!(F8::from(12) - F8::from(1), F8::from(11));
}
#[test]
fn test_mul() {
assert_eq!(F8::from(3) * F8::from(0), F8::from(0));
assert_eq!(F8::from(0) * F8::from(7), F8::from(0));
assert_eq!(F8::from(3) * F8::from(5), F8::from(15));
assert_eq!(F8::from(5) * F8::from(3), F8::from(15));
assert_eq!(F8::from(3) * F8::from(6), F8::from(18));
assert_eq!(F8::from(3) * F8::from(7), F8::from(21));
assert_eq!(F8::from(12) * F8::from(1), F8::from(12));
assert_eq!(F8::from(12) * F8::from(2), F8::from(24));
}
}