Skip to main content

substrait/parse/text/simple_extensions/
types.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Concrete type system for function validation in the registry.
4//!
5//! This module provides a clean, type-safe wrapper around Substrait extension types,
6//! separating function signature patterns from concrete argument types.
7
8use super::TypeExpr;
9use super::argument::{
10    EnumOptions as ParsedEnumOptions, EnumOptionsError as ParsedEnumOptionsError,
11};
12use super::extensions::TypeContext;
13use super::type_ast::TypeExprParam;
14use crate::parse::Parse;
15use crate::parse::text::simple_extensions::type_ast::TypeParseError;
16use crate::text::simple_extensions::{
17    EnumOptions as RawEnumOptions, SimpleExtensionsTypesItem, Type as RawType, TypeParamDefs,
18    TypeParamDefsItem, TypeParamDefsItemType,
19};
20use indexmap::IndexMap;
21use serde_json::{Map, Value};
22use std::convert::TryFrom;
23use std::fmt;
24use std::ops::RangeInclusive;
25use thiserror::Error;
26
27/// Write a sequence of items separated by a separator, with a start and end
28/// delimiter.
29///
30/// Start and end are only included in the output if there is at least one item.
31fn write_separated<I, T>(
32    f: &mut fmt::Formatter<'_>,
33    iter: I,
34    start: &str,
35    end: &str,
36    sep: &str,
37) -> fmt::Result
38where
39    I: IntoIterator<Item = T>,
40    T: fmt::Display,
41{
42    let mut it = iter.into_iter();
43    if let Some(first) = it.next() {
44        f.write_str(start)?;
45        write!(f, "{first}")?;
46        for item in it {
47            f.write_str(sep)?;
48            write!(f, "{item}")?;
49        }
50        f.write_str(end)
51    } else {
52        Ok(())
53    }
54}
55
56/// A pair of a key and a value, separated by a separator. For display purposes.
57struct KeyValueDisplay<K, V>(K, V, &'static str);
58
59impl<K, V> fmt::Display for KeyValueDisplay<K, V>
60where
61    K: fmt::Display,
62    V: fmt::Display,
63{
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(f, "{}{}{}", self.0, self.2, self.1)
66    }
67}
68
69/// Non-recursive, built-in Substrait types: types with no parameters (primitive
70/// types), or simple with only primitive / literal parameters.
71#[derive(Clone, Debug, PartialEq)]
72pub enum BasicBuiltinType {
73    /// Boolean type - `bool`
74    Boolean,
75    /// 8-bit signed integer - `i8`
76    I8,
77    /// 16-bit signed integer - `i16`
78    I16,
79    /// 32-bit signed integer - `i32`
80    I32,
81    /// 64-bit signed integer - `i64`
82    I64,
83    /// 32-bit floating point - `fp32`
84    Fp32,
85    /// 64-bit floating point - `fp64`
86    Fp64,
87    /// Variable-length string - `string`
88    String,
89    /// Variable-length binary data - `binary`
90    Binary,
91    /// Naive Timestamp
92    Timestamp,
93    /// Timestamp with time zone - `timestamp_tz`
94    TimestampTz,
95    /// Calendar date - `date`
96    Date,
97    /// Time of day - `time`
98    Time,
99    /// Year-month interval - `interval_year`
100    IntervalYear,
101    /// 128-bit UUID - `uuid`
102    Uuid,
103    /// Fixed-length character string: `FIXEDCHAR<L>`
104    FixedChar {
105        /// Length (number of characters), must be >= 1
106        length: i32,
107    },
108    /// Variable-length character string: `VARCHAR<L>`
109    VarChar {
110        /// Maximum length (number of characters), must be >= 1
111        length: i32,
112    },
113    /// Fixed-length binary data: `FIXEDBINARY<L>`
114    FixedBinary {
115        /// Length (number of bytes), must be >= 1
116        length: i32,
117    },
118    /// Fixed-point decimal: `DECIMAL<P, S>`
119    Decimal {
120        /// Precision (total digits), <= 38
121        precision: i32,
122        /// Scale (digits after decimal point), 0 <= S <= P
123        scale: i32,
124    },
125    /// Time with sub-second precision: `PRECISIONTIME<P>`
126    PrecisionTime {
127        /// Sub-second precision digits (0-12: seconds to picoseconds)
128        precision: i32,
129    },
130    /// Timestamp with sub-second precision: `PRECISIONTIMESTAMP<P>`
131    PrecisionTimestamp {
132        /// Sub-second precision digits (0-12: seconds to picoseconds)
133        precision: i32,
134    },
135    /// Timezone-aware timestamp with precision: `PRECISIONTIMESTAMPTZ<P>`
136    PrecisionTimestampTz {
137        /// Sub-second precision digits (0-12: seconds to picoseconds)
138        precision: i32,
139    },
140    /// Day-time interval: `INTERVAL_DAY<P>`
141    IntervalDay {
142        /// Sub-second precision digits (0-9: seconds to nanoseconds)
143        precision: i32,
144    },
145    /// Compound interval: `INTERVAL_COMPOUND<P>`
146    IntervalCompound {
147        /// Sub-second precision digits
148        precision: i32,
149    },
150}
151
152impl BasicBuiltinType {
153    /// Check if a string is a valid name for a builtin scalar type
154    pub fn is_name(name: &str) -> bool {
155        let lower = name.to_ascii_lowercase();
156        primitive_builtin(&lower).is_some()
157            || matches!(
158                lower.as_str(),
159                "fixedchar"
160                    | "varchar"
161                    | "fixedbinary"
162                    | "decimal"
163                    | "precisiontime"
164                    | "precision_time"
165                    | "precision_timestamp"
166                    | "precision_timestamp_tz"
167                    | "interval_day"
168                    | "interval_compound"
169            )
170    }
171}
172
173impl fmt::Display for BasicBuiltinType {
174    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
175        match self {
176            BasicBuiltinType::Boolean => f.write_str("bool"),
177            BasicBuiltinType::I8 => f.write_str("i8"),
178            BasicBuiltinType::I16 => f.write_str("i16"),
179            BasicBuiltinType::I32 => f.write_str("i32"),
180            BasicBuiltinType::I64 => f.write_str("i64"),
181            BasicBuiltinType::Fp32 => f.write_str("fp32"),
182            BasicBuiltinType::Fp64 => f.write_str("fp64"),
183            BasicBuiltinType::String => f.write_str("string"),
184            BasicBuiltinType::Binary => f.write_str("binary"),
185            BasicBuiltinType::Timestamp => f.write_str("timestamp"),
186            BasicBuiltinType::TimestampTz => f.write_str("timestamp_tz"),
187            BasicBuiltinType::Date => f.write_str("date"),
188            BasicBuiltinType::Time => f.write_str("time"),
189            BasicBuiltinType::IntervalYear => f.write_str("interval_year"),
190            BasicBuiltinType::Uuid => f.write_str("uuid"),
191            BasicBuiltinType::FixedChar { length } => write!(f, "FIXEDCHAR<{length}>"),
192            BasicBuiltinType::VarChar { length } => write!(f, "VARCHAR<{length}>"),
193            BasicBuiltinType::FixedBinary { length } => write!(f, "FIXEDBINARY<{length}>"),
194            BasicBuiltinType::Decimal { precision, scale } => {
195                write!(f, "DECIMAL<{precision}, {scale}>")
196            }
197            BasicBuiltinType::PrecisionTime { precision } => {
198                write!(f, "PRECISIONTIME<{precision}>")
199            }
200            BasicBuiltinType::PrecisionTimestamp { precision } => {
201                write!(f, "PRECISIONTIMESTAMP<{precision}>")
202            }
203            BasicBuiltinType::PrecisionTimestampTz { precision } => {
204                write!(f, "PRECISIONTIMESTAMPTZ<{precision}>")
205            }
206            BasicBuiltinType::IntervalDay { precision } => write!(f, "INTERVAL_DAY<{precision}>"),
207            BasicBuiltinType::IntervalCompound { precision } => {
208                write!(f, "INTERVAL_COMPOUND<{precision}>")
209            }
210        }
211    }
212}
213
214/// A parameter, used in parameterized types
215#[derive(Clone, Debug, PartialEq)]
216pub enum TypeParameter {
217    /// Integer parameter (e.g., precision, scale)
218    Integer(i64),
219    /// Type parameter (nested type)
220    Type(ConcreteType),
221    // TODO: Add support for other type parameters, as described in
222    // https://github.com/substrait-io/substrait/blob/35101020d961eda48f8dd1aafbc794c9e5cac077/proto/substrait/type.proto#L250-L265
223}
224
225impl fmt::Display for TypeParameter {
226    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
227        match self {
228            TypeParameter::Integer(i) => write!(f, "{i}"),
229            TypeParameter::Type(t) => write!(f, "{t}"),
230        }
231    }
232}
233
234/// Parse a primitive (no type parameters) builtin type name
235fn primitive_builtin(lower_name: &str) -> Option<BasicBuiltinType> {
236    match lower_name {
237        "bool" | "boolean" => Some(BasicBuiltinType::Boolean),
238        "i8" => Some(BasicBuiltinType::I8),
239        "i16" => Some(BasicBuiltinType::I16),
240        "i32" => Some(BasicBuiltinType::I32),
241        "i64" => Some(BasicBuiltinType::I64),
242        "fp32" => Some(BasicBuiltinType::Fp32),
243        "fp64" => Some(BasicBuiltinType::Fp64),
244        "string" => Some(BasicBuiltinType::String),
245        "binary" => Some(BasicBuiltinType::Binary),
246        "timestamp" => Some(BasicBuiltinType::Timestamp),
247        "timestamp_tz" => Some(BasicBuiltinType::TimestampTz),
248        "date" => Some(BasicBuiltinType::Date),
249        "time" => Some(BasicBuiltinType::Time),
250        "interval_year" => Some(BasicBuiltinType::IntervalYear),
251        "uuid" => Some(BasicBuiltinType::Uuid),
252        _ => None,
253    }
254}
255
256/// Parameter type information for type definitions
257#[derive(Clone, Debug, PartialEq)]
258pub enum ParameterConstraint {
259    /// Data type parameter
260    DataType,
261    /// Integer parameter with range constraints
262    Integer {
263        /// Minimum value (inclusive), if specified
264        min: Option<i64>,
265        /// Maximum value (inclusive), if specified
266        max: Option<i64>,
267    },
268    /// Enumeration parameter
269    Enumeration {
270        /// Valid enumeration values (validated, deduplicated)
271        options: ParsedEnumOptions,
272    },
273    /// Boolean parameter
274    Boolean,
275    /// String parameter
276    String,
277}
278
279impl ParameterConstraint {
280    /// Convert back to raw TypeParamDefsItemType
281    fn raw_type(&self) -> TypeParamDefsItemType {
282        match self {
283            ParameterConstraint::DataType => TypeParamDefsItemType::DataType,
284            ParameterConstraint::Boolean => TypeParamDefsItemType::Boolean,
285            ParameterConstraint::Integer { .. } => TypeParamDefsItemType::Integer,
286            ParameterConstraint::Enumeration { .. } => TypeParamDefsItemType::Enumeration,
287            ParameterConstraint::String => TypeParamDefsItemType::String,
288        }
289    }
290
291    /// Extract raw bounds for integer parameters (min, max)
292    fn raw_bounds(&self) -> (Option<f64>, Option<f64>) {
293        match self {
294            ParameterConstraint::Integer { min, max } => {
295                (min.map(|i| i as f64), max.map(|i| i as f64))
296            }
297            _ => (None, None),
298        }
299    }
300
301    /// Extract raw enum options for enumeration parameters
302    fn raw_options(&self) -> Option<RawEnumOptions> {
303        match self {
304            ParameterConstraint::Enumeration { options } => Some(options.clone().into()),
305            _ => None,
306        }
307    }
308
309    /// Check if a parameter value is valid for this parameter type
310    pub fn is_valid_value(&self, value: &Value) -> bool {
311        match (self, value) {
312            (ParameterConstraint::DataType, Value::String(_)) => true,
313            (ParameterConstraint::Integer { min, max }, Value::Number(n)) => {
314                if let Some(i) = n.as_i64() {
315                    min.is_none_or(|min_val| i >= min_val) && max.is_none_or(|max_val| i <= max_val)
316                } else {
317                    false
318                }
319            }
320            (ParameterConstraint::Enumeration { options }, Value::String(s)) => options.contains(s),
321            (ParameterConstraint::Boolean, Value::Bool(_)) => true,
322            (ParameterConstraint::String, Value::String(_)) => true,
323            _ => false,
324        }
325    }
326
327    fn from_raw(
328        t: TypeParamDefsItemType,
329        opts: Option<RawEnumOptions>,
330        min: Option<f64>,
331        max: Option<f64>,
332    ) -> Result<Self, TypeParamError> {
333        Ok(match t {
334            TypeParamDefsItemType::DataType => Self::DataType,
335            TypeParamDefsItemType::Boolean => Self::Boolean,
336            TypeParamDefsItemType::Integer => {
337                match (min, max) {
338                    (Some(min_f), _) if min_f.fract() != 0.0 => {
339                        return Err(TypeParamError::InvalidIntegerBounds { min, max });
340                    }
341                    (_, Some(max_f)) if max_f.fract() != 0.0 => {
342                        return Err(TypeParamError::InvalidIntegerBounds { min, max });
343                    }
344                    _ => (),
345                }
346
347                let min_i = min.map(|v| v as i64);
348                let max_i = max.map(|v| v as i64);
349                Self::Integer {
350                    min: min_i,
351                    max: max_i,
352                }
353            }
354            TypeParamDefsItemType::Enumeration => {
355                let options: ParsedEnumOptions =
356                    opts.ok_or(TypeParamError::MissingEnumOptions)?.try_into()?;
357                Self::Enumeration { options }
358            }
359            TypeParamDefsItemType::String => Self::String,
360        })
361    }
362}
363
364/// A validated type parameter with name and constraints
365#[derive(Clone, Debug, PartialEq)]
366pub struct TypeParam {
367    /// Parameter name (e.g., "K" for a type variable)
368    pub name: String,
369    /// Parameter type constraints
370    pub param_type: ParameterConstraint,
371    /// Human-readable description
372    pub description: Option<String>,
373}
374
375impl TypeParam {
376    /// Create a new type parameter
377    pub fn new(name: String, param_type: ParameterConstraint, description: Option<String>) -> Self {
378        Self {
379            name,
380            param_type,
381            description,
382        }
383    }
384
385    /// Check if a parameter value is valid
386    pub fn is_valid_value(&self, value: &Value) -> bool {
387        self.param_type.is_valid_value(value)
388    }
389}
390
391impl TryFrom<TypeParamDefsItem> for TypeParam {
392    type Error = TypeParamError;
393
394    fn try_from(item: TypeParamDefsItem) -> Result<Self, Self::Error> {
395        let name = item.name.ok_or(TypeParamError::MissingName)?;
396        let param_type =
397            ParameterConstraint::from_raw(item.type_, item.options, item.min, item.max)?;
398
399        Ok(Self {
400            name,
401            param_type,
402            description: item.description,
403        })
404    }
405}
406
407/// Error types for extension type validation
408#[derive(Debug, Error, PartialEq)]
409pub enum ExtensionTypeError {
410    /// Extension type name is invalid
411    #[error("{0}")]
412    InvalidName(#[from] InvalidTypeName),
413    /// Any type variable is invalid for concrete types
414    #[error("Any type variable is invalid for concrete types: any{}{}", id, nullability.then_some("?").unwrap_or(""))]
415    InvalidAnyTypeVariable {
416        /// The type variable name
417        id: u32,
418        /// Whether the type variable is nullable
419        nullability: bool,
420    },
421    /// Unknown type name (not a builtin, missing u! prefix for extension types)
422    #[error(
423        "Unknown type name: '{}'. Extension types must use the u! prefix (e.g., u!{})",
424        name,
425        name
426    )]
427    UnknownTypeName {
428        /// The unknown type name
429        name: String,
430    },
431    /// Parameter validation failed
432    #[error("Invalid parameter: {0}")]
433    InvalidParameter(#[from] TypeParamError),
434    /// Field type is invalid
435    #[error("Invalid structure field type: {0}")]
436    InvalidFieldType(String),
437    /// Duplicate struct field name
438    #[error("Duplicate struct field '{field_name}'")]
439    DuplicateFieldName {
440        /// The duplicated field name
441        field_name: String,
442    },
443    /// Type parameter count is invalid for the given type name
444    #[error("Type '{type_name}' expects {expected} parameters, got {actual}")]
445    InvalidParameterCount {
446        /// The type name being validated
447        type_name: String,
448        /// Expected number of parameters
449        expected: usize,
450        /// The actual number of parameters provided
451        actual: usize,
452    },
453    /// Type parameter is of the wrong kind for the given position
454    #[error("Type '{type_name}' parameter {index} must be {expected}")]
455    InvalidParameterKind {
456        /// The type name being validated
457        type_name: String,
458        /// Zero-based index of the offending parameter
459        index: usize,
460        /// Expected parameter kind (e.g., integer, type)
461        expected: &'static str,
462    },
463    /// Provided parameter value does not fit within the expected bounds
464    #[error("Type '{type_name}' parameter {index} value {value} is not within {expected}")]
465    InvalidParameterValue {
466        /// The type name being validated
467        type_name: String,
468        /// Zero-based index of the offending parameter
469        index: usize,
470        /// Provided parameter value
471        value: i64,
472        /// Description of the expected range or type
473        expected: &'static str,
474    },
475    /// Provided parameter value does not fit within the expected bounds
476    #[error("Type '{type_name}' parameter {index} value {value} is out of range {expected:?}")]
477    InvalidParameterRange {
478        /// The type name being validated
479        type_name: String,
480        /// Zero-based index of the offending parameter
481        index: usize,
482        /// Provided parameter value
483        value: i64,
484        /// Description of the expected range or type
485        expected: RangeInclusive<i32>,
486    },
487    /// Structure representation cannot be nullable
488    #[error("Structure representation cannot be nullable: {type_string}")]
489    StructureCannotBeNullable {
490        /// The type string that was nullable
491        type_string: String,
492    },
493    /// Error parsing type
494    #[error("Error parsing type: {0}")]
495    ParseType(#[from] TypeParseError),
496}
497
498/// Error types for TypeParam validation
499#[derive(Debug, Error, PartialEq)]
500pub enum TypeParamError {
501    /// Parameter name is missing
502    #[error("Parameter name is required")]
503    MissingName,
504    /// Integer parameter has non-integer min/max values
505    #[error("Integer parameter has invalid min/max values: min={min:?}, max={max:?}")]
506    InvalidIntegerBounds {
507        /// The invalid minimum value
508        min: Option<f64>,
509        /// The invalid maximum value
510        max: Option<f64>,
511    },
512    /// Enumeration parameter is missing options
513    #[error("Enumeration parameter is missing options")]
514    MissingEnumOptions,
515    /// Enumeration parameter has invalid options
516    #[error("Enumeration parameter has invalid options: {0}")]
517    InvalidEnumOptions(#[from] ParsedEnumOptionsError),
518}
519
520/// A validated Simple Extension type definition
521#[derive(Clone, Debug, PartialEq)]
522pub struct CustomType {
523    /// Type name
524    pub name: String,
525    /// Type parameters (e.g., for generic types)
526    pub parameters: Vec<TypeParam>,
527    /// Concrete structure definition, if any
528    pub structure: Option<ConcreteType>,
529    /// Whether this type can have variadic parameters
530    pub variadic: Option<bool>,
531    /// Human-readable description
532    pub description: Option<String>,
533}
534
535impl CustomType {
536    /// Check if this type name is valid according to Substrait naming rules
537    /// (see the `Identifier` rule in `substrait/grammar/SubstraitLexer.g4`).
538    /// Identifiers are case-insensitive and must start with a an ASCII letter,
539    /// `_`, or `$`, followed by ASCII letters, digits, `_`, or `$`.
540    //
541    // Note: I'm not sure if `$` is actually something we want to allow, or if
542    // `_` is, but it's in the grammar so I'm allowing it here.
543    pub fn validate_name(name: &str) -> Result<(), InvalidTypeName> {
544        let mut chars = name.chars();
545        let first = chars
546            .next()
547            .ok_or_else(|| InvalidTypeName(name.to_string()))?;
548        if !(first.is_ascii_alphabetic() || first == '_' || first == '$') {
549            return Err(InvalidTypeName(name.to_string()));
550        }
551
552        if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '$') {
553            return Err(InvalidTypeName(name.to_string()));
554        }
555
556        Ok(())
557    }
558
559    /// Create a new custom type with validation
560    pub fn new(
561        name: String,
562        parameters: Vec<TypeParam>,
563        structure: Option<ConcreteType>,
564        variadic: Option<bool>,
565        description: Option<String>,
566    ) -> Result<Self, ExtensionTypeError> {
567        Self::validate_name(&name)?;
568
569        Ok(Self {
570            name,
571            parameters,
572            structure,
573            variadic,
574            description,
575        })
576    }
577}
578
579impl From<CustomType> for SimpleExtensionsTypesItem {
580    fn from(value: CustomType) -> Self {
581        // Convert parameters back to TypeParamDefs if any
582        let parameters = if value.parameters.is_empty() {
583            None
584        } else {
585            Some(TypeParamDefs(
586                value
587                    .parameters
588                    .into_iter()
589                    .map(|param| {
590                        let (min, max) = param.param_type.raw_bounds();
591                        TypeParamDefsItem {
592                            name: Some(param.name),
593                            description: param.description,
594                            type_: param.param_type.raw_type(),
595                            min,
596                            max,
597                            options: param.param_type.raw_options(),
598                            // TODO: add this to TypeParamDefsItem parsing, and
599                            // follow it through here. I'm not entirely sure
600                            // when/if it is used.
601                            optional: None,
602                        }
603                    })
604                    .collect(),
605            ))
606        };
607
608        // Convert structure back to Type if any
609        let structure = value.structure.map(Into::into);
610
611        SimpleExtensionsTypesItem {
612            name: value.name,
613            description: value.description,
614            metadata: Default::default(),
615            parameters,
616            structure,
617            variadic: value.variadic,
618        }
619    }
620}
621
622impl Parse<TypeContext> for SimpleExtensionsTypesItem {
623    type Parsed = CustomType;
624    type Error = ExtensionTypeError;
625
626    fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
627        let name = self.name;
628        CustomType::validate_name(&name)?;
629
630        // Register this type as found
631        ctx.found(&name);
632
633        let parameters = if let Some(param_defs) = self.parameters {
634            param_defs
635                .0
636                .into_iter()
637                .map(TypeParam::try_from)
638                .collect::<Result<Vec<_>, _>>()?
639        } else {
640            Vec::new()
641        };
642
643        // Parse structure with context, so referenced extension types are recorded as linked
644        let structure = match self.structure {
645            Some(structure_data) => {
646                let parsed = Parse::parse(structure_data, ctx)?;
647                // TODO: check that the structure is valid. The `Type::Object`
648                // form of `structure_data` is by definition a non-nullable `NSTRUCT`; however,
649                // what types allowed under the `Type::String` form is less clear in the spec:
650                // See https://github.com/substrait-io/substrait/issues/920.
651                Some(parsed)
652            }
653            None => None,
654        };
655
656        Ok(CustomType {
657            name,
658            parameters,
659            structure,
660            variadic: self.variadic,
661            description: self.description,
662        })
663    }
664}
665
666impl Parse<TypeContext> for RawType {
667    type Parsed = ConcreteType;
668    type Error = ExtensionTypeError;
669
670    fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
671        match self {
672            RawType::String(type_string) => {
673                let parsed_type = TypeExpr::parse(&type_string)?;
674                let mut link = |name: &str| ctx.linked(name);
675                parsed_type.visit_references(&mut link);
676                let concrete = ConcreteType::try_from(parsed_type)?;
677                Ok(concrete)
678            }
679            RawType::Object(field_map) => {
680                // Type structure in Substrait must preserve field order (see
681                // substrait-io/substrait#915). The typify generation uses
682                // IndexMap to retain the YAML order so that the order of the
683                // fields in the structure matches that of the extensions file.
684                let mut fields = IndexMap::new();
685
686                for (field_name, field_type_value) in field_map {
687                    let type_string = match field_type_value {
688                        serde_json::Value::String(s) => s,
689                        _ => {
690                            return Err(ExtensionTypeError::InvalidFieldType(
691                                "Struct field types must be strings".to_string(),
692                            ));
693                        }
694                    };
695
696                    let parsed_field_type = TypeExpr::parse(&type_string)?;
697                    let mut link = |name: &str| ctx.linked(name);
698                    parsed_field_type.visit_references(&mut link);
699                    let field_concrete_type = ConcreteType::try_from(parsed_field_type)?;
700
701                    if fields
702                        .insert(field_name.clone(), field_concrete_type)
703                        .is_some()
704                    {
705                        return Err(ExtensionTypeError::DuplicateFieldName { field_name });
706                    }
707                }
708
709                Ok(ConcreteType {
710                    kind: ConcreteTypeKind::NamedStruct { fields },
711                    nullable: false,
712                })
713            }
714        }
715    }
716}
717
718/// Invalid type name error
719#[derive(Debug, Error, PartialEq)]
720#[error("invalid type name `{0}`")]
721pub struct InvalidTypeName(String);
722
723/// The structural kind of a Substrait type (builtin, list, map, etc).
724///
725/// This is almost a complete type, but is missing nullability information. It must be
726/// wrapped in a [`ConcreteType`] to form a complete type with nullable/non-nullable annotation.
727///
728/// Note that this is a recursive type - other than the [BuiltinType]s, the other variants can
729/// have type parameters that are themselves [ConcreteType]s.
730#[derive(Clone, Debug, PartialEq)]
731pub enum ConcreteTypeKind {
732    /// Built-in Substrait type (primitive or parameterized)
733    Builtin(BasicBuiltinType),
734    /// Extension type with optional parameters
735    Extension {
736        /// Extension type name
737        name: String,
738        /// Type parameters
739        parameters: Vec<TypeParameter>,
740    },
741    /// List type with element type
742    List(Box<ConcreteType>),
743    /// Map type with key and value types
744    Map {
745        /// Key type
746        key: Box<ConcreteType>,
747        /// Value type
748        value: Box<ConcreteType>,
749    },
750    /// Struct type (ordered fields without names)
751    Struct(Vec<ConcreteType>),
752    /// Named struct type (nstruct - ordered fields with names)
753    NamedStruct {
754        /// Ordered field names and types. They are in the order they should
755        /// appear in the struct - hence the use of [`IndexMap`].
756        fields: IndexMap<String, ConcreteType>,
757    },
758}
759
760impl fmt::Display for ConcreteTypeKind {
761    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
762        match self {
763            ConcreteTypeKind::Builtin(b) => write!(f, "{b}"),
764            ConcreteTypeKind::Extension { name, parameters } => {
765                write!(f, "{name}")?;
766                write_separated(f, parameters.iter(), "<", ">", ", ")
767            }
768            ConcreteTypeKind::List(elem) => write!(f, "list<{elem}>"),
769            ConcreteTypeKind::Map { key, value } => write!(f, "map<{key}, {value}>"),
770            ConcreteTypeKind::Struct(types) => {
771                write_separated(f, types.iter(), "struct<", ">", ", ")
772            }
773            ConcreteTypeKind::NamedStruct { fields } => {
774                let kvs = fields.iter().map(|(k, v)| KeyValueDisplay(k, v, ": "));
775
776                write_separated(f, kvs, "{", "}", ", ")
777            }
778        }
779    }
780}
781
782/// A concrete, fully-resolved type instance with nullability.
783#[derive(Clone, Debug, PartialEq)]
784pub struct ConcreteType {
785    /// The resolved type shape
786    pub kind: ConcreteTypeKind,
787    /// Whether this type is nullable
788    pub nullable: bool,
789}
790
791impl ConcreteType {
792    /// Create a new builtin scalar type
793    pub fn builtin(builtin_type: BasicBuiltinType, nullable: bool) -> ConcreteType {
794        ConcreteType {
795            kind: ConcreteTypeKind::Builtin(builtin_type),
796            nullable,
797        }
798    }
799
800    /// Create a new extension type reference (without parameters)
801    pub fn extension(name: String, nullable: bool) -> ConcreteType {
802        ConcreteType {
803            kind: ConcreteTypeKind::Extension {
804                name,
805                parameters: Vec::new(),
806            },
807            nullable,
808        }
809    }
810
811    /// Create a new parameterized extension type
812    pub fn extension_with_params(
813        name: String,
814        parameters: Vec<TypeParameter>,
815        nullable: bool,
816    ) -> ConcreteType {
817        ConcreteType {
818            kind: ConcreteTypeKind::Extension { name, parameters },
819            nullable,
820        }
821    }
822
823    /// Create a new list type
824    pub fn list(element_type: ConcreteType, nullable: bool) -> ConcreteType {
825        ConcreteType {
826            kind: ConcreteTypeKind::List(Box::new(element_type)),
827            nullable,
828        }
829    }
830
831    /// Create a new struct type (ordered fields without names)
832    pub fn r#struct(field_types: Vec<ConcreteType>, nullable: bool) -> ConcreteType {
833        ConcreteType {
834            kind: ConcreteTypeKind::Struct(field_types),
835            nullable,
836        }
837    }
838
839    /// Create a new map type
840    pub fn map(key_type: ConcreteType, value_type: ConcreteType, nullable: bool) -> ConcreteType {
841        ConcreteType {
842            kind: ConcreteTypeKind::Map {
843                key: Box::new(key_type),
844                value: Box::new(value_type),
845            },
846            nullable,
847        }
848    }
849
850    /// Create a new named struct type (nstruct - ordered fields with names)
851    pub fn named_struct(fields: IndexMap<String, ConcreteType>, nullable: bool) -> ConcreteType {
852        ConcreteType {
853            kind: ConcreteTypeKind::NamedStruct { fields },
854            nullable,
855        }
856    }
857
858    /// Check if this type (as a function argument) is compatible with another
859    /// type (as an input).
860    ///
861    /// Mainly checks nullability:
862    ///   - `i64?` is compatible with `i64` and `i64?` - both can be passed as
863    ///     arguments
864    ///   - `i64` is compatible with `i64` but NOT `i64?` - you can't pass a
865    ///     nullable type to a function that only accepts non-nullable arguments
866    pub fn is_compatible_with(&self, other: &ConcreteType) -> bool {
867        // Types must match exactly, but nullable types can accept non-nullable values
868        self.kind == other.kind && (self.nullable || !other.nullable)
869    }
870}
871
872impl fmt::Display for ConcreteType {
873    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
874        write!(f, "{}", self.kind)?;
875        if self.nullable {
876            write!(f, "?")?;
877        }
878        Ok(())
879    }
880}
881
882impl From<ConcreteType> for RawType {
883    fn from(val: ConcreteType) -> Self {
884        match val.kind {
885            ConcreteTypeKind::NamedStruct { fields } => {
886                let map = Map::from_iter(
887                    fields
888                        .into_iter()
889                        .map(|(name, ty)| (name, serde_json::Value::String(ty.to_string()))),
890                );
891                RawType::Object(map)
892            }
893            _ => RawType::String(val.to_string()),
894        }
895    }
896}
897
898/// Extract and validate an integer parameter for a built-in type.
899///
900/// For `DECIMAL<10,2>`, this validates that `10` (index 0) and `2` (index 1)
901/// are integers within their required ranges (precision 1-38, scale
902/// 0-precision).
903///
904/// - `type_name`: Type being validated (for error messages, e.g., "DECIMAL")
905/// - `index`: Parameter position (0-based, e.g., 0 for precision, 1 for scale);
906///   needed for error messages
907/// - `param`: The parameter to validate
908/// - `range`: Optional bounds to enforce (e.g., `Some(1..=38)` for precision)
909fn expect_integer_param(
910    type_name: &str,
911    index: usize,
912    param: &TypeExprParam<'_>,
913    range: Option<RangeInclusive<i32>>,
914) -> Result<i32, ExtensionTypeError> {
915    let value = match param {
916        TypeExprParam::Integer(value) => {
917            i32::try_from(*value).map_err(|_| ExtensionTypeError::InvalidParameterValue {
918                type_name: type_name.to_string(),
919                index,
920                value: *value,
921                expected: "an i32",
922            })
923        }
924        _ => Err(ExtensionTypeError::InvalidParameterKind {
925            type_name: type_name.to_string(),
926            index,
927            expected: "an integer",
928        }),
929    }?;
930
931    if let Some(range) = range {
932        if range.contains(&value) {
933            return Ok(value);
934        }
935        return Err(ExtensionTypeError::InvalidParameterRange {
936            type_name: type_name.to_string(),
937            index,
938            value: i64::from(value),
939            expected: range,
940        });
941    }
942
943    Ok(value)
944}
945
946/// Helper function - checks that param length matches expectations, returns
947/// error if not. Assumes a fixed number of expected parameters.
948fn expect_param_len(
949    type_name: &str,
950    params: &[TypeExprParam<'_>],
951    expected: usize,
952) -> Result<(), ExtensionTypeError> {
953    if params.len() != expected {
954        return Err(ExtensionTypeError::InvalidParameterCount {
955            type_name: type_name.to_string(),
956            expected,
957            actual: params.len(),
958        });
959    }
960    Ok(())
961}
962
963/// Helper function - expect a type parameter, and return the [ConcreteType] if it is a [TypeExpr]
964/// or an error if not.
965fn expect_type_argument<'a>(
966    type_name: &str,
967    index: usize,
968    param: TypeExprParam<'a>,
969) -> Result<ConcreteType, ExtensionTypeError> {
970    match param {
971        TypeExprParam::Type(t) => ConcreteType::try_from(t),
972        TypeExprParam::Integer(_) => Err(ExtensionTypeError::InvalidParameterKind {
973            type_name: type_name.to_string(),
974            index,
975            expected: "a type",
976        }),
977    }
978}
979
980impl<'a> TryFrom<TypeExprParam<'a>> for TypeParameter {
981    type Error = ExtensionTypeError;
982
983    fn try_from(param: TypeExprParam<'a>) -> Result<Self, Self::Error> {
984        Ok(match param {
985            TypeExprParam::Integer(v) => TypeParameter::Integer(v),
986            TypeExprParam::Type(t) => TypeParameter::Type(ConcreteType::try_from(t)?),
987        })
988    }
989}
990
991/// Parse a builtin type. Returns an `ExtensionTypeError` if the type name is
992/// matched, but parameters are incorrect; returns `Some(None)` if the type is
993/// not known.
994fn parse_builtin<'a>(
995    display_name: &str,
996    lower_name: &str,
997    params: &[TypeExprParam<'a>],
998) -> Result<Option<BasicBuiltinType>, ExtensionTypeError> {
999    if let Some(builtin) = primitive_builtin(lower_name) {
1000        expect_param_len(display_name, params, 0)?;
1001        return Ok(Some(builtin));
1002    }
1003
1004    match lower_name {
1005        // Parameterized builtins
1006        "fixedchar" => {
1007            expect_param_len(display_name, params, 1)?;
1008            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1009            Ok(Some(BasicBuiltinType::FixedChar { length }))
1010        }
1011        "varchar" => {
1012            expect_param_len(display_name, params, 1)?;
1013            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1014            Ok(Some(BasicBuiltinType::VarChar { length }))
1015        }
1016        "fixedbinary" => {
1017            expect_param_len(display_name, params, 1)?;
1018            let length = expect_integer_param(display_name, 0, &params[0], None)?;
1019            Ok(Some(BasicBuiltinType::FixedBinary { length }))
1020        }
1021        "decimal" => {
1022            expect_param_len(display_name, params, 2)?;
1023            let precision = expect_integer_param(display_name, 0, &params[0], Some(1..=38))?;
1024            let scale = expect_integer_param(display_name, 1, &params[1], Some(0..=precision))?;
1025            Ok(Some(BasicBuiltinType::Decimal { precision, scale }))
1026        }
1027        "precisiontime" | "precision_time" => {
1028            expect_param_len(display_name, params, 1)?;
1029            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1030            Ok(Some(BasicBuiltinType::PrecisionTime { precision }))
1031        }
1032        "precision_timestamp" => {
1033            expect_param_len(display_name, params, 1)?;
1034            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1035            Ok(Some(BasicBuiltinType::PrecisionTimestamp { precision }))
1036        }
1037        "precision_timestamp_tz" => {
1038            expect_param_len(display_name, params, 1)?;
1039            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=12))?;
1040            Ok(Some(BasicBuiltinType::PrecisionTimestampTz { precision }))
1041        }
1042        "interval_day" => {
1043            expect_param_len(display_name, params, 1)?;
1044            let precision = expect_integer_param(display_name, 0, &params[0], Some(0..=9))?;
1045            Ok(Some(BasicBuiltinType::IntervalDay { precision }))
1046        }
1047        "interval_compound" => {
1048            expect_param_len(display_name, params, 1)?;
1049            let precision = expect_integer_param(display_name, 0, &params[0], None)?;
1050            Ok(Some(BasicBuiltinType::IntervalCompound { precision }))
1051        }
1052        _ => Ok(None),
1053    }
1054}
1055
1056impl<'a> TryFrom<TypeExpr<'a>> for ConcreteType {
1057    type Error = ExtensionTypeError;
1058
1059    fn try_from(parsed_type: TypeExpr<'a>) -> Result<Self, Self::Error> {
1060        match parsed_type {
1061            TypeExpr::Simple(name, params, nullable) => {
1062                let lower = name.to_ascii_lowercase();
1063
1064                match lower.as_str() {
1065                    "list" => {
1066                        expect_param_len(name, &params, 1)?;
1067                        let element =
1068                            expect_type_argument(name, 0, params.into_iter().next().unwrap())?;
1069                        return Ok(ConcreteType::list(element, nullable));
1070                    }
1071                    "map" => {
1072                        expect_param_len(name, &params, 2)?;
1073                        let mut iter = params.into_iter();
1074                        let key = expect_type_argument(name, 0, iter.next().unwrap())?;
1075                        let value = expect_type_argument(name, 1, iter.next().unwrap())?;
1076                        return Ok(ConcreteType::map(key, value, nullable));
1077                    }
1078                    "struct" => {
1079                        let field_types = params
1080                            .into_iter()
1081                            .enumerate()
1082                            .map(|(idx, param)| expect_type_argument(name, idx, param))
1083                            .collect::<Result<Vec<_>, _>>()?;
1084                        return Ok(ConcreteType::r#struct(field_types, nullable));
1085                    }
1086                    _ => {}
1087                }
1088
1089                if let Some(builtin) = parse_builtin(name, lower.as_str(), &params)? {
1090                    return Ok(ConcreteType::builtin(builtin, nullable));
1091                }
1092
1093                // Simple types that aren't builtins are unknown
1094                // Extension types MUST use the u! prefix
1095                Err(ExtensionTypeError::UnknownTypeName {
1096                    name: name.to_string(),
1097                })
1098            }
1099            TypeExpr::UserDefined(name, params, nullable) => {
1100                let parameters = params
1101                    .into_iter()
1102                    .map(TypeParameter::try_from)
1103                    .collect::<Result<Vec<_>, _>>()?;
1104                Ok(ConcreteType::extension_with_params(
1105                    name.to_string(),
1106                    parameters,
1107                    nullable,
1108                ))
1109            }
1110            TypeExpr::TypeVariable(id, nullability) => {
1111                Err(ExtensionTypeError::InvalidAnyTypeVariable { id, nullability })
1112            }
1113        }
1114    }
1115}
1116
1117#[cfg(test)]
1118mod tests {
1119    use super::super::extensions::TypeContext;
1120    use super::*;
1121    use crate::parse::text::simple_extensions::TypeExpr;
1122    use crate::parse::text::simple_extensions::argument::EnumOptions as ParsedEnumOptions;
1123    use crate::text::simple_extensions;
1124    use std::iter::FromIterator;
1125
1126    /// Create a [ConcreteType] from a [BuiltinType]
1127    fn concretize(builtin: BasicBuiltinType) -> ConcreteType {
1128        ConcreteType::builtin(builtin, false)
1129    }
1130
1131    /// Parse a string into a [ConcreteType]
1132    fn parse_type(expr: &str) -> ConcreteType {
1133        let parsed = TypeExpr::parse(expr).unwrap();
1134        ConcreteType::try_from(parsed).unwrap()
1135    }
1136
1137    /// Parse a string into a [ConcreteType], returning the result
1138    fn parse_type_result(expr: &str) -> Result<ConcreteType, ExtensionTypeError> {
1139        let parsed = TypeExpr::parse(expr).unwrap();
1140        ConcreteType::try_from(parsed)
1141    }
1142
1143    /// Parse a string into a builtin [ConcreteType], with no unresolved
1144    /// extension references
1145    fn parse_simple(s: &str) -> ConcreteType {
1146        let parsed = TypeExpr::parse(s).unwrap();
1147
1148        let mut refs = Vec::new();
1149        parsed.visit_references(&mut |name| refs.push(name.to_string()));
1150        assert!(refs.is_empty(), "{s} should not add an extension reference");
1151
1152        ConcreteType::try_from(parsed).unwrap()
1153    }
1154
1155    /// Create a type parameter from a type expression string
1156    fn type_param(expr: &str) -> TypeParameter {
1157        TypeParameter::Type(parse_type(expr))
1158    }
1159
1160    /// Create an extension type
1161    fn extension(name: &str, parameters: Vec<TypeParameter>, nullable: bool) -> ConcreteType {
1162        ConcreteType::extension_with_params(name.to_string(), parameters, nullable)
1163    }
1164
1165    /// Convert a custom type to raw and back, ensuring round-trip consistency
1166    fn round_trip(custom: &CustomType) {
1167        let item: simple_extensions::SimpleExtensionsTypesItem = custom.clone().into();
1168        let mut ctx = TypeContext::default();
1169        let parsed = Parse::parse(item, &mut ctx).unwrap();
1170        assert_eq!(&parsed, custom);
1171    }
1172
1173    /// Create a raw named struct (e.g. straight from YAML) from field name and
1174    /// type pairs
1175    fn raw_named_struct(fields: &[(&str, &str)]) -> RawType {
1176        let map = Map::from_iter(
1177            fields
1178                .iter()
1179                .map(|(name, ty)| ((*name).into(), serde_json::Value::String((*ty).into()))),
1180        );
1181
1182        RawType::Object(map)
1183    }
1184
1185    #[test]
1186    fn test_builtin_scalar_parsing() {
1187        let cases = vec![
1188            ("bool", Some(BasicBuiltinType::Boolean)),
1189            ("i8", Some(BasicBuiltinType::I8)),
1190            ("i16", Some(BasicBuiltinType::I16)),
1191            ("i32", Some(BasicBuiltinType::I32)),
1192            ("i64", Some(BasicBuiltinType::I64)),
1193            ("fp32", Some(BasicBuiltinType::Fp32)),
1194            ("fp64", Some(BasicBuiltinType::Fp64)),
1195            ("STRING", Some(BasicBuiltinType::String)),
1196            ("binary", Some(BasicBuiltinType::Binary)),
1197            ("uuid", Some(BasicBuiltinType::Uuid)),
1198            ("date", Some(BasicBuiltinType::Date)),
1199            ("interval_year", Some(BasicBuiltinType::IntervalYear)),
1200            ("time", Some(BasicBuiltinType::Time)),
1201            ("timestamp", Some(BasicBuiltinType::Timestamp)),
1202            ("timestamp_tz", Some(BasicBuiltinType::TimestampTz)),
1203            ("invalid", None),
1204        ];
1205
1206        for (input, expected) in cases {
1207            let result = parse_builtin(input, input.to_ascii_lowercase().as_str(), &[]).unwrap();
1208            match expected {
1209                Some(expected_type) => {
1210                    assert_eq!(
1211                        result,
1212                        Some(expected_type),
1213                        "expected builtin type for {input}"
1214                    );
1215                }
1216                None => {
1217                    assert!(result.is_none(), "expected parsing {input} to fail");
1218                }
1219            }
1220        }
1221    }
1222
1223    #[test]
1224    fn test_parameterized_builtin_types() {
1225        let cases = vec![
1226            (
1227                "precisiontime<2>",
1228                concretize(BasicBuiltinType::PrecisionTime { precision: 2 }),
1229            ),
1230            (
1231                "precision_timestamp<1>",
1232                concretize(BasicBuiltinType::PrecisionTimestamp { precision: 1 }),
1233            ),
1234            (
1235                "precision_timestamp_tz<5>",
1236                concretize(BasicBuiltinType::PrecisionTimestampTz { precision: 5 }),
1237            ),
1238            (
1239                "DECIMAL<10,2>",
1240                concretize(BasicBuiltinType::Decimal {
1241                    precision: 10,
1242                    scale: 2,
1243                }),
1244            ),
1245            (
1246                "fixedchar<12>",
1247                concretize(BasicBuiltinType::FixedChar { length: 12 }),
1248            ),
1249            (
1250                "VarChar<42>",
1251                concretize(BasicBuiltinType::VarChar { length: 42 }),
1252            ),
1253            (
1254                "fixedbinary<8>",
1255                concretize(BasicBuiltinType::FixedBinary { length: 8 }),
1256            ),
1257            (
1258                "interval_day<7>",
1259                concretize(BasicBuiltinType::IntervalDay { precision: 7 }),
1260            ),
1261            (
1262                "interval_compound<6>",
1263                concretize(BasicBuiltinType::IntervalCompound { precision: 6 }),
1264            ),
1265        ];
1266
1267        for (expr, expected) in cases {
1268            let found = parse_simple(expr);
1269            assert_eq!(found, expected, "unexpected type for {expr}");
1270        }
1271    }
1272
1273    #[test]
1274    fn test_parameterized_builtin_range_errors() {
1275        use ExtensionTypeError::InvalidParameterRange;
1276
1277        let cases = vec![
1278            ("precisiontime<13>", "precisiontime", 0, 13, 0..=12),
1279            ("precisiontime<-1>", "precisiontime", 0, -1, 0..=12),
1280            (
1281                "precision_timestamp<13>",
1282                "precision_timestamp",
1283                0,
1284                13,
1285                0..=12,
1286            ),
1287            (
1288                "precision_timestamp<-1>",
1289                "precision_timestamp",
1290                0,
1291                -1,
1292                0..=12,
1293            ),
1294            (
1295                "precision_timestamp_tz<20>",
1296                "precision_timestamp_tz",
1297                0,
1298                20,
1299                0..=12,
1300            ),
1301            ("interval_day<10>", "interval_day", 0, 10, 0..=9),
1302            ("DECIMAL<39,0>", "DECIMAL", 0, 39, 1..=38),
1303            ("DECIMAL<0,0>", "DECIMAL", 0, 0, 1..=38),
1304            ("DECIMAL<10,-1>", "DECIMAL", 1, -1, 0..=10),
1305            ("DECIMAL<10,12>", "DECIMAL", 1, 12, 0..=10),
1306        ];
1307
1308        for (expr, expected_type, expected_index, expected_value, expected_range) in cases {
1309            match parse_type_result(expr) {
1310                Ok(value) => panic!("expected error parsing {expr}, got {value:?}"),
1311                Err(InvalidParameterRange {
1312                    type_name,
1313                    index,
1314                    value,
1315                    expected,
1316                }) => {
1317                    assert_eq!(type_name, expected_type, "unexpected type for {expr}");
1318                    assert_eq!(index, expected_index, "unexpected index for {expr}");
1319                    assert_eq!(
1320                        value,
1321                        i64::from(expected_value),
1322                        "unexpected value for {expr}"
1323                    );
1324                    assert_eq!(expected, expected_range, "unexpected range for {expr}");
1325                }
1326                Err(other) => panic!("expected InvalidParameterRange for {expr}, got {other:?}"),
1327            }
1328        }
1329    }
1330
1331    #[test]
1332    fn test_container_types() {
1333        let cases = vec![
1334            (
1335                "List<i32>",
1336                ConcreteType::list(ConcreteType::builtin(BasicBuiltinType::I32, false), false),
1337            ),
1338            (
1339                "List<fp64?>",
1340                ConcreteType::list(ConcreteType::builtin(BasicBuiltinType::Fp64, true), false),
1341            ),
1342            (
1343                "Map?<i64, string?>",
1344                ConcreteType::map(
1345                    ConcreteType::builtin(BasicBuiltinType::I64, false),
1346                    ConcreteType::builtin(BasicBuiltinType::String, true),
1347                    true,
1348                ),
1349            ),
1350            (
1351                "Struct?<i8, string?>",
1352                ConcreteType::r#struct(
1353                    vec![
1354                        ConcreteType::builtin(BasicBuiltinType::I8, false),
1355                        ConcreteType::builtin(BasicBuiltinType::String, true),
1356                    ],
1357                    true,
1358                ),
1359            ),
1360        ];
1361
1362        for (expr, expected) in cases {
1363            assert_eq!(parse_type(expr), expected, "unexpected parse for {expr}");
1364        }
1365    }
1366
1367    #[test]
1368    fn test_extension_types() {
1369        let cases = vec![
1370            (
1371                "u!geo<List<i32>, 10>",
1372                extension(
1373                    "geo",
1374                    vec![type_param("List<i32>"), TypeParameter::Integer(10)],
1375                    false,
1376                ),
1377            ),
1378            (
1379                "u!Geo?<List<i32?>>",
1380                extension("Geo", vec![type_param("List<i32?>")], true),
1381            ),
1382            (
1383                "u!Custom<string?, bool>",
1384                extension(
1385                    "Custom",
1386                    vec![
1387                        type_param("string?"),
1388                        TypeParameter::Type(ConcreteType::builtin(
1389                            BasicBuiltinType::Boolean,
1390                            false,
1391                        )),
1392                    ],
1393                    false,
1394                ),
1395            ),
1396        ];
1397
1398        for (expr, expected) in cases {
1399            assert_eq!(
1400                parse_type(expr),
1401                expected,
1402                "unexpected extension for {expr}"
1403            );
1404        }
1405    }
1406
1407    #[test]
1408    fn test_parameter_type_validation() {
1409        let int_param = ParameterConstraint::Integer {
1410            min: Some(1),
1411            max: Some(10),
1412        };
1413        let enum_param = ParameterConstraint::Enumeration {
1414            options: ParsedEnumOptions::try_from(simple_extensions::EnumOptions(vec![
1415                "OVERFLOW".to_string(),
1416                "ERROR".to_string(),
1417            ]))
1418            .unwrap(),
1419        };
1420
1421        let cases = vec![
1422            (&int_param, Value::Number(5.into()), true),
1423            (&int_param, Value::Number(0.into()), false),
1424            (&int_param, Value::Number(11.into()), false),
1425            (&int_param, Value::String("not a number".into()), false),
1426            (&enum_param, Value::String("OVERFLOW".into()), true),
1427            (&enum_param, Value::String("INVALID".into()), false),
1428        ];
1429
1430        for (param, value, expected) in cases {
1431            assert_eq!(
1432                param.is_valid_value(&value),
1433                expected,
1434                "unexpected validation result for {value:?}"
1435            );
1436        }
1437    }
1438
1439    #[test]
1440    fn test_type_round_trip_display() {
1441        // (example, canonical form)
1442        let cases = vec![
1443            ("i32", "i32"),
1444            ("I64?", "i64?"),
1445            ("list<string>", "list<string>"),
1446            ("List<STRING?>", "list<string?>"),
1447            ("map<i32, list<string>>", "map<i32, list<string>>"),
1448            ("struct<i8, string?>", "struct<i8, string?>"),
1449            (
1450                "Struct<List<i32>, Map<string, list<i64?>>>",
1451                "struct<list<i32>, map<string, list<i64?>>>",
1452            ),
1453            (
1454                "Map<List<I32?>, Struct<string, list<i64?>>>",
1455                "map<list<i32?>, struct<string, list<i64?>>>",
1456            ),
1457            ("u!custom<i32>", "custom<i32>"),
1458        ];
1459
1460        for (input, expected) in cases {
1461            let parsed = TypeExpr::parse(input).unwrap();
1462            let concrete = ConcreteType::try_from(parsed).unwrap();
1463            let actual = concrete.to_string();
1464
1465            assert_eq!(actual, expected, "unexpected display for {input}");
1466        }
1467    }
1468
1469    /// Test that named struct field order preserves the structure order when
1470    /// round-tripping through RawType (Substrait #915).
1471    #[test]
1472    fn test_named_struct_field_order_stability() -> Result<(), ExtensionTypeError> {
1473        let mut raw_fields = Map::new();
1474        raw_fields.insert("beta".to_string(), Value::String("i32".to_string()));
1475        raw_fields.insert("alpha".to_string(), Value::String("string?".to_string()));
1476
1477        let raw = RawType::Object(raw_fields);
1478        let mut ctx = TypeContext::default();
1479        let concrete = Parse::parse(raw, &mut ctx)?;
1480
1481        let round_tripped: RawType = concrete.into();
1482        match round_tripped {
1483            RawType::Object(result_map) => {
1484                let keys: Vec<_> = result_map.keys().collect();
1485                assert_eq!(
1486                    keys,
1487                    vec!["beta", "alpha"],
1488                    "field order should be preserved"
1489                );
1490            }
1491            other => panic!("expected Object, got {other:?}"),
1492        }
1493
1494        Ok(())
1495    }
1496
1497    #[test]
1498    fn test_integer_param_bounds_round_trip() {
1499        let cases = vec![
1500            (
1501                "bounded",
1502                simple_extensions::TypeParamDefsItem {
1503                    name: Some("K".to_string()),
1504                    description: None,
1505                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1506                    min: Some(1.0),
1507                    max: Some(10.0),
1508                    options: None,
1509                    optional: None,
1510                },
1511                Ok((Some(1), Some(10))),
1512            ),
1513            (
1514                "fractional_min",
1515                simple_extensions::TypeParamDefsItem {
1516                    name: Some("K".to_string()),
1517                    description: None,
1518                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1519                    min: Some(1.5),
1520                    max: None,
1521                    options: None,
1522                    optional: None,
1523                },
1524                Err(TypeParamError::InvalidIntegerBounds {
1525                    min: Some(1.5),
1526                    max: None,
1527                }),
1528            ),
1529            (
1530                "fractional_max",
1531                simple_extensions::TypeParamDefsItem {
1532                    name: Some("K".to_string()),
1533                    description: None,
1534                    type_: simple_extensions::TypeParamDefsItemType::Integer,
1535                    min: None,
1536                    max: Some(9.9),
1537                    options: None,
1538                    optional: None,
1539                },
1540                Err(TypeParamError::InvalidIntegerBounds {
1541                    min: None,
1542                    max: Some(9.9),
1543                }),
1544            ),
1545        ];
1546
1547        for (label, item, expected) in cases {
1548            match (TypeParam::try_from(item), expected) {
1549                (Ok(tp), Ok((expected_min, expected_max))) => match tp.param_type {
1550                    ParameterConstraint::Integer { min, max } => {
1551                        assert_eq!(min, expected_min, "min mismatch for {label}");
1552                        assert_eq!(max, expected_max, "max mismatch for {label}");
1553                    }
1554                    _ => panic!("expected integer param type for {label}"),
1555                },
1556                (Err(actual_err), Err(expected_err)) => {
1557                    assert_eq!(actual_err, expected_err, "unexpected error for {label}");
1558                }
1559                (result, expected) => {
1560                    panic!("unexpected result for {label}: got {result:?}, expected {expected:?}")
1561                }
1562            }
1563        }
1564    }
1565
1566    #[test]
1567    fn test_custom_type_round_trip() -> Result<(), ExtensionTypeError> {
1568        let fields = IndexMap::from_iter([
1569            (
1570                "x".to_string(),
1571                ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1572            ),
1573            (
1574                "y".to_string(),
1575                ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1576            ),
1577        ]);
1578
1579        let cases = vec![
1580            CustomType::new(
1581                "AliasType".to_string(),
1582                vec![],
1583                Some(ConcreteType::builtin(BasicBuiltinType::I32, false)),
1584                None,
1585                Some("a test alias type".to_string()),
1586            )?,
1587            CustomType::new(
1588                "Point".to_string(),
1589                vec![TypeParam::new(
1590                    "T".to_string(),
1591                    ParameterConstraint::DataType,
1592                    Some("a numeric value".to_string()),
1593                )],
1594                Some(ConcreteType::named_struct(fields, false)),
1595                None,
1596                None,
1597            )?,
1598        ];
1599
1600        for custom in cases {
1601            round_trip(&custom);
1602        }
1603
1604        Ok(())
1605    }
1606
1607    #[test]
1608    fn test_invalid_type_names() {
1609        let cases = vec![
1610            ("", false),
1611            ("bad name", false),
1612            ("9bad", false),
1613            ("bad-name", false),
1614            ("bad.name", false),
1615            ("GoodName", true),
1616            ("also_good", true),
1617            ("_underscore", true),
1618            ("$dollar", true),
1619            ("CamelCase123", true),
1620        ];
1621
1622        for (name, expected_ok) in cases {
1623            let result = CustomType::validate_name(name);
1624            assert_eq!(
1625                result.is_ok(),
1626                expected_ok,
1627                "unexpected validation for {name}"
1628            );
1629        }
1630    }
1631
1632    #[test]
1633    fn test_ext_type_to_concrete_type() -> Result<(), ExtensionTypeError> {
1634        let cases = vec![
1635            (
1636                "alias",
1637                RawType::String("i32".to_string()),
1638                ConcreteType::builtin(BasicBuiltinType::I32, false),
1639            ),
1640            (
1641                "named_struct",
1642                raw_named_struct(&[("field1", "fp64"), ("field2", "i32?")]),
1643                ConcreteType::named_struct(
1644                    IndexMap::from_iter([
1645                        (
1646                            "field1".to_string(),
1647                            ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1648                        ),
1649                        (
1650                            "field2".to_string(),
1651                            ConcreteType::builtin(BasicBuiltinType::I32, true),
1652                        ),
1653                    ]),
1654                    false,
1655                ),
1656            ),
1657        ];
1658
1659        for (label, raw, expected) in cases {
1660            let mut ctx = TypeContext::default();
1661            let parsed = Parse::parse(raw, &mut ctx)?;
1662            assert_eq!(parsed, expected, "unexpected type for {label}");
1663        }
1664
1665        Ok(())
1666    }
1667
1668    #[test]
1669    fn test_custom_type_parsing() -> Result<(), ExtensionTypeError> {
1670        let cases = vec![
1671            (
1672                "alias",
1673                simple_extensions::SimpleExtensionsTypesItem {
1674                    name: "Alias".to_string(),
1675                    description: Some("Alias type".to_string()),
1676                    metadata: Default::default(),
1677                    parameters: None,
1678                    structure: Some(RawType::String("BINARY".to_string())),
1679                    variadic: None,
1680                },
1681                "Alias",
1682                Some("Alias type"),
1683                Some(ConcreteType::builtin(BasicBuiltinType::Binary, false)),
1684            ),
1685            (
1686                "named_struct",
1687                simple_extensions::SimpleExtensionsTypesItem {
1688                    name: "Point".to_string(),
1689                    description: Some("A 2D point".to_string()),
1690                    metadata: Default::default(),
1691                    parameters: None,
1692                    structure: Some(raw_named_struct(&[("x", "fp64"), ("y", "fp64?")])),
1693                    variadic: None,
1694                },
1695                "Point",
1696                Some("A 2D point"),
1697                Some(ConcreteType::named_struct(
1698                    IndexMap::from_iter([
1699                        (
1700                            "x".to_string(),
1701                            ConcreteType::builtin(BasicBuiltinType::Fp64, false),
1702                        ),
1703                        (
1704                            "y".to_string(),
1705                            ConcreteType::builtin(BasicBuiltinType::Fp64, true),
1706                        ),
1707                    ]),
1708                    false,
1709                )),
1710            ),
1711            (
1712                "no_structure",
1713                simple_extensions::SimpleExtensionsTypesItem {
1714                    name: "Opaque".to_string(),
1715                    description: None,
1716                    metadata: Default::default(),
1717                    parameters: None,
1718                    structure: None,
1719                    variadic: Some(true),
1720                },
1721                "Opaque",
1722                None,
1723                None,
1724            ),
1725        ];
1726
1727        for (label, item, expected_name, expected_description, expected_structure) in cases {
1728            let mut ctx = TypeContext::default();
1729            let parsed = Parse::parse(item, &mut ctx)?;
1730            assert_eq!(parsed.name, expected_name);
1731            assert_eq!(
1732                parsed.description.as_deref(),
1733                expected_description,
1734                "description mismatch for {label}"
1735            );
1736            assert_eq!(
1737                parsed.structure, expected_structure,
1738                "structure mismatch for {label}"
1739            );
1740        }
1741
1742        Ok(())
1743    }
1744}