1use std::collections::{HashMap, hash_map::Entry};
14
15use super::{ExtensionFile, SimpleExtensions, SimpleExtensionsError, types::CustomType};
16use crate::urn::Urn;
17
18#[derive(Debug)]
23pub struct Registry {
24 extensions: HashMap<Urn, SimpleExtensions>,
26}
27
28impl Registry {
29 pub fn new<I: IntoIterator<Item = ExtensionFile>>(
33 extensions: I,
34 ) -> Result<Self, SimpleExtensionsError> {
35 let mut map = HashMap::new();
36 for ExtensionFile { urn, extension } in extensions {
37 match map.entry(urn.clone()) {
38 Entry::Occupied(_) => return Err(SimpleExtensionsError::DuplicateUrn(urn)),
39 Entry::Vacant(entry) => {
40 entry.insert(extension);
41 }
42 }
43 }
44 Ok(Self { extensions: map })
45 }
46
47 pub fn extensions(&self) -> impl Iterator<Item = (&Urn, &SimpleExtensions)> {
49 self.extensions.iter()
50 }
51
52 #[cfg(feature = "extensions")]
57 pub fn from_core_extensions() -> Self {
58 use crate::extensions::EXTENSIONS;
59
60 let extensions: HashMap<Urn, SimpleExtensions> = EXTENSIONS
62 .iter()
63 .filter_map(|(orig_urn, simple_extensions)| {
64 let urn_str = orig_urn.to_string();
68 if urn_str == "extension:io.substrait:extension_types" ||
69 urn_str == "extension:io.substrait:unknown" {
70 return None;
71 }
72
73 let ExtensionFile { urn, extension } = ExtensionFile::create(simple_extensions.clone())
74 .unwrap_or_else(|err| panic!("Core extensions should be valid, but failed to create extension file for {orig_urn}: {err}"));
75 debug_assert_eq!(orig_urn, &urn);
76 Some((urn, extension))
77 })
78 .collect();
79
80 Self { extensions }
81 }
82
83 fn get_extension(&self, urn: &Urn) -> Option<&SimpleExtensions> {
84 self.extensions.get(urn)
85 }
86
87 pub fn get_type(&self, urn: &Urn, name: &str) -> Option<&CustomType> {
89 self.get_extension(urn)?.get_type(name)
90 }
91
92 pub fn get_scalar_function(&self, urn: &Urn, name: &str) -> Option<&super::ScalarFunction> {
97 self.get_extension(urn)?.get_scalar_function(name)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::{ExtensionFile, Registry};
104 use crate::parse::text::simple_extensions::{
105 SimpleExtensionsError, scalar_functions::ScalarFunctionError, types::ExtensionTypeError,
106 };
107 use crate::text::simple_extensions::{SimpleExtensions, SimpleExtensionsTypesItem};
108 use crate::urn::Urn;
109 use std::str::FromStr;
110
111 fn extension_file(urn: &str, type_names: &[&str]) -> ExtensionFile {
112 let types = type_names
113 .iter()
114 .map(|name| SimpleExtensionsTypesItem {
115 name: (*name).to_string(),
116 description: None,
117 metadata: Default::default(),
118 parameters: None,
119 structure: None,
120 variadic: None,
121 })
122 .collect();
123
124 let raw = SimpleExtensions {
125 scalar_functions: vec![],
126 aggregate_functions: vec![],
127 window_functions: vec![],
128 dependencies: Default::default(),
129 metadata: Default::default(),
130 type_variations: vec![],
131 types,
132 urn: urn.to_string(),
133 };
134
135 ExtensionFile::create(raw).expect("valid extension file")
136 }
137
138 #[test]
139 fn test_registry_iteration() {
140 let urns = vec![
141 "extension:example.com:first",
142 "extension:example.com:second",
143 ];
144 let registry =
145 Registry::new(urns.iter().map(|&urn| extension_file(urn, &["type"]))).unwrap();
146
147 let collected: Vec<&Urn> = registry.extensions().map(|(urn, _)| urn).collect();
148 assert_eq!(collected.len(), 2);
149 for urn in urns {
150 assert!(
151 collected
152 .iter()
153 .any(|candidate| candidate.to_string() == urn)
154 );
155 }
156 }
157
158 #[test]
159 fn test_type_lookup() {
160 let urn = Urn::from_str("extension:example.com:test").unwrap();
161 let registry =
162 Registry::new(vec![extension_file(&urn.to_string(), &["test_type"])]).unwrap();
163 let other_urn = Urn::from_str("extension:example.com:other").unwrap();
164
165 let cases = vec![
166 (&urn, "test_type", true),
167 (&urn, "missing", false),
168 (&other_urn, "test_type", false),
169 ];
170
171 for (query_urn, type_name, expected) in cases {
172 assert_eq!(
173 registry.get_type(query_urn, type_name).is_some(),
174 expected,
175 "unexpected lookup result for {query_urn}:{type_name}"
176 );
177 }
178 }
179
180 #[cfg(feature = "extensions")]
181 #[test]
182 fn test_from_core_extensions() {
183 let registry = Registry::from_core_extensions();
184 assert!(registry.extensions().count() > 0);
185
186 let urn = Urn::from_str("extension:io.substrait:functions_geometry").unwrap();
188 let core_extension = registry
189 .get_extension(&urn)
190 .expect("Should find functions_geometry extension");
191
192 let geometry_type = core_extension.get_type("geometry");
193 assert!(
194 geometry_type.is_some(),
195 "Should find 'geometry' type in functions_geometry extension"
196 );
197
198 let type_via_registry = registry.get_type(&urn, "geometry");
200 assert!(type_via_registry.is_some());
201
202 let extension_types_urn = Urn::from_str("extension:io.substrait:extension_types").unwrap();
204 assert!(
205 registry.get_extension(&extension_types_urn).is_none(),
206 "extension_types should be skipped due to missing u! prefix bug"
207 );
208 }
209
210 #[test]
211 fn test_unknown_type_without_prefix_fails() {
212 use crate::text::simple_extensions;
213
214 let invalid_extension = SimpleExtensions {
216 scalar_functions: vec![simple_extensions::ScalarFunction {
217 name: "bad_function".to_string(),
218 description: None,
219 metadata: Default::default(),
220 impls: vec![simple_extensions::ScalarFunctionImplsItem {
221 args: None,
222 options: None,
223 variadic: None,
224 session_dependent: None,
225 deterministic: None,
226 nullability: None,
227 return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
228 "point".to_string(), )),
230 implementation: None,
231 }],
232 }],
233 aggregate_functions: vec![],
234 window_functions: vec![],
235 dependencies: Default::default(),
236 metadata: Default::default(),
237 type_variations: vec![],
238 types: vec![],
239 urn: "extension:example.com:invalid".to_string(),
240 };
241
242 let result = ExtensionFile::create(invalid_extension);
243 assert!(
244 result.is_err(),
245 "Should fail when type is missing u! prefix"
246 );
247
248 match result {
249 Err(SimpleExtensionsError::ScalarFunctionError(ScalarFunctionError::TypeError(
250 ExtensionTypeError::UnknownTypeName { name },
251 ))) => {
252 assert_eq!(name, "point");
253 }
254 other => panic!("Expected UnknownTypeName error, got {:?}", other),
255 }
256 }
257
258 fn extension_with_custom_type_reference(
260 urn: &str,
261 function_name: &str,
262 return_type: &str,
263 defined_types: Vec<&str>,
264 ) -> SimpleExtensions {
265 use crate::text::simple_extensions;
266
267 SimpleExtensions {
268 scalar_functions: vec![simple_extensions::ScalarFunction {
269 name: function_name.to_string(),
270 description: None,
271 metadata: Default::default(),
272 impls: vec![simple_extensions::ScalarFunctionImplsItem {
273 args: None,
274 options: None,
275 variadic: None,
276 session_dependent: None,
277 deterministic: None,
278 nullability: None,
279 return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
280 return_type.to_string(),
281 )),
282 implementation: None,
283 }],
284 }],
285 aggregate_functions: vec![],
286 window_functions: vec![],
287 dependencies: Default::default(),
288 metadata: Default::default(),
289 type_variations: vec![],
290 types: defined_types
291 .into_iter()
292 .map(|name| SimpleExtensionsTypesItem {
293 name: name.to_string(),
294 description: None,
295 metadata: Default::default(),
296 parameters: None,
297 structure: None,
298 variadic: None,
299 })
300 .collect(),
301 urn: urn.to_string(),
302 }
303 }
304
305 #[test]
306 fn test_custom_type_reference_valid() {
307 let extension = extension_with_custom_type_reference(
308 "extension:example.com:valid",
309 "get_point",
310 "u!point",
311 vec!["point"],
312 );
313
314 let result = ExtensionFile::create(extension);
315 assert!(
316 result.is_ok(),
317 "Should succeed when referenced type exists with u! prefix"
318 );
319 }
320
321 #[test]
322 fn test_custom_type_reference_missing() {
323 let extension = extension_with_custom_type_reference(
324 "extension:example.com:invalid",
325 "get_rectangle",
326 "u!rectangle",
327 vec![], );
329
330 let result = ExtensionFile::create(extension);
331 assert!(
332 result.is_err(),
333 "Should fail when referenced type doesn't exist"
334 );
335
336 match result {
337 Err(SimpleExtensionsError::UnresolvedTypeReference { type_name }) => {
338 assert_eq!(type_name, "rectangle");
339 }
340 other => panic!("Expected UnresolvedTypeReference error, got {:?}", other),
341 }
342 }
343
344 #[cfg(feature = "extensions")]
345 #[test]
346 fn test_scalar_function_parses_completely() {
347 use super::super::{
348 argument::ArgumentsItem,
349 scalar_functions::{Impl, NullabilityHandling, Options},
350 types::*,
351 };
352 use crate::parse::Parse;
353 use crate::text::simple_extensions;
354 use std::collections::HashMap;
355
356 let registry = Registry::from_core_extensions();
357 let functions_arithmetic_urn =
358 Urn::from_str("extension:io.substrait:functions_arithmetic").unwrap();
359
360 let add = registry
361 .get_scalar_function(&functions_arithmetic_urn, "add")
362 .expect("add function should exist");
363
364 assert_eq!(add.name, "add");
366 assert_eq!(add.description, Some("Add two values.".to_string()));
367 assert!(
368 !add.impls.is_empty(),
369 "add should have at least one implementation"
370 );
371
372 let mut ctx = super::super::extensions::TypeContext::default();
374 let expected_impl = Impl {
375 args: vec![
376 ArgumentsItem::ValueArgument(
377 simple_extensions::ValueArg {
378 name: Some("x".to_string()),
379 description: None,
380 value: simple_extensions::Type::String("i8".to_string()),
381 constant: None,
382 }
383 .parse(&mut ctx)
384 .unwrap(),
385 ),
386 ArgumentsItem::ValueArgument(
387 simple_extensions::ValueArg {
388 name: Some("y".to_string()),
389 description: None,
390 value: simple_extensions::Type::String("i8".to_string()),
391 constant: None,
392 }
393 .parse(&mut ctx)
394 .unwrap(),
395 ),
396 ],
397 options: Options({
398 let mut map = HashMap::new();
399 map.insert(
400 "overflow".to_string(),
401 vec![
402 "SILENT".to_string(),
403 "SATURATE".to_string(),
404 "ERROR".to_string(),
405 ],
406 );
407 map
408 }),
409 variadic: None,
410 session_dependent: false,
411 deterministic: true,
412 nullability: NullabilityHandling::Mirror,
413 return_type: ConcreteType {
414 kind: ConcreteTypeKind::Builtin(BasicBuiltinType::I8),
415 nullable: false,
416 },
417 implementation: HashMap::new(),
418 };
419
420 assert_eq!(&add.impls[0], &expected_impl);
421 }
422}