Skip to main content

use_modular/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4//! Small modular arithmetic primitives for `RustUse`.
5
6fn checked_modulus(modulus: i64) -> Option<i128> {
7    (modulus > 0).then_some(i128::from(modulus))
8}
9
10fn normalized_i128(value: i64, modulus: i64) -> Option<i128> {
11    let modulus = checked_modulus(modulus)?;
12    Some(i128::from(value).rem_euclid(modulus))
13}
14
15/// Basic modular arithmetic helpers.
16pub mod arithmetic {
17    use crate::{checked_modulus, normalized_i128};
18
19    /// Normalizes `value` into the residue class `0..modulus`.
20    ///
21    /// Returns `None` when `modulus <= 0`.
22    #[must_use]
23    pub fn mod_normalize(value: i64, modulus: i64) -> Option<i64> {
24        i64::try_from(normalized_i128(value, modulus)?).ok()
25    }
26
27    /// Computes `(a + b) mod modulus` and returns the normalized residue.
28    ///
29    /// Returns `None` when `modulus <= 0`.
30    #[must_use]
31    pub fn mod_add(a: i64, b: i64, modulus: i64) -> Option<i64> {
32        let modulus_i128 = checked_modulus(modulus)?;
33        let sum = normalized_i128(a, modulus)? + normalized_i128(b, modulus)?;
34        i64::try_from(sum.rem_euclid(modulus_i128)).ok()
35    }
36
37    /// Computes `(a - b) mod modulus` and returns the normalized residue.
38    ///
39    /// Returns `None` when `modulus <= 0`.
40    #[must_use]
41    pub fn mod_sub(a: i64, b: i64, modulus: i64) -> Option<i64> {
42        let modulus_i128 = checked_modulus(modulus)?;
43        let difference = normalized_i128(a, modulus)? - normalized_i128(b, modulus)?;
44        i64::try_from(difference.rem_euclid(modulus_i128)).ok()
45    }
46
47    /// Computes `(a * b) mod modulus` and returns the normalized residue.
48    ///
49    /// Uses `i128` internally to reduce overflow risk for large `i64` inputs.
50    /// Returns `None` when `modulus <= 0`.
51    #[must_use]
52    pub fn mod_mul(a: i64, b: i64, modulus: i64) -> Option<i64> {
53        let modulus_i128 = checked_modulus(modulus)?;
54        let product = normalized_i128(a, modulus)? * normalized_i128(b, modulus)?;
55        i64::try_from(product.rem_euclid(modulus_i128)).ok()
56    }
57}
58
59/// Modular exponentiation helpers.
60pub mod power {
61    use crate::{checked_modulus, normalized_i128};
62
63    /// Computes `base.pow(exponent) mod modulus` using exponentiation by squaring.
64    ///
65    /// Returns the normalized residue in `0..modulus`, or `None` when
66    /// `modulus <= 0`.
67    #[must_use]
68    pub fn mod_pow(base: i64, exponent: u64, modulus: i64) -> Option<i64> {
69        let modulus_i128 = checked_modulus(modulus)?;
70        let mut result = i128::from(1 % modulus);
71        let mut factor = normalized_i128(base, modulus)?;
72        let mut power = exponent;
73
74        while power > 0 {
75            if power & 1 == 1 {
76                result = (result * factor).rem_euclid(modulus_i128);
77            }
78
79            factor = (factor * factor).rem_euclid(modulus_i128);
80            power >>= 1;
81        }
82
83        i64::try_from(result).ok()
84    }
85}
86
87/// Modular inverse helpers.
88pub mod inverse {
89    use crate::{checked_modulus, normalized_i128};
90
91    /// Computes the multiplicative inverse of `value` modulo `modulus`.
92    ///
93    /// Returns `Some(inverse)` only when the inverse exists. The returned
94    /// residue is normalized to `0..modulus`. Returns `None` when
95    /// `modulus <= 0` or when `value` and `modulus` are not coprime.
96    #[must_use]
97    pub fn mod_inverse(value: i64, modulus: i64) -> Option<i64> {
98        let modulus_i128 = checked_modulus(modulus)?;
99        let value_i128 = normalized_i128(value, modulus)?;
100        let (gcd, coefficient, _) = extended_gcd(value_i128, modulus_i128);
101
102        (gcd == 1)
103            .then(|| coefficient.rem_euclid(modulus_i128))
104            .and_then(|inverse| i64::try_from(inverse).ok())
105    }
106
107    const fn extended_gcd(left: i128, right: i128) -> (i128, i128, i128) {
108        let (mut old_remainder, mut remainder) = (left, right);
109        let (mut old_left_coefficient, mut left_coefficient) = (1_i128, 0_i128);
110        let (mut old_right_coefficient, mut right_coefficient) = (0_i128, 1_i128);
111
112        while remainder != 0 {
113            let quotient = old_remainder / remainder;
114
115            (old_remainder, remainder) = (remainder, old_remainder - quotient * remainder);
116            (old_left_coefficient, left_coefficient) = (
117                left_coefficient,
118                old_left_coefficient - quotient * left_coefficient,
119            );
120            (old_right_coefficient, right_coefficient) = (
121                right_coefficient,
122                old_right_coefficient - quotient * right_coefficient,
123            );
124        }
125
126        (
127            old_remainder.abs(),
128            old_left_coefficient,
129            old_right_coefficient,
130        )
131    }
132}
133
134/// Modular congruence helpers.
135pub mod congruence {
136    use crate::arithmetic::mod_normalize;
137
138    /// Returns `true` when `a` and `b` are congruent modulo `modulus`.
139    ///
140    /// Returns `false` when `modulus <= 0`.
141    #[must_use]
142    pub fn is_congruent(a: i64, b: i64, modulus: i64) -> bool {
143        match (mod_normalize(a, modulus), mod_normalize(b, modulus)) {
144            (Some(left), Some(right)) => left == right,
145            _ => false,
146        }
147    }
148}
149
150pub use arithmetic::{mod_add, mod_mul, mod_normalize, mod_sub};
151pub use congruence::is_congruent;
152pub use inverse::mod_inverse;
153pub use power::mod_pow;
154
155/// A normalized modular residue paired with its positive modulus.
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub struct Modular {
158    value: i64,
159    modulus: i64,
160}
161
162impl Modular {
163    /// Creates a normalized modular value.
164    ///
165    /// Returns `None` when `modulus <= 0`.
166    #[must_use]
167    pub fn new(value: i64, modulus: i64) -> Option<Self> {
168        Some(Self {
169            value: mod_normalize(value, modulus)?,
170            modulus,
171        })
172    }
173
174    /// Returns the normalized residue in `0..modulus`.
175    #[must_use]
176    pub const fn value(self) -> i64 {
177        self.value
178    }
179
180    /// Returns the positive modulus carried by this value.
181    #[must_use]
182    pub const fn modulus(self) -> i64 {
183        self.modulus
184    }
185
186    /// Adds two modular values with the same modulus.
187    ///
188    /// Returns `None` when the moduli differ.
189    #[must_use]
190    #[allow(clippy::should_implement_trait)]
191    pub fn add(self, other: Self) -> Option<Self> {
192        let modulus = self.same_modulus(other)?;
193        Self::new(mod_add(self.value, other.value, modulus)?, modulus)
194    }
195
196    /// Subtracts two modular values with the same modulus.
197    ///
198    /// Returns `None` when the moduli differ.
199    #[must_use]
200    #[allow(clippy::should_implement_trait)]
201    pub fn sub(self, other: Self) -> Option<Self> {
202        let modulus = self.same_modulus(other)?;
203        Self::new(mod_sub(self.value, other.value, modulus)?, modulus)
204    }
205
206    /// Multiplies two modular values with the same modulus.
207    ///
208    /// Returns `None` when the moduli differ.
209    #[must_use]
210    #[allow(clippy::should_implement_trait)]
211    pub fn mul(self, other: Self) -> Option<Self> {
212        let modulus = self.same_modulus(other)?;
213        Self::new(mod_mul(self.value, other.value, modulus)?, modulus)
214    }
215
216    /// Raises the modular value to `exponent` using modular exponentiation.
217    #[must_use]
218    pub fn pow(self, exponent: u64) -> Option<Self> {
219        Self::new(mod_pow(self.value, exponent, self.modulus)?, self.modulus)
220    }
221
222    /// Computes the multiplicative inverse when one exists.
223    #[must_use]
224    pub fn inverse(self) -> Option<Self> {
225        Self::new(mod_inverse(self.value, self.modulus)?, self.modulus)
226    }
227
228    const fn same_modulus(self, other: Self) -> Option<i64> {
229        if self.modulus == other.modulus {
230            Some(self.modulus)
231        } else {
232            None
233        }
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::{
240        Modular, is_congruent, mod_add, mod_inverse, mod_mul, mod_normalize, mod_pow, mod_sub,
241    };
242
243    #[test]
244    fn accepts_positive_modulus() {
245        assert_eq!(mod_normalize(0, 1), Some(0));
246        assert_eq!(mod_normalize(7, 5), Some(2));
247    }
248
249    #[test]
250    fn rejects_zero_modulus() {
251        assert_eq!(mod_normalize(3, 0), None);
252        assert_eq!(mod_add(1, 2, 0), None);
253        assert_eq!(mod_sub(1, 2, 0), None);
254        assert_eq!(mod_mul(1, 2, 0), None);
255        assert_eq!(mod_pow(2, 3, 0), None);
256        assert_eq!(mod_inverse(3, 0), None);
257        assert!(!is_congruent(1, 1, 0));
258    }
259
260    #[test]
261    fn rejects_negative_modulus() {
262        assert_eq!(mod_normalize(3, -5), None);
263        assert_eq!(mod_add(1, 2, -5), None);
264        assert_eq!(mod_sub(1, 2, -5), None);
265        assert_eq!(mod_mul(1, 2, -5), None);
266        assert_eq!(mod_pow(2, 3, -5), None);
267        assert_eq!(mod_inverse(3, -5), None);
268        assert!(!is_congruent(1, 1, -5));
269    }
270
271    #[test]
272    fn normalizes_positive_values() {
273        assert_eq!(mod_normalize(17, 5), Some(2));
274    }
275
276    #[test]
277    fn normalizes_negative_values() {
278        assert_eq!(mod_normalize(-1, 5), Some(4));
279        assert_eq!(mod_normalize(-13, 5), Some(2));
280    }
281
282    #[test]
283    fn adds_residues() {
284        assert_eq!(mod_add(4, 3, 5), Some(2));
285    }
286
287    #[test]
288    fn subtracts_residues() {
289        assert_eq!(mod_sub(2, 4, 5), Some(3));
290    }
291
292    #[test]
293    fn multiplies_residues() {
294        assert_eq!(mod_mul(4, 4, 5), Some(1));
295    }
296
297    #[test]
298    fn computes_modular_powers() {
299        assert_eq!(mod_pow(2, 10, 1_000), Some(24));
300    }
301
302    #[test]
303    fn handles_zero_exponent() {
304        assert_eq!(mod_pow(9, 0, 5), Some(1));
305        assert_eq!(mod_pow(9, 0, 1), Some(0));
306    }
307
308    #[test]
309    fn computes_existing_inverse() {
310        assert_eq!(mod_inverse(3, 11), Some(4));
311    }
312
313    #[test]
314    fn reports_missing_inverse() {
315        assert_eq!(mod_inverse(2, 4), None);
316    }
317
318    #[test]
319    fn checks_congruence() {
320        assert!(is_congruent(17, 5, 12));
321    }
322
323    #[test]
324    fn checks_non_congruence() {
325        assert!(!is_congruent(17, 6, 12));
326    }
327
328    #[test]
329    fn multiplies_large_values_with_i128_intermediate() {
330        let left = 3_037_000_500_i64;
331        let right = 3_037_000_500_i64;
332        let modulus = 97_i64;
333        let expected = (i128::from(left) * i128::from(right)).rem_euclid(i128::from(modulus));
334
335        assert_eq!(mod_mul(left, right, modulus), i64::try_from(expected).ok());
336    }
337
338    #[test]
339    fn constructs_and_operates_on_modular_values() {
340        let left = Modular::new(-1, 5).expect("valid modular value");
341        let right = Modular::new(3, 5).expect("valid modular value");
342        let different = Modular::new(1, 7).expect("valid modular value");
343
344        assert_eq!(left.value(), 4);
345        assert_eq!(left.modulus(), 5);
346        assert_eq!(left.add(right).map(Modular::value), Some(2));
347        assert_eq!(left.sub(right).map(Modular::value), Some(1));
348        assert_eq!(left.mul(right).map(Modular::value), Some(2));
349        assert_eq!(right.pow(4).map(Modular::value), Some(1));
350        assert_eq!(right.inverse().map(Modular::value), Some(2));
351        assert_eq!(left.add(different), None);
352    }
353}