Skip to main content

use_ai_provider/
lib.rs

1#![forbid(unsafe_code)]
2#![doc = include_str!("../README.md")]
3
4use core::{fmt, str::FromStr};
5use std::error::Error;
6
7pub mod prelude {
8    pub use crate::{
9        AiApiMode, AiBillingUnit, AiEndpointKind, AiEndpointName, AiProviderError, AiProviderId,
10        AiProviderKind, AiProviderName, AiQuotaKind, AiRateLimitKind, AiRegionKind,
11    };
12}
13
14macro_rules! provider_text_newtype {
15    ($name:ident) => {
16        #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
17        pub struct $name(String);
18
19        impl $name {
20            pub fn new(value: impl AsRef<str>) -> Result<Self, AiProviderError> {
21                non_empty_text(value).map(Self)
22            }
23
24            pub fn as_str(&self) -> &str {
25                &self.0
26            }
27
28            pub fn value(&self) -> &str {
29                self.as_str()
30            }
31
32            pub fn into_string(self) -> String {
33                self.0
34            }
35        }
36
37        impl AsRef<str> for $name {
38            fn as_ref(&self) -> &str {
39                self.as_str()
40            }
41        }
42
43        impl fmt::Display for $name {
44            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
45                formatter.write_str(self.as_str())
46            }
47        }
48
49        impl FromStr for $name {
50            type Err = AiProviderError;
51
52            fn from_str(value: &str) -> Result<Self, Self::Err> {
53                Self::new(value)
54            }
55        }
56
57        impl TryFrom<&str> for $name {
58            type Error = AiProviderError;
59
60            fn try_from(value: &str) -> Result<Self, Self::Error> {
61                Self::new(value)
62            }
63        }
64    };
65}
66
67macro_rules! provider_enum {
68    ($name:ident { $($variant:ident => $label:literal),+ $(,)? }) => {
69        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
70        pub enum $name {
71            $($variant),+
72        }
73
74        impl $name {
75            pub const ALL: &'static [Self] = &[$(Self::$variant),+];
76
77            pub const fn as_str(self) -> &'static str {
78                match self {
79                    $(Self::$variant => $label),+
80                }
81            }
82        }
83
84        impl fmt::Display for $name {
85            fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
86                formatter.write_str(self.as_str())
87            }
88        }
89
90        impl FromStr for $name {
91            type Err = AiProviderError;
92
93            fn from_str(value: &str) -> Result<Self, Self::Err> {
94                match normalized_label(value)?.as_str() {
95                    $($label => Ok(Self::$variant),)+
96                    _ => Err(AiProviderError::UnknownLabel),
97                }
98            }
99        }
100    };
101}
102
103provider_text_newtype!(AiProviderName);
104provider_text_newtype!(AiProviderId);
105provider_text_newtype!(AiEndpointName);
106
107provider_enum!(AiProviderKind {
108    HostedApi => "hosted-api",
109    CloudPlatform => "cloud-platform",
110    LocalRuntime => "local-runtime",
111    OpenSourceHost => "open-source-host",
112    InternalGateway => "internal-gateway",
113    Custom => "custom",
114});
115
116provider_enum!(AiEndpointKind {
117    Chat => "chat",
118    Completion => "completion",
119    Responses => "responses",
120    Embedding => "embedding",
121    Rerank => "rerank",
122    Image => "image",
123    Audio => "audio",
124    Realtime => "realtime",
125    Batch => "batch",
126    Moderation => "moderation",
127    Custom => "custom",
128});
129
130provider_enum!(AiApiMode {
131    Sync => "sync",
132    Async => "async",
133    Streaming => "streaming",
134    Batch => "batch",
135    Realtime => "realtime",
136});
137
138provider_enum!(AiRateLimitKind {
139    RequestsPerMinute => "requests-per-minute",
140    TokensPerMinute => "tokens-per-minute",
141    ImagesPerMinute => "images-per-minute",
142    ConcurrentRequests => "concurrent-requests",
143    Custom => "custom",
144});
145
146provider_enum!(AiQuotaKind {
147    HardLimit => "hard-limit",
148    SoftLimit => "soft-limit",
149    Burst => "burst",
150    Trial => "trial",
151    Unknown => "unknown",
152});
153
154provider_enum!(AiBillingUnit {
155    InputToken => "input-token",
156    OutputToken => "output-token",
157    CachedToken => "cached-token",
158    Request => "request",
159    Image => "image",
160    AudioSecond => "audio-second",
161    ComputeSecond => "compute-second",
162    Custom => "custom",
163});
164
165provider_enum!(AiRegionKind {
166    Global => "global",
167    Regional => "regional",
168    Local => "local",
169    Edge => "edge",
170    Unknown => "unknown",
171});
172
173#[derive(Clone, Copy, Debug, Eq, PartialEq)]
174pub enum AiProviderError {
175    Empty,
176    UnknownLabel,
177}
178
179impl fmt::Display for AiProviderError {
180    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
181        match self {
182            Self::Empty => formatter.write_str("AI provider metadata text cannot be empty"),
183            Self::UnknownLabel => formatter.write_str("unknown AI provider metadata label"),
184        }
185    }
186}
187
188impl Error for AiProviderError {}
189
190fn non_empty_text(value: impl AsRef<str>) -> Result<String, AiProviderError> {
191    let trimmed = value.as_ref().trim();
192    if trimmed.is_empty() {
193        Err(AiProviderError::Empty)
194    } else {
195        Ok(trimmed.to_string())
196    }
197}
198
199fn normalized_label(value: &str) -> Result<String, AiProviderError> {
200    let trimmed = value.trim();
201    if trimmed.is_empty() {
202        Err(AiProviderError::Empty)
203    } else {
204        Ok(trimmed.to_ascii_lowercase().replace(['_', ' '], "-"))
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::{
211        AiApiMode, AiBillingUnit, AiEndpointKind, AiEndpointName, AiProviderError, AiProviderId,
212        AiProviderKind, AiProviderName, AiQuotaKind, AiRateLimitKind, AiRegionKind,
213    };
214    use core::{fmt, str::FromStr};
215
216    macro_rules! assert_text_newtype {
217        ($type:ty, $value:literal) => {{
218            let value = <$type>::new(concat!(" ", $value, " "))?;
219            assert_eq!(value.as_str(), $value);
220            assert_eq!(value.value(), $value);
221            assert_eq!(value.as_ref(), $value);
222            assert_eq!(value.to_string(), $value);
223            assert_eq!(<$type as TryFrom<&str>>::try_from($value)?, value);
224            assert_eq!(value.into_string(), $value.to_string());
225        }};
226    }
227
228    fn assert_enum_family<T>(variants: &[T]) -> Result<(), AiProviderError>
229    where
230        T: Copy + Eq + fmt::Debug + fmt::Display + FromStr<Err = AiProviderError>,
231    {
232        for variant in variants {
233            let label = variant.to_string();
234            assert_eq!(label.parse::<T>()?, *variant);
235            assert_eq!(label.replace('-', "_").parse::<T>()?, *variant);
236            assert_eq!(label.replace('-', " ").parse::<T>()?, *variant);
237        }
238        Ok(())
239    }
240
241    #[test]
242    fn validates_provider_text_newtypes() -> Result<(), AiProviderError> {
243        assert_text_newtype!(AiProviderName, "local-runtime");
244        assert_text_newtype!(AiProviderId, "provider-001");
245        assert_text_newtype!(AiEndpointName, "chat-primary");
246        assert_eq!(AiProviderName::new("  "), Err(AiProviderError::Empty));
247        Ok(())
248    }
249
250    #[test]
251    fn displays_and_parses_provider_enums() -> Result<(), AiProviderError> {
252        assert_enum_family(AiProviderKind::ALL)?;
253        assert_enum_family(AiEndpointKind::ALL)?;
254        assert_enum_family(AiApiMode::ALL)?;
255        assert_enum_family(AiRateLimitKind::ALL)?;
256        assert_enum_family(AiQuotaKind::ALL)?;
257        assert_enum_family(AiBillingUnit::ALL)?;
258        assert_enum_family(AiRegionKind::ALL)?;
259        assert_eq!(
260            "hosted api".parse::<AiProviderKind>()?,
261            AiProviderKind::HostedApi
262        );
263        Ok(())
264    }
265}