diff --git a/src/ops.rs b/src/ops.rs index 2ce284e..39c8b4c 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -18,6 +18,26 @@ impl std::ops::Add for F8 { } } +impl std::ops::Sub for F8 { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + let (m1, e1) = self.split_unbias(); + let (m2, e2) = rhs.split_unbias(); + if e1 < e2 { + return Self(0); + } + if rhs.0 == 0 { + return self; + } + let e_diff = e1 - e2; + let Some(m) = m1.checked_sub(m2 >> e_diff) else { + return Self(0); + }; + Self::merge_unbias(m, e1) + } +} + #[cfg(test)] mod tests { use super::*; @@ -32,4 +52,15 @@ mod tests { assert_eq!(F8::from(0b0011) + F8::from(0b0111), F8::from(0b1010)); assert_eq!(F8::from(0b1100) + F8::from(0b0001), F8::from(0b1101)); } + + #[test] + fn test_sub() { + assert_eq!(F8::from(0b0011) - F8::from(0b0000), F8::from(0b0011)); + assert_eq!(F8::from(0b0000) - F8::from(0b0111), F8::from(0)); + assert_eq!(F8::from(0b0011) - F8::from(0b0101), F8::from(0)); + assert_eq!(F8::from(0b0101) - F8::from(0b0011), F8::from(0b0010)); + assert_eq!(F8::from(0b0110) - F8::from(0b0011), F8::from(0b0011)); + assert_eq!(F8::from(0b0111) - F8::from(0b0011), F8::from(0b0100)); + assert_eq!(F8::from(0b1100) - F8::from(0b0001), F8::from(0b1011)); + } }