1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8 pub use crate::{
9 AiApiMode, AiBillingUnit, AiEndpointKind, AiEndpointName, AiProviderError, AiProviderId,
10 AiProviderKind, AiProviderName, AiQuotaKind, AiRateLimitKind, AiRegionKind,
11 };
12}
13
14macro_rules! provider_text_newtype {
15 ($name:ident) => {
16 #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17 pub struct $name(String);
18
19 impl $name {
20 pub fn new(value: impl AsRef<str>) -> Result<Self, AiProviderError> {
21 non_empty_text(value).map(Self)
22 }
23
24 pub fn as_str(&self) -> &str {
25 &self.0
26 }
27
28 pub fn value(&self) -> &str {
29 self.as_str()
30 }
31
32 pub fn into_string(self) -> String {
33 self.0
34 }
35 }
36
37 impl AsRef<str> for $name {
38 fn as_ref(&self) -> &str {
39 self.as_str()
40 }
41 }
42
43 impl fmt::Display for $name {
44 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
45 formatter.write_str(self.as_str())
46 }
47 }
48
49 impl FromStr for $name {
50 type Err = AiProviderError;
51
52 fn from_str(value: &str) -> Result<Self, Self::Err> {
53 Self::new(value)
54 }
55 }
56
57 impl TryFrom<&str> for $name {
58 type Error = AiProviderError;
59
60 fn try_from(value: &str) -> Result<Self, Self::Error> {
61 Self::new(value)
62 }
63 }
64 };
65}
66
67macro_rules! provider_enum {
68 ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
69 #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
70 pub enum $name {
71 $($variant),+
72 }
73
74 impl $name {
75 pub const ALL: &'static [Self] = &[$(Self::$variant),+];
76
77 pub const fn as_str(self) -> &'static str {
78 match self {
79 $(Self::$variant => $label),+
80 }
81 }
82 }
83
84 impl fmt::Display for $name {
85 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
86 formatter.write_str(self.as_str())
87 }
88 }
89
90 impl FromStr for $name {
91 type Err = AiProviderError;
92
93 fn from_str(value: &str) -> Result<Self, Self::Err> {
94 match normalized_label(value)?.as_str() {
95 $($label => Ok(Self::$variant),)+
96 _ => Err(AiProviderError::UnknownLabel),
97 }
98 }
99 }
100 };
101}
102
103provider_text_newtype!(AiProviderName);
104provider_text_newtype!(AiProviderId);
105provider_text_newtype!(AiEndpointName);
106
107provider_enum!(AiProviderKind {
108 HostedApi => "hosted-api",
109 CloudPlatform => "cloud-platform",
110 LocalRuntime => "local-runtime",
111 OpenSourceHost => "open-source-host",
112 InternalGateway => "internal-gateway",
113 Custom => "custom",
114});
115
116provider_enum!(AiEndpointKind {
117 Chat => "chat",
118 Completion => "completion",
119 Responses => "responses",
120 Embedding => "embedding",
121 Rerank => "rerank",
122 Image => "image",
123 Audio => "audio",
124 Realtime => "realtime",
125 Batch => "batch",
126 Moderation => "moderation",
127 Custom => "custom",
128});
129
130provider_enum!(AiApiMode {
131 Sync => "sync",
132 Async => "async",
133 Streaming => "streaming",
134 Batch => "batch",
135 Realtime => "realtime",
136});
137
138provider_enum!(AiRateLimitKind {
139 RequestsPerMinute => "requests-per-minute",
140 TokensPerMinute => "tokens-per-minute",
141 ImagesPerMinute => "images-per-minute",
142 ConcurrentRequests => "concurrent-requests",
143 Custom => "custom",
144});
145
146provider_enum!(AiQuotaKind {
147 HardLimit => "hard-limit",
148 SoftLimit => "soft-limit",
149 Burst => "burst",
150 Trial => "trial",
151 Unknown => "unknown",
152});
153
154provider_enum!(AiBillingUnit {
155 InputToken => "input-token",
156 OutputToken => "output-token",
157 CachedToken => "cached-token",
158 Request => "request",
159 Image => "image",
160 AudioSecond => "audio-second",
161 ComputeSecond => "compute-second",
162 Custom => "custom",
163});
164
165provider_enum!(AiRegionKind {
166 Global => "global",
167 Regional => "regional",
168 Local => "local",
169 Edge => "edge",
170 Unknown => "unknown",
171});
172
173#[derive(Clone, Copy, Debug, Eq, PartialEq)]
174pub enum AiProviderError {
175 Empty,
176 UnknownLabel,
177}
178
179impl fmt::Display for AiProviderError {
180 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
181 match self {
182 Self::Empty => formatter.write_str("AI provider metadata text cannot be empty"),
183 Self::UnknownLabel => formatter.write_str("unknown AI provider metadata label"),
184 }
185 }
186}
187
188impl Error for AiProviderError {}
189
190fn non_empty_text(value: impl AsRef<str>) -> Result<String, AiProviderError> {
191 let trimmed = value.as_ref().trim();
192 if trimmed.is_empty() {
193 Err(AiProviderError::Empty)
194 } else {
195 Ok(trimmed.to_string())
196 }
197}
198
199fn normalized_label(value: &str) -> Result<String, AiProviderError> {
200 let trimmed = value.trim();
201 if trimmed.is_empty() {
202 Err(AiProviderError::Empty)
203 } else {
204 Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::{
211 AiApiMode, AiBillingUnit, AiEndpointKind, AiEndpointName, AiProviderError, AiProviderId,
212 AiProviderKind, AiProviderName, AiQuotaKind, AiRateLimitKind, AiRegionKind,
213 };
214 use core::{fmt, str::FromStr};
215
216 macro_rules! assert_text_newtype {
217 ($type:ty, $value:literal) => {{
218 let value = <$type>::new(concat!(" ", $value, " "))?;
219 assert_eq!(value.as_str(), $value);
220 assert_eq!(value.value(), $value);
221 assert_eq!(value.as_ref(), $value);
222 assert_eq!(value.to_string(), $value);
223 assert_eq!(<$type as TryFrom<&str>>::try_from($value)?, value);
224 assert_eq!(value.into_string(), $value.to_string());
225 }};
226 }
227
228 fn assert_enum_family<T>(variants: &[T]) -> Result<(), AiProviderError>
229 where
230 T: Copy + Eq + fmt::Debug + fmt::Display + FromStr<Err = AiProviderError>,
231 {
232 for variant in variants {
233 let label = variant.to_string();
234 assert_eq!(label.parse::<T>()?, *variant);
235 assert_eq!(label.replace('-', "_").parse::<T>()?, *variant);
236 assert_eq!(label.replace('-', " ").parse::<T>()?, *variant);
237 }
238 Ok(())
239 }
240
241 #[test]
242 fn validates_provider_text_newtypes() -> Result<(), AiProviderError> {
243 assert_text_newtype!(AiProviderName, "local-runtime");
244 assert_text_newtype!(AiProviderId, "provider-001");
245 assert_text_newtype!(AiEndpointName, "chat-primary");
246 assert_eq!(AiProviderName::new(" "), Err(AiProviderError::Empty));
247 Ok(())
248 }
249
250 #[test]
251 fn displays_and_parses_provider_enums() -> Result<(), AiProviderError> {
252 assert_enum_family(AiProviderKind::ALL)?;
253 assert_enum_family(AiEndpointKind::ALL)?;
254 assert_enum_family(AiApiMode::ALL)?;
255 assert_enum_family(AiRateLimitKind::ALL)?;
256 assert_enum_family(AiQuotaKind::ALL)?;
257 assert_enum_family(AiBillingUnit::ALL)?;
258 assert_enum_family(AiRegionKind::ALL)?;
259 assert_eq!(
260 "hosted api".parse::<AiProviderKind>()?,
261 AiProviderKind::HostedApi
262 );
263 Ok(())
264 }
265}