1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8 pub use crate::{
9 MlModelCard, MlModelCardAudience, MlModelCardDatasetRef, MlModelCardError,
10 MlModelCardEvaluationSummary, MlModelCardIntendedUse, MlModelCardLimitation,
11 MlModelCardName, MlModelCardOwner, MlModelCardRisk, MlModelCardSection,
12 };
13}
14
15#[derive(Clone, Debug, Eq, PartialEq)]
16pub struct MlModelCard {
17 name: MlModelCardName,
18 sections: Vec<MlModelCardSection>,
19}
20
21impl MlModelCard {
22 pub fn new(name: MlModelCardName) -> Self {
23 Self {
24 name,
25 sections: Vec::new(),
26 }
27 }
28
29 pub fn name(&self) -> &MlModelCardName {
30 &self.name
31 }
32
33 pub fn sections(&self) -> &[MlModelCardSection] {
34 &self.sections
35 }
36
37 pub fn with_section(mut self, section: MlModelCardSection) -> Self {
38 self.sections.push(section);
39 self
40 }
41}
42
43macro_rules! model_card_text_newtype {
44 ($name:ident) => {
45 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
46 pub struct $name(String);
47
48 impl $name {
49 pub fn new(value: impl AsRef<str>) -> Result<Self, MlModelCardError> {
50 non_empty_text(value).map(Self)
51 }
52
53 pub fn as_str(&self) -> &str {
54 &self.0
55 }
56 }
57
58 impl AsRef<str> for $name {
59 fn as_ref(&self) -> &str {
60 self.as_str()
61 }
62 }
63
64 impl fmt::Display for $name {
65 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
66 formatter.write_str(self.as_str())
67 }
68 }
69
70 impl FromStr for $name {
71 type Err = MlModelCardError;
72
73 fn from_str(value: &str) -> Result<Self, Self::Err> {
74 Self::new(value)
75 }
76 }
77
78 impl TryFrom<&str> for $name {
79 type Error = MlModelCardError;
80
81 fn try_from(value: &str) -> Result<Self, Self::Error> {
82 Self::new(value)
83 }
84 }
85 };
86}
87
88macro_rules! model_card_enum {
89 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
90 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
91 pub enum $name {
92 $($variant),+
93 }
94
95 impl $name {
96 pub const fn as_str(self) -> &'static str {
97 match self {
98 $(Self::$variant => $label),+
99 }
100 }
101 }
102
103 impl fmt::Display for $name {
104 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
105 formatter.write_str(self.as_str())
106 }
107 }
108
109 impl FromStr for $name {
110 type Err = MlModelCardError;
111
112 fn from_str(value: &str) -> Result<Self, Self::Err> {
113 match normalized_label(value)?.as_str() {
114 $($label => Ok(Self::$variant),)+
115 _ => Err(MlModelCardError::UnknownLabel),
116 }
117 }
118 }
119 };
120}
121
122model_card_text_newtype!(MlModelCardName);
123model_card_text_newtype!(MlModelCardLimitation);
124model_card_text_newtype!(MlModelCardEvaluationSummary);
125model_card_text_newtype!(MlModelCardDatasetRef);
126model_card_text_newtype!(MlModelCardOwner);
127
128model_card_enum!(MlModelCardSection {
129 Overview => "overview",
130 IntendedUse => "intended-use",
131 Factors => "factors",
132 Metrics => "metrics",
133 EvaluationData => "evaluation-data",
134 TrainingData => "training-data",
135 EthicalConsiderations => "ethical-considerations",
136 CaveatsAndRecommendations => "caveats-and-recommendations",
137 Limitations => "limitations",
138 Contact => "contact",
139 License => "license",
140 Other => "other",
141});
142
143model_card_enum!(MlModelCardAudience {
144 Developer => "developer",
145 Researcher => "researcher",
146 Operator => "operator",
147 EndUser => "end-user",
148 Auditor => "auditor",
149 Regulator => "regulator",
150 Public => "public",
151 Internal => "internal",
152});
153
154model_card_enum!(MlModelCardIntendedUse {
155 Research => "research",
156 Production => "production",
157 Education => "education",
158 Evaluation => "evaluation",
159 Demo => "demo",
160 InternalTooling => "internal-tooling",
161 Other => "other",
162});
163
164model_card_enum!(MlModelCardRisk {
165 Bias => "bias",
166 Privacy => "privacy",
167 Security => "security",
168 Safety => "safety",
169 Misuse => "misuse",
170 Performance => "performance",
171 DataQuality => "data-quality",
172 DistributionShift => "distribution-shift",
173 Other => "other",
174});
175
176#[derive(Clone, Copy, Debug, Eq, PartialEq)]
177pub enum MlModelCardError {
178 Empty,
179 UnknownLabel,
180}
181
182impl fmt::Display for MlModelCardError {
183 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
184 match self {
185 Self::Empty => formatter.write_str("ML model-card metadata text cannot be empty"),
186 Self::UnknownLabel => formatter.write_str("unknown ML model-card metadata label"),
187 }
188 }
189}
190
191impl Error for MlModelCardError {}
192
193fn non_empty_text(value: impl AsRef<str>) -> Result<String, MlModelCardError> {
194 let trimmed = value.as_ref().trim();
195 if trimmed.is_empty() {
196 Err(MlModelCardError::Empty)
197 } else {
198 Ok(trimmed.to_string())
199 }
200}
201
202fn normalized_label(value: &str) -> Result<String, MlModelCardError> {
203 let trimmed = value.trim();
204 if trimmed.is_empty() {
205 Err(MlModelCardError::Empty)
206 } else {
207 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::{
214 MlModelCard, MlModelCardAudience, MlModelCardError, MlModelCardName, MlModelCardRisk,
215 MlModelCardSection,
216 };
217
218 #[test]
219 fn validates_model_card_names_and_builds_cards() -> Result<(), MlModelCardError> {
220 let card = MlModelCard::new(MlModelCardName::new(" baseline-card ")?)
221 .with_section(MlModelCardSection::Overview);
222
223 assert_eq!(card.name().as_str(), "baseline-card");
224 assert_eq!(card.sections(), &[MlModelCardSection::Overview]);
225 assert_eq!(MlModelCardName::new(" "), Err(MlModelCardError::Empty));
226 Ok(())
227 }
228
229 #[test]
230 fn displays_and_parses_model_card_enums() -> Result<(), MlModelCardError> {
231 assert_eq!(
232 "ethical considerations".parse::<MlModelCardSection>()?,
233 MlModelCardSection::EthicalConsiderations
234 );
235 assert_eq!(
236 "end user".parse::<MlModelCardAudience>()?,
237 MlModelCardAudience::EndUser
238 );
239 assert_eq!(
240 "data quality".parse::<MlModelCardRisk>()?,
241 MlModelCardRisk::DataQuality
242 );
243 Ok(())
244 }
245}