Skip to main content

substrait/parse/text/simple_extensions/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Substrait Extension Registry
4//!
5//! This module provides registries for Substrait extensions:
6//! - **Global Registry**: Immutable, reusable across plans, URI+name based lookup
7//! - **Local Registry**: Per-plan, anchor-based, references Global Registry (TODO)
8//!
9//! Currently only type definitions are supported. Function parsing will be added in a future update.
10//!
11//! This module is only available when the `parse` feature is enabled.
12
13use std::collections::{HashMap, hash_map::Entry};
14
15use super::{ExtensionFile, SimpleExtensions, SimpleExtensionsError, types::CustomType};
16use crate::urn::Urn;
17
18/// Extension Registry that manages Substrait extensions
19///
20/// This registry is immutable and reusable across multiple plans.
21/// It provides URN + name based lookup for extension types. Function parsing will be added in a future update.
22#[derive(Debug)]
23pub struct Registry {
24    /// Pre-validated extension files
25    extensions: HashMap<Urn, SimpleExtensions>,
26}
27
28impl Registry {
29    /// Create a new Global Registry from validated extension files.
30    ///
31    /// Any duplicate URNs will raise an error.
32    pub fn new<I: IntoIterator<Item = ExtensionFile>>(
33        extensions: I,
34    ) -> Result<Self, SimpleExtensionsError> {
35        let mut map = HashMap::new();
36        for ExtensionFile { urn, extension } in extensions {
37            match map.entry(urn.clone()) {
38                Entry::Occupied(_) => return Err(SimpleExtensionsError::DuplicateUrn(urn)),
39                Entry::Vacant(entry) => {
40                    entry.insert(extension);
41                }
42            }
43        }
44        Ok(Self { extensions: map })
45    }
46
47    /// Get an iterator over all extension files in this registry
48    pub fn extensions(&self) -> impl Iterator<Item = (&Urn, &SimpleExtensions)> {
49        self.extensions.iter()
50    }
51
52    /// Create a Global Registry from the built-in core extensions.
53    ///
54    /// Most core extensions are included. Some are skipped due to bugs in the upstream
55    /// YAML files (see <https://github.com/substrait-io/substrait/issues/935>).
56    #[cfg(feature = "extensions")]
57    pub fn from_core_extensions() -> Self {
58        use crate::extensions::EXTENSIONS;
59
60        // Parse the core extensions from the raw extensions format to the parsed format
61        let extensions: HashMap<Urn, SimpleExtensions> = EXTENSIONS
62            .iter()
63            .filter_map(|(orig_urn, simple_extensions)| {
64                // Skip specific core extensions that have bugs (missing u! prefix on type references).
65                // Most core extensions are included; only these problematic ones are filtered out.
66                // See: https://github.com/substrait-io/substrait/issues/935
67                let urn_str = orig_urn.to_string();
68                if urn_str == "extension:io.substrait:extension_types" ||
69                   urn_str == "extension:io.substrait:unknown" {
70                    return None;
71                }
72
73                let ExtensionFile { urn, extension } = ExtensionFile::create(simple_extensions.clone())
74                    .unwrap_or_else(|err| panic!("Core extensions should be valid, but failed to create extension file for {orig_urn}: {err}"));
75                debug_assert_eq!(orig_urn, &urn);
76                Some((urn, extension))
77            })
78            .collect();
79
80        Self { extensions }
81    }
82
83    fn get_extension(&self, urn: &Urn) -> Option<&SimpleExtensions> {
84        self.extensions.get(urn)
85    }
86
87    /// Get a type by URN and name
88    pub fn get_type(&self, urn: &Urn, name: &str) -> Option<&CustomType> {
89        self.get_extension(urn)?.get_type(name)
90    }
91
92    /// Get a scalar function by URN and name.
93    ///
94    /// TODO: Add support for retrieving functions by their full signature shorthand
95    /// (e.g., "add:i32_i32").
96    pub fn get_scalar_function(&self, urn: &Urn, name: &str) -> Option<&super::ScalarFunction> {
97        self.get_extension(urn)?.get_scalar_function(name)
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::{ExtensionFile, Registry};
104    use crate::parse::text::simple_extensions::{
105        SimpleExtensionsError, scalar_functions::ScalarFunctionError, types::ExtensionTypeError,
106    };
107    use crate::text::simple_extensions::{SimpleExtensions, SimpleExtensionsTypesItem};
108    use crate::urn::Urn;
109    use std::str::FromStr;
110
111    fn extension_file(urn: &str, type_names: &[&str]) -> ExtensionFile {
112        let types = type_names
113            .iter()
114            .map(|name| SimpleExtensionsTypesItem {
115                name: (*name).to_string(),
116                description: None,
117                metadata: Default::default(),
118                parameters: None,
119                structure: None,
120                variadic: None,
121            })
122            .collect();
123
124        let raw = SimpleExtensions {
125            scalar_functions: vec![],
126            aggregate_functions: vec![],
127            window_functions: vec![],
128            dependencies: Default::default(),
129            metadata: Default::default(),
130            type_variations: vec![],
131            types,
132            urn: urn.to_string(),
133        };
134
135        ExtensionFile::create(raw).expect("valid extension file")
136    }
137
138    #[test]
139    fn test_registry_iteration() {
140        let urns = vec![
141            "extension:example.com:first",
142            "extension:example.com:second",
143        ];
144        let registry =
145            Registry::new(urns.iter().map(|&urn| extension_file(urn, &["type"]))).unwrap();
146
147        let collected: Vec<&Urn> = registry.extensions().map(|(urn, _)| urn).collect();
148        assert_eq!(collected.len(), 2);
149        for urn in urns {
150            assert!(
151                collected
152                    .iter()
153                    .any(|candidate| candidate.to_string() == urn)
154            );
155        }
156    }
157
158    #[test]
159    fn test_type_lookup() {
160        let urn = Urn::from_str("extension:example.com:test").unwrap();
161        let registry =
162            Registry::new(vec![extension_file(&urn.to_string(), &["test_type"])]).unwrap();
163        let other_urn = Urn::from_str("extension:example.com:other").unwrap();
164
165        let cases = vec![
166            (&urn, "test_type", true),
167            (&urn, "missing", false),
168            (&other_urn, "test_type", false),
169        ];
170
171        for (query_urn, type_name, expected) in cases {
172            assert_eq!(
173                registry.get_type(query_urn, type_name).is_some(),
174                expected,
175                "unexpected lookup result for {query_urn}:{type_name}"
176            );
177        }
178    }
179
180    #[cfg(feature = "extensions")]
181    #[test]
182    fn test_from_core_extensions() {
183        let registry = Registry::from_core_extensions();
184        assert!(registry.extensions().count() > 0);
185
186        // Test that functions_geometry.yaml loaded correctly with its geometry type
187        let urn = Urn::from_str("extension:io.substrait:functions_geometry").unwrap();
188        let core_extension = registry
189            .get_extension(&urn)
190            .expect("Should find functions_geometry extension");
191
192        let geometry_type = core_extension.get_type("geometry");
193        assert!(
194            geometry_type.is_some(),
195            "Should find 'geometry' type in functions_geometry extension"
196        );
197
198        // Also test the registry's get_type method with the actual URN
199        let type_via_registry = registry.get_type(&urn, "geometry");
200        assert!(type_via_registry.is_some());
201
202        // Verify extension_types is skipped due to u! prefix bug (substrait#935)
203        let extension_types_urn = Urn::from_str("extension:io.substrait:extension_types").unwrap();
204        assert!(
205            registry.get_extension(&extension_types_urn).is_none(),
206            "extension_types should be skipped due to missing u! prefix bug"
207        );
208    }
209
210    #[test]
211    fn test_unknown_type_without_prefix_fails() {
212        use crate::text::simple_extensions;
213
214        // Function that references a type without u! prefix - should fail with UnknownTypeName
215        let invalid_extension = SimpleExtensions {
216            scalar_functions: vec![simple_extensions::ScalarFunction {
217                name: "bad_function".to_string(),
218                description: None,
219                metadata: Default::default(),
220                impls: vec![simple_extensions::ScalarFunctionImplsItem {
221                    args: None,
222                    options: None,
223                    variadic: None,
224                    session_dependent: None,
225                    deterministic: None,
226                    nullability: None,
227                    return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
228                        "point".to_string(), // Missing u! prefix - this is an error, not NYI
229                    )),
230                    implementation: None,
231                }],
232            }],
233            aggregate_functions: vec![],
234            window_functions: vec![],
235            dependencies: Default::default(),
236            metadata: Default::default(),
237            type_variations: vec![],
238            types: vec![],
239            urn: "extension:example.com:invalid".to_string(),
240        };
241
242        let result = ExtensionFile::create(invalid_extension);
243        assert!(
244            result.is_err(),
245            "Should fail when type is missing u! prefix"
246        );
247
248        match result {
249            Err(SimpleExtensionsError::ScalarFunctionError(ScalarFunctionError::TypeError(
250                ExtensionTypeError::UnknownTypeName { name },
251            ))) => {
252                assert_eq!(name, "point");
253            }
254            other => panic!("Expected UnknownTypeName error, got {:?}", other),
255        }
256    }
257
258    /// Helper to create a minimal extension with a scalar function returning a custom type
259    fn extension_with_custom_type_reference(
260        urn: &str,
261        function_name: &str,
262        return_type: &str,
263        defined_types: Vec<&str>,
264    ) -> SimpleExtensions {
265        use crate::text::simple_extensions;
266
267        SimpleExtensions {
268            scalar_functions: vec![simple_extensions::ScalarFunction {
269                name: function_name.to_string(),
270                description: None,
271                metadata: Default::default(),
272                impls: vec![simple_extensions::ScalarFunctionImplsItem {
273                    args: None,
274                    options: None,
275                    variadic: None,
276                    session_dependent: None,
277                    deterministic: None,
278                    nullability: None,
279                    return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
280                        return_type.to_string(),
281                    )),
282                    implementation: None,
283                }],
284            }],
285            aggregate_functions: vec![],
286            window_functions: vec![],
287            dependencies: Default::default(),
288            metadata: Default::default(),
289            type_variations: vec![],
290            types: defined_types
291                .into_iter()
292                .map(|name| SimpleExtensionsTypesItem {
293                    name: name.to_string(),
294                    description: None,
295                    metadata: Default::default(),
296                    parameters: None,
297                    structure: None,
298                    variadic: None,
299                })
300                .collect(),
301            urn: urn.to_string(),
302        }
303    }
304
305    #[test]
306    fn test_custom_type_reference_valid() {
307        let extension = extension_with_custom_type_reference(
308            "extension:example.com:valid",
309            "get_point",
310            "u!point",
311            vec!["point"],
312        );
313
314        let result = ExtensionFile::create(extension);
315        assert!(
316            result.is_ok(),
317            "Should succeed when referenced type exists with u! prefix"
318        );
319    }
320
321    #[test]
322    fn test_custom_type_reference_missing() {
323        let extension = extension_with_custom_type_reference(
324            "extension:example.com:invalid",
325            "get_rectangle",
326            "u!rectangle",
327            vec![], // rectangle type not defined
328        );
329
330        let result = ExtensionFile::create(extension);
331        assert!(
332            result.is_err(),
333            "Should fail when referenced type doesn't exist"
334        );
335
336        match result {
337            Err(SimpleExtensionsError::UnresolvedTypeReference { type_name }) => {
338                assert_eq!(type_name, "rectangle");
339            }
340            other => panic!("Expected UnresolvedTypeReference error, got {:?}", other),
341        }
342    }
343
344    #[cfg(feature = "extensions")]
345    #[test]
346    fn test_scalar_function_parses_completely() {
347        use super::super::{
348            argument::ArgumentsItem,
349            scalar_functions::{Impl, NullabilityHandling, Options},
350            types::*,
351        };
352        use crate::parse::Parse;
353        use crate::text::simple_extensions;
354        use std::collections::HashMap;
355
356        let registry = Registry::from_core_extensions();
357        let functions_arithmetic_urn =
358            Urn::from_str("extension:io.substrait:functions_arithmetic").unwrap();
359
360        let add = registry
361            .get_scalar_function(&functions_arithmetic_urn, "add")
362            .expect("add function should exist");
363
364        // Verify function-level metadata
365        assert_eq!(add.name, "add");
366        assert_eq!(add.description, Some("Add two values.".to_string()));
367        assert!(
368            !add.impls.is_empty(),
369            "add should have at least one implementation"
370        );
371
372        // Create the expected first implementation (i8 + i8 -> i8)
373        let mut ctx = super::super::extensions::TypeContext::default();
374        let expected_impl = Impl {
375            args: vec![
376                ArgumentsItem::ValueArgument(
377                    simple_extensions::ValueArg {
378                        name: Some("x".to_string()),
379                        description: None,
380                        value: simple_extensions::Type::String("i8".to_string()),
381                        constant: None,
382                    }
383                    .parse(&mut ctx)
384                    .unwrap(),
385                ),
386                ArgumentsItem::ValueArgument(
387                    simple_extensions::ValueArg {
388                        name: Some("y".to_string()),
389                        description: None,
390                        value: simple_extensions::Type::String("i8".to_string()),
391                        constant: None,
392                    }
393                    .parse(&mut ctx)
394                    .unwrap(),
395                ),
396            ],
397            options: Options({
398                let mut map = HashMap::new();
399                map.insert(
400                    "overflow".to_string(),
401                    vec![
402                        "SILENT".to_string(),
403                        "SATURATE".to_string(),
404                        "ERROR".to_string(),
405                    ],
406                );
407                map
408            }),
409            variadic: None,
410            session_dependent: false,
411            deterministic: true,
412            nullability: NullabilityHandling::Mirror,
413            return_type: ConcreteType {
414                kind: ConcreteTypeKind::Builtin(BasicBuiltinType::I8),
415                nullable: false,
416            },
417            implementation: HashMap::new(),
418        };
419
420        assert_eq!(&add.impls[0], &expected_impl);
421    }
422}