Skip to main content

use_ml_metric/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        MlClassificationMetric, MlClusteringMetric, MlGenerationMetric, MlMetricAggregation,
10        MlMetricDirection, MlMetricError, MlMetricKind, MlMetricName, MlMetricValue,
11        MlRankingMetric, MlRegressionMetric,
12    };
13}
14
15#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
16pub struct MlMetricName(String);
17
18impl MlMetricName {
19    pub fn new(value: impl AsRef<str>) -> Result<Self, MlMetricError> {
20        non_empty_text(value).map(Self)
21    }
22
23    pub fn as_str(&self) -> &str {
24        &self.0
25    }
26}
27
28impl AsRef<str> for MlMetricName {
29    fn as_ref(&self) -> &str {
30        self.as_str()
31    }
32}
33
34impl fmt::Display for MlMetricName {
35    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
36        formatter.write_str(self.as_str())
37    }
38}
39
40impl FromStr for MlMetricName {
41    type Err = MlMetricError;
42
43    fn from_str(value: &str) -> Result<Self, Self::Err> {
44        Self::new(value)
45    }
46}
47
48impl TryFrom<&str> for MlMetricName {
49    type Error = MlMetricError;
50
51    fn try_from(value: &str) -> Result<Self, Self::Error> {
52        Self::new(value)
53    }
54}
55
56#[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
57pub struct MlMetricValue(f64);
58
59impl MlMetricValue {
60    pub fn new(value: f64) -> Result<Self, MlMetricError> {
61        if value.is_finite() {
62            Ok(Self(value))
63        } else {
64            Err(MlMetricError::NonFinite)
65        }
66    }
67
68    pub const fn value(self) -> f64 {
69        self.0
70    }
71}
72
73macro_rules! metric_enum {
74    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
75        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
76        pub enum $name {
77            $($variant),+
78        }
79
80        impl $name {
81            pub const fn as_str(self) -> &'static str {
82                match self {
83                    $(Self::$variant => $label),+
84                }
85            }
86        }
87
88        impl fmt::Display for $name {
89            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
90                formatter.write_str(self.as_str())
91            }
92        }
93
94        impl FromStr for $name {
95            type Err = MlMetricError;
96
97            fn from_str(value: &str) -> Result<Self, Self::Err> {
98                match normalized_label(value)?.as_str() {
99                    $($label => Ok(Self::$variant),)+
100                    _ => Err(MlMetricError::UnknownLabel),
101                }
102            }
103        }
104    };
105}
106
107metric_enum!(MlMetricKind {
108    Classification => "classification",
109    Regression => "regression",
110    Ranking => "ranking",
111    Clustering => "clustering",
112    Forecasting => "forecasting",
113    Generation => "generation",
114    Retrieval => "retrieval",
115    Calibration => "calibration",
116    Fairness => "fairness",
117    Performance => "performance",
118    Resource => "resource",
119    Other => "other",
120});
121
122metric_enum!(MlMetricDirection {
123    HigherIsBetter => "higher-is-better",
124    LowerIsBetter => "lower-is-better",
125    TargetIsBest => "target-is-best",
126    Unknown => "unknown",
127});
128
129metric_enum!(MlMetricAggregation {
130    Mean => "mean",
131    Median => "median",
132    Min => "min",
133    Max => "max",
134    Sum => "sum",
135    WeightedMean => "weighted-mean",
136    Macro => "macro",
137    Micro => "micro",
138    Samples => "samples",
139    None => "none",
140});
141
142metric_enum!(MlClassificationMetric {
143    Accuracy => "accuracy",
144    Precision => "precision",
145    Recall => "recall",
146    F1 => "f1",
147    RocAuc => "roc-auc",
148    PrAuc => "pr-auc",
149    LogLoss => "log-loss",
150    MatthewsCorrelationCoefficient => "matthews-correlation-coefficient",
151    BalancedAccuracy => "balanced-accuracy",
152});
153
154impl MlClassificationMetric {
155    pub const fn direction(self) -> MlMetricDirection {
156        match self {
157            Self::LogLoss => MlMetricDirection::LowerIsBetter,
158            Self::Accuracy
159            | Self::Precision
160            | Self::Recall
161            | Self::F1
162            | Self::RocAuc
163            | Self::PrAuc
164            | Self::MatthewsCorrelationCoefficient
165            | Self::BalancedAccuracy => MlMetricDirection::HigherIsBetter,
166        }
167    }
168}
169
170metric_enum!(MlRegressionMetric {
171    Mae => "mae",
172    Mse => "mse",
173    Rmse => "rmse",
174    R2 => "r2",
175    Mape => "mape",
176    Smape => "smape",
177    MedianAbsoluteError => "median-absolute-error",
178});
179
180impl MlRegressionMetric {
181    pub const fn direction(self) -> MlMetricDirection {
182        match self {
183            Self::R2 => MlMetricDirection::HigherIsBetter,
184            Self::Mae
185            | Self::Mse
186            | Self::Rmse
187            | Self::Mape
188            | Self::Smape
189            | Self::MedianAbsoluteError => MlMetricDirection::LowerIsBetter,
190        }
191    }
192}
193
194metric_enum!(MlRankingMetric {
195    Ndcg => "ndcg",
196    Map => "map",
197    Mrr => "mrr",
198    HitRate => "hit-rate",
199    RecallAtK => "recall-at-k",
200    PrecisionAtK => "precision-at-k",
201});
202
203impl MlRankingMetric {
204    pub const fn direction(self) -> MlMetricDirection {
205        MlMetricDirection::HigherIsBetter
206    }
207}
208
209metric_enum!(MlClusteringMetric {
210    Silhouette => "silhouette",
211    AdjustedRandIndex => "adjusted-rand-index",
212    NormalizedMutualInfo => "normalized-mutual-info",
213    DaviesBouldin => "davies-bouldin",
214});
215
216impl MlClusteringMetric {
217    pub const fn direction(self) -> MlMetricDirection {
218        match self {
219            Self::DaviesBouldin => MlMetricDirection::LowerIsBetter,
220            Self::Silhouette | Self::AdjustedRandIndex | Self::NormalizedMutualInfo => {
221                MlMetricDirection::HigherIsBetter
222            },
223        }
224    }
225}
226
227metric_enum!(MlGenerationMetric {
228    Bleu => "bleu",
229    Rouge => "rouge",
230    Meteor => "meteor",
231    BertScore => "bert-score",
232    ExactMatch => "exact-match",
233    Perplexity => "perplexity",
234});
235
236impl MlGenerationMetric {
237    pub const fn direction(self) -> MlMetricDirection {
238        match self {
239            Self::Perplexity => MlMetricDirection::LowerIsBetter,
240            Self::Bleu | Self::Rouge | Self::Meteor | Self::BertScore | Self::ExactMatch => {
241                MlMetricDirection::HigherIsBetter
242            },
243        }
244    }
245}
246
247#[derive(Clone, Copy, Debug, Eq, PartialEq)]
248pub enum MlMetricError {
249    Empty,
250    NonFinite,
251    UnknownLabel,
252}
253
254impl fmt::Display for MlMetricError {
255    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
256        match self {
257            Self::Empty => formatter.write_str("ML metric metadata text cannot be empty"),
258            Self::NonFinite => formatter.write_str("ML metric value must be finite"),
259            Self::UnknownLabel => formatter.write_str("unknown ML metric metadata label"),
260        }
261    }
262}
263
264impl Error for MlMetricError {}
265
266fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlMetricError> {
267    let trimmed = value.as_ref().trim();
268    if trimmed.is_empty() {
269        Err(MlMetricError::Empty)
270    } else {
271        Ok(trimmed.to_string())
272    }
273}
274
275fn normalized_label(value: &str) -> Result<String, MlMetricError> {
276    let trimmed = value.trim();
277    if trimmed.is_empty() {
278        Err(MlMetricError::Empty)
279    } else {
280        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::{
287        MlClassificationMetric, MlMetricDirection, MlMetricError, MlMetricName, MlMetricValue,
288        MlRankingMetric, MlRegressionMetric,
289    };
290
291    #[test]
292    fn validates_metric_names_and_values() -> Result<(), MlMetricError> {
293        let name = MlMetricName::new(" accuracy ")?;
294        let value = MlMetricValue::new(0.93)?;
295
296        assert_eq!(name.as_str(), "accuracy");
297        assert_eq!(value.value(), 0.93);
298        assert_eq!(MlMetricName::new("  "), Err(MlMetricError::Empty));
299        assert_eq!(MlMetricValue::new(f64::NAN), Err(MlMetricError::NonFinite));
300        Ok(())
301    }
302
303    #[test]
304    fn displays_parses_and_labels_metric_directions() -> Result<(), MlMetricError> {
305        assert_eq!(
306            "roc auc".parse::<MlClassificationMetric>()?,
307            MlClassificationMetric::RocAuc
308        );
309        assert_eq!(
310            "precision at k".parse::<MlRankingMetric>()?,
311            MlRankingMetric::PrecisionAtK
312        );
313        assert_eq!(
314            "rmse".parse::<MlRegressionMetric>()?,
315            MlRegressionMetric::Rmse
316        );
317        assert_eq!(
318            MlClassificationMetric::Accuracy.direction(),
319            MlMetricDirection::HigherIsBetter
320        );
321        assert_eq!(
322            MlClassificationMetric::LogLoss.direction(),
323            MlMetricDirection::LowerIsBetter
324        );
325        assert_eq!(
326            MlRegressionMetric::R2.direction(),
327            MlMetricDirection::HigherIsBetter
328        );
329        assert_eq!(
330            MlRegressionMetric::Rmse.direction(),
331            MlMetricDirection::LowerIsBetter
332        );
333        Ok(())
334    }
335}