1use std::collections::HashMap;
9
10use crate::text::simple_extensions::{
11 NullabilityHandling as RawNullabilityHandling, Options as RawOptions,
12 ScalarFunction as RawScalarFunction, ScalarFunctionImplsItem as RawImpl, Type as RawType,
13 VariadicBehavior as RawVariadicBehavior, VariadicBehaviorParameterConsistency,
14};
15
16use super::argument::{ArgumentsItem, ArgumentsItemError};
17use super::extensions::TypeContext;
18use super::type_ast::{TypeExpr, TypeParseError};
19use super::types::{ConcreteType, ExtensionTypeError};
20use crate::parse::Parse;
21use thiserror::Error;
22
23#[derive(Debug, Error)]
25pub enum ScalarFunctionError {
26 #[error("Scalar function '{name}' must have at least one implementation")]
28 NoImplementations {
29 name: String,
31 },
32 #[error("Variadic behavior {field} must be a non-negative integer, got {value}")]
34 InvalidVariadicBehavior {
35 field: String,
37 value: f64,
39 },
40 #[error("Variadic min ({min}) must be less than or equal to max ({max})")]
42 VariadicMinGreaterThanMax {
43 min: u32,
45 max: u32,
47 },
48 #[error("Argument error: {0}")]
50 ArgumentError(#[from] ArgumentsItemError),
51 #[error("Type error: {0}")]
53 TypeError(#[from] ExtensionTypeError),
54 #[error("Type parse error: {0}")]
56 TypeParseError(#[from] TypeParseError),
57 #[error("Not yet implemented: {0}")]
59 NotYetImplemented(String),
60}
61
62#[derive(Clone, Debug, PartialEq)]
64pub struct ScalarFunction {
65 pub name: String,
67 pub description: Option<String>,
69 pub impls: Vec<Impl>,
71}
72
73impl ScalarFunction {
74 pub(super) fn from_raw(
76 raw: RawScalarFunction,
77 ctx: &mut TypeContext,
78 ) -> Result<Self, ScalarFunctionError> {
79 if raw.impls.is_empty() {
80 return Err(ScalarFunctionError::NoImplementations { name: raw.name });
81 }
82
83 let impls = raw
84 .impls
85 .into_iter()
86 .map(|impl_| Impl::from_raw(impl_, ctx))
87 .collect::<Result<Vec<_>, _>>()?;
88
89 Ok(ScalarFunction {
90 name: raw.name,
91 description: raw.description,
92 impls,
93 })
94 }
95}
96
97#[derive(Clone, Debug, PartialEq)]
99pub struct Impl {
100 pub args: Vec<ArgumentsItem>,
102 pub options: Options,
104 pub variadic: Option<VariadicBehavior>,
108 pub session_dependent: bool,
112 pub deterministic: bool,
116 pub nullability: NullabilityHandling,
120 pub return_type: ConcreteType,
126 pub implementation: HashMap<String, String>,
130}
131
132impl Impl {
133 pub(super) fn from_raw(
135 raw: RawImpl,
136 ctx: &mut TypeContext,
137 ) -> Result<Self, ScalarFunctionError> {
138 let return_type = match raw.return_.0 {
140 RawType::String(s) => {
141 if s.contains('\n') {
144 return Err(ScalarFunctionError::NotYetImplemented(
145 "Type derivation expressions - issue #449".to_string(),
146 ));
147 }
148 let type_expr = TypeExpr::parse(&s)?;
149 type_expr.visit_references(&mut |name| ctx.linked(name));
150 match ConcreteType::try_from(type_expr) {
151 Ok(concrete) => concrete,
152 Err(ExtensionTypeError::InvalidAnyTypeVariable { .. })
153 | Err(ExtensionTypeError::InvalidParameter(_))
154 | Err(ExtensionTypeError::InvalidParameterKind { .. }) => {
155 return Err(ScalarFunctionError::NotYetImplemented(
158 "Type variables in function signatures - issue #452".to_string(),
159 ));
160 }
161 Err(ExtensionTypeError::UnknownTypeName { name }) => {
162 return Err(ScalarFunctionError::TypeError(
163 ExtensionTypeError::UnknownTypeName { name },
164 ));
165 }
166 Err(e) => return Err(ScalarFunctionError::TypeError(e)),
167 }
168 }
169 RawType::Object(_) => {
170 return Err(ScalarFunctionError::NotYetImplemented(
173 "Struct return types - issue #450".to_string(),
174 ));
175 }
176 };
177
178 let variadic = raw.variadic.map(|v| v.try_into()).transpose()?;
179
180 let args = match raw.args {
181 Some(a) => {
182 a.0.into_iter()
183 .map(|raw_arg| raw_arg.parse(ctx))
184 .collect::<Result<Vec<_>, _>>()?
185 }
186 None => Vec::new(),
187 };
188
189 Ok(Impl {
190 args,
191 options: raw.options.as_ref().map(Options::from).unwrap_or_default(),
192 variadic,
193 session_dependent: raw.session_dependent.map(|b| b.0).unwrap_or(false),
194 deterministic: raw.deterministic.map(|b| b.0).unwrap_or(true),
195 nullability: raw
196 .nullability
197 .map(Into::into)
198 .unwrap_or(NullabilityHandling::Mirror),
199 return_type,
200 implementation: raw
201 .implementation
202 .map(|i| i.0.into_iter().collect())
203 .unwrap_or_default(),
204 })
205 }
206}
207
208#[derive(Clone, Debug, PartialEq)]
210pub struct VariadicBehavior {
211 pub min: u32,
213 pub max: Option<u32>,
215 pub parameter_consistency: Option<ParameterConsistency>,
224}
225
226#[derive(Clone, Debug, PartialEq)]
233pub enum ParameterConsistency {
234 Consistent,
238 Inconsistent,
243}
244
245impl From<VariadicBehaviorParameterConsistency> for ParameterConsistency {
246 fn from(raw: VariadicBehaviorParameterConsistency) -> Self {
247 match raw {
248 VariadicBehaviorParameterConsistency::Consistent => ParameterConsistency::Consistent,
249 VariadicBehaviorParameterConsistency::Inconsistent => {
250 ParameterConsistency::Inconsistent
251 }
252 }
253 }
254}
255
256impl TryFrom<RawVariadicBehavior> for VariadicBehavior {
257 type Error = ScalarFunctionError;
258
259 fn try_from(raw: RawVariadicBehavior) -> Result<Self, Self::Error> {
260 fn parse_bound(value: f64, field: &str) -> Result<u32, ScalarFunctionError> {
261 if value < 0.0 || value.fract() != 0.0 {
262 return Err(ScalarFunctionError::InvalidVariadicBehavior {
263 field: field.to_string(),
264 value,
265 });
266 }
267 Ok(value as u32)
268 }
269
270 let min = raw
271 .min
272 .map(|v| parse_bound(v, "min"))
273 .transpose()?
274 .unwrap_or(0);
275 let max = raw.max.map(|v| parse_bound(v, "max")).transpose()?;
276
277 if let Some(max_val) = max {
278 if min > max_val {
279 return Err(ScalarFunctionError::VariadicMinGreaterThanMax { min, max: max_val });
280 }
281 }
282
283 Ok(VariadicBehavior {
284 min,
285 max,
286 parameter_consistency: raw.parameter_consistency.map(Into::into),
287 })
288 }
289}
290
291#[derive(Clone, Debug, PartialEq)]
293pub enum NullabilityHandling {
294 Mirror,
296 DeclaredOutput,
298 Discrete,
300}
301
302impl From<RawNullabilityHandling> for NullabilityHandling {
303 fn from(raw: RawNullabilityHandling) -> Self {
304 match raw {
305 RawNullabilityHandling::Mirror => NullabilityHandling::Mirror,
306 RawNullabilityHandling::DeclaredOutput => NullabilityHandling::DeclaredOutput,
307 RawNullabilityHandling::Discrete => NullabilityHandling::Discrete,
308 }
309 }
310}
311
312#[derive(Clone, Debug, Default, PartialEq)]
314pub struct Options(pub HashMap<String, Vec<String>>);
315
316impl From<&RawOptions> for Options {
317 fn from(raw: &RawOptions) -> Self {
318 Options(
319 raw.0
320 .iter()
321 .map(|(k, v)| (k.clone(), v.values.clone()))
322 .collect(),
323 )
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_variadic_invalid_values() {
333 let invalid_cases = vec![
334 (Some(-1.0), None, "negative min"),
335 (None, Some(-2.5), "negative max"),
336 (Some(7.2), None, "non-integer min"),
337 (None, Some(3.5), "non-integer max"),
338 (Some(5.0), Some(3.0), "min greater than max"),
339 ];
340
341 for (min, max, description) in invalid_cases {
342 let raw = RawVariadicBehavior {
343 min,
344 max,
345 parameter_consistency: None,
346 };
347 assert!(
348 VariadicBehavior::try_from(raw).is_err(),
349 "expected error for {}",
350 description
351 );
352 }
353 }
354
355 #[test]
356 fn test_variadic_valid() {
357 let raw = RawVariadicBehavior {
358 min: Some(1.0),
359 max: Some(5.0),
360 parameter_consistency: None,
361 };
362 let result = VariadicBehavior::try_from(raw).unwrap();
363 assert_eq!(result.min, 1);
364 assert_eq!(result.max, Some(5));
365 }
366
367 #[test]
368 fn test_variadic_none_values() {
369 let raw = RawVariadicBehavior {
370 min: None,
371 max: None,
372 parameter_consistency: None,
373 };
374 let result = VariadicBehavior::try_from(raw).unwrap();
375 assert_eq!(result.min, 0);
376 assert_eq!(result.max, None);
377 }
378
379 #[test]
380 fn test_no_implementations_error() {
381 use crate::text::simple_extensions::ScalarFunction as RawScalarFunction;
382
383 let raw = RawScalarFunction {
384 name: "empty_function".to_string(),
385 description: None,
386 impls: vec![],
387 };
388
389 let mut ctx = super::super::extensions::TypeContext::default();
390 let result = ScalarFunction::from_raw(raw, &mut ctx);
391
392 assert!(matches!(
393 result,
394 Err(ScalarFunctionError::NoImplementations { name })
395 if name == "empty_function"
396 ));
397 }
398
399 #[test]
400 fn test_scalar_function_with_single_impl() {
401 use crate::text::simple_extensions::{
402 ReturnValue, ScalarFunction as RawScalarFunction, ScalarFunctionImplsItem, Type,
403 };
404
405 let raw = RawScalarFunction {
406 name: "add".to_string(),
407 description: Some("Addition function".to_string()),
408 impls: vec![ScalarFunctionImplsItem {
409 args: None,
410 options: None,
411 variadic: None,
412 session_dependent: None,
413 deterministic: None,
414 nullability: None,
415 return_: ReturnValue(Type::String("i32".to_string())),
416 implementation: None,
417 }],
418 };
419
420 let mut ctx = super::super::extensions::TypeContext::default();
421 let result = ScalarFunction::from_raw(raw, &mut ctx).unwrap();
422
423 assert_eq!(result.name, "add");
424 assert_eq!(result.description, Some("Addition function".to_string()));
425 assert_eq!(result.impls.len(), 1);
426
427 use super::super::types::{BasicBuiltinType, ConcreteTypeKind};
429 let return_type = &result.impls[0].return_type;
430 assert!(!return_type.nullable, "i32 should not be nullable");
431 assert!(matches!(
432 &return_type.kind,
433 ConcreteTypeKind::Builtin(BasicBuiltinType::I32)
434 ));
435 }
436
437 #[test]
438 fn test_options_conversion() {
439 use crate::text::simple_extensions::{Options as RawOptions, OptionsValue};
440 use indexmap::IndexMap;
441
442 let mut raw_map = IndexMap::new();
443 raw_map.insert(
444 "overflow".to_string(),
445 OptionsValue {
446 values: vec!["SILENT".to_string(), "ERROR".to_string()],
447 description: None,
448 },
449 );
450
451 let raw = RawOptions(raw_map);
452 let options = Options::from(&raw);
453
454 assert_eq!(options.0.len(), 1);
455 assert_eq!(
456 options.0.get("overflow").unwrap(),
457 &vec!["SILENT".to_string(), "ERROR".to_string()]
458 );
459 }
460}