substrait/parse/text/simple_extensions/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Substrait Extension Registry
4//!
5//! This module provides registries for Substrait extensions:
6//! - **Global Registry**: Immutable, reusable across plans, URI+name based lookup
7//! - **Local Registry**: Per-plan, anchor-based, references Global Registry (TODO)
8//!
9//! Currently only type definitions are supported. Function parsing will be added in a future update.
10//!
11//! This module is only available when the `parse` feature is enabled.
12
13use std::collections::{HashMap, hash_map::Entry};
14
15use super::{ExtensionFile, SimpleExtensions, SimpleExtensionsError, types::CustomType};
16use crate::urn::Urn;
17
18/// Extension Registry that manages Substrait extensions
19///
20/// This registry is immutable and reusable across multiple plans.
21/// It provides URN + name based lookup for extension types. Function parsing will be added in a future update.
22#[derive(Debug)]
23pub struct Registry {
24    /// Pre-validated extension files
25    extensions: HashMap<Urn, SimpleExtensions>,
26}
27
28impl Registry {
29    /// Create a new Global Registry from validated extension files.
30    ///
31    /// Any duplicate URNs will raise an error.
32    pub fn new<I: IntoIterator<Item = ExtensionFile>>(
33        extensions: I,
34    ) -> Result<Self, SimpleExtensionsError> {
35        let mut map = HashMap::new();
36        for ExtensionFile { urn, extension } in extensions {
37            match map.entry(urn.clone()) {
38                Entry::Occupied(_) => return Err(SimpleExtensionsError::DuplicateUrn(urn)),
39                Entry::Vacant(entry) => {
40                    entry.insert(extension);
41                }
42            }
43        }
44        Ok(Self { extensions: map })
45    }
46
47    /// Get an iterator over all extension files in this registry
48    pub fn extensions(&self) -> impl Iterator<Item = (&Urn, &SimpleExtensions)> {
49        self.extensions.iter()
50    }
51
52    /// Create a Global Registry from the built-in core extensions.
53    ///
54    /// Most core extensions are included. Some are skipped due to bugs in the upstream
55    /// YAML files (see <https://github.com/substrait-io/substrait/issues/935>).
56    #[cfg(feature = "extensions")]
57    pub fn from_core_extensions() -> Self {
58        use crate::extensions::EXTENSIONS;
59
60        // Parse the core extensions from the raw extensions format to the parsed format
61        let extensions: HashMap<Urn, SimpleExtensions> = EXTENSIONS
62            .iter()
63            .filter_map(|(orig_urn, simple_extensions)| {
64                // Skip specific core extensions that have bugs (missing u! prefix on type references).
65                // Most core extensions are included; only these problematic ones are filtered out.
66                // See: https://github.com/substrait-io/substrait/issues/935
67                let urn_str = orig_urn.to_string();
68                if urn_str == "extension:io.substrait:extension_types" ||
69                   urn_str == "extension:io.substrait:unknown" {
70                    return None;
71                }
72
73                let ExtensionFile { urn, extension } = ExtensionFile::create(simple_extensions.clone())
74                    .unwrap_or_else(|err| panic!("Core extensions should be valid, but failed to create extension file for {orig_urn}: {err}"));
75                debug_assert_eq!(orig_urn, &urn);
76                Some((urn, extension))
77            })
78            .collect();
79
80        Self { extensions }
81    }
82
83    fn get_extension(&self, urn: &Urn) -> Option<&SimpleExtensions> {
84        self.extensions.get(urn)
85    }
86
87    /// Get a type by URN and name
88    pub fn get_type(&self, urn: &Urn, name: &str) -> Option<&CustomType> {
89        self.get_extension(urn)?.get_type(name)
90    }
91
92    /// Get a scalar function by URN and name.
93    ///
94    /// TODO: Add support for retrieving functions by their full signature shorthand
95    /// (e.g., "add:i32_i32").
96    pub fn get_scalar_function(&self, urn: &Urn, name: &str) -> Option<&super::ScalarFunction> {
97        self.get_extension(urn)?.get_scalar_function(name)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::{ExtensionFile, Registry};
104    use crate::parse::text::simple_extensions::{
105        SimpleExtensionsError, scalar_functions::ScalarFunctionError, types::ExtensionTypeError,
106    };
107    use crate::text::simple_extensions::{SimpleExtensions, SimpleExtensionsTypesItem};
108    use crate::urn::Urn;
109    use std::str::FromStr;
110
111    fn extension_file(urn: &str, type_names: &[&str]) -> ExtensionFile {
112        let types = type_names
113            .iter()
114            .map(|name| SimpleExtensionsTypesItem {
115                name: (*name).to_string(),
116                description: None,
117                parameters: None,
118                structure: None,
119                variadic: None,
120            })
121            .collect();
122
123        let raw = SimpleExtensions {
124            scalar_functions: vec![],
125            aggregate_functions: vec![],
126            window_functions: vec![],
127            dependencies: Default::default(),
128            type_variations: vec![],
129            types,
130            urn: urn.to_string(),
131        };
132
133        ExtensionFile::create(raw).expect("valid extension file")
134    }
135
136    #[test]
137    fn test_registry_iteration() {
138        let urns = vec![
139            "extension:example.com:first",
140            "extension:example.com:second",
141        ];
142        let registry =
143            Registry::new(urns.iter().map(|&urn| extension_file(urn, &["type"]))).unwrap();
144
145        let collected: Vec<&Urn> = registry.extensions().map(|(urn, _)| urn).collect();
146        assert_eq!(collected.len(), 2);
147        for urn in urns {
148            assert!(
149                collected
150                    .iter()
151                    .any(|candidate| candidate.to_string() == urn)
152            );
153        }
154    }
155
156    #[test]
157    fn test_type_lookup() {
158        let urn = Urn::from_str("extension:example.com:test").unwrap();
159        let registry =
160            Registry::new(vec![extension_file(&urn.to_string(), &["test_type"])]).unwrap();
161        let other_urn = Urn::from_str("extension:example.com:other").unwrap();
162
163        let cases = vec![
164            (&urn, "test_type", true),
165            (&urn, "missing", false),
166            (&other_urn, "test_type", false),
167        ];
168
169        for (query_urn, type_name, expected) in cases {
170            assert_eq!(
171                registry.get_type(query_urn, type_name).is_some(),
172                expected,
173                "unexpected lookup result for {query_urn}:{type_name}"
174            );
175        }
176    }
177
178    #[cfg(feature = "extensions")]
179    #[test]
180    fn test_from_core_extensions() {
181        let registry = Registry::from_core_extensions();
182        assert!(registry.extensions().count() > 0);
183
184        // Test that functions_geometry.yaml loaded correctly with its geometry type
185        let urn = Urn::from_str("extension:io.substrait:functions_geometry").unwrap();
186        let core_extension = registry
187            .get_extension(&urn)
188            .expect("Should find functions_geometry extension");
189
190        let geometry_type = core_extension.get_type("geometry");
191        assert!(
192            geometry_type.is_some(),
193            "Should find 'geometry' type in functions_geometry extension"
194        );
195
196        // Also test the registry's get_type method with the actual URN
197        let type_via_registry = registry.get_type(&urn, "geometry");
198        assert!(type_via_registry.is_some());
199
200        // Verify extension_types is skipped due to u! prefix bug (substrait#935)
201        let extension_types_urn = Urn::from_str("extension:io.substrait:extension_types").unwrap();
202        assert!(
203            registry.get_extension(&extension_types_urn).is_none(),
204            "extension_types should be skipped due to missing u! prefix bug"
205        );
206    }
207
208    #[test]
209    fn test_unknown_type_without_prefix_fails() {
210        use crate::text::simple_extensions;
211
212        // Function that references a type without u! prefix - should fail with UnknownTypeName
213        let invalid_extension = SimpleExtensions {
214            scalar_functions: vec![simple_extensions::ScalarFunction {
215                name: "bad_function".to_string(),
216                description: None,
217                impls: vec![simple_extensions::ScalarFunctionImplsItem {
218                    args: None,
219                    options: None,
220                    variadic: None,
221                    session_dependent: None,
222                    deterministic: None,
223                    nullability: None,
224                    return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
225                        "point".to_string(), // Missing u! prefix - this is an error, not NYI
226                    )),
227                    implementation: None,
228                }],
229            }],
230            aggregate_functions: vec![],
231            window_functions: vec![],
232            dependencies: Default::default(),
233            type_variations: vec![],
234            types: vec![],
235            urn: "extension:example.com:invalid".to_string(),
236        };
237
238        let result = ExtensionFile::create(invalid_extension);
239        assert!(
240            result.is_err(),
241            "Should fail when type is missing u! prefix"
242        );
243
244        match result {
245            Err(SimpleExtensionsError::ScalarFunctionError(ScalarFunctionError::TypeError(
246                ExtensionTypeError::UnknownTypeName { name },
247            ))) => {
248                assert_eq!(name, "point");
249            }
250            other => panic!("Expected UnknownTypeName error, got {:?}", other),
251        }
252    }
253
254    /// Helper to create a minimal extension with a scalar function returning a custom type
255    fn extension_with_custom_type_reference(
256        urn: &str,
257        function_name: &str,
258        return_type: &str,
259        defined_types: Vec<&str>,
260    ) -> SimpleExtensions {
261        use crate::text::simple_extensions;
262
263        SimpleExtensions {
264            scalar_functions: vec![simple_extensions::ScalarFunction {
265                name: function_name.to_string(),
266                description: None,
267                impls: vec![simple_extensions::ScalarFunctionImplsItem {
268                    args: None,
269                    options: None,
270                    variadic: None,
271                    session_dependent: None,
272                    deterministic: None,
273                    nullability: None,
274                    return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
275                        return_type.to_string(),
276                    )),
277                    implementation: None,
278                }],
279            }],
280            aggregate_functions: vec![],
281            window_functions: vec![],
282            dependencies: Default::default(),
283            type_variations: vec![],
284            types: defined_types
285                .into_iter()
286                .map(|name| SimpleExtensionsTypesItem {
287                    name: name.to_string(),
288                    description: None,
289                    parameters: None,
290                    structure: None,
291                    variadic: None,
292                })
293                .collect(),
294            urn: urn.to_string(),
295        }
296    }
297
298    #[test]
299    fn test_custom_type_reference_valid() {
300        let extension = extension_with_custom_type_reference(
301            "extension:example.com:valid",
302            "get_point",
303            "u!point",
304            vec!["point"],
305        );
306
307        let result = ExtensionFile::create(extension);
308        assert!(
309            result.is_ok(),
310            "Should succeed when referenced type exists with u! prefix"
311        );
312    }
313
314    #[test]
315    fn test_custom_type_reference_missing() {
316        let extension = extension_with_custom_type_reference(
317            "extension:example.com:invalid",
318            "get_rectangle",
319            "u!rectangle",
320            vec![], // rectangle type not defined
321        );
322
323        let result = ExtensionFile::create(extension);
324        assert!(
325            result.is_err(),
326            "Should fail when referenced type doesn't exist"
327        );
328
329        match result {
330            Err(SimpleExtensionsError::UnresolvedTypeReference { type_name }) => {
331                assert_eq!(type_name, "rectangle");
332            }
333            other => panic!("Expected UnresolvedTypeReference error, got {:?}", other),
334        }
335    }
336
337    #[cfg(feature = "extensions")]
338    #[test]
339    fn test_scalar_function_parses_completely() {
340        use super::super::{
341            argument::ArgumentsItem,
342            scalar_functions::{Impl, NullabilityHandling, Options},
343            types::*,
344        };
345        use crate::parse::Parse;
346        use crate::text::simple_extensions;
347        use std::collections::HashMap;
348
349        let registry = Registry::from_core_extensions();
350        let functions_arithmetic_urn =
351            Urn::from_str("extension:io.substrait:functions_arithmetic").unwrap();
352
353        let add = registry
354            .get_scalar_function(&functions_arithmetic_urn, "add")
355            .expect("add function should exist");
356
357        // Verify function-level metadata
358        assert_eq!(add.name, "add");
359        assert_eq!(add.description, Some("Add two values.".to_string()));
360        assert!(
361            !add.impls.is_empty(),
362            "add should have at least one implementation"
363        );
364
365        // Create the expected first implementation (i8 + i8 -> i8)
366        let mut ctx = super::super::extensions::TypeContext::default();
367        let expected_impl = Impl {
368            args: vec![
369                ArgumentsItem::ValueArgument(
370                    simple_extensions::ValueArg {
371                        name: Some("x".to_string()),
372                        description: None,
373                        value: simple_extensions::Type::String("i8".to_string()),
374                        constant: None,
375                    }
376                    .parse(&mut ctx)
377                    .unwrap(),
378                ),
379                ArgumentsItem::ValueArgument(
380                    simple_extensions::ValueArg {
381                        name: Some("y".to_string()),
382                        description: None,
383                        value: simple_extensions::Type::String("i8".to_string()),
384                        constant: None,
385                    }
386                    .parse(&mut ctx)
387                    .unwrap(),
388                ),
389            ],
390            options: Options({
391                let mut map = HashMap::new();
392                map.insert(
393                    "overflow".to_string(),
394                    vec![
395                        "SILENT".to_string(),
396                        "SATURATE".to_string(),
397                        "ERROR".to_string(),
398                    ],
399                );
400                map
401            }),
402            variadic: None,
403            session_dependent: false,
404            deterministic: true,
405            nullability: NullabilityHandling::Mirror,
406            return_type: ConcreteType {
407                kind: ConcreteTypeKind::Builtin(BasicBuiltinType::I8),
408                nullable: false,
409            },
410            implementation: HashMap::new(),
411        };
412
413        assert_eq!(&add.impls[0], &expected_impl);
414    }
415}