1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::fmt;
7use core::ops::{Add, Mul, Neg, Sub};
8
9#[derive(Clone, Copy, Debug, PartialEq)]
11pub enum LinearError {
12 SingularMatrix { determinant: f64 },
14}
15
16impl fmt::Display for LinearError {
17 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
18 match self {
19 Self::SingularMatrix { determinant } => {
20 write!(
21 formatter,
22 "matrix is singular with determinant {determinant}"
23 )
24 },
25 }
26 }
27}
28
29impl std::error::Error for LinearError {}
30
31#[derive(Clone, Copy, Debug, Default, PartialEq)]
33pub struct Vector2 {
34 pub x: f64,
36 pub y: f64,
38}
39
40impl Vector2 {
41 #[must_use]
43 pub const fn new(x: f64, y: f64) -> Self {
44 Self { x, y }
45 }
46
47 #[must_use]
49 pub const fn magnitude_squared(self) -> f64 {
50 dot(self, self)
51 }
52
53 #[must_use]
55 pub fn magnitude(self) -> f64 {
56 self.magnitude_squared().sqrt()
57 }
58
59 #[must_use]
61 pub const fn dot(self, other: Self) -> f64 {
62 dot(self, other)
63 }
64}
65
66impl Add for Vector2 {
67 type Output = Self;
68
69 fn add(self, rhs: Self) -> Self::Output {
70 Self::new(self.x + rhs.x, self.y + rhs.y)
71 }
72}
73
74impl Sub for Vector2 {
75 type Output = Self;
76
77 fn sub(self, rhs: Self) -> Self::Output {
78 Self::new(self.x - rhs.x, self.y - rhs.y)
79 }
80}
81
82impl Neg for Vector2 {
83 type Output = Self;
84
85 fn neg(self) -> Self::Output {
86 Self::new(-self.x, -self.y)
87 }
88}
89
90impl Mul<f64> for Vector2 {
91 type Output = Self;
92
93 fn mul(self, rhs: f64) -> Self::Output {
94 Self::new(self.x * rhs, self.y * rhs)
95 }
96}
97
98#[derive(Clone, Copy, Debug, Default, PartialEq)]
100pub struct Matrix2 {
101 pub m11: f64,
103 pub m12: f64,
105 pub m21: f64,
107 pub m22: f64,
109}
110
111impl Matrix2 {
112 #[must_use]
114 pub const fn new(m11: f64, m12: f64, m21: f64, m22: f64) -> Self {
115 Self { m11, m12, m21, m22 }
116 }
117
118 #[must_use]
120 pub const fn identity() -> Self {
121 Self::new(1.0, 0.0, 0.0, 1.0)
122 }
123
124 #[must_use]
126 pub const fn transpose(self) -> Self {
127 Self::new(self.m11, self.m21, self.m12, self.m22)
128 }
129
130 #[must_use]
132 pub const fn determinant(self) -> f64 {
133 (self.m11 * self.m22) - (self.m12 * self.m21)
134 }
135
136 #[must_use]
138 pub const fn trace(self) -> f64 {
139 self.m11 + self.m22
140 }
141
142 #[must_use]
144 pub const fn mul_vector(self, vector: Vector2) -> Vector2 {
145 Vector2::new(
146 (self.m11 * vector.x) + (self.m12 * vector.y),
147 (self.m21 * vector.x) + (self.m22 * vector.y),
148 )
149 }
150
151 #[must_use]
153 pub const fn mul_matrix(self, rhs: Self) -> Self {
154 let first_row = Vector2::new(self.m11, self.m12);
155 let second_row = Vector2::new(self.m21, self.m22);
156 let first_column = Vector2::new(rhs.m11, rhs.m21);
157 let second_column = Vector2::new(rhs.m12, rhs.m22);
158
159 Self::new(
160 dot(first_row, first_column),
161 dot(first_row, second_column),
162 dot(second_row, first_column),
163 dot(second_row, second_column),
164 )
165 }
166
167 pub fn solve(self, rhs: Vector2) -> Result<Vector2, LinearError> {
173 let determinant = self.determinant();
174
175 if determinant == 0.0 {
176 return Err(LinearError::SingularMatrix { determinant });
177 }
178
179 Ok(Vector2::new(
180 self.m22.mul_add(rhs.x, -(self.m12 * rhs.y)) / determinant,
181 self.m11.mul_add(rhs.y, -(self.m21 * rhs.x)) / determinant,
182 ))
183 }
184}
185
186impl Mul<Vector2> for Matrix2 {
187 type Output = Vector2;
188
189 fn mul(self, rhs: Vector2) -> Self::Output {
190 self.mul_vector(rhs)
191 }
192}
193
194impl Mul for Matrix2 {
195 type Output = Self;
196
197 fn mul(self, rhs: Self) -> Self::Output {
198 self.mul_matrix(rhs)
199 }
200}
201
202#[must_use]
204pub const fn dot(left: Vector2, right: Vector2) -> f64 {
205 (left.x * right.x) + (left.y * right.y)
206}
207
208pub fn solve_2x2(matrix: Matrix2, rhs: Vector2) -> Result<Vector2, LinearError> {
214 matrix.solve(rhs)
215}
216
217pub mod prelude;
218
219#[cfg(test)]
220mod tests {
221 use super::{LinearError, Matrix2, Vector2, dot, solve_2x2};
222
223 fn assert_close(left: f64, right: f64) {
224 assert!((left - right).abs() < 1.0e-12, "left={left}, right={right}");
225 }
226
227 #[test]
228 fn computes_vector_and_matrix_products() {
229 let left = Vector2::new(3.0, 4.0);
230 let right = Vector2::new(-2.0, 1.0);
231 let matrix = Matrix2::new(2.0, 1.0, 5.0, 3.0);
232
233 assert_eq!(left + right, Vector2::new(1.0, 5.0));
234 assert_eq!(left - right, Vector2::new(5.0, 3.0));
235 assert_eq!(-right, Vector2::new(2.0, -1.0));
236 assert_eq!(left * 2.0, Vector2::new(6.0, 8.0));
237 assert_close(dot(left, right), -2.0);
238 assert_close(left.dot(right), -2.0);
239 assert_close(left.magnitude_squared(), 25.0);
240 assert_close(left.magnitude(), 5.0);
241 assert_eq!(
242 matrix.mul_vector(Vector2::new(1.0, -1.0)),
243 Vector2::new(1.0, 2.0)
244 );
245 assert_eq!(matrix * Vector2::new(1.0, -1.0), Vector2::new(1.0, 2.0));
246 assert_eq!(matrix.transpose(), Matrix2::new(2.0, 5.0, 1.0, 3.0));
247 assert_close(matrix.trace(), 5.0);
248 assert_close(matrix.determinant(), 1.0);
249 assert_eq!(matrix * Matrix2::identity(), matrix);
250 }
251
252 #[test]
253 fn solves_nonsingular_systems_and_rejects_singular_ones() {
254 let matrix = Matrix2::new(2.0, 1.0, 5.0, 3.0);
255 let rhs = Vector2::new(1.0, 2.0);
256
257 assert_eq!(
258 matrix.solve(rhs).expect("system should solve"),
259 Vector2::new(1.0, -1.0)
260 );
261 assert_eq!(
262 solve_2x2(matrix, rhs).expect("system should solve"),
263 Vector2::new(1.0, -1.0)
264 );
265 assert_eq!(
266 Matrix2::new(1.0, 2.0, 2.0, 4.0).solve(Vector2::new(1.0, 2.0)),
267 Err(LinearError::SingularMatrix { determinant: 0.0 })
268 );
269 }
270}