Skip to main content

use_linear/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4//! Linear-algebra utilities for `RustUse`.
5
6use core::fmt;
7use core::ops::{Add, Mul, Neg, Sub};
8
9/// Errors returned by linear helpers when a system cannot be solved.
10#[derive(Clone, Copy, Debug, PartialEq)]
11pub enum LinearError {
12    /// The matrix is singular and does not have an inverse.
13    SingularMatrix { determinant: f64 },
14}
15
16impl fmt::Display for LinearError {
17    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
18        match self {
19            Self::SingularMatrix { determinant } => {
20                write!(
21                    formatter,
22                    "matrix is singular with determinant {determinant}"
23                )
24            },
25        }
26    }
27}
28
29impl std::error::Error for LinearError {}
30
31/// A 2D column vector.
32#[derive(Clone, Copy, Debug, Default, PartialEq)]
33pub struct Vector2 {
34    /// The first component.
35    pub x: f64,
36    /// The second component.
37    pub y: f64,
38}
39
40impl Vector2 {
41    /// Creates a vector from two components.
42    #[must_use]
43    pub const fn new(x: f64, y: f64) -> Self {
44        Self { x, y }
45    }
46
47    /// Returns the squared Euclidean norm.
48    #[must_use]
49    pub const fn magnitude_squared(self) -> f64 {
50        dot(self, self)
51    }
52
53    /// Returns the Euclidean norm.
54    #[must_use]
55    pub fn magnitude(self) -> f64 {
56        self.magnitude_squared().sqrt()
57    }
58
59    /// Returns the dot product with `other`.
60    #[must_use]
61    pub const fn dot(self, other: Self) -> f64 {
62        dot(self, other)
63    }
64}
65
66impl Add for Vector2 {
67    type Output = Self;
68
69    fn add(self, rhs: Self) -> Self::Output {
70        Self::new(self.x + rhs.x, self.y + rhs.y)
71    }
72}
73
74impl Sub for Vector2 {
75    type Output = Self;
76
77    fn sub(self, rhs: Self) -> Self::Output {
78        Self::new(self.x - rhs.x, self.y - rhs.y)
79    }
80}
81
82impl Neg for Vector2 {
83    type Output = Self;
84
85    fn neg(self) -> Self::Output {
86        Self::new(-self.x, -self.y)
87    }
88}
89
90impl Mul<f64> for Vector2 {
91    type Output = Self;
92
93    fn mul(self, rhs: f64) -> Self::Output {
94        Self::new(self.x * rhs, self.y * rhs)
95    }
96}
97
98/// A 2×2 matrix stored in row-major order.
99#[derive(Clone, Copy, Debug, Default, PartialEq)]
100pub struct Matrix2 {
101    /// Row 1, column 1.
102    pub m11: f64,
103    /// Row 1, column 2.
104    pub m12: f64,
105    /// Row 2, column 1.
106    pub m21: f64,
107    /// Row 2, column 2.
108    pub m22: f64,
109}
110
111impl Matrix2 {
112    /// Creates a 2×2 matrix from row-major entries.
113    #[must_use]
114    pub const fn new(m11: f64, m12: f64, m21: f64, m22: f64) -> Self {
115        Self { m11, m12, m21, m22 }
116    }
117
118    /// Returns the identity matrix.
119    #[must_use]
120    pub const fn identity() -> Self {
121        Self::new(1.0, 0.0, 0.0, 1.0)
122    }
123
124    /// Returns the transpose.
125    #[must_use]
126    pub const fn transpose(self) -> Self {
127        Self::new(self.m11, self.m21, self.m12, self.m22)
128    }
129
130    /// Returns the determinant.
131    #[must_use]
132    pub const fn determinant(self) -> f64 {
133        (self.m11 * self.m22) - (self.m12 * self.m21)
134    }
135
136    /// Returns the trace.
137    #[must_use]
138    pub const fn trace(self) -> f64 {
139        self.m11 + self.m22
140    }
141
142    /// Returns the matrix-vector product.
143    #[must_use]
144    pub const fn mul_vector(self, vector: Vector2) -> Vector2 {
145        Vector2::new(
146            (self.m11 * vector.x) + (self.m12 * vector.y),
147            (self.m21 * vector.x) + (self.m22 * vector.y),
148        )
149    }
150
151    /// Returns the matrix-matrix product.
152    #[must_use]
153    pub const fn mul_matrix(self, rhs: Self) -> Self {
154        let first_row = Vector2::new(self.m11, self.m12);
155        let second_row = Vector2::new(self.m21, self.m22);
156        let first_column = Vector2::new(rhs.m11, rhs.m21);
157        let second_column = Vector2::new(rhs.m12, rhs.m22);
158
159        Self::new(
160            dot(first_row, first_column),
161            dot(first_row, second_column),
162            dot(second_row, first_column),
163            dot(second_row, second_column),
164        )
165    }
166
167    /// Solves `self * x = rhs` for `x`.
168    ///
169    /// # Errors
170    ///
171    /// Returns [`LinearError::SingularMatrix`] when the determinant is zero.
172    pub fn solve(self, rhs: Vector2) -> Result<Vector2, LinearError> {
173        let determinant = self.determinant();
174
175        if determinant == 0.0 {
176            return Err(LinearError::SingularMatrix { determinant });
177        }
178
179        Ok(Vector2::new(
180            self.m22.mul_add(rhs.x, -(self.m12 * rhs.y)) / determinant,
181            self.m11.mul_add(rhs.y, -(self.m21 * rhs.x)) / determinant,
182        ))
183    }
184}
185
186impl Mul<Vector2> for Matrix2 {
187    type Output = Vector2;
188
189    fn mul(self, rhs: Vector2) -> Self::Output {
190        self.mul_vector(rhs)
191    }
192}
193
194impl Mul for Matrix2 {
195    type Output = Self;
196
197    fn mul(self, rhs: Self) -> Self::Output {
198        self.mul_matrix(rhs)
199    }
200}
201
202/// Returns the dot product of two vectors.
203#[must_use]
204pub const fn dot(left: Vector2, right: Vector2) -> f64 {
205    (left.x * right.x) + (left.y * right.y)
206}
207
208/// Solves `matrix * x = rhs` for `x`.
209///
210/// # Errors
211///
212/// Returns [`LinearError::SingularMatrix`] when the determinant is zero.
213pub fn solve_2x2(matrix: Matrix2, rhs: Vector2) -> Result<Vector2, LinearError> {
214    matrix.solve(rhs)
215}
216
217pub mod prelude;
218
219#[cfg(test)]
220mod tests {
221    use super::{LinearError, Matrix2, Vector2, dot, solve_2x2};
222
223    fn assert_close(left: f64, right: f64) {
224        assert!((left - right).abs() < 1.0e-12, "left={left}, right={right}");
225    }
226
227    #[test]
228    fn computes_vector_and_matrix_products() {
229        let left = Vector2::new(3.0, 4.0);
230        let right = Vector2::new(-2.0, 1.0);
231        let matrix = Matrix2::new(2.0, 1.0, 5.0, 3.0);
232
233        assert_eq!(left + right, Vector2::new(1.0, 5.0));
234        assert_eq!(left - right, Vector2::new(5.0, 3.0));
235        assert_eq!(-right, Vector2::new(2.0, -1.0));
236        assert_eq!(left * 2.0, Vector2::new(6.0, 8.0));
237        assert_close(dot(left, right), -2.0);
238        assert_close(left.dot(right), -2.0);
239        assert_close(left.magnitude_squared(), 25.0);
240        assert_close(left.magnitude(), 5.0);
241        assert_eq!(
242            matrix.mul_vector(Vector2::new(1.0, -1.0)),
243            Vector2::new(1.0, 2.0)
244        );
245        assert_eq!(matrix * Vector2::new(1.0, -1.0), Vector2::new(1.0, 2.0));
246        assert_eq!(matrix.transpose(), Matrix2::new(2.0, 5.0, 1.0, 3.0));
247        assert_close(matrix.trace(), 5.0);
248        assert_close(matrix.determinant(), 1.0);
249        assert_eq!(matrix * Matrix2::identity(), matrix);
250    }
251
252    #[test]
253    fn solves_nonsingular_systems_and_rejects_singular_ones() {
254        let matrix = Matrix2::new(2.0, 1.0, 5.0, 3.0);
255        let rhs = Vector2::new(1.0, 2.0);
256
257        assert_eq!(
258            matrix.solve(rhs).expect("system should solve"),
259            Vector2::new(1.0, -1.0)
260        );
261        assert_eq!(
262            solve_2x2(matrix, rhs).expect("system should solve"),
263            Vector2::new(1.0, -1.0)
264        );
265        assert_eq!(
266            Matrix2::new(1.0, 2.0, 2.0, 4.0).solve(Vector2::new(1.0, 2.0)),
267            Err(LinearError::SingularMatrix { determinant: 0.0 })
268        );
269    }
270}