substrait/parse/text/simple_extensions/
scalar_functions.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Scalar function definitions with validated signatures and resolved types.
4//!
5//! This module provides typed wrappers around scalar functions parsed from extension
6//! YAML files, validating constraints and resolving type strings to concrete types.
7
8use std::collections::HashMap;
9
10use crate::text::simple_extensions::{
11    NullabilityHandling as RawNullabilityHandling, Options as RawOptions,
12    ScalarFunction as RawScalarFunction, ScalarFunctionImplsItem as RawImpl, Type as RawType,
13    VariadicBehavior as RawVariadicBehavior, VariadicBehaviorParameterConsistency,
14};
15
16use super::argument::{ArgumentsItem, ArgumentsItemError};
17use super::extensions::TypeContext;
18use super::type_ast::{TypeExpr, TypeParseError};
19use super::types::{ConcreteType, ExtensionTypeError};
20use crate::parse::Parse;
21use thiserror::Error;
22
23/// Errors that can occur when parsing scalar functions
24#[derive(Debug, Error)]
25pub enum ScalarFunctionError {
26    /// Scalar function has no implementations
27    #[error("Scalar function '{name}' must have at least one implementation")]
28    NoImplementations {
29        /// The function name
30        name: String,
31    },
32    /// Invalid variadic behavior
33    #[error("Variadic behavior {field} must be a non-negative integer, got {value}")]
34    InvalidVariadicBehavior {
35        /// The field that was invalid (min or max)
36        field: String,
37        /// The invalid value
38        value: f64,
39    },
40    /// Variadic min is greater than max
41    #[error("Variadic min ({min}) must be less than or equal to max ({max})")]
42    VariadicMinGreaterThanMax {
43        /// The minimum value
44        min: u32,
45        /// The maximum value
46        max: u32,
47    },
48    /// Error parsing function argument
49    #[error("Argument error: {0}")]
50    ArgumentError(#[from] ArgumentsItemError),
51    /// Error parsing type in function signature
52    #[error("Type error: {0}")]
53    TypeError(#[from] ExtensionTypeError),
54    /// Error parsing type expression
55    #[error("Type parse error: {0}")]
56    TypeParseError(#[from] TypeParseError),
57    /// Feature not yet implemented
58    #[error("Not yet implemented: {0}")]
59    NotYetImplemented(String),
60}
61
62/// A validated scalar function definition with one or more implementations
63#[derive(Clone, Debug, PartialEq)]
64pub struct ScalarFunction {
65    /// Function name
66    pub name: String,
67    /// Human-readable description
68    pub description: Option<String>,
69    /// Function implementations (overloads)
70    pub impls: Vec<Impl>,
71}
72
73impl ScalarFunction {
74    /// Parse a scalar function from raw YAML, resolving types with the provided context
75    pub(super) fn from_raw(
76        raw: RawScalarFunction,
77        ctx: &mut TypeContext,
78    ) -> Result<Self, ScalarFunctionError> {
79        if raw.impls.is_empty() {
80            return Err(ScalarFunctionError::NoImplementations { name: raw.name });
81        }
82
83        let impls = raw
84            .impls
85            .into_iter()
86            .map(|impl_| Impl::from_raw(impl_, ctx))
87            .collect::<Result<Vec<_>, _>>()?;
88
89        Ok(ScalarFunction {
90            name: raw.name,
91            description: raw.description,
92            impls,
93        })
94    }
95}
96
97/// A single function implementation (overload) with signature and resolved types
98#[derive(Clone, Debug, PartialEq)]
99pub struct Impl {
100    /// Function arguments with types and optional names/descriptions
101    pub args: Vec<ArgumentsItem>,
102    /// Configurable function options (e.g., overflow behavior, rounding modes)
103    pub options: Options,
104    /// Variadic argument behavior.
105    ///
106    /// `None` indicates the function is not variadic.
107    pub variadic: Option<VariadicBehavior>,
108    /// Whether the function output depends on session state (e.g., timezone, locale).
109    ///
110    /// Defaults to `false` per the Substrait spec.
111    pub session_dependent: bool,
112    /// Whether the function is deterministic (same inputs always produce same output).
113    ///
114    /// Defaults to `true` per the Substrait spec.
115    pub deterministic: bool,
116    /// How the function handles null inputs and produces nullable outputs.
117    ///
118    /// Defaults to [`NullabilityHandling::Mirror`] per the Substrait spec.
119    pub nullability: NullabilityHandling,
120    /// Return type resolved to a concrete type
121    ///
122    /// The raw YAML type string is parsed and validated. Only concrete types
123    /// (without type variables) are supported; functions with type variables
124    /// are skipped in this basic implementation.
125    pub return_type: ConcreteType,
126    /// Language-specific implementation code (e.g., SQL, C++, Python)
127    ///
128    /// Maps language identifiers to implementation source code snippets.
129    pub implementation: HashMap<String, String>,
130}
131
132impl Impl {
133    /// Parse an implementation from raw YAML, resolving types with the provided context
134    pub(super) fn from_raw(
135        raw: RawImpl,
136        ctx: &mut TypeContext,
137    ) -> Result<Self, ScalarFunctionError> {
138        // Parse and validate the return type
139        let return_type = match raw.return_.0 {
140            RawType::String(s) => {
141                // Multiline strings indicate type derivation expressions
142                // See: https://github.com/substrait-io/substrait-rs/issues/449
143                if s.contains('\n') {
144                    return Err(ScalarFunctionError::NotYetImplemented(
145                        "Type derivation expressions - issue #449".to_string(),
146                    ));
147                }
148                let type_expr = TypeExpr::parse(&s)?;
149                type_expr.visit_references(&mut |name| ctx.linked(name));
150                match ConcreteType::try_from(type_expr) {
151                    Ok(concrete) => concrete,
152                    Err(ExtensionTypeError::InvalidAnyTypeVariable { .. })
153                    | Err(ExtensionTypeError::InvalidParameter(_))
154                    | Err(ExtensionTypeError::InvalidParameterKind { .. }) => {
155                        // Type has type/parameter variables (any1, L1, P, etc.) - not yet supported
156                        // See: https://github.com/substrait-io/substrait-rs/issues/452
157                        return Err(ScalarFunctionError::NotYetImplemented(
158                            "Type variables in function signatures - issue #452".to_string(),
159                        ));
160                    }
161                    Err(ExtensionTypeError::UnknownTypeName { name }) => {
162                        return Err(ScalarFunctionError::TypeError(
163                            ExtensionTypeError::UnknownTypeName { name },
164                        ));
165                    }
166                    Err(e) => return Err(ScalarFunctionError::TypeError(e)),
167                }
168            }
169            RawType::Object(_) => {
170                // Struct return types (YAML syntactic sugar) are not yet supported
171                // See: https://github.com/substrait-io/substrait-rs/issues/450
172                return Err(ScalarFunctionError::NotYetImplemented(
173                    "Struct return types - issue #450".to_string(),
174                ));
175            }
176        };
177
178        let variadic = raw.variadic.map(|v| v.try_into()).transpose()?;
179
180        let args = match raw.args {
181            Some(a) => {
182                a.0.into_iter()
183                    .map(|raw_arg| raw_arg.parse(ctx))
184                    .collect::<Result<Vec<_>, _>>()?
185            }
186            None => Vec::new(),
187        };
188
189        Ok(Impl {
190            args,
191            options: raw.options.as_ref().map(Options::from).unwrap_or_default(),
192            variadic,
193            session_dependent: raw.session_dependent.map(|b| b.0).unwrap_or(false),
194            deterministic: raw.deterministic.map(|b| b.0).unwrap_or(true),
195            nullability: raw
196                .nullability
197                .map(Into::into)
198                .unwrap_or(NullabilityHandling::Mirror),
199            return_type,
200            implementation: raw
201                .implementation
202                .map(|i| i.0.into_iter().collect())
203                .unwrap_or_default(),
204        })
205    }
206}
207
208/// Validated variadic behavior with min/max constraints
209#[derive(Clone, Debug, PartialEq)]
210pub struct VariadicBehavior {
211    /// Minimum number of arguments
212    pub min: u32,
213    /// Maximum number of arguments (unlimited when None)
214    pub max: Option<u32>,
215    /// Whether all variadic parameters must have the same type.
216    ///
217    /// `None` when the parameter is not specified in the YAML. We cannot assume a default
218    /// because the Substrait spec does not define default behavior for missing values
219    /// (see <https://github.com/substrait-io/substrait/issues/928>).
220    ///
221    /// TODO: Once the spec defines default behavior, apply it and change this to a non-Option
222    /// type (see issue #454).
223    pub parameter_consistency: Option<ParameterConsistency>,
224}
225
226/// Specifies whether variadic parameters must have consistent types.
227///
228/// When a function's last argument is variadic with a type parameter (e.g., `fn(A, B, C...)`),
229/// this controls type binding for the variadic arguments.
230///
231/// See: <https://github.com/substrait-io/substrait/issues/928>
232#[derive(Clone, Debug, PartialEq)]
233pub enum ParameterConsistency {
234    /// All variadic arguments must have the same concrete type.
235    ///
236    /// For example, if `C` binds to `i32`, all variadic arguments must be `i32`.
237    Consistent,
238    /// Each variadic argument can have a different type.
239    ///
240    /// Each instance of the variadic parameter can bind to different types
241    /// within the constraints of the type parameter.
242    Inconsistent,
243}
244
245impl From<VariadicBehaviorParameterConsistency> for ParameterConsistency {
246    fn from(raw: VariadicBehaviorParameterConsistency) -> Self {
247        match raw {
248            VariadicBehaviorParameterConsistency::Consistent => ParameterConsistency::Consistent,
249            VariadicBehaviorParameterConsistency::Inconsistent => {
250                ParameterConsistency::Inconsistent
251            }
252        }
253    }
254}
255
256impl TryFrom<RawVariadicBehavior> for VariadicBehavior {
257    type Error = ScalarFunctionError;
258
259    fn try_from(raw: RawVariadicBehavior) -> Result<Self, Self::Error> {
260        fn parse_bound(value: f64, field: &str) -> Result<u32, ScalarFunctionError> {
261            if value < 0.0 || value.fract() != 0.0 {
262                return Err(ScalarFunctionError::InvalidVariadicBehavior {
263                    field: field.to_string(),
264                    value,
265                });
266            }
267            Ok(value as u32)
268        }
269
270        let min = raw
271            .min
272            .map(|v| parse_bound(v, "min"))
273            .transpose()?
274            .unwrap_or(0);
275        let max = raw.max.map(|v| parse_bound(v, "max")).transpose()?;
276
277        if let Some(max_val) = max {
278            if min > max_val {
279                return Err(ScalarFunctionError::VariadicMinGreaterThanMax { min, max: max_val });
280            }
281        }
282
283        Ok(VariadicBehavior {
284            min,
285            max,
286            parameter_consistency: raw.parameter_consistency.map(Into::into),
287        })
288    }
289}
290
291/// How a function handles null inputs and produces nullable outputs
292#[derive(Clone, Debug, PartialEq)]
293pub enum NullabilityHandling {
294    /// Nullability of output mirrors the nullability of input(s)
295    Mirror,
296    /// Function explicitly declares the nullability of its output
297    DeclaredOutput,
298    /// Function handles nulls in a custom way per implementation
299    Discrete,
300}
301
302impl From<RawNullabilityHandling> for NullabilityHandling {
303    fn from(raw: RawNullabilityHandling) -> Self {
304        match raw {
305            RawNullabilityHandling::Mirror => NullabilityHandling::Mirror,
306            RawNullabilityHandling::DeclaredOutput => NullabilityHandling::DeclaredOutput,
307            RawNullabilityHandling::Discrete => NullabilityHandling::Discrete,
308        }
309    }
310}
311
312/// Validated function options
313#[derive(Clone, Debug, Default, PartialEq)]
314pub struct Options(pub HashMap<String, Vec<String>>);
315
316impl From<&RawOptions> for Options {
317    fn from(raw: &RawOptions) -> Self {
318        Options(
319            raw.0
320                .iter()
321                .map(|(k, v)| (k.clone(), v.values.clone()))
322                .collect(),
323        )
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_variadic_invalid_values() {
333        let invalid_cases = vec![
334            (Some(-1.0), None, "negative min"),
335            (None, Some(-2.5), "negative max"),
336            (Some(7.2), None, "non-integer min"),
337            (None, Some(3.5), "non-integer max"),
338            (Some(5.0), Some(3.0), "min greater than max"),
339        ];
340
341        for (min, max, description) in invalid_cases {
342            let raw = RawVariadicBehavior {
343                min,
344                max,
345                parameter_consistency: None,
346            };
347            assert!(
348                VariadicBehavior::try_from(raw).is_err(),
349                "expected error for {}",
350                description
351            );
352        }
353    }
354
355    #[test]
356    fn test_variadic_valid() {
357        let raw = RawVariadicBehavior {
358            min: Some(1.0),
359            max: Some(5.0),
360            parameter_consistency: None,
361        };
362        let result = VariadicBehavior::try_from(raw).unwrap();
363        assert_eq!(result.min, 1);
364        assert_eq!(result.max, Some(5));
365    }
366
367    #[test]
368    fn test_variadic_none_values() {
369        let raw = RawVariadicBehavior {
370            min: None,
371            max: None,
372            parameter_consistency: None,
373        };
374        let result = VariadicBehavior::try_from(raw).unwrap();
375        assert_eq!(result.min, 0);
376        assert_eq!(result.max, None);
377    }
378
379    #[test]
380    fn test_no_implementations_error() {
381        use crate::text::simple_extensions::ScalarFunction as RawScalarFunction;
382
383        let raw = RawScalarFunction {
384            name: "empty_function".to_string(),
385            description: None,
386            impls: vec![],
387        };
388
389        let mut ctx = super::super::extensions::TypeContext::default();
390        let result = ScalarFunction::from_raw(raw, &mut ctx);
391
392        assert!(matches!(
393            result,
394            Err(ScalarFunctionError::NoImplementations { name })
395            if name == "empty_function"
396        ));
397    }
398
399    #[test]
400    fn test_scalar_function_with_single_impl() {
401        use crate::text::simple_extensions::{
402            ReturnValue, ScalarFunction as RawScalarFunction, ScalarFunctionImplsItem, Type,
403        };
404
405        let raw = RawScalarFunction {
406            name: "add".to_string(),
407            description: Some("Addition function".to_string()),
408            impls: vec![ScalarFunctionImplsItem {
409                args: None,
410                options: None,
411                variadic: None,
412                session_dependent: None,
413                deterministic: None,
414                nullability: None,
415                return_: ReturnValue(Type::String("i32".to_string())),
416                implementation: None,
417            }],
418        };
419
420        let mut ctx = super::super::extensions::TypeContext::default();
421        let result = ScalarFunction::from_raw(raw, &mut ctx).unwrap();
422
423        assert_eq!(result.name, "add");
424        assert_eq!(result.description, Some("Addition function".to_string()));
425        assert_eq!(result.impls.len(), 1);
426
427        // Verify return type is properly parsed to ConcreteType
428        use super::super::types::{BasicBuiltinType, ConcreteTypeKind};
429        let return_type = &result.impls[0].return_type;
430        assert!(!return_type.nullable, "i32 should not be nullable");
431        assert!(matches!(
432            &return_type.kind,
433            ConcreteTypeKind::Builtin(BasicBuiltinType::I32)
434        ));
435    }
436
437    #[test]
438    fn test_options_conversion() {
439        use crate::text::simple_extensions::{Options as RawOptions, OptionsValue};
440        use indexmap::IndexMap;
441
442        let mut raw_map = IndexMap::new();
443        raw_map.insert(
444            "overflow".to_string(),
445            OptionsValue {
446                values: vec!["SILENT".to_string(), "ERROR".to_string()],
447                description: None,
448            },
449        );
450
451        let raw = RawOptions(raw_map);
452        let options = Options::from(&raw);
453
454        assert_eq!(options.0.len(), 1);
455        assert_eq!(
456            options.0.get("overflow").unwrap(),
457            &vec!["SILENT".to_string(), "ERROR".to_string()]
458        );
459    }
460}