substrait/parse/text/simple_extensions/
types.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Concrete type system for function validation in the registry.
4//!
5//! This module provides a clean, type-safe wrapper around Substrait extension types,
6//! separating function signature patterns from concrete argument types.
7
8use super::TypeExpr;
9use super::argument::{
10    EnumOptions as ParsedEnumOptions, EnumOptionsError as ParsedEnumOptionsError,
11};
12use super::extensions::TypeContext;
13use super::type_ast::TypeExprParam;
14use crate::parse::Parse;
15use crate::parse::text::simple_extensions::type_ast::TypeParseError;
16use crate::text::simple_extensions::{
17    EnumOptions as RawEnumOptions, SimpleExtensionsTypesItem, Type as RawType, TypeParamDefs,
18    TypeParamDefsItem, TypeParamDefsItemType,
19};
20use indexmap::IndexMap;
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23use std::fmt;
24use std::ops::RangeInclusive;
25use thiserror::Error;
26
27/// Write a sequence of items separated by a separator, with a start and end
28/// delimiter.
29///
30/// Start and end are only included in the output if there is at least one item.
31fn write_separated<I, T>(
32    f: &mut fmt::Formatter<'_>,
33    iter: I,
34    start: &str,
35    end: &str,
36    sep: &str,
37) -> fmt::Result
38where
39    I: IntoIterator<Item = T>,
40    T: fmt::Display,
41{
42    let mut it = iter.into_iter();
43    if let Some(first) = it.next() {
44        f.write_str(start)?;
45        write!(f, "{first}")?;
46        for item in it {
47            f.write_str(sep)?;
48            write!(f, "{item}")?;
49        }
50        f.write_str(end)
51    } else {
52        Ok(())
53    }
54}
55
56/// A pair of a key and a value, separated by a separator. For display purposes.
57struct KeyValueDisplay<K, V>(K, V, &'static str);
58
59impl<K, V> fmt::Display for KeyValueDisplay<K, V>
60where
61    K: fmt::Display,
62    V: fmt::Display,
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(f, "{}{}{}", self.0, self.2, self.1)
66    }
67}
68
69/// Non-recursive, built-in Substrait types: types with no parameters (primitive
70/// types), or simple with only primitive / literal parameters.
71#[derive(Clone, Debug, PartialEq)]
72pub enum BasicBuiltinType {
73    /// Boolean type - `bool`
74    Boolean,
75    /// 8-bit signed integer - `i8`
76    I8,
77    /// 16-bit signed integer - `i16`
78    I16,
79    /// 32-bit signed integer - `i32`
80    I32,
81    /// 64-bit signed integer - `i64`
82    I64,
83    /// 32-bit floating point - `fp32`
84    Fp32,
85    /// 64-bit floating point - `fp64`
86    Fp64,
87    /// Variable-length string - `string`
88    String,
89    /// Variable-length binary data - `binary`
90    Binary,
91    /// Naive Timestamp
92    Timestamp,
93    /// Timestamp with time zone - `timestamp_tz`
94    TimestampTz,
95    /// Calendar date - `date`
96    Date,
97    /// Time of day - `time`
98    Time,
99    /// Year-month interval - `interval_year`
100    IntervalYear,
101    /// 128-bit UUID - `uuid`
102    Uuid,
103    /// Fixed-length character string: `FIXEDCHAR<L>`
104    FixedChar {
105        /// Length (number of characters), must be >= 1
106        length: i32,
107    },
108    /// Variable-length character string: `VARCHAR<L>`
109    VarChar {
110        /// Maximum length (number of characters), must be >= 1
111        length: i32,
112    },
113    /// Fixed-length binary data: `FIXEDBINARY<L>`
114    FixedBinary {
115        /// Length (number of bytes), must be >= 1
116        length: i32,
117    },
118    /// Fixed-point decimal: `DECIMAL<P, S>`
119    Decimal {
120        /// Precision (total digits), <= 38
121        precision: i32,
122        /// Scale (digits after decimal point), 0 <= S <= P
123        scale: i32,
124    },
125    /// Time with sub-second precision: `PRECISIONTIME<P>`
126    PrecisionTime {
127        /// Sub-second precision digits (0-12: seconds to picoseconds)
128        precision: i32,
129    },
130    /// Timestamp with sub-second precision: `PRECISIONTIMESTAMP<P>`
131    PrecisionTimestamp {
132        /// Sub-second precision digits (0-12: seconds to picoseconds)
133        precision: i32,
134    },
135    /// Timezone-aware timestamp with precision: `PRECISIONTIMESTAMPTZ<P>`
136    PrecisionTimestampTz {
137        /// Sub-second precision digits (0-12: seconds to picoseconds)
138        precision: i32,
139    },
140    /// Day-time interval: `INTERVAL_DAY<P>`
141    IntervalDay {
142        /// Sub-second precision digits (0-9: seconds to nanoseconds)
143        precision: i32,
144    },
145    /// Compound interval: `INTERVAL_COMPOUND<P>`
146    IntervalCompound {
147        /// Sub-second precision digits
148        precision: i32,
149    },
150}
151
152impl BasicBuiltinType {
153    /// Check if a string is a valid name for a builtin scalar type
154    pub fn is_name(name: &str) -> bool {
155        let lower = name.to_ascii_lowercase();
156        primitive_builtin(&lower).is_some()
157            || matches!(
158                lower.as_str(),
159                "fixedchar"
160                    | "varchar"
161                    | "fixedbinary"
162                    | "decimal"
163                    | "precisiontime"
164                    | "precision_time"
165                    | "precision_timestamp"
166                    | "precision_timestamp_tz"
167                    | "interval_day"
168                    | "interval_compound"
169            )
170    }
171}
172
173impl fmt::Display for BasicBuiltinType {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        match self {
176            BasicBuiltinType::Boolean => f.write_str("bool"),
177            BasicBuiltinType::I8 => f.write_str("i8"),
178            BasicBuiltinType::I16 => f.write_str("i16"),
179            BasicBuiltinType::I32 => f.write_str("i32"),
180            BasicBuiltinType::I64 => f.write_str("i64"),
181            BasicBuiltinType::Fp32 => f.write_str("fp32"),
182            BasicBuiltinType::Fp64 => f.write_str("fp64"),
183            BasicBuiltinType::String => f.write_str("string"),
184            BasicBuiltinType::Binary => f.write_str("binary"),
185            BasicBuiltinType::Timestamp => f.write_str("timestamp"),
186            BasicBuiltinType::TimestampTz => f.write_str("timestamp_tz"),
187            BasicBuiltinType::Date => f.write_str("date"),
188            BasicBuiltinType::Time => f.write_str("time"),
189            BasicBuiltinType::IntervalYear => f.write_str("interval_year"),
190            BasicBuiltinType::Uuid => f.write_str("uuid"),
191            BasicBuiltinType::FixedChar { length } => write!(f, "FIXEDCHAR<{length}>"),
192            BasicBuiltinType::VarChar { length } => write!(f, "VARCHAR<{length}>"),
193            BasicBuiltinType::FixedBinary { length } => write!(f, "FIXEDBINARY<{length}>"),
194            BasicBuiltinType::Decimal { precision, scale } => {
195                write!(f, "DECIMAL<{precision}, {scale}>")
196            }
197            BasicBuiltinType::PrecisionTime { precision } => {
198                write!(f, "PRECISIONTIME<{precision}>")
199            }
200            BasicBuiltinType::PrecisionTimestamp { precision } => {
201                write!(f, "PRECISIONTIMESTAMP<{precision}>")
202            }
203            BasicBuiltinType::PrecisionTimestampTz { precision } => {
204                write!(f, "PRECISIONTIMESTAMPTZ<{precision}>")
205            }
206            BasicBuiltinType::IntervalDay { precision } => write!(f, "INTERVAL_DAY<{precision}>"),
207            BasicBuiltinType::IntervalCompound { precision } => {
208                write!(f, "INTERVAL_COMPOUND<{precision}>")
209            }
210        }
211    }
212}
213
214/// A parameter, used in parameterized types
215#[derive(Clone, Debug, PartialEq)]
216pub enum TypeParameter {
217    /// Integer parameter (e.g., precision, scale)
218    Integer(i64),
219    /// Type parameter (nested type)
220    Type(ConcreteType),
221    // TODO: Add support for other type parameters, as described in
222    // https://github.com/substrait-io/substrait/blob/35101020d961eda48f8dd1aafbc794c9e5cac077/proto/substrait/type.proto#L250-L265
223}
224
225impl fmt::Display for TypeParameter {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        match self {
228            TypeParameter::Integer(i) => write!(f, "{i}"),
229            TypeParameter::Type(t) => write!(f, "{t}"),
230        }
231    }
232}
233
234/// Parse a primitive (no type parameters) builtin type name
235fn primitive_builtin(lower_name: &str) -> Option<BasicBuiltinType> {
236    match lower_name {
237        "bool" | "boolean" => Some(BasicBuiltinType::Boolean),
238        "i8" => Some(BasicBuiltinType::I8),
239        "i16" => Some(BasicBuiltinType::I16),
240        "i32" => Some(BasicBuiltinType::I32),
241        "i64" => Some(BasicBuiltinType::I64),
242        "fp32" => Some(BasicBuiltinType::Fp32),
243        "fp64" => Some(BasicBuiltinType::Fp64),
244        "string" => Some(BasicBuiltinType::String),
245        "binary" => Some(BasicBuiltinType::Binary),
246        "timestamp" => Some(BasicBuiltinType::Timestamp),
247        "timestamp_tz" => Some(BasicBuiltinType::TimestampTz),
248        "date" => Some(BasicBuiltinType::Date),
249        "time" => Some(BasicBuiltinType::Time),
250        "interval_year" => Some(BasicBuiltinType::IntervalYear),
251        "uuid" => Some(BasicBuiltinType::Uuid),
252        _ => None,
253    }
254}
255
256/// Parameter type information for type definitions
257#[derive(Clone, Debug, PartialEq)]
258pub enum ParameterConstraint {
259    /// Data type parameter
260    DataType,
261    /// Integer parameter with range constraints
262    Integer {
263        /// Minimum value (inclusive), if specified
264        min: Option<i64>,
265        /// Maximum value (inclusive), if specified
266        max: Option<i64>,
267    },
268    /// Enumeration parameter
269    Enumeration {
270        /// Valid enumeration values (validated, deduplicated)
271        options: ParsedEnumOptions,
272    },
273    /// Boolean parameter
274    Boolean,
275    /// String parameter
276    String,
277}
278
279impl ParameterConstraint {
280    /// Convert back to raw TypeParamDefsItemType
281    fn raw_type(&self) -> TypeParamDefsItemType {
282        match self {
283            ParameterConstraint::DataType => TypeParamDefsItemType::DataType,
284            ParameterConstraint::Boolean => TypeParamDefsItemType::Boolean,
285            ParameterConstraint::Integer { .. } => TypeParamDefsItemType::Integer,
286            ParameterConstraint::Enumeration { .. } => TypeParamDefsItemType::Enumeration,
287            ParameterConstraint::String => TypeParamDefsItemType::String,
288        }
289    }
290
291    /// Extract raw bounds for integer parameters (min, max)
292    fn raw_bounds(&self) -> (Option<f64>, Option<f64>) {
293        match self {
294            ParameterConstraint::Integer { min, max } => {
295                (min.map(|i| i as f64), max.map(|i| i as f64))
296            }
297            _ => (None, None),
298        }
299    }
300
301    /// Extract raw enum options for enumeration parameters
302    fn raw_options(&self) -> Option<RawEnumOptions> {
303        match self {
304            ParameterConstraint::Enumeration { options } => Some(options.clone().into()),
305            _ => None,
306        }
307    }
308
309    /// Check if a parameter value is valid for this parameter type
310    pub fn is_valid_value(&self, value: &Value) -> bool {
311        match (self, value) {
312            (ParameterConstraint::DataType, Value::String(_)) => true,
313            (ParameterConstraint::Integer { min, max }, Value::Number(n)) => {
314                if let Some(i) = n.as_i64() {
315                    min.is_none_or(|min_val| i >= min_val) && max.is_none_or(|max_val| i <= max_val)
316                } else {
317                    false
318                }
319            }
320            (ParameterConstraint::Enumeration { options }, Value::String(s)) => options.contains(s),
321            (ParameterConstraint::Boolean, Value::Bool(_)) => true,
322            (ParameterConstraint::String, Value::String(_)) => true,
323            _ => false,
324        }
325    }
326
327    fn from_raw(
328        t: TypeParamDefsItemType,
329        opts: Option<RawEnumOptions>,
330        min: Option<f64>,
331        max: Option<f64>,
332    ) -> Result<Self, TypeParamError> {
333        Ok(match t {
334            TypeParamDefsItemType::DataType => Self::DataType,
335            TypeParamDefsItemType::Boolean => Self::Boolean,
336            TypeParamDefsItemType::Integer => {
337                match (min, max) {
338                    (Some(min_f), _) if min_f.fract() != 0.0 => {
339                        return Err(TypeParamError::InvalidIntegerBounds { min, max });
340                    }
341                    (_, Some(max_f)) if max_f.fract() != 0.0 => {
342                        return Err(TypeParamError::InvalidIntegerBounds { min, max });
343                    }
344                    _ => (),
345                }
346
347                let min_i = min.map(|v| v as i64);
348                let max_i = max.map(|v| v as i64);
349                Self::Integer {
350                    min: min_i,
351                    max: max_i,
352                }
353            }
354            TypeParamDefsItemType::Enumeration => {
355                let options: ParsedEnumOptions =
356                    opts.ok_or(TypeParamError::MissingEnumOptions)?.try_into()?;
357                Self::Enumeration { options }
358            }
359            TypeParamDefsItemType::String => Self::String,
360        })
361    }
362}
363
364/// A validated type parameter with name and constraints
365#[derive(Clone, Debug, PartialEq)]
366pub struct TypeParam {
367    /// Parameter name (e.g., "K" for a type variable)
368    pub name: String,
369    /// Parameter type constraints
370    pub param_type: ParameterConstraint,
371    /// Human-readable description
372    pub description: Option<String>,
373}
374
375impl TypeParam {
376    /// Create a new type parameter
377    pub fn new(name: String, param_type: ParameterConstraint, description: Option<String>) -> Self {
378        Self {
379            name,
380            param_type,
381            description,
382        }
383    }
384
385    /// Check if a parameter value is valid
386    pub fn is_valid_value(&self, value: &Value) -> bool {
387        self.param_type.is_valid_value(value)
388    }
389}
390
391impl TryFrom<TypeParamDefsItem> for TypeParam {
392    type Error = TypeParamError;
393
394    fn try_from(item: TypeParamDefsItem) -> Result<Self, Self::Error> {
395        let name = item.name.ok_or(TypeParamError::MissingName)?;
396        let param_type =
397            ParameterConstraint::from_raw(item.type_, item.options, item.min, item.max)?;
398
399        Ok(Self {
400            name,
401            param_type,
402            description: item.description,
403        })
404    }
405}
406
407/// Error types for extension type validation
408#[derive(Debug, Error, PartialEq)]
409pub enum ExtensionTypeError {
410    /// Extension type name is invalid
411    #[error("{0}")]
412    InvalidName(#[from] InvalidTypeName),
413    /// Any type variable is invalid for concrete types
414    #[error("Any type variable is invalid for concrete types: any{}{}", id, nullability.then_some("?").unwrap_or(""))]
415    InvalidAnyTypeVariable {
416        /// The type variable name
417        id: u32,
418        /// Whether the type variable is nullable
419        nullability: bool,
420    },
421    /// Unknown type name (not a builtin, missing u! prefix for extension types)
422    #[error(
423        "Unknown type name: '{}'. Extension types must use the u! prefix (e.g., u!{})",
424        name,
425        name
426    )]
427    UnknownTypeName {
428        /// The unknown type name
429        name: String,
430    },
431    /// Parameter validation failed
432    #[error("Invalid parameter: {0}")]
433    InvalidParameter(#[from] TypeParamError),
434    /// Field type is invalid
435    #[error("Invalid structure field type: {0}")]
436    InvalidFieldType(String),
437    /// Duplicate struct field name
438    #[error("Duplicate struct field '{field_name}'")]
439    DuplicateFieldName {
440        /// The duplicated field name
441        field_name: String,
442    },
443    /// Type parameter count is invalid for the given type name
444    #[error("Type '{type_name}' expects {expected} parameters, got {actual}")]
445    InvalidParameterCount {
446        /// The type name being validated
447        type_name: String,
448        /// Expected number of parameters
449        expected: usize,
450        /// The actual number of parameters provided
451        actual: usize,
452    },
453    /// Type parameter is of the wrong kind for the given position
454    #[error("Type '{type_name}' parameter {index} must be {expected}")]
455    InvalidParameterKind {
456        /// The type name being validated
457        type_name: String,
458        /// Zero-based index of the offending parameter
459        index: usize,
460        /// Expected parameter kind (e.g., integer, type)
461        expected: &'static str,
462    },
463    /// Provided parameter value does not fit within the expected bounds
464    #[error("Type '{type_name}' parameter {index} value {value} is not within {expected}")]
465    InvalidParameterValue {
466        /// The type name being validated
467        type_name: String,
468        /// Zero-based index of the offending parameter
469        index: usize,
470        /// Provided parameter value
471        value: i64,
472        /// Description of the expected range or type
473        expected: &'static str,
474    },
475    /// Provided parameter value does not fit within the expected bounds
476    #[error("Type '{type_name}' parameter {index} value {value} is out of range {expected:?}")]
477    InvalidParameterRange {
478        /// The type name being validated
479        type_name: String,
480        /// Zero-based index of the offending parameter
481        index: usize,
482        /// Provided parameter value
483        value: i64,
484        /// Description of the expected range or type
485        expected: RangeInclusive<i32>,
486    },
487    /// Structure representation cannot be nullable
488    #[error("Structure representation cannot be nullable: {type_string}")]
489    StructureCannotBeNullable {
490        /// The type string that was nullable
491        type_string: String,
492    },
493    /// Error parsing type
494    #[error("Error parsing type: {0}")]
495    ParseType(#[from] TypeParseError),
496}
497
498/// Error types for TypeParam validation
499#[derive(Debug, Error, PartialEq)]
500pub enum TypeParamError {
501    /// Parameter name is missing
502    #[error("Parameter name is required")]
503    MissingName,
504    /// Integer parameter has non-integer min/max values
505    #[error("Integer parameter has invalid min/max values: min={min:?}, max={max:?}")]
506    InvalidIntegerBounds {
507        /// The invalid minimum value
508        min: Option<f64>,
509        /// The invalid maximum value
510        max: Option<f64>,
511    },
512    /// Enumeration parameter is missing options
513    #[error("Enumeration parameter is missing options")]
514    MissingEnumOptions,
515    /// Enumeration parameter has invalid options
516    #[error("Enumeration parameter has invalid options: {0}")]
517    InvalidEnumOptions(#[from] ParsedEnumOptionsError),
518}
519
520/// A validated Simple Extension type definition
521#[derive(Clone, Debug, PartialEq)]
522pub struct CustomType {
523    /// Type name
524    pub name: String,
525    /// Type parameters (e.g., for generic types)
526    pub parameters: Vec<TypeParam>,
527    /// Concrete structure definition, if any
528    pub structure: Option<ConcreteType>,
529    /// Whether this type can have variadic parameters
530    pub variadic: Option<bool>,
531    /// Human-readable description
532    pub description: Option<String>,
533}
534
535impl CustomType {
536    /// Check if this type name is valid according to Substrait naming rules
537    /// (see the `Identifier` rule in `substrait/grammar/SubstraitLexer.g4`).
538    /// Identifiers are case-insensitive and must start with a an ASCII letter,
539    /// `_`, or `$`, followed by ASCII letters, digits, `_`, or `$`.
540    //
541    // Note: I'm not sure if `$` is actually something we want to allow, or if
542    // `_` is, but it's in the grammar so I'm allowing it here.
543    pub fn validate_name(name: &str) -> Result<(), InvalidTypeName> {
544        let mut chars = name.chars();
545        let first = chars
546            .next()
547            .ok_or_else(|| InvalidTypeName(name.to_string()))?;
548        if !(first.is_ascii_alphabetic() || first == '_' || first == '$') {
549            return Err(InvalidTypeName(name.to_string()));
550        }
551
552        if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '$') {
553            return Err(InvalidTypeName(name.to_string()));
554        }
555
556        Ok(())
557    }
558
559    /// Create a new custom type with validation
560    pub fn new(
561        name: String,
562        parameters: Vec<TypeParam>,
563        structure: Option<ConcreteType>,
564        variadic: Option<bool>,
565        description: Option<String>,
566    ) -> Result<Self, ExtensionTypeError> {
567        Self::validate_name(&name)?;
568
569        Ok(Self {
570            name,
571            parameters,
572            structure,
573            variadic,
574            description,
575        })
576    }
577}
578
579impl From<CustomType> for SimpleExtensionsTypesItem {
580    fn from(value: CustomType) -> Self {
581        // Convert parameters back to TypeParamDefs if any
582        let parameters = if value.parameters.is_empty() {
583            None
584        } else {
585            Some(TypeParamDefs(
586                value
587                    .parameters
588                    .into_iter()
589                    .map(|param| {
590                        let (min, max) = param.param_type.raw_bounds();
591                        TypeParamDefsItem {
592                            name: Some(param.name),
593                            description: param.description,
594                            type_: param.param_type.raw_type(),
595                            min,
596                            max,
597                            options: param.param_type.raw_options(),
598                            // TODO: add this to TypeParamDefsItem parsing, and
599                            // follow it through here. I'm not entirely sure
600                            // when/if it is used.
601                            optional: None,
602                        }
603                    })
604                    .collect(),
605            ))
606        };
607
608        // Convert structure back to Type if any
609        let structure = value.structure.map(Into::into);
610
611        SimpleExtensionsTypesItem {
612            name: value.name,
613            description: value.description,
614            parameters,
615            structure,
616            variadic: value.variadic,
617        }
618    }
619}
620
621impl Parse<TypeContext> for SimpleExtensionsTypesItem {
622    type Parsed = CustomType;
623    type Error = ExtensionTypeError;
624
625    fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
626        let name = self.name;
627        CustomType::validate_name(&name)?;
628
629        // Register this type as found
630        ctx.found(&name);
631
632        let parameters = if let Some(param_defs) = self.parameters {
633            param_defs
634                .0
635                .into_iter()
636                .map(TypeParam::try_from)
637                .collect::<Result<Vec<_>, _>>()?
638        } else {
639            Vec::new()
640        };
641
642        // Parse structure with context, so referenced extension types are recorded as linked
643        let structure = match self.structure {
644            Some(structure_data) => {
645                let parsed = Parse::parse(structure_data, ctx)?;
646                // TODO: check that the structure is valid. The `Type::Object`
647                // form of `structure_data` is by definition a non-nullable `NSTRUCT`; however,
648                // what types allowed under the `Type::String` form is less clear in the spec:
649                // See https://github.com/substrait-io/substrait/issues/920.
650                Some(parsed)
651            }
652            None => None,
653        };
654
655        Ok(CustomType {
656            name,
657            parameters,
658            structure,
659            variadic: self.variadic,
660            description: self.description,
661        })
662    }
663}
664
665impl Parse<TypeContext> for RawType {
666    type Parsed = ConcreteType;
667    type Error = ExtensionTypeError;
668
669    fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
670        match self {
671            RawType::String(type_string) => {
672                let parsed_type = TypeExpr::parse(&type_string)?;
673                let mut link = |name: &str| ctx.linked(name);
674                parsed_type.visit_references(&mut link);
675                let concrete = ConcreteType::try_from(parsed_type)?;
676                Ok(concrete)
677            }
678            RawType::Object(field_map) => {
679                // Type structure in Substrait must preserve field order (see
680                // substrait-io/substrait#915). The typify generation uses
681                // IndexMap to retain the YAML order so that the order of the
682                // fields in the structure matches that of the extensions file.
683                let mut fields = IndexMap::new();
684
685                for (field_name, field_type_value) in field_map {
686                    let type_string = match field_type_value {
687                        serde_json::Value::String(s) => s,
688                        _ => {
689                            return Err(ExtensionTypeError::InvalidFieldType(
690                                "Struct field types must be strings".to_string(),
691                            ));
692                        }
693                    };
694
695                    let parsed_field_type = TypeExpr::parse(&type_string)?;
696                    let mut link = |name: &str| ctx.linked(name);
697                    parsed_field_type.visit_references(&mut link);
698                    let field_concrete_type = ConcreteType::try_from(parsed_field_type)?;
699
700                    if fields
701                        .insert(field_name.clone(), field_concrete_type)
702                        .is_some()
703                    {
704                        return Err(ExtensionTypeError::DuplicateFieldName { field_name });
705                    }
706                }
707
708                Ok(ConcreteType {
709                    kind: ConcreteTypeKind::NamedStruct { fields },
710                    nullable: false,
711                })
712            }
713        }
714    }
715}
716
717/// Invalid type name error
718#[derive(Debug, Error, PartialEq)]
719#[error("invalid type name `{0}`")]
720pub struct InvalidTypeName(String);
721
722/// The structural kind of a Substrait type (builtin, list, map, etc).
723///
724/// This is almost a complete type, but is missing nullability information. It must be
725/// wrapped in a [`ConcreteType`] to form a complete type with nullable/non-nullable annotation.
726///
727/// Note that this is a recursive type - other than the [BuiltinType]s, the other variants can
728/// have type parameters that are themselves [ConcreteType]s.
729#[derive(Clone, Debug, PartialEq)]
730pub enum ConcreteTypeKind {
731    /// Built-in Substrait type (primitive or parameterized)
732    Builtin(BasicBuiltinType),
733    /// Extension type with optional parameters
734    Extension {
735        /// Extension type name
736        name: String,
737        /// Type parameters
738        parameters: Vec<TypeParameter>,
739    },
740    /// List type with element type
741    List(Box<ConcreteType>),
742    /// Map type with key and value types
743    Map {
744        /// Key type
745        key: Box<ConcreteType>,
746        /// Value type
747        value: Box<ConcreteType>,
748    },
749    /// Struct type (ordered fields without names)
750    Struct(Vec<ConcreteType>),
751    /// Named struct type (nstruct - ordered fields with names)
752    NamedStruct {
753        /// Ordered field names and types. They are in the order they should
754        /// appear in the struct - hence the use of [`IndexMap`].
755        fields: IndexMap<String, ConcreteType>,
756    },
757}
758
759impl fmt::Display for ConcreteTypeKind {
760    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
761        match self {
762            ConcreteTypeKind::Builtin(b) => write!(f, "{b}"),
763            ConcreteTypeKind::Extension { name, parameters } => {
764                write!(f, "{name}")?;
765                write_separated(f, parameters.iter(), "<", ">", ", ")
766            }
767            ConcreteTypeKind::List(elem) => write!(f, "list<{elem}>"),
768            ConcreteTypeKind::Map { key, value } => write!(f, "map<{key}, {value}>"),
769            ConcreteTypeKind::Struct(types) => {
770                write_separated(f, types.iter(), "struct<", ">", ", ")
771            }
772            ConcreteTypeKind::NamedStruct { fields } => {
773                let kvs = fields.iter().map(|(k, v)| KeyValueDisplay(k, v, ": "));
774
775                write_separated(f, kvs, "{", "}", ", ")
776            }
777        }
778    }
779}
780
781/// A concrete, fully-resolved type instance with nullability.
782#[derive(Clone, Debug, PartialEq)]
783pub struct ConcreteType {
784    /// The resolved type shape
785    pub kind: ConcreteTypeKind,
786    /// Whether this type is nullable
787    pub nullable: bool,
788}
789
790impl ConcreteType {
791    /// Create a new builtin scalar type
792    pub fn builtin(builtin_type: BasicBuiltinType, nullable: bool) -> ConcreteType {
793        ConcreteType {
794            kind: ConcreteTypeKind::Builtin(builtin_type),
795            nullable,
796        }
797    }
798
799    /// Create a new extension type reference (without parameters)
800    pub fn extension(name: String, nullable: bool) -> ConcreteType {
801        ConcreteType {
802            kind: ConcreteTypeKind::Extension {
803                name,
804                parameters: Vec::new(),
805            },
806            nullable,
807        }
808    }
809
810    /// Create a new parameterized extension type
811    pub fn extension_with_params(
812        name: String,
813        parameters: Vec<TypeParameter>,
814        nullable: bool,
815    ) -> ConcreteType {
816        ConcreteType {
817            kind: ConcreteTypeKind::Extension { name, parameters },
818            nullable,
819        }
820    }
821
822    /// Create a new list type
823    pub fn list(element_type: ConcreteType, nullable: bool) -> ConcreteType {
824        ConcreteType {
825            kind: ConcreteTypeKind::List(Box::new(element_type)),
826            nullable,
827        }
828    }
829
830    /// Create a new struct type (ordered fields without names)
831    pub fn r#struct(field_types: Vec<ConcreteType>, nullable: bool) -> ConcreteType {
832        ConcreteType {
833            kind: ConcreteTypeKind::Struct(field_types),
834            nullable,
835        }
836    }
837
838    /// Create a new map type
839    pub fn map(key_type: ConcreteType, value_type: ConcreteType, nullable: bool) -> ConcreteType {
840        ConcreteType {
841            kind: ConcreteTypeKind::Map {
842                key: Box::new(key_type),
843                value: Box::new(value_type),
844            },
845            nullable,
846        }
847    }
848
849    /// Create a new named struct type (nstruct - ordered fields with names)
850    pub fn named_struct(fields: IndexMap<String, ConcreteType>, nullable: bool) -> ConcreteType {
851        ConcreteType {
852            kind: ConcreteTypeKind::NamedStruct { fields },
853            nullable,
854        }
855    }
856
857    /// Check if this type (as a function argument) is compatible with another
858    /// type (as an input).
859    ///
860    /// Mainly checks nullability:
861    ///   - `i64?` is compatible with `i64` and `i64?` - both can be passed as
862    ///     arguments
863    ///   - `i64` is compatible with `i64` but NOT `i64?` - you can't pass a
864    ///     nullable type to a function that only accepts non-nullable arguments
865    pub fn is_compatible_with(&self, other: &ConcreteType) -> bool {
866        // Types must match exactly, but nullable types can accept non-nullable values
867        self.kind == other.kind && (self.nullable || !other.nullable)
868    }
869}
870
871impl fmt::Display for ConcreteType {
872    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
873        write!(f, "{}", self.kind)?;
874        if self.nullable {
875            write!(f, "?")?;
876        }
877        Ok(())
878    }
879}
880
881impl From<ConcreteType> for RawType {
882    fn from(val: ConcreteType) -> Self {
883        match val.kind {
884            ConcreteTypeKind::NamedStruct { fields } => {
885                let map = Map::from_iter(
886                    fields
887                        .into_iter()
888                        .map(|(name, ty)| (name, serde_json::Value::String(ty.to_string()))),
889                );
890                RawType::Object(map)
891            }
892            _ => RawType::String(val.to_string()),
893        }
894    }
895}
896
897/// Extract and validate an integer parameter for a built-in type.
898///
899/// For `DECIMAL<10,2>`, this validates that `10` (index 0) and `2` (index 1)
900/// are integers within their required ranges (precision 1-38, scale
901/// 0-precision).
902///
903/// - `type_name`: Type being validated (for error messages, e.g., "DECIMAL")
904/// - `index`: Parameter position (0-based, e.g., 0 for precision, 1 for scale);
905///   needed for error messages
906/// - `param`: The parameter to validate
907/// - `range`: Optional bounds to enforce (e.g., `Some(1..=38)` for precision)
908fn expect_integer_param(
909    type_name: &str,
910    index: usize,
911    param: &TypeExprParam<'_>,
912    range: Option<RangeInclusive<i32>>,
913) -> Result<i32, ExtensionTypeError> {
914    let value = match param {
915        TypeExprParam::Integer(value) => {
916            i32::try_from(*value).map_err(|_| ExtensionTypeError::InvalidParameterValue {
917                type_name: type_name.to_string(),
918                index,
919                value: *value,
920                expected: "an i32",
921            })
922        }
923        _ => Err(ExtensionTypeError::InvalidParameterKind {
924            type_name: type_name.to_string(),
925            index,
926            expected: "an integer",
927        }),
928    }?;
929
930    if let Some(range) = range {
931        if range.contains(&value) {
932            return Ok(value);
933        }
934        return Err(ExtensionTypeError::InvalidParameterRange {
935            type_name: type_name.to_string(),
936            index,
937            value: i64::from(value),
938            expected: range,
939        });
940    }
941
942    Ok(value)
943}
944
945/// Helper function - checks that param length matches expectations, returns
946/// error if not. Assumes a fixed number of expected parameters.
947fn expect_param_len(
948    type_name: &str,
949    params: &[TypeExprParam<'_>],
950    expected: usize,
951) -> Result<(), ExtensionTypeError> {
952    if params.len() != expected {
953        return Err(ExtensionTypeError::InvalidParameterCount {
954            type_name: type_name.to_string(),
955            expected,
956            actual: params.len(),
957        });
958    }
959    Ok(())
960}
961
962/// Helper function - expect a type parameter, and return the [ConcreteType] if it is a [TypeExpr]
963/// or an error if not.
964fn expect_type_argument<'a>(
965    type_name: &str,
966    index: usize,
967    param: TypeExprParam<'a>,
968) -> Result<ConcreteType, ExtensionTypeError> {
969    match param {
970        TypeExprParam::Type(t) => ConcreteType::try_from(t),
971        TypeExprParam::Integer(_) => Err(ExtensionTypeError::InvalidParameterKind {
972            type_name: type_name.to_string(),
973            index,
974            expected: "a type",
975        }),
976    }
977}
978
979impl<'a> TryFrom<TypeExprParam<'a>> for TypeParameter {
980    type Error = ExtensionTypeError;
981
982    fn try_from(param: TypeExprParam<'a>) -> Result<Self, Self::Error> {
983        Ok(match param {
984            TypeExprParam::Integer(v) => TypeParameter::Integer(v),
985            TypeExprParam::Type(t) => TypeParameter::Type(ConcreteType::try_from(t)?),
986        })
987    }
988}
989
990/// Parse a builtin type. Returns an `ExtensionTypeError` if the type name is
991/// matched, but parameters are incorrect; returns `Some(None)` if the type is
992/// not known.
993fn parse_builtin<'a>(
994    display_name: &str,
995    lower_name: &str,
996    params: &[TypeExprParam<'a>],
997) -> Result<Option<BasicBuiltinType>, ExtensionTypeError> {
998    if let Some(builtin) = primitive_builtin(lower_name) {
999        expect_param_len(display_name, params, 0)?;
1000        return Ok(Some(builtin));
1001    }
1002
1003    match lower_name {
1004        // Parameterized builtins
1005        "fixedchar" => {
1006            expect_param_len(display_name, params, 1)?;
1007            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1008            Ok(Some(BasicBuiltinType::FixedChar { length }))
1009        }
1010        "varchar" => {
1011            expect_param_len(display_name, params, 1)?;
1012            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1013            Ok(Some(BasicBuiltinType::VarChar { length }))
1014        }
1015        "fixedbinary" => {
1016            expect_param_len(display_name, params, 1)?;
1017            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1018            Ok(Some(BasicBuiltinType::FixedBinary { length }))
1019        }
1020        "decimal" => {
1021            expect_param_len(display_name, params, 2)?;
1022            let precision = expect_integer_param(display_name, 0, &params[0], Some(1..=38))?;
1023            let scale = expect_integer_param(display_name, 1, &params[1], Some(0..=precision))?;
1024            Ok(Some(BasicBuiltinType::Decimal { precision, scale }))
1025        }
1026        "precisiontime" | "precision_time" => {
1027            expect_param_len(display_name, params, 1)?;
1028            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1029            Ok(Some(BasicBuiltinType::PrecisionTime { precision }))
1030        }
1031        "precision_timestamp" => {
1032            expect_param_len(display_name, params, 1)?;
1033            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1034            Ok(Some(BasicBuiltinType::PrecisionTimestamp { precision }))
1035        }
1036        "precision_timestamp_tz" => {
1037            expect_param_len(display_name, params, 1)?;
1038            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1039            Ok(Some(BasicBuiltinType::PrecisionTimestampTz { precision }))
1040        }
1041        "interval_day" => {
1042            expect_param_len(display_name, params, 1)?;
1043            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=9))?;
1044            Ok(Some(BasicBuiltinType::IntervalDay { precision }))
1045        }
1046        "interval_compound" => {
1047            expect_param_len(display_name, params, 1)?;
1048            let precision = expect_integer_param(display_name, 0, &params[0], None)?;
1049            Ok(Some(BasicBuiltinType::IntervalCompound { precision }))
1050        }
1051        _ => Ok(None),
1052    }
1053}
1054
1055impl<'a> TryFrom<TypeExpr<'a>> for ConcreteType {
1056    type Error = ExtensionTypeError;
1057
1058    fn try_from(parsed_type: TypeExpr<'a>) -> Result<Self, Self::Error> {
1059        match parsed_type {
1060            TypeExpr::Simple(name, params, nullable) => {
1061                let lower = name.to_ascii_lowercase();
1062
1063                match lower.as_str() {
1064                    "list" => {
1065                        expect_param_len(name, &params, 1)?;
1066                        let element =
1067                            expect_type_argument(name, 0, params.into_iter().next().unwrap())?;
1068                        return Ok(ConcreteType::list(element, nullable));
1069                    }
1070                    "map" => {
1071                        expect_param_len(name, &params, 2)?;
1072                        let mut iter = params.into_iter();
1073                        let key = expect_type_argument(name, 0, iter.next().unwrap())?;
1074                        let value = expect_type_argument(name, 1, iter.next().unwrap())?;
1075                        return Ok(ConcreteType::map(key, value, nullable));
1076                    }
1077                    "struct" => {
1078                        let field_types = params
1079                            .into_iter()
1080                            .enumerate()
1081                            .map(|(idx, param)| expect_type_argument(name, idx, param))
1082                            .collect::<Result<Vec<_>, _>>()?;
1083                        return Ok(ConcreteType::r#struct(field_types, nullable));
1084                    }
1085                    _ => {}
1086                }
1087
1088                if let Some(builtin) = parse_builtin(name, lower.as_str(), &params)? {
1089                    return Ok(ConcreteType::builtin(builtin, nullable));
1090                }
1091
1092                // Simple types that aren't builtins are unknown
1093                // Extension types MUST use the u! prefix
1094                Err(ExtensionTypeError::UnknownTypeName {
1095                    name: name.to_string(),
1096                })
1097            }
1098            TypeExpr::UserDefined(name, params, nullable) => {
1099                let parameters = params
1100                    .into_iter()
1101                    .map(TypeParameter::try_from)
1102                    .collect::<Result<Vec<_>, _>>()?;
1103                Ok(ConcreteType::extension_with_params(
1104                    name.to_string(),
1105                    parameters,
1106                    nullable,
1107                ))
1108            }
1109            TypeExpr::TypeVariable(id, nullability) => {
1110                Err(ExtensionTypeError::InvalidAnyTypeVariable { id, nullability })
1111            }
1112        }
1113    }
1114}
1115
1116#[cfg(test)]
1117mod tests {
1118    use super::super::extensions::TypeContext;
1119    use super::*;
1120    use crate::parse::text::simple_extensions::TypeExpr;
1121    use crate::parse::text::simple_extensions::argument::EnumOptions as ParsedEnumOptions;
1122    use crate::text::simple_extensions;
1123    use std::iter::FromIterator;
1124
1125    /// Create a [ConcreteType] from a [BuiltinType]
1126    fn concretize(builtin: BasicBuiltinType) -> ConcreteType {
1127        ConcreteType::builtin(builtin, false)
1128    }
1129
1130    /// Parse a string into a [ConcreteType]
1131    fn parse_type(expr: &str) -> ConcreteType {
1132        let parsed = TypeExpr::parse(expr).unwrap();
1133        ConcreteType::try_from(parsed).unwrap()
1134    }
1135
1136    /// Parse a string into a [ConcreteType], returning the result
1137    fn parse_type_result(expr: &str) -> Result<ConcreteType, ExtensionTypeError> {
1138        let parsed = TypeExpr::parse(expr).unwrap();
1139        ConcreteType::try_from(parsed)
1140    }
1141
1142    /// Parse a string into a builtin [ConcreteType], with no unresolved
1143    /// extension references
1144    fn parse_simple(s: &str) -> ConcreteType {
1145        let parsed = TypeExpr::parse(s).unwrap();
1146
1147        let mut refs = Vec::new();
1148        parsed.visit_references(&mut |name| refs.push(name.to_string()));
1149        assert!(refs.is_empty(), "{s} should not add an extension reference");
1150
1151        ConcreteType::try_from(parsed).unwrap()
1152    }
1153
1154    /// Create a type parameter from a type expression string
1155    fn type_param(expr: &str) -> TypeParameter {
1156        TypeParameter::Type(parse_type(expr))
1157    }
1158
1159    /// Create an extension type
1160    fn extension(name: &str, parameters: Vec<TypeParameter>, nullable: bool) -> ConcreteType {
1161        ConcreteType::extension_with_params(name.to_string(), parameters, nullable)
1162    }
1163
1164    /// Convert a custom type to raw and back, ensuring round-trip consistency
1165    fn round_trip(custom: &CustomType) {
1166        let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into();
1167        let mut ctx = TypeContext::default();
1168        let parsed = Parse::parse(item, &mut ctx).unwrap();
1169        assert_eq!(&parsed, custom);
1170    }
1171
1172    /// Create a raw named struct (e.g. straight from YAML) from field name and
1173    /// type pairs
1174    fn raw_named_struct(fields: &[(&str, &str)]) -> RawType {
1175        let map = Map::from_iter(
1176            fields
1177                .iter()
1178                .map(|(name, ty)| ((*name).into(), serde_json::Value::String((*ty).into()))),
1179        );
1180
1181        RawType::Object(map)
1182    }
1183
1184    #[test]
1185    fn test_builtin_scalar_parsing() {
1186        let cases = vec![
1187            ("bool", Some(BasicBuiltinType::Boolean)),
1188            ("i8", Some(BasicBuiltinType::I8)),
1189            ("i16", Some(BasicBuiltinType::I16)),
1190            ("i32", Some(BasicBuiltinType::I32)),
1191            ("i64", Some(BasicBuiltinType::I64)),
1192            ("fp32", Some(BasicBuiltinType::Fp32)),
1193            ("fp64", Some(BasicBuiltinType::Fp64)),
1194            ("STRING", Some(BasicBuiltinType::String)),
1195            ("binary", Some(BasicBuiltinType::Binary)),
1196            ("uuid", Some(BasicBuiltinType::Uuid)),
1197            ("date", Some(BasicBuiltinType::Date)),
1198            ("interval_year", Some(BasicBuiltinType::IntervalYear)),
1199            ("time", Some(BasicBuiltinType::Time)),
1200            ("timestamp", Some(BasicBuiltinType::Timestamp)),
1201            ("timestamp_tz", Some(BasicBuiltinType::TimestampTz)),
1202            ("invalid", None),
1203        ];
1204
1205        for (input, expected) in cases {
1206            let result = parse_builtin(input, input.to_ascii_lowercase().as_str(), &[]).unwrap();
1207            match expected {
1208                Some(expected_type) => {
1209                    assert_eq!(
1210                        result,
1211                        Some(expected_type),
1212                        "expected builtin type for {input}"
1213                    );
1214                }
1215                None => {
1216                    assert!(result.is_none(), "expected parsing {input} to fail");
1217                }
1218            }
1219        }
1220    }
1221
1222    #[test]
1223    fn test_parameterized_builtin_types() {
1224        let cases = vec![
1225            (
1226                "precisiontime<2>",
1227                concretize(BasicBuiltinType::PrecisionTime { precision: 2 }),
1228            ),
1229            (
1230                "precision_timestamp<1>",
1231                concretize(BasicBuiltinType::PrecisionTimestamp { precision: 1 }),
1232            ),
1233            (
1234                "precision_timestamp_tz<5>",
1235                concretize(BasicBuiltinType::PrecisionTimestampTz { precision: 5 }),
1236            ),
1237            (
1238                "DECIMAL<10,2>",
1239                concretize(BasicBuiltinType::Decimal {
1240                    precision: 10,
1241                    scale: 2,
1242                }),
1243            ),
1244            (
1245                "fixedchar<12>",
1246                concretize(BasicBuiltinType::FixedChar { length: 12 }),
1247            ),
1248            (
1249                "VarChar<42>",
1250                concretize(BasicBuiltinType::VarChar { length: 42 }),
1251            ),
1252            (
1253                "fixedbinary<8>",
1254                concretize(BasicBuiltinType::FixedBinary { length: 8 }),
1255            ),
1256            (
1257                "interval_day<7>",
1258                concretize(BasicBuiltinType::IntervalDay { precision: 7 }),
1259            ),
1260            (
1261                "interval_compound<6>",
1262                concretize(BasicBuiltinType::IntervalCompound { precision: 6 }),
1263            ),
1264        ];
1265
1266        for (expr, expected) in cases {
1267            let found = parse_simple(expr);
1268            assert_eq!(found, expected, "unexpected type for {expr}");
1269        }
1270    }
1271
1272    #[test]
1273    fn test_parameterized_builtin_range_errors() {
1274        use ExtensionTypeError::InvalidParameterRange;
1275
1276        let cases = vec![
1277            ("precisiontime<13>", "precisiontime", 0, 13, 0..=12),
1278            ("precisiontime<-1>", "precisiontime", 0, -1, 0..=12),
1279            (
1280                "precision_timestamp<13>",
1281                "precision_timestamp",
1282                0,
1283                13,
1284                0..=12,
1285            ),
1286            (
1287                "precision_timestamp<-1>",
1288                "precision_timestamp",
1289                0,
1290                -1,
1291                0..=12,
1292            ),
1293            (
1294                "precision_timestamp_tz<20>",
1295                "precision_timestamp_tz",
1296                0,
1297                20,
1298                0..=12,
1299            ),
1300            ("interval_day<10>", "interval_day", 0, 10, 0..=9),
1301            ("DECIMAL<39,0>", "DECIMAL", 0, 39, 1..=38),
1302            ("DECIMAL<0,0>", "DECIMAL", 0, 0, 1..=38),
1303            ("DECIMAL<10,-1>", "DECIMAL", 1, -1, 0..=10),
1304            ("DECIMAL<10,12>", "DECIMAL", 1, 12, 0..=10),
1305        ];
1306
1307        for (expr, expected_type, expected_index, expected_value, expected_range) in cases {
1308            match parse_type_result(expr) {
1309                Ok(value) => panic!("expected error parsing {expr}, got {value:?}"),
1310                Err(InvalidParameterRange {
1311                    type_name,
1312                    index,
1313                    value,
1314                    expected,
1315                }) => {
1316                    assert_eq!(type_name, expected_type, "unexpected type for {expr}");
1317                    assert_eq!(index, expected_index, "unexpected index for {expr}");
1318                    assert_eq!(
1319                        value,
1320                        i64::from(expected_value),
1321                        "unexpected value for {expr}"
1322                    );
1323                    assert_eq!(expected, expected_range, "unexpected range for {expr}");
1324                }
1325                Err(other) => panic!("expected InvalidParameterRange for {expr}, got {other:?}"),
1326            }
1327        }
1328    }
1329
1330    #[test]
1331    fn test_container_types() {
1332        let cases = vec![
1333            (
1334                "List<i32>",
1335                ConcreteType::list(ConcreteType::builtin(BasicBuiltinType::I32, false), false),
1336            ),
1337            (
1338                "List<fp64?>",
1339                ConcreteType::list(ConcreteType::builtin(BasicBuiltinType::Fp64, true), false),
1340            ),
1341            (
1342                "Map?<i64, string?>",
1343                ConcreteType::map(
1344                    ConcreteType::builtin(BasicBuiltinType::I64, false),
1345                    ConcreteType::builtin(BasicBuiltinType::String, true),
1346                    true,
1347                ),
1348            ),
1349            (
1350                "Struct?<i8, string?>",
1351                ConcreteType::r#struct(
1352                    vec![
1353                        ConcreteType::builtin(BasicBuiltinType::I8, false),
1354                        ConcreteType::builtin(BasicBuiltinType::String, true),
1355                    ],
1356                    true,
1357                ),
1358            ),
1359        ];
1360
1361        for (expr, expected) in cases {
1362            assert_eq!(parse_type(expr), expected, "unexpected parse for {expr}");
1363        }
1364    }
1365
1366    #[test]
1367    fn test_extension_types() {
1368        let cases = vec![
1369            (
1370                "u!geo<List<i32>, 10>",
1371                extension(
1372                    "geo",
1373                    vec![type_param("List<i32>"), TypeParameter::Integer(10)],
1374                    false,
1375                ),
1376            ),
1377            (
1378                "u!Geo?<List<i32?>>",
1379                extension("Geo", vec![type_param("List<i32?>")], true),
1380            ),
1381            (
1382                "u!Custom<string?, bool>",
1383                extension(
1384                    "Custom",
1385                    vec![
1386                        type_param("string?"),
1387                        TypeParameter::Type(ConcreteType::builtin(
1388                            BasicBuiltinType::Boolean,
1389                            false,
1390                        )),
1391                    ],
1392                    false,
1393                ),
1394            ),
1395        ];
1396
1397        for (expr, expected) in cases {
1398            assert_eq!(
1399                parse_type(expr),
1400                expected,
1401                "unexpected extension for {expr}"
1402            );
1403        }
1404    }
1405
1406    #[test]
1407    fn test_parameter_type_validation() {
1408        let int_param = ParameterConstraint::Integer {
1409            min: Some(1),
1410            max: Some(10),
1411        };
1412        let enum_param = ParameterConstraint::Enumeration {
1413            options: ParsedEnumOptions::try_from(simple_extensions::EnumOptions(vec![
1414                "OVERFLOW".to_string(),
1415                "ERROR".to_string(),
1416            ]))
1417            .unwrap(),
1418        };
1419
1420        let cases = vec![
1421            (&int_param, Value::Number(5.into()), true),
1422            (&int_param, Value::Number(0.into()), false),
1423            (&int_param, Value::Number(11.into()), false),
1424            (&int_param, Value::String("not a number".into()), false),
1425            (&enum_param, Value::String("OVERFLOW".into()), true),
1426            (&enum_param, Value::String("INVALID".into()), false),
1427        ];
1428
1429        for (param, value, expected) in cases {
1430            assert_eq!(
1431                param.is_valid_value(&value),
1432                expected,
1433                "unexpected validation result for {value:?}"
1434            );
1435        }
1436    }
1437
1438    #[test]
1439    fn test_type_round_trip_display() {
1440        // (example, canonical form)
1441        let cases = vec![
1442            ("i32", "i32"),
1443            ("I64?", "i64?"),
1444            ("list<string>", "list<string>"),
1445            ("List<STRING?>", "list<string?>"),
1446            ("map<i32, list<string>>", "map<i32, list<string>>"),
1447            ("struct<i8, string?>", "struct<i8, string?>"),
1448            (
1449                "Struct<List<i32>, Map<string, list<i64?>>>",
1450                "struct<list<i32>, map<string, list<i64?>>>",
1451            ),
1452            (
1453                "Map<List<I32?>, Struct<string, list<i64?>>>",
1454                "map<list<i32?>, struct<string, list<i64?>>>",
1455            ),
1456            ("u!custom<i32>", "custom<i32>"),
1457        ];
1458
1459        for (input, expected) in cases {
1460            let parsed = TypeExpr::parse(input).unwrap();
1461            let concrete = ConcreteType::try_from(parsed).unwrap();
1462            let actual = concrete.to_string();
1463
1464            assert_eq!(actual, expected, "unexpected display for {input}");
1465        }
1466    }
1467
1468    /// Test that named struct field order preserves the structure order when
1469    /// round-tripping through RawType (Substrait #915).
1470    #[test]
1471    fn test_named_struct_field_order_stability() -> Result<(), ExtensionTypeError> {
1472        let mut raw_fields = Map::new();
1473        raw_fields.insert("beta".to_string(), Value::String("i32".to_string()));
1474        raw_fields.insert("alpha".to_string(), Value::String("string?".to_string()));
1475
1476        let raw = RawType::Object(raw_fields);
1477        let mut ctx = TypeContext::default();
1478        let concrete = Parse::parse(raw, &mut ctx)?;
1479
1480        let round_tripped: RawType = concrete.into();
1481        match round_tripped {
1482            RawType::Object(result_map) => {
1483                let keys: Vec<_> = result_map.keys().collect();
1484                assert_eq!(
1485                    keys,
1486                    vec!["beta", "alpha"],
1487                    "field order should be preserved"
1488                );
1489            }
1490            other => panic!("expected Object, got {other:?}"),
1491        }
1492
1493        Ok(())
1494    }
1495
1496    #[test]
1497    fn test_integer_param_bounds_round_trip() {
1498        let cases = vec![
1499            (
1500                "bounded",
1501                simple_extensions::TypeParamDefsItem {
1502                    name: Some("K".to_string()),
1503                    description: None,
1504                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1505                    min: Some(1.0),
1506                    max: Some(10.0),
1507                    options: None,
1508                    optional: None,
1509                },
1510                Ok((Some(1), Some(10))),
1511            ),
1512            (
1513                "fractional_min",
1514                simple_extensions::TypeParamDefsItem {
1515                    name: Some("K".to_string()),
1516                    description: None,
1517                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1518                    min: Some(1.5),
1519                    max: None,
1520                    options: None,
1521                    optional: None,
1522                },
1523                Err(TypeParamError::InvalidIntegerBounds {
1524                    min: Some(1.5),
1525                    max: None,
1526                }),
1527            ),
1528            (
1529                "fractional_max",
1530                simple_extensions::TypeParamDefsItem {
1531                    name: Some("K".to_string()),
1532                    description: None,
1533                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1534                    min: None,
1535                    max: Some(9.9),
1536                    options: None,
1537                    optional: None,
1538                },
1539                Err(TypeParamError::InvalidIntegerBounds {
1540                    min: None,
1541                    max: Some(9.9),
1542                }),
1543            ),
1544        ];
1545
1546        for (label, item, expected) in cases {
1547            match (TypeParam::try_from(item), expected) {
1548                (Ok(tp), Ok((expected_min, expected_max))) => match tp.param_type {
1549                    ParameterConstraint::Integer { min, max } => {
1550                        assert_eq!(min, expected_min, "min mismatch for {label}");
1551                        assert_eq!(max, expected_max, "max mismatch for {label}");
1552                    }
1553                    _ => panic!("expected integer param type for {label}"),
1554                },
1555                (Err(actual_err), Err(expected_err)) => {
1556                    assert_eq!(actual_err, expected_err, "unexpected error for {label}");
1557                }
1558                (result, expected) => {
1559                    panic!("unexpected result for {label}: got {result:?}, expected {expected:?}")
1560                }
1561            }
1562        }
1563    }
1564
1565    #[test]
1566    fn test_custom_type_round_trip() -> Result<(), ExtensionTypeError> {
1567        let fields = IndexMap::from_iter([
1568            (
1569                "x".to_string(),
1570                ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1571            ),
1572            (
1573                "y".to_string(),
1574                ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1575            ),
1576        ]);
1577
1578        let cases = vec![
1579            CustomType::new(
1580                "AliasType".to_string(),
1581                vec![],
1582                Some(ConcreteType::builtin(BasicBuiltinType::I32, false)),
1583                None,
1584                Some("a test alias type".to_string()),
1585            )?,
1586            CustomType::new(
1587                "Point".to_string(),
1588                vec![TypeParam::new(
1589                    "T".to_string(),
1590                    ParameterConstraint::DataType,
1591                    Some("a numeric value".to_string()),
1592                )],
1593                Some(ConcreteType::named_struct(fields, false)),
1594                None,
1595                None,
1596            )?,
1597        ];
1598
1599        for custom in cases {
1600            round_trip(&custom);
1601        }
1602
1603        Ok(())
1604    }
1605
1606    #[test]
1607    fn test_invalid_type_names() {
1608        let cases = vec![
1609            ("", false),
1610            ("bad name", false),
1611            ("9bad", false),
1612            ("bad-name", false),
1613            ("bad.name", false),
1614            ("GoodName", true),
1615            ("also_good", true),
1616            ("_underscore", true),
1617            ("$dollar", true),
1618            ("CamelCase123", true),
1619        ];
1620
1621        for (name, expected_ok) in cases {
1622            let result = CustomType::validate_name(name);
1623            assert_eq!(
1624                result.is_ok(),
1625                expected_ok,
1626                "unexpected validation for {name}"
1627            );
1628        }
1629    }
1630
1631    #[test]
1632    fn test_ext_type_to_concrete_type() -> Result<(), ExtensionTypeError> {
1633        let cases = vec![
1634            (
1635                "alias",
1636                RawType::String("i32".to_string()),
1637                ConcreteType::builtin(BasicBuiltinType::I32, false),
1638            ),
1639            (
1640                "named_struct",
1641                raw_named_struct(&[("field1", "fp64"), ("field2", "i32?")]),
1642                ConcreteType::named_struct(
1643                    IndexMap::from_iter([
1644                        (
1645                            "field1".to_string(),
1646                            ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1647                        ),
1648                        (
1649                            "field2".to_string(),
1650                            ConcreteType::builtin(BasicBuiltinType::I32, true),
1651                        ),
1652                    ]),
1653                    false,
1654                ),
1655            ),
1656        ];
1657
1658        for (label, raw, expected) in cases {
1659            let mut ctx = TypeContext::default();
1660            let parsed = Parse::parse(raw, &mut ctx)?;
1661            assert_eq!(parsed, expected, "unexpected type for {label}");
1662        }
1663
1664        Ok(())
1665    }
1666
1667    #[test]
1668    fn test_custom_type_parsing() -> Result<(), ExtensionTypeError> {
1669        let cases = vec![
1670            (
1671                "alias",
1672                simple_extensions::SimpleExtensionsTypesItem {
1673                    name: "Alias".to_string(),
1674                    description: Some("Alias type".to_string()),
1675                    parameters: None,
1676                    structure: Some(RawType::String("BINARY".to_string())),
1677                    variadic: None,
1678                },
1679                "Alias",
1680                Some("Alias type"),
1681                Some(ConcreteType::builtin(BasicBuiltinType::Binary, false)),
1682            ),
1683            (
1684                "named_struct",
1685                simple_extensions::SimpleExtensionsTypesItem {
1686                    name: "Point".to_string(),
1687                    description: Some("A 2D point".to_string()),
1688                    parameters: None,
1689                    structure: Some(raw_named_struct(&[("x", "fp64"), ("y", "fp64?")])),
1690                    variadic: None,
1691                },
1692                "Point",
1693                Some("A 2D point"),
1694                Some(ConcreteType::named_struct(
1695                    IndexMap::from_iter([
1696                        (
1697                            "x".to_string(),
1698                            ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1699                        ),
1700                        (
1701                            "y".to_string(),
1702                            ConcreteType::builtin(BasicBuiltinType::Fp64, true),
1703                        ),
1704                    ]),
1705                    false,
1706                )),
1707            ),
1708            (
1709                "no_structure",
1710                simple_extensions::SimpleExtensionsTypesItem {
1711                    name: "Opaque".to_string(),
1712                    description: None,
1713                    parameters: None,
1714                    structure: None,
1715                    variadic: Some(true),
1716                },
1717                "Opaque",
1718                None,
1719                None,
1720            ),
1721        ];
1722
1723        for (label, item, expected_name, expected_description, expected_structure) in cases {
1724            let mut ctx = TypeContext::default();
1725            let parsed = Parse::parse(item, &mut ctx)?;
1726            assert_eq!(parsed.name, expected_name);
1727            assert_eq!(
1728                parsed.description.as_deref(),
1729                expected_description,
1730                "description mismatch for {label}"
1731            );
1732            assert_eq!(
1733                parsed.structure, expected_structure,
1734                "structure mismatch for {label}"
1735            );
1736        }
1737
1738        Ok(())
1739    }
1740}