substrait/parse/text/simple_extensions/
types.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Concrete type system for function validation in the registry.
4//!
5//! This module provides a clean, type-safe wrapper around Substrait extension types,
6//! separating function signature patterns from concrete argument types.
7
8use super::TypeExpr;
9use super::argument::{
10    EnumOptions as ParsedEnumOptions, EnumOptionsError as ParsedEnumOptionsError,
11};
12use super::extensions::TypeContext;
13use super::type_ast::TypeExprParam;
14use crate::parse::Parse;
15use crate::parse::text::simple_extensions::type_ast::TypeParseError;
16use crate::text::simple_extensions::{
17    EnumOptions as RawEnumOptions, SimpleExtensionsTypesItem, Type as RawType, TypeParamDefs,
18    TypeParamDefsItem, TypeParamDefsItemType,
19};
20use indexmap::IndexMap;
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23use std::fmt;
24use std::ops::RangeInclusive;
25use thiserror::Error;
26
27/// Write a sequence of items separated by a separator, with a start and end
28/// delimiter.
29///
30/// Start and end are only included in the output if there is at least one item.
31fn write_separated<I, T>(
32    f: &mut fmt::Formatter<'_>,
33    iter: I,
34    start: &str,
35    end: &str,
36    sep: &str,
37) -> fmt::Result
38where
39    I: IntoIterator<Item = T>,
40    T: fmt::Display,
41{
42    let mut it = iter.into_iter();
43    if let Some(first) = it.next() {
44        f.write_str(start)?;
45        write!(f, "{first}")?;
46        for item in it {
47            f.write_str(sep)?;
48            write!(f, "{item}")?;
49        }
50        f.write_str(end)
51    } else {
52        Ok(())
53    }
54}
55
56/// A pair of a key and a value, separated by a separator. For display purposes.
57struct KeyValueDisplay<K, V>(K, V, &'static str);
58
59impl<K, V> fmt::Display for KeyValueDisplay<K, V>
60where
61    K: fmt::Display,
62    V: fmt::Display,
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(f, "{}{}{}", self.0, self.2, self.1)
66    }
67}
68
69/// Non-recursive, built-in Substrait types: types with no parameters (primitive
70/// types), or simple with only primitive / literal parameters.
71#[derive(Clone, Debug, PartialEq)]
72pub enum BasicBuiltinType {
73    /// Boolean type - `bool`
74    Boolean,
75    /// 8-bit signed integer - `i8`
76    I8,
77    /// 16-bit signed integer - `i16`
78    I16,
79    /// 32-bit signed integer - `i32`
80    I32,
81    /// 64-bit signed integer - `i64`
82    I64,
83    /// 32-bit floating point - `fp32`
84    Fp32,
85    /// 64-bit floating point - `fp64`
86    Fp64,
87    /// Variable-length string - `string`
88    String,
89    /// Variable-length binary data - `binary`
90    Binary,
91    /// Naive Timestamp
92    Timestamp,
93    /// Timestamp with time zone - `timestamp_tz`
94    TimestampTz,
95    /// Calendar date - `date`
96    Date,
97    /// Time of day - `time`
98    Time,
99    /// Year-month interval - `interval_year`
100    IntervalYear,
101    /// 128-bit UUID - `uuid`
102    Uuid,
103    /// Fixed-length character string: `FIXEDCHAR<L>`
104    FixedChar {
105        /// Length (number of characters), must be >= 1
106        length: i32,
107    },
108    /// Variable-length character string: `VARCHAR<L>`
109    VarChar {
110        /// Maximum length (number of characters), must be >= 1
111        length: i32,
112    },
113    /// Fixed-length binary data: `FIXEDBINARY<L>`
114    FixedBinary {
115        /// Length (number of bytes), must be >= 1
116        length: i32,
117    },
118    /// Fixed-point decimal: `DECIMAL<P, S>`
119    Decimal {
120        /// Precision (total digits), <= 38
121        precision: i32,
122        /// Scale (digits after decimal point), 0 <= S <= P
123        scale: i32,
124    },
125    /// Time with sub-second precision: `PRECISIONTIME<P>`
126    PrecisionTime {
127        /// Sub-second precision digits (0-12: seconds to picoseconds)
128        precision: i32,
129    },
130    /// Timestamp with sub-second precision: `PRECISIONTIMESTAMP<P>`
131    PrecisionTimestamp {
132        /// Sub-second precision digits (0-12: seconds to picoseconds)
133        precision: i32,
134    },
135    /// Timezone-aware timestamp with precision: `PRECISIONTIMESTAMPTZ<P>`
136    PrecisionTimestampTz {
137        /// Sub-second precision digits (0-12: seconds to picoseconds)
138        precision: i32,
139    },
140    /// Day-time interval: `INTERVAL_DAY<P>`
141    IntervalDay {
142        /// Sub-second precision digits (0-9: seconds to nanoseconds)
143        precision: i32,
144    },
145    /// Compound interval: `INTERVAL_COMPOUND<P>`
146    IntervalCompound {
147        /// Sub-second precision digits
148        precision: i32,
149    },
150}
151
152impl BasicBuiltinType {
153    /// Check if a string is a valid name for a builtin scalar type
154    pub fn is_name(name: &str) -> bool {
155        let lower = name.to_ascii_lowercase();
156        primitive_builtin(&lower).is_some()
157            || matches!(
158                lower.as_str(),
159                "fixedchar"
160                    | "varchar"
161                    | "fixedbinary"
162                    | "decimal"
163                    | "precisiontime"
164                    | "precisiontimestamp"
165                    | "precisiontimestamptz"
166                    | "interval_day"
167                    | "interval_compound"
168            )
169    }
170}
171
172impl fmt::Display for BasicBuiltinType {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        match self {
175            BasicBuiltinType::Boolean => f.write_str("bool"),
176            BasicBuiltinType::I8 => f.write_str("i8"),
177            BasicBuiltinType::I16 => f.write_str("i16"),
178            BasicBuiltinType::I32 => f.write_str("i32"),
179            BasicBuiltinType::I64 => f.write_str("i64"),
180            BasicBuiltinType::Fp32 => f.write_str("fp32"),
181            BasicBuiltinType::Fp64 => f.write_str("fp64"),
182            BasicBuiltinType::String => f.write_str("string"),
183            BasicBuiltinType::Binary => f.write_str("binary"),
184            BasicBuiltinType::Timestamp => f.write_str("timestamp"),
185            BasicBuiltinType::TimestampTz => f.write_str("timestamp_tz"),
186            BasicBuiltinType::Date => f.write_str("date"),
187            BasicBuiltinType::Time => f.write_str("time"),
188            BasicBuiltinType::IntervalYear => f.write_str("interval_year"),
189            BasicBuiltinType::Uuid => f.write_str("uuid"),
190            BasicBuiltinType::FixedChar { length } => write!(f, "FIXEDCHAR<{length}>"),
191            BasicBuiltinType::VarChar { length } => write!(f, "VARCHAR<{length}>"),
192            BasicBuiltinType::FixedBinary { length } => write!(f, "FIXEDBINARY<{length}>"),
193            BasicBuiltinType::Decimal { precision, scale } => {
194                write!(f, "DECIMAL<{precision}, {scale}>")
195            }
196            BasicBuiltinType::PrecisionTime { precision } => {
197                write!(f, "PRECISIONTIME<{precision}>")
198            }
199            BasicBuiltinType::PrecisionTimestamp { precision } => {
200                write!(f, "PRECISIONTIMESTAMP<{precision}>")
201            }
202            BasicBuiltinType::PrecisionTimestampTz { precision } => {
203                write!(f, "PRECISIONTIMESTAMPTZ<{precision}>")
204            }
205            BasicBuiltinType::IntervalDay { precision } => write!(f, "INTERVAL_DAY<{precision}>"),
206            BasicBuiltinType::IntervalCompound { precision } => {
207                write!(f, "INTERVAL_COMPOUND<{precision}>")
208            }
209        }
210    }
211}
212
213/// A parameter, used in parameterized types
214#[derive(Clone, Debug, PartialEq)]
215pub enum TypeParameter {
216    /// Integer parameter (e.g., precision, scale)
217    Integer(i64),
218    /// Type parameter (nested type)
219    Type(ConcreteType),
220    // TODO: Add support for other type parameters, as described in
221    // https://github.com/substrait-io/substrait/blob/35101020d961eda48f8dd1aafbc794c9e5cac077/proto/substrait/type.proto#L250-L265
222}
223
224impl fmt::Display for TypeParameter {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        match self {
227            TypeParameter::Integer(i) => write!(f, "{i}"),
228            TypeParameter::Type(t) => write!(f, "{t}"),
229        }
230    }
231}
232
233/// Check if a name corresponds to any built-in type (scalar or container)
234pub fn is_builtin_type_name(name: &str) -> bool {
235    let lower = name.to_ascii_lowercase();
236    BasicBuiltinType::is_name(&lower) || matches!(lower.as_str(), "list" | "map" | "struct")
237}
238
239/// Parse a primitive (no type parameters) builtin type name
240fn primitive_builtin(lower_name: &str) -> Option<BasicBuiltinType> {
241    match lower_name {
242        "bool" => Some(BasicBuiltinType::Boolean),
243        "i8" => Some(BasicBuiltinType::I8),
244        "i16" => Some(BasicBuiltinType::I16),
245        "i32" => Some(BasicBuiltinType::I32),
246        "i64" => Some(BasicBuiltinType::I64),
247        "fp32" => Some(BasicBuiltinType::Fp32),
248        "fp64" => Some(BasicBuiltinType::Fp64),
249        "string" => Some(BasicBuiltinType::String),
250        "binary" => Some(BasicBuiltinType::Binary),
251        "timestamp" => Some(BasicBuiltinType::Timestamp),
252        "timestamp_tz" => Some(BasicBuiltinType::TimestampTz),
253        "date" => Some(BasicBuiltinType::Date),
254        "time" => Some(BasicBuiltinType::Time),
255        "interval_year" => Some(BasicBuiltinType::IntervalYear),
256        "uuid" => Some(BasicBuiltinType::Uuid),
257        _ => None,
258    }
259}
260
261/// Parameter type information for type definitions
262#[derive(Clone, Debug, PartialEq)]
263pub enum ParameterConstraint {
264    /// Data type parameter
265    DataType,
266    /// Integer parameter with range constraints
267    Integer {
268        /// Minimum value (inclusive), if specified
269        min: Option<i64>,
270        /// Maximum value (inclusive), if specified
271        max: Option<i64>,
272    },
273    /// Enumeration parameter
274    Enumeration {
275        /// Valid enumeration values (validated, deduplicated)
276        options: ParsedEnumOptions,
277    },
278    /// Boolean parameter
279    Boolean,
280    /// String parameter
281    String,
282}
283
284impl ParameterConstraint {
285    /// Convert back to raw TypeParamDefsItemType
286    fn raw_type(&self) -> TypeParamDefsItemType {
287        match self {
288            ParameterConstraint::DataType => TypeParamDefsItemType::DataType,
289            ParameterConstraint::Boolean => TypeParamDefsItemType::Boolean,
290            ParameterConstraint::Integer { .. } => TypeParamDefsItemType::Integer,
291            ParameterConstraint::Enumeration { .. } => TypeParamDefsItemType::Enumeration,
292            ParameterConstraint::String => TypeParamDefsItemType::String,
293        }
294    }
295
296    /// Extract raw bounds for integer parameters (min, max)
297    fn raw_bounds(&self) -> (Option<f64>, Option<f64>) {
298        match self {
299            ParameterConstraint::Integer { min, max } => {
300                (min.map(|i| i as f64), max.map(|i| i as f64))
301            }
302            _ => (None, None),
303        }
304    }
305
306    /// Extract raw enum options for enumeration parameters
307    fn raw_options(&self) -> Option<RawEnumOptions> {
308        match self {
309            ParameterConstraint::Enumeration { options } => Some(options.clone().into()),
310            _ => None,
311        }
312    }
313
314    /// Check if a parameter value is valid for this parameter type
315    pub fn is_valid_value(&self, value: &Value) -> bool {
316        match (self, value) {
317            (ParameterConstraint::DataType, Value::String(_)) => true,
318            (ParameterConstraint::Integer { min, max }, Value::Number(n)) => {
319                if let Some(i) = n.as_i64() {
320                    min.is_none_or(|min_val| i >= min_val) && max.is_none_or(|max_val| i <= max_val)
321                } else {
322                    false
323                }
324            }
325            (ParameterConstraint::Enumeration { options }, Value::String(s)) => options.contains(s),
326            (ParameterConstraint::Boolean, Value::Bool(_)) => true,
327            (ParameterConstraint::String, Value::String(_)) => true,
328            _ => false,
329        }
330    }
331
332    fn from_raw(
333        t: TypeParamDefsItemType,
334        opts: Option<RawEnumOptions>,
335        min: Option<f64>,
336        max: Option<f64>,
337    ) -> Result<Self, TypeParamError> {
338        Ok(match t {
339            TypeParamDefsItemType::DataType => Self::DataType,
340            TypeParamDefsItemType::Boolean => Self::Boolean,
341            TypeParamDefsItemType::Integer => {
342                match (min, max) {
343                    (Some(min_f), _) if min_f.fract() != 0.0 => {
344                        return Err(TypeParamError::InvalidIntegerBounds { min, max });
345                    }
346                    (_, Some(max_f)) if max_f.fract() != 0.0 => {
347                        return Err(TypeParamError::InvalidIntegerBounds { min, max });
348                    }
349                    _ => (),
350                }
351
352                let min_i = min.map(|v| v as i64);
353                let max_i = max.map(|v| v as i64);
354                Self::Integer {
355                    min: min_i,
356                    max: max_i,
357                }
358            }
359            TypeParamDefsItemType::Enumeration => {
360                let options: ParsedEnumOptions =
361                    opts.ok_or(TypeParamError::MissingEnumOptions)?.try_into()?;
362                Self::Enumeration { options }
363            }
364            TypeParamDefsItemType::String => Self::String,
365        })
366    }
367}
368
369/// A validated type parameter with name and constraints
370#[derive(Clone, Debug, PartialEq)]
371pub struct TypeParam {
372    /// Parameter name (e.g., "K" for a type variable)
373    pub name: String,
374    /// Parameter type constraints
375    pub param_type: ParameterConstraint,
376    /// Human-readable description
377    pub description: Option<String>,
378}
379
380impl TypeParam {
381    /// Create a new type parameter
382    pub fn new(name: String, param_type: ParameterConstraint, description: Option<String>) -> Self {
383        Self {
384            name,
385            param_type,
386            description,
387        }
388    }
389
390    /// Check if a parameter value is valid
391    pub fn is_valid_value(&self, value: &Value) -> bool {
392        self.param_type.is_valid_value(value)
393    }
394}
395
396impl TryFrom<TypeParamDefsItem> for TypeParam {
397    type Error = TypeParamError;
398
399    fn try_from(item: TypeParamDefsItem) -> Result<Self, Self::Error> {
400        let name = item.name.ok_or(TypeParamError::MissingName)?;
401        let param_type =
402            ParameterConstraint::from_raw(item.type_, item.options, item.min, item.max)?;
403
404        Ok(Self {
405            name,
406            param_type,
407            description: item.description,
408        })
409    }
410}
411
412/// Error types for extension type validation
413#[derive(Debug, Error, PartialEq)]
414pub enum ExtensionTypeError {
415    /// Extension type name is invalid
416    #[error("{0}")]
417    InvalidName(#[from] InvalidTypeName),
418    /// Any type variable is invalid for concrete types
419    #[error("Any type variable is invalid for concrete types: any{}{}", id, nullability.then_some("?").unwrap_or(""))]
420    InvalidAnyTypeVariable {
421        /// The type variable name
422        id: u32,
423        /// Whether the type variable is nullable
424        nullability: bool,
425    },
426    /// Parameter validation failed
427    #[error("Invalid parameter: {0}")]
428    InvalidParameter(#[from] TypeParamError),
429    /// Field type is invalid
430    #[error("Invalid structure field type: {0}")]
431    InvalidFieldType(String),
432    /// Duplicate struct field name
433    #[error("Duplicate struct field '{field_name}'")]
434    DuplicateFieldName {
435        /// The duplicated field name
436        field_name: String,
437    },
438    /// Type parameter count is invalid for the given type name
439    #[error("Type '{type_name}' expects {expected} parameters, got {actual}")]
440    InvalidParameterCount {
441        /// The type name being validated
442        type_name: String,
443        /// Expected number of parameters
444        expected: usize,
445        /// The actual number of parameters provided
446        actual: usize,
447    },
448    /// Type parameter is of the wrong kind for the given position
449    #[error("Type '{type_name}' parameter {index} must be {expected}")]
450    InvalidParameterKind {
451        /// The type name being validated
452        type_name: String,
453        /// Zero-based index of the offending parameter
454        index: usize,
455        /// Expected parameter kind (e.g., integer, type)
456        expected: &'static str,
457    },
458    /// Provided parameter value does not fit within the expected bounds
459    #[error("Type '{type_name}' parameter {index} value {value} is not within {expected}")]
460    InvalidParameterValue {
461        /// The type name being validated
462        type_name: String,
463        /// Zero-based index of the offending parameter
464        index: usize,
465        /// Provided parameter value
466        value: i64,
467        /// Description of the expected range or type
468        expected: &'static str,
469    },
470    /// Provided parameter value does not fit within the expected bounds
471    #[error("Type '{type_name}' parameter {index} value {value} is out of range {expected:?}")]
472    InvalidParameterRange {
473        /// The type name being validated
474        type_name: String,
475        /// Zero-based index of the offending parameter
476        index: usize,
477        /// Provided parameter value
478        value: i64,
479        /// Description of the expected range or type
480        expected: RangeInclusive<i32>,
481    },
482    /// Structure representation cannot be nullable
483    #[error("Structure representation cannot be nullable: {type_string}")]
484    StructureCannotBeNullable {
485        /// The type string that was nullable
486        type_string: String,
487    },
488    /// Error parsing type
489    #[error("Error parsing type: {0}")]
490    ParseType(#[from] TypeParseError),
491}
492
493/// Error types for TypeParam validation
494#[derive(Debug, Error, PartialEq)]
495pub enum TypeParamError {
496    /// Parameter name is missing
497    #[error("Parameter name is required")]
498    MissingName,
499    /// Integer parameter has non-integer min/max values
500    #[error("Integer parameter has invalid min/max values: min={min:?}, max={max:?}")]
501    InvalidIntegerBounds {
502        /// The invalid minimum value
503        min: Option<f64>,
504        /// The invalid maximum value
505        max: Option<f64>,
506    },
507    /// Enumeration parameter is missing options
508    #[error("Enumeration parameter is missing options")]
509    MissingEnumOptions,
510    /// Enumeration parameter has invalid options
511    #[error("Enumeration parameter has invalid options: {0}")]
512    InvalidEnumOptions(#[from] ParsedEnumOptionsError),
513}
514
515/// A validated Simple Extension type definition
516#[derive(Clone, Debug, PartialEq)]
517pub struct CustomType {
518    /// Type name
519    pub name: String,
520    /// Type parameters (e.g., for generic types)
521    pub parameters: Vec<TypeParam>,
522    /// Concrete structure definition, if any
523    pub structure: Option<ConcreteType>,
524    /// Whether this type can have variadic parameters
525    pub variadic: Option<bool>,
526    /// Human-readable description
527    pub description: Option<String>,
528}
529
530impl CustomType {
531    /// Check if this type name is valid according to Substrait naming rules
532    /// (see the `Identifier` rule in `substrait/grammar/SubstraitLexer.g4`).
533    /// Identifiers are case-insensitive and must start with a an ASCII letter,
534    /// `_`, or `$`, followed by ASCII letters, digits, `_`, or `$`.
535    //
536    // Note: I'm not sure if `$` is actually something we want to allow, or if
537    // `_` is, but it's in the grammar so I'm allowing it here.
538    pub fn validate_name(name: &str) -> Result<(), InvalidTypeName> {
539        let mut chars = name.chars();
540        let first = chars
541            .next()
542            .ok_or_else(|| InvalidTypeName(name.to_string()))?;
543        if !(first.is_ascii_alphabetic() || first == '_' || first == '$') {
544            return Err(InvalidTypeName(name.to_string()));
545        }
546
547        if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '$') {
548            return Err(InvalidTypeName(name.to_string()));
549        }
550
551        Ok(())
552    }
553
554    /// Create a new custom type with validation
555    pub fn new(
556        name: String,
557        parameters: Vec<TypeParam>,
558        structure: Option<ConcreteType>,
559        variadic: Option<bool>,
560        description: Option<String>,
561    ) -> Result<Self, ExtensionTypeError> {
562        Self::validate_name(&name)?;
563
564        Ok(Self {
565            name,
566            parameters,
567            structure,
568            variadic,
569            description,
570        })
571    }
572}
573
574impl From<CustomType> for SimpleExtensionsTypesItem {
575    fn from(value: CustomType) -> Self {
576        // Convert parameters back to TypeParamDefs if any
577        let parameters = if value.parameters.is_empty() {
578            None
579        } else {
580            Some(TypeParamDefs(
581                value
582                    .parameters
583                    .into_iter()
584                    .map(|param| {
585                        let (min, max) = param.param_type.raw_bounds();
586                        TypeParamDefsItem {
587                            name: Some(param.name),
588                            description: param.description,
589                            type_: param.param_type.raw_type(),
590                            min,
591                            max,
592                            options: param.param_type.raw_options(),
593                            // TODO: add this to TypeParamDefsItem parsing, and
594                            // follow it through here. I'm not entirely sure
595                            // when/if it is used.
596                            optional: None,
597                        }
598                    })
599                    .collect(),
600            ))
601        };
602
603        // Convert structure back to Type if any
604        let structure = value.structure.map(Into::into);
605
606        SimpleExtensionsTypesItem {
607            name: value.name,
608            description: value.description,
609            parameters,
610            structure,
611            variadic: value.variadic,
612        }
613    }
614}
615
616impl Parse<TypeContext> for SimpleExtensionsTypesItem {
617    type Parsed = CustomType;
618    type Error = ExtensionTypeError;
619
620    fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
621        let name = self.name;
622        CustomType::validate_name(&name)?;
623
624        // Register this type as found
625        ctx.found(&name);
626
627        let parameters = if let Some(param_defs) = self.parameters {
628            param_defs
629                .0
630                .into_iter()
631                .map(TypeParam::try_from)
632                .collect::<Result<Vec<_>, _>>()?
633        } else {
634            Vec::new()
635        };
636
637        // Parse structure with context, so referenced extension types are recorded as linked
638        let structure = match self.structure {
639            Some(structure_data) => {
640                let parsed = Parse::parse(structure_data, ctx)?;
641                // TODO: check that the structure is valid. The `Type::Object`
642                // form of `structure_data` is by definition a non-nullable `NSTRUCT`; however,
643                // what types allowed under the `Type::String` form is less clear in the spec:
644                // See https://github.com/substrait-io/substrait/issues/920.
645                Some(parsed)
646            }
647            None => None,
648        };
649
650        Ok(CustomType {
651            name,
652            parameters,
653            structure,
654            variadic: self.variadic,
655            description: self.description,
656        })
657    }
658}
659
660impl Parse<TypeContext> for RawType {
661    type Parsed = ConcreteType;
662    type Error = ExtensionTypeError;
663
664    fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
665        match self {
666            RawType::String(type_string) => {
667                let parsed_type = TypeExpr::parse(&type_string)?;
668                let mut link = |name: &str| ctx.linked(name);
669                parsed_type.visit_references(&mut link);
670                let concrete = ConcreteType::try_from(parsed_type)?;
671                Ok(concrete)
672            }
673            RawType::Object(field_map) => {
674                // Type structure in Substrait must preserve field order (see
675                // substrait-io/substrait#915). The typify generation uses
676                // IndexMap to retain the YAML order so that the order of the
677                // fields in the structure matches that of the extensions file.
678                let mut fields = IndexMap::new();
679
680                for (field_name, field_type_value) in field_map {
681                    let type_string = match field_type_value {
682                        serde_json::Value::String(s) => s,
683                        _ => {
684                            return Err(ExtensionTypeError::InvalidFieldType(
685                                "Struct field types must be strings".to_string(),
686                            ));
687                        }
688                    };
689
690                    let parsed_field_type = TypeExpr::parse(&type_string)?;
691                    let mut link = |name: &str| ctx.linked(name);
692                    parsed_field_type.visit_references(&mut link);
693                    let field_concrete_type = ConcreteType::try_from(parsed_field_type)?;
694
695                    if fields
696                        .insert(field_name.clone(), field_concrete_type)
697                        .is_some()
698                    {
699                        return Err(ExtensionTypeError::DuplicateFieldName { field_name });
700                    }
701                }
702
703                Ok(ConcreteType {
704                    kind: ConcreteTypeKind::NamedStruct { fields },
705                    nullable: false,
706                })
707            }
708        }
709    }
710}
711
712/// Invalid type name error
713#[derive(Debug, Error, PartialEq)]
714#[error("invalid type name `{0}`")]
715pub struct InvalidTypeName(String);
716
717/// The structural kind of a Substrait type (builtin, list, map, etc).
718///
719/// This is almost a complete type, but is missing nullability information. It must be
720/// wrapped in a [`ConcreteType`] to form a complete type with nullable/non-nullable annotation.
721///
722/// Note that this is a recursive type - other than the [BuiltinType]s, the other variants can
723/// have type parameters that are themselves [ConcreteType]s.
724#[derive(Clone, Debug, PartialEq)]
725pub enum ConcreteTypeKind {
726    /// Built-in Substrait type (primitive or parameterized)
727    Builtin(BasicBuiltinType),
728    /// Extension type with optional parameters
729    Extension {
730        /// Extension type name
731        name: String,
732        /// Type parameters
733        parameters: Vec<TypeParameter>,
734    },
735    /// List type with element type
736    List(Box<ConcreteType>),
737    /// Map type with key and value types
738    Map {
739        /// Key type
740        key: Box<ConcreteType>,
741        /// Value type
742        value: Box<ConcreteType>,
743    },
744    /// Struct type (ordered fields without names)
745    Struct(Vec<ConcreteType>),
746    /// Named struct type (nstruct - ordered fields with names)
747    NamedStruct {
748        /// Ordered field names and types. They are in the order they should
749        /// appear in the struct - hence the use of [`IndexMap`].
750        fields: IndexMap<String, ConcreteType>,
751    },
752}
753
754impl fmt::Display for ConcreteTypeKind {
755    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
756        match self {
757            ConcreteTypeKind::Builtin(b) => write!(f, "{b}"),
758            ConcreteTypeKind::Extension { name, parameters } => {
759                write!(f, "{name}")?;
760                write_separated(f, parameters.iter(), "<", ">", ", ")
761            }
762            ConcreteTypeKind::List(elem) => write!(f, "list<{elem}>"),
763            ConcreteTypeKind::Map { key, value } => write!(f, "map<{key}, {value}>"),
764            ConcreteTypeKind::Struct(types) => {
765                write_separated(f, types.iter(), "struct<", ">", ", ")
766            }
767            ConcreteTypeKind::NamedStruct { fields } => {
768                let kvs = fields.iter().map(|(k, v)| KeyValueDisplay(k, v, ": "));
769
770                write_separated(f, kvs, "{", "}", ", ")
771            }
772        }
773    }
774}
775
776/// A concrete, fully-resolved type instance with nullability.
777#[derive(Clone, Debug, PartialEq)]
778pub struct ConcreteType {
779    /// The resolved type shape
780    pub kind: ConcreteTypeKind,
781    /// Whether this type is nullable
782    pub nullable: bool,
783}
784
785impl ConcreteType {
786    /// Create a new builtin scalar type
787    pub fn builtin(builtin_type: BasicBuiltinType, nullable: bool) -> ConcreteType {
788        ConcreteType {
789            kind: ConcreteTypeKind::Builtin(builtin_type),
790            nullable,
791        }
792    }
793
794    /// Create a new extension type reference (without parameters)
795    pub fn extension(name: String, nullable: bool) -> ConcreteType {
796        ConcreteType {
797            kind: ConcreteTypeKind::Extension {
798                name,
799                parameters: Vec::new(),
800            },
801            nullable,
802        }
803    }
804
805    /// Create a new parameterized extension type
806    pub fn extension_with_params(
807        name: String,
808        parameters: Vec<TypeParameter>,
809        nullable: bool,
810    ) -> ConcreteType {
811        ConcreteType {
812            kind: ConcreteTypeKind::Extension { name, parameters },
813            nullable,
814        }
815    }
816
817    /// Create a new list type
818    pub fn list(element_type: ConcreteType, nullable: bool) -> ConcreteType {
819        ConcreteType {
820            kind: ConcreteTypeKind::List(Box::new(element_type)),
821            nullable,
822        }
823    }
824
825    /// Create a new struct type (ordered fields without names)
826    pub fn r#struct(field_types: Vec<ConcreteType>, nullable: bool) -> ConcreteType {
827        ConcreteType {
828            kind: ConcreteTypeKind::Struct(field_types),
829            nullable,
830        }
831    }
832
833    /// Create a new map type
834    pub fn map(key_type: ConcreteType, value_type: ConcreteType, nullable: bool) -> ConcreteType {
835        ConcreteType {
836            kind: ConcreteTypeKind::Map {
837                key: Box::new(key_type),
838                value: Box::new(value_type),
839            },
840            nullable,
841        }
842    }
843
844    /// Create a new named struct type (nstruct - ordered fields with names)
845    pub fn named_struct(fields: IndexMap<String, ConcreteType>, nullable: bool) -> ConcreteType {
846        ConcreteType {
847            kind: ConcreteTypeKind::NamedStruct { fields },
848            nullable,
849        }
850    }
851
852    /// Check if this type (as a function argument) is compatible with another
853    /// type (as an input).
854    ///
855    /// Mainly checks nullability:
856    ///   - `i64?` is compatible with `i64` and `i64?` - both can be passed as
857    ///     arguments
858    ///   - `i64` is compatible with `i64` but NOT `i64?` - you can't pass a
859    ///     nullable type to a function that only accepts non-nullable arguments
860    pub fn is_compatible_with(&self, other: &ConcreteType) -> bool {
861        // Types must match exactly, but nullable types can accept non-nullable values
862        self.kind == other.kind && (self.nullable || !other.nullable)
863    }
864}
865
866impl fmt::Display for ConcreteType {
867    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
868        write!(f, "{}", self.kind)?;
869        if self.nullable {
870            write!(f, "?")?;
871        }
872        Ok(())
873    }
874}
875
876impl From<ConcreteType> for RawType {
877    fn from(val: ConcreteType) -> Self {
878        match val.kind {
879            ConcreteTypeKind::NamedStruct { fields } => {
880                let map = Map::from_iter(
881                    fields
882                        .into_iter()
883                        .map(|(name, ty)| (name, serde_json::Value::String(ty.to_string()))),
884                );
885                RawType::Object(map)
886            }
887            _ => RawType::String(val.to_string()),
888        }
889    }
890}
891
892/// Extract and validate an integer parameter for a built-in type.
893///
894/// For `DECIMAL<10,2>`, this validates that `10` (index 0) and `2` (index 1)
895/// are integers within their required ranges (precision 1-38, scale
896/// 0-precision).
897///
898/// - `type_name`: Type being validated (for error messages, e.g., "DECIMAL")
899/// - `index`: Parameter position (0-based, e.g., 0 for precision, 1 for scale);
900///   needed for error messages
901/// - `param`: The parameter to validate
902/// - `range`: Optional bounds to enforce (e.g., `Some(1..=38)` for precision)
903fn expect_integer_param(
904    type_name: &str,
905    index: usize,
906    param: &TypeExprParam<'_>,
907    range: Option<RangeInclusive<i32>>,
908) -> Result<i32, ExtensionTypeError> {
909    let value = match param {
910        TypeExprParam::Integer(value) => {
911            i32::try_from(*value).map_err(|_| ExtensionTypeError::InvalidParameterValue {
912                type_name: type_name.to_string(),
913                index,
914                value: *value,
915                expected: "an i32",
916            })
917        }
918        _ => Err(ExtensionTypeError::InvalidParameterKind {
919            type_name: type_name.to_string(),
920            index,
921            expected: "an integer",
922        }),
923    }?;
924
925    if let Some(range) = range {
926        if range.contains(&value) {
927            return Ok(value);
928        }
929        return Err(ExtensionTypeError::InvalidParameterRange {
930            type_name: type_name.to_string(),
931            index,
932            value: i64::from(value),
933            expected: range,
934        });
935    }
936
937    Ok(value)
938}
939
940/// Helper function - checks that param length matches expectations, returns
941/// error if not. Assumes a fixed number of expected parameters.
942fn expect_param_len(
943    type_name: &str,
944    params: &[TypeExprParam<'_>],
945    expected: usize,
946) -> Result<(), ExtensionTypeError> {
947    if params.len() != expected {
948        return Err(ExtensionTypeError::InvalidParameterCount {
949            type_name: type_name.to_string(),
950            expected,
951            actual: params.len(),
952        });
953    }
954    Ok(())
955}
956
957/// Helper function - expect a type parameter, and return the [ConcreteType] if it is a [TypeExpr]
958/// or an error if not.
959fn expect_type_argument<'a>(
960    type_name: &str,
961    index: usize,
962    param: TypeExprParam<'a>,
963) -> Result<ConcreteType, ExtensionTypeError> {
964    match param {
965        TypeExprParam::Type(t) => ConcreteType::try_from(t),
966        TypeExprParam::Integer(_) => Err(ExtensionTypeError::InvalidParameterKind {
967            type_name: type_name.to_string(),
968            index,
969            expected: "a type",
970        }),
971    }
972}
973
974impl<'a> TryFrom<TypeExprParam<'a>> for TypeParameter {
975    type Error = ExtensionTypeError;
976
977    fn try_from(param: TypeExprParam<'a>) -> Result<Self, Self::Error> {
978        Ok(match param {
979            TypeExprParam::Integer(v) => TypeParameter::Integer(v),
980            TypeExprParam::Type(t) => TypeParameter::Type(ConcreteType::try_from(t)?),
981        })
982    }
983}
984
985/// Parse a builtin type. Returns an `ExtensionTypeError` if the type name is
986/// matched, but parameters are incorrect; returns `Some(None)` if the type is
987/// not known.
988fn parse_builtin<'a>(
989    display_name: &str,
990    lower_name: &str,
991    params: &[TypeExprParam<'a>],
992) -> Result<Option<BasicBuiltinType>, ExtensionTypeError> {
993    if let Some(builtin) = primitive_builtin(lower_name) {
994        expect_param_len(display_name, params, 0)?;
995        return Ok(Some(builtin));
996    }
997
998    match lower_name {
999        // Parameterized builtins
1000        "fixedchar" => {
1001            expect_param_len(display_name, params, 1)?;
1002            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1003            Ok(Some(BasicBuiltinType::FixedChar { length }))
1004        }
1005        "varchar" => {
1006            expect_param_len(display_name, params, 1)?;
1007            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1008            Ok(Some(BasicBuiltinType::VarChar { length }))
1009        }
1010        "fixedbinary" => {
1011            expect_param_len(display_name, params, 1)?;
1012            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1013            Ok(Some(BasicBuiltinType::FixedBinary { length }))
1014        }
1015        "decimal" => {
1016            expect_param_len(display_name, params, 2)?;
1017            let precision = expect_integer_param(display_name, 0, &params[0], Some(1..=38))?;
1018            let scale = expect_integer_param(display_name, 1, &params[1], Some(0..=precision))?;
1019            Ok(Some(BasicBuiltinType::Decimal { precision, scale }))
1020        }
1021        // Should we accept both "precision_time" and "precisiontime"? The
1022        // docs/spec say PRECISIONTIME. The protos use underscores, so it could
1023        // show up in generated code, although maybe that's out of spec.
1024        "precisiontime" => {
1025            expect_param_len(display_name, params, 1)?;
1026            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1027            Ok(Some(BasicBuiltinType::PrecisionTime { precision }))
1028        }
1029        "precisiontimestamp" => {
1030            expect_param_len(display_name, params, 1)?;
1031            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1032            Ok(Some(BasicBuiltinType::PrecisionTimestamp { precision }))
1033        }
1034        "precisiontimestamptz" => {
1035            expect_param_len(display_name, params, 1)?;
1036            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1037            Ok(Some(BasicBuiltinType::PrecisionTimestampTz { precision }))
1038        }
1039        "interval_day" => {
1040            expect_param_len(display_name, params, 1)?;
1041            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=9))?;
1042            Ok(Some(BasicBuiltinType::IntervalDay { precision }))
1043        }
1044        "interval_compound" => {
1045            expect_param_len(display_name, params, 1)?;
1046            let precision = expect_integer_param(display_name, 0, &params[0], None)?;
1047            Ok(Some(BasicBuiltinType::IntervalCompound { precision }))
1048        }
1049        _ => Ok(None),
1050    }
1051}
1052
1053impl<'a> TryFrom<TypeExpr<'a>> for ConcreteType {
1054    type Error = ExtensionTypeError;
1055
1056    fn try_from(parsed_type: TypeExpr<'a>) -> Result<Self, Self::Error> {
1057        match parsed_type {
1058            TypeExpr::Simple(name, params, nullable) => {
1059                let lower = name.to_ascii_lowercase();
1060
1061                match lower.as_str() {
1062                    "list" => {
1063                        expect_param_len(name, &params, 1)?;
1064                        let element =
1065                            expect_type_argument(name, 0, params.into_iter().next().unwrap())?;
1066                        return Ok(ConcreteType::list(element, nullable));
1067                    }
1068                    "map" => {
1069                        expect_param_len(name, &params, 2)?;
1070                        let mut iter = params.into_iter();
1071                        let key = expect_type_argument(name, 0, iter.next().unwrap())?;
1072                        let value = expect_type_argument(name, 1, iter.next().unwrap())?;
1073                        return Ok(ConcreteType::map(key, value, nullable));
1074                    }
1075                    "struct" => {
1076                        let field_types = params
1077                            .into_iter()
1078                            .enumerate()
1079                            .map(|(idx, param)| expect_type_argument(name, idx, param))
1080                            .collect::<Result<Vec<_>, _>>()?;
1081                        return Ok(ConcreteType::r#struct(field_types, nullable));
1082                    }
1083                    _ => {}
1084                }
1085
1086                if let Some(builtin) = parse_builtin(name, lower.as_str(), &params)? {
1087                    return Ok(ConcreteType::builtin(builtin, nullable));
1088                }
1089
1090                let parameters = params
1091                    .into_iter()
1092                    .map(TypeParameter::try_from)
1093                    .collect::<Result<Vec<_>, _>>()?;
1094                Ok(ConcreteType::extension_with_params(
1095                    name.to_string(),
1096                    parameters,
1097                    nullable,
1098                ))
1099            }
1100            TypeExpr::UserDefined(name, params, nullable) => {
1101                let parameters = params
1102                    .into_iter()
1103                    .map(TypeParameter::try_from)
1104                    .collect::<Result<Vec<_>, _>>()?;
1105                Ok(ConcreteType::extension_with_params(
1106                    name.to_string(),
1107                    parameters,
1108                    nullable,
1109                ))
1110            }
1111            TypeExpr::TypeVariable(id, nullability) => {
1112                Err(ExtensionTypeError::InvalidAnyTypeVariable { id, nullability })
1113            }
1114        }
1115    }
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120    use super::super::extensions::TypeContext;
1121    use super::*;
1122    use crate::parse::text::simple_extensions::TypeExpr;
1123    use crate::parse::text::simple_extensions::argument::EnumOptions as ParsedEnumOptions;
1124    use crate::text::simple_extensions;
1125    use std::iter::FromIterator;
1126
1127    /// Create a [ConcreteType] from a [BuiltinType]
1128    fn concretize(builtin: BasicBuiltinType) -> ConcreteType {
1129        ConcreteType::builtin(builtin, false)
1130    }
1131
1132    /// Parse a string into a [ConcreteType]
1133    fn parse_type(expr: &str) -> ConcreteType {
1134        let parsed = TypeExpr::parse(expr).unwrap();
1135        ConcreteType::try_from(parsed).unwrap()
1136    }
1137
1138    /// Parse a string into a [ConcreteType], returning the result
1139    fn parse_type_result(expr: &str) -> Result<ConcreteType, ExtensionTypeError> {
1140        let parsed = TypeExpr::parse(expr).unwrap();
1141        ConcreteType::try_from(parsed)
1142    }
1143
1144    /// Parse a string into a builtin [ConcreteType], with no unresolved
1145    /// extension references
1146    fn parse_simple(s: &str) -> ConcreteType {
1147        let parsed = TypeExpr::parse(s).unwrap();
1148
1149        let mut refs = Vec::new();
1150        parsed.visit_references(&mut |name| refs.push(name.to_string()));
1151        assert!(refs.is_empty(), "{s} should not add an extension reference");
1152
1153        ConcreteType::try_from(parsed).unwrap()
1154    }
1155
1156    /// Create a type parameter from a type expression string
1157    fn type_param(expr: &str) -> TypeParameter {
1158        TypeParameter::Type(parse_type(expr))
1159    }
1160
1161    /// Create an extension type
1162    fn extension(name: &str, parameters: Vec<TypeParameter>, nullable: bool) -> ConcreteType {
1163        ConcreteType::extension_with_params(name.to_string(), parameters, nullable)
1164    }
1165
1166    /// Convert a custom type to raw and back, ensuring round-trip consistency
1167    fn round_trip(custom: &CustomType) {
1168        let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into();
1169        let mut ctx = TypeContext::default();
1170        let parsed = Parse::parse(item, &mut ctx).unwrap();
1171        assert_eq!(&parsed, custom);
1172    }
1173
1174    /// Create a raw named struct (e.g. straight from YAML) from field name and
1175    /// type pairs
1176    fn raw_named_struct(fields: &[(&str, &str)]) -> RawType {
1177        let map = Map::from_iter(
1178            fields
1179                .iter()
1180                .map(|(name, ty)| ((*name).into(), serde_json::Value::String((*ty).into()))),
1181        );
1182
1183        RawType::Object(map)
1184    }
1185
1186    #[test]
1187    fn test_builtin_scalar_parsing() {
1188        let cases = vec![
1189            ("bool", Some(BasicBuiltinType::Boolean)),
1190            ("i8", Some(BasicBuiltinType::I8)),
1191            ("i16", Some(BasicBuiltinType::I16)),
1192            ("i32", Some(BasicBuiltinType::I32)),
1193            ("i64", Some(BasicBuiltinType::I64)),
1194            ("fp32", Some(BasicBuiltinType::Fp32)),
1195            ("fp64", Some(BasicBuiltinType::Fp64)),
1196            ("STRING", Some(BasicBuiltinType::String)),
1197            ("binary", Some(BasicBuiltinType::Binary)),
1198            ("uuid", Some(BasicBuiltinType::Uuid)),
1199            ("date", Some(BasicBuiltinType::Date)),
1200            ("interval_year", Some(BasicBuiltinType::IntervalYear)),
1201            ("time", Some(BasicBuiltinType::Time)),
1202            ("timestamp", Some(BasicBuiltinType::Timestamp)),
1203            ("timestamp_tz", Some(BasicBuiltinType::TimestampTz)),
1204            ("invalid", None),
1205        ];
1206
1207        for (input, expected) in cases {
1208            let result = parse_builtin(input, input.to_ascii_lowercase().as_str(), &[]).unwrap();
1209            match expected {
1210                Some(expected_type) => {
1211                    assert_eq!(
1212                        result,
1213                        Some(expected_type),
1214                        "expected builtin type for {input}"
1215                    );
1216                }
1217                None => {
1218                    assert!(result.is_none(), "expected parsing {input} to fail");
1219                }
1220            }
1221        }
1222    }
1223
1224    #[test]
1225    fn test_parameterized_builtin_types() {
1226        let cases = vec![
1227            (
1228                "precisiontime<2>",
1229                concretize(BasicBuiltinType::PrecisionTime { precision: 2 }),
1230            ),
1231            (
1232                "precisiontimestamp<1>",
1233                concretize(BasicBuiltinType::PrecisionTimestamp { precision: 1 }),
1234            ),
1235            (
1236                "precisiontimestamptz<5>",
1237                concretize(BasicBuiltinType::PrecisionTimestampTz { precision: 5 }),
1238            ),
1239            (
1240                "DECIMAL<10,2>",
1241                concretize(BasicBuiltinType::Decimal {
1242                    precision: 10,
1243                    scale: 2,
1244                }),
1245            ),
1246            (
1247                "fixedchar<12>",
1248                concretize(BasicBuiltinType::FixedChar { length: 12 }),
1249            ),
1250            (
1251                "VarChar<42>",
1252                concretize(BasicBuiltinType::VarChar { length: 42 }),
1253            ),
1254            (
1255                "fixedbinary<8>",
1256                concretize(BasicBuiltinType::FixedBinary { length: 8 }),
1257            ),
1258            (
1259                "interval_day<7>",
1260                concretize(BasicBuiltinType::IntervalDay { precision: 7 }),
1261            ),
1262            (
1263                "interval_compound<6>",
1264                concretize(BasicBuiltinType::IntervalCompound { precision: 6 }),
1265            ),
1266        ];
1267
1268        for (expr, expected) in cases {
1269            let found = parse_simple(expr);
1270            assert_eq!(found, expected, "unexpected type for {expr}");
1271        }
1272    }
1273
1274    #[test]
1275    fn test_parameterized_builtin_range_errors() {
1276        use ExtensionTypeError::InvalidParameterRange;
1277
1278        let cases = vec![
1279            ("precisiontime<13>", "precisiontime", 0, 13, 0..=12),
1280            ("precisiontime<-1>", "precisiontime", 0, -1, 0..=12),
1281            (
1282                "precisiontimestamp<13>",
1283                "precisiontimestamp",
1284                0,
1285                13,
1286                0..=12,
1287            ),
1288            (
1289                "precisiontimestamp<-1>",
1290                "precisiontimestamp",
1291                0,
1292                -1,
1293                0..=12,
1294            ),
1295            (
1296                "precisiontimestamptz<20>",
1297                "precisiontimestamptz",
1298                0,
1299                20,
1300                0..=12,
1301            ),
1302            ("interval_day<10>", "interval_day", 0, 10, 0..=9),
1303            ("DECIMAL<39,0>", "DECIMAL", 0, 39, 1..=38),
1304            ("DECIMAL<0,0>", "DECIMAL", 0, 0, 1..=38),
1305            ("DECIMAL<10,-1>", "DECIMAL", 1, -1, 0..=10),
1306            ("DECIMAL<10,12>", "DECIMAL", 1, 12, 0..=10),
1307        ];
1308
1309        for (expr, expected_type, expected_index, expected_value, expected_range) in cases {
1310            match parse_type_result(expr) {
1311                Ok(value) => panic!("expected error parsing {expr}, got {value:?}"),
1312                Err(InvalidParameterRange {
1313                    type_name,
1314                    index,
1315                    value,
1316                    expected,
1317                }) => {
1318                    assert_eq!(type_name, expected_type, "unexpected type for {expr}");
1319                    assert_eq!(index, expected_index, "unexpected index for {expr}");
1320                    assert_eq!(
1321                        value,
1322                        i64::from(expected_value),
1323                        "unexpected value for {expr}"
1324                    );
1325                    assert_eq!(expected, expected_range, "unexpected range for {expr}");
1326                }
1327                Err(other) => panic!("expected InvalidParameterRange for {expr}, got {other:?}"),
1328            }
1329        }
1330    }
1331
1332    #[test]
1333    fn test_container_types() {
1334        let cases = vec![
1335            (
1336                "List<i32>",
1337                ConcreteType::list(ConcreteType::builtin(BasicBuiltinType::I32, false), false),
1338            ),
1339            (
1340                "List<fp64?>",
1341                ConcreteType::list(ConcreteType::builtin(BasicBuiltinType::Fp64, true), false),
1342            ),
1343            (
1344                "Map?<i64, string?>",
1345                ConcreteType::map(
1346                    ConcreteType::builtin(BasicBuiltinType::I64, false),
1347                    ConcreteType::builtin(BasicBuiltinType::String, true),
1348                    true,
1349                ),
1350            ),
1351            (
1352                "Struct?<i8, string?>",
1353                ConcreteType::r#struct(
1354                    vec![
1355                        ConcreteType::builtin(BasicBuiltinType::I8, false),
1356                        ConcreteType::builtin(BasicBuiltinType::String, true),
1357                    ],
1358                    true,
1359                ),
1360            ),
1361        ];
1362
1363        for (expr, expected) in cases {
1364            assert_eq!(parse_type(expr), expected, "unexpected parse for {expr}");
1365        }
1366    }
1367
1368    #[test]
1369    fn test_extension_types() {
1370        let cases = vec![
1371            (
1372                "u!geo<List<i32>, 10>",
1373                extension(
1374                    "geo",
1375                    vec![type_param("List<i32>"), TypeParameter::Integer(10)],
1376                    false,
1377                ),
1378            ),
1379            (
1380                "Geo?<List<i32?>>",
1381                extension("Geo", vec![type_param("List<i32?>")], true),
1382            ),
1383            (
1384                "Custom<string?, bool>",
1385                extension(
1386                    "Custom",
1387                    vec![
1388                        type_param("string?"),
1389                        TypeParameter::Type(ConcreteType::builtin(
1390                            BasicBuiltinType::Boolean,
1391                            false,
1392                        )),
1393                    ],
1394                    false,
1395                ),
1396            ),
1397        ];
1398
1399        for (expr, expected) in cases {
1400            assert_eq!(
1401                parse_type(expr),
1402                expected,
1403                "unexpected extension for {expr}"
1404            );
1405        }
1406    }
1407
1408    #[test]
1409    fn test_parameter_type_validation() {
1410        let int_param = ParameterConstraint::Integer {
1411            min: Some(1),
1412            max: Some(10),
1413        };
1414        let enum_param = ParameterConstraint::Enumeration {
1415            options: ParsedEnumOptions::try_from(simple_extensions::EnumOptions(vec![
1416                "OVERFLOW".to_string(),
1417                "ERROR".to_string(),
1418            ]))
1419            .unwrap(),
1420        };
1421
1422        let cases = vec![
1423            (&int_param, Value::Number(5.into()), true),
1424            (&int_param, Value::Number(0.into()), false),
1425            (&int_param, Value::Number(11.into()), false),
1426            (&int_param, Value::String("not a number".into()), false),
1427            (&enum_param, Value::String("OVERFLOW".into()), true),
1428            (&enum_param, Value::String("INVALID".into()), false),
1429        ];
1430
1431        for (param, value, expected) in cases {
1432            assert_eq!(
1433                param.is_valid_value(&value),
1434                expected,
1435                "unexpected validation result for {value:?}"
1436            );
1437        }
1438    }
1439
1440    #[test]
1441    fn test_type_round_trip_display() {
1442        // (example, canonical form)
1443        let cases = vec![
1444            ("i32", "i32"),
1445            ("I64?", "i64?"),
1446            ("list<string>", "list<string>"),
1447            ("List<STRING?>", "list<string?>"),
1448            ("map<i32, list<string>>", "map<i32, list<string>>"),
1449            ("struct<i8, string?>", "struct<i8, string?>"),
1450            (
1451                "Struct<List<i32>, Map<string, list<i64?>>>",
1452                "struct<list<i32>, map<string, list<i64?>>>",
1453            ),
1454            (
1455                "Map<List<I32?>, Struct<string, list<i64?>>>",
1456                "map<list<i32?>, struct<string, list<i64?>>>",
1457            ),
1458            ("u!custom<i32>", "custom<i32>"),
1459        ];
1460
1461        for (input, expected) in cases {
1462            let parsed = TypeExpr::parse(input).unwrap();
1463            let concrete = ConcreteType::try_from(parsed).unwrap();
1464            let actual = concrete.to_string();
1465
1466            assert_eq!(actual, expected, "unexpected display for {input}");
1467        }
1468    }
1469
1470    /// Test that named struct field order preserves the structure order when
1471    /// round-tripping through RawType (Substrait #915).
1472    #[test]
1473    fn test_named_struct_field_order_stability() -> Result<(), ExtensionTypeError> {
1474        let mut raw_fields = Map::new();
1475        raw_fields.insert("beta".to_string(), Value::String("i32".to_string()));
1476        raw_fields.insert("alpha".to_string(), Value::String("string?".to_string()));
1477
1478        let raw = RawType::Object(raw_fields);
1479        let mut ctx = TypeContext::default();
1480        let concrete = Parse::parse(raw, &mut ctx)?;
1481
1482        let round_tripped: RawType = concrete.into();
1483        match round_tripped {
1484            RawType::Object(result_map) => {
1485                let keys: Vec<_> = result_map.keys().collect();
1486                assert_eq!(
1487                    keys,
1488                    vec!["beta", "alpha"],
1489                    "field order should be preserved"
1490                );
1491            }
1492            other => panic!("expected Object, got {other:?}"),
1493        }
1494
1495        Ok(())
1496    }
1497
1498    #[test]
1499    fn test_integer_param_bounds_round_trip() {
1500        let cases = vec![
1501            (
1502                "bounded",
1503                simple_extensions::TypeParamDefsItem {
1504                    name: Some("K".to_string()),
1505                    description: None,
1506                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1507                    min: Some(1.0),
1508                    max: Some(10.0),
1509                    options: None,
1510                    optional: None,
1511                },
1512                Ok((Some(1), Some(10))),
1513            ),
1514            (
1515                "fractional_min",
1516                simple_extensions::TypeParamDefsItem {
1517                    name: Some("K".to_string()),
1518                    description: None,
1519                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1520                    min: Some(1.5),
1521                    max: None,
1522                    options: None,
1523                    optional: None,
1524                },
1525                Err(TypeParamError::InvalidIntegerBounds {
1526                    min: Some(1.5),
1527                    max: None,
1528                }),
1529            ),
1530            (
1531                "fractional_max",
1532                simple_extensions::TypeParamDefsItem {
1533                    name: Some("K".to_string()),
1534                    description: None,
1535                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1536                    min: None,
1537                    max: Some(9.9),
1538                    options: None,
1539                    optional: None,
1540                },
1541                Err(TypeParamError::InvalidIntegerBounds {
1542                    min: None,
1543                    max: Some(9.9),
1544                }),
1545            ),
1546        ];
1547
1548        for (label, item, expected) in cases {
1549            match (TypeParam::try_from(item), expected) {
1550                (Ok(tp), Ok((expected_min, expected_max))) => match tp.param_type {
1551                    ParameterConstraint::Integer { min, max } => {
1552                        assert_eq!(min, expected_min, "min mismatch for {label}");
1553                        assert_eq!(max, expected_max, "max mismatch for {label}");
1554                    }
1555                    _ => panic!("expected integer param type for {label}"),
1556                },
1557                (Err(actual_err), Err(expected_err)) => {
1558                    assert_eq!(actual_err, expected_err, "unexpected error for {label}");
1559                }
1560                (result, expected) => {
1561                    panic!("unexpected result for {label}: got {result:?}, expected {expected:?}")
1562                }
1563            }
1564        }
1565    }
1566
1567    #[test]
1568    fn test_custom_type_round_trip() -> Result<(), ExtensionTypeError> {
1569        let fields = IndexMap::from_iter([
1570            (
1571                "x".to_string(),
1572                ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1573            ),
1574            (
1575                "y".to_string(),
1576                ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1577            ),
1578        ]);
1579
1580        let cases = vec![
1581            CustomType::new(
1582                "AliasType".to_string(),
1583                vec![],
1584                Some(ConcreteType::builtin(BasicBuiltinType::I32, false)),
1585                None,
1586                Some("a test alias type".to_string()),
1587            )?,
1588            CustomType::new(
1589                "Point".to_string(),
1590                vec![TypeParam::new(
1591                    "T".to_string(),
1592                    ParameterConstraint::DataType,
1593                    Some("a numeric value".to_string()),
1594                )],
1595                Some(ConcreteType::named_struct(fields, false)),
1596                None,
1597                None,
1598            )?,
1599        ];
1600
1601        for custom in cases {
1602            round_trip(&custom);
1603        }
1604
1605        Ok(())
1606    }
1607
1608    #[test]
1609    fn test_invalid_type_names() {
1610        let cases = vec![
1611            ("", false),
1612            ("bad name", false),
1613            ("9bad", false),
1614            ("bad-name", false),
1615            ("bad.name", false),
1616            ("GoodName", true),
1617            ("also_good", true),
1618            ("_underscore", true),
1619            ("$dollar", true),
1620            ("CamelCase123", true),
1621        ];
1622
1623        for (name, expected_ok) in cases {
1624            let result = CustomType::validate_name(name);
1625            assert_eq!(
1626                result.is_ok(),
1627                expected_ok,
1628                "unexpected validation for {name}"
1629            );
1630        }
1631    }
1632
1633    #[test]
1634    fn test_ext_type_to_concrete_type() -> Result<(), ExtensionTypeError> {
1635        let cases = vec![
1636            (
1637                "alias",
1638                RawType::String("i32".to_string()),
1639                ConcreteType::builtin(BasicBuiltinType::I32, false),
1640            ),
1641            (
1642                "named_struct",
1643                raw_named_struct(&[("field1", "fp64"), ("field2", "i32?")]),
1644                ConcreteType::named_struct(
1645                    IndexMap::from_iter([
1646                        (
1647                            "field1".to_string(),
1648                            ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1649                        ),
1650                        (
1651                            "field2".to_string(),
1652                            ConcreteType::builtin(BasicBuiltinType::I32, true),
1653                        ),
1654                    ]),
1655                    false,
1656                ),
1657            ),
1658        ];
1659
1660        for (label, raw, expected) in cases {
1661            let mut ctx = TypeContext::default();
1662            let parsed = Parse::parse(raw, &mut ctx)?;
1663            assert_eq!(parsed, expected, "unexpected type for {label}");
1664        }
1665
1666        Ok(())
1667    }
1668
1669    #[test]
1670    fn test_custom_type_parsing() -> Result<(), ExtensionTypeError> {
1671        let cases = vec![
1672            (
1673                "alias",
1674                simple_extensions::SimpleExtensionsTypesItem {
1675                    name: "Alias".to_string(),
1676                    description: Some("Alias type".to_string()),
1677                    parameters: None,
1678                    structure: Some(RawType::String("BINARY".to_string())),
1679                    variadic: None,
1680                },
1681                "Alias",
1682                Some("Alias type"),
1683                Some(ConcreteType::builtin(BasicBuiltinType::Binary, false)),
1684            ),
1685            (
1686                "named_struct",
1687                simple_extensions::SimpleExtensionsTypesItem {
1688                    name: "Point".to_string(),
1689                    description: Some("A 2D point".to_string()),
1690                    parameters: None,
1691                    structure: Some(raw_named_struct(&[("x", "fp64"), ("y", "fp64?")])),
1692                    variadic: None,
1693                },
1694                "Point",
1695                Some("A 2D point"),
1696                Some(ConcreteType::named_struct(
1697                    IndexMap::from_iter([
1698                        (
1699                            "x".to_string(),
1700                            ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1701                        ),
1702                        (
1703                            "y".to_string(),
1704                            ConcreteType::builtin(BasicBuiltinType::Fp64, true),
1705                        ),
1706                    ]),
1707                    false,
1708                )),
1709            ),
1710            (
1711                "no_structure",
1712                simple_extensions::SimpleExtensionsTypesItem {
1713                    name: "Opaque".to_string(),
1714                    description: None,
1715                    parameters: None,
1716                    structure: None,
1717                    variadic: Some(true),
1718                },
1719                "Opaque",
1720                None,
1721                None,
1722            ),
1723        ];
1724
1725        for (label, item, expected_name, expected_description, expected_structure) in cases {
1726            let mut ctx = TypeContext::default();
1727            let parsed = Parse::parse(item, &mut ctx)?;
1728            assert_eq!(parsed.name, expected_name);
1729            assert_eq!(
1730                parsed.description.as_deref(),
1731                expected_description,
1732                "description mismatch for {label}"
1733            );
1734            assert_eq!(
1735                parsed.structure, expected_structure,
1736                "structure mismatch for {label}"
1737            );
1738        }
1739
1740        Ok(())
1741    }
1742}