1use std::collections::{HashMap, hash_map::Entry};
14
15use super::{ExtensionFile, SimpleExtensions, SimpleExtensionsError, types::CustomType};
16use crate::urn::Urn;
17
18#[derive(Debug)]
23pub struct Registry {
24 extensions: HashMap<Urn, SimpleExtensions>,
26}
27
28impl Registry {
29 pub fn new<I: IntoIterator<Item = ExtensionFile>>(
33 extensions: I,
34 ) -> Result<Self, SimpleExtensionsError> {
35 let mut map = HashMap::new();
36 for ExtensionFile { urn, extension } in extensions {
37 match map.entry(urn.clone()) {
38 Entry::Occupied(_) => return Err(SimpleExtensionsError::DuplicateUrn(urn)),
39 Entry::Vacant(entry) => {
40 entry.insert(extension);
41 }
42 }
43 }
44 Ok(Self { extensions: map })
45 }
46
47 pub fn extensions(&self) -> impl Iterator<Item = (&Urn, &SimpleExtensions)> {
49 self.extensions.iter()
50 }
51
52 #[cfg(feature = "extensions")]
57 pub fn from_core_extensions() -> Self {
58 use crate::extensions::EXTENSIONS;
59
60 let extensions: HashMap<Urn, SimpleExtensions> = EXTENSIONS
62 .iter()
63 .filter_map(|(orig_urn, simple_extensions)| {
64 let urn_str = orig_urn.to_string();
68 if urn_str == "extension:io.substrait:extension_types" ||
69 urn_str == "extension:io.substrait:unknown" {
70 return None;
71 }
72
73 let ExtensionFile { urn, extension } = ExtensionFile::create(simple_extensions.clone())
74 .unwrap_or_else(|err| panic!("Core extensions should be valid, but failed to create extension file for {orig_urn}: {err}"));
75 debug_assert_eq!(orig_urn, &urn);
76 Some((urn, extension))
77 })
78 .collect();
79
80 Self { extensions }
81 }
82
83 fn get_extension(&self, urn: &Urn) -> Option<&SimpleExtensions> {
84 self.extensions.get(urn)
85 }
86
87 pub fn get_type(&self, urn: &Urn, name: &str) -> Option<&CustomType> {
89 self.get_extension(urn)?.get_type(name)
90 }
91
92 pub fn get_scalar_function(&self, urn: &Urn, name: &str) -> Option<&super::ScalarFunction> {
97 self.get_extension(urn)?.get_scalar_function(name)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::{ExtensionFile, Registry};
104 use crate::parse::text::simple_extensions::{
105 SimpleExtensionsError, scalar_functions::ScalarFunctionError, types::ExtensionTypeError,
106 };
107 use crate::text::simple_extensions::{SimpleExtensions, SimpleExtensionsTypesItem};
108 use crate::urn::Urn;
109 use std::str::FromStr;
110
111 fn extension_file(urn: &str, type_names: &[&str]) -> ExtensionFile {
112 let types = type_names
113 .iter()
114 .map(|name| SimpleExtensionsTypesItem {
115 name: (*name).to_string(),
116 description: None,
117 parameters: None,
118 structure: None,
119 variadic: None,
120 })
121 .collect();
122
123 let raw = SimpleExtensions {
124 scalar_functions: vec![],
125 aggregate_functions: vec![],
126 window_functions: vec![],
127 dependencies: Default::default(),
128 type_variations: vec![],
129 types,
130 urn: urn.to_string(),
131 };
132
133 ExtensionFile::create(raw).expect("valid extension file")
134 }
135
136 #[test]
137 fn test_registry_iteration() {
138 let urns = vec![
139 "extension:example.com:first",
140 "extension:example.com:second",
141 ];
142 let registry =
143 Registry::new(urns.iter().map(|&urn| extension_file(urn, &["type"]))).unwrap();
144
145 let collected: Vec<&Urn> = registry.extensions().map(|(urn, _)| urn).collect();
146 assert_eq!(collected.len(), 2);
147 for urn in urns {
148 assert!(
149 collected
150 .iter()
151 .any(|candidate| candidate.to_string() == urn)
152 );
153 }
154 }
155
156 #[test]
157 fn test_type_lookup() {
158 let urn = Urn::from_str("extension:example.com:test").unwrap();
159 let registry =
160 Registry::new(vec![extension_file(&urn.to_string(), &["test_type"])]).unwrap();
161 let other_urn = Urn::from_str("extension:example.com:other").unwrap();
162
163 let cases = vec![
164 (&urn, "test_type", true),
165 (&urn, "missing", false),
166 (&other_urn, "test_type", false),
167 ];
168
169 for (query_urn, type_name, expected) in cases {
170 assert_eq!(
171 registry.get_type(query_urn, type_name).is_some(),
172 expected,
173 "unexpected lookup result for {query_urn}:{type_name}"
174 );
175 }
176 }
177
178 #[cfg(feature = "extensions")]
179 #[test]
180 fn test_from_core_extensions() {
181 let registry = Registry::from_core_extensions();
182 assert!(registry.extensions().count() > 0);
183
184 let urn = Urn::from_str("extension:io.substrait:functions_geometry").unwrap();
186 let core_extension = registry
187 .get_extension(&urn)
188 .expect("Should find functions_geometry extension");
189
190 let geometry_type = core_extension.get_type("geometry");
191 assert!(
192 geometry_type.is_some(),
193 "Should find 'geometry' type in functions_geometry extension"
194 );
195
196 let type_via_registry = registry.get_type(&urn, "geometry");
198 assert!(type_via_registry.is_some());
199
200 let extension_types_urn = Urn::from_str("extension:io.substrait:extension_types").unwrap();
202 assert!(
203 registry.get_extension(&extension_types_urn).is_none(),
204 "extension_types should be skipped due to missing u! prefix bug"
205 );
206 }
207
208 #[test]
209 fn test_unknown_type_without_prefix_fails() {
210 use crate::text::simple_extensions;
211
212 let invalid_extension = SimpleExtensions {
214 scalar_functions: vec![simple_extensions::ScalarFunction {
215 name: "bad_function".to_string(),
216 description: None,
217 impls: vec![simple_extensions::ScalarFunctionImplsItem {
218 args: None,
219 options: None,
220 variadic: None,
221 session_dependent: None,
222 deterministic: None,
223 nullability: None,
224 return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
225 "point".to_string(), )),
227 implementation: None,
228 }],
229 }],
230 aggregate_functions: vec![],
231 window_functions: vec![],
232 dependencies: Default::default(),
233 type_variations: vec![],
234 types: vec![],
235 urn: "extension:example.com:invalid".to_string(),
236 };
237
238 let result = ExtensionFile::create(invalid_extension);
239 assert!(
240 result.is_err(),
241 "Should fail when type is missing u! prefix"
242 );
243
244 match result {
245 Err(SimpleExtensionsError::ScalarFunctionError(ScalarFunctionError::TypeError(
246 ExtensionTypeError::UnknownTypeName { name },
247 ))) => {
248 assert_eq!(name, "point");
249 }
250 other => panic!("Expected UnknownTypeName error, got {:?}", other),
251 }
252 }
253
254 fn extension_with_custom_type_reference(
256 urn: &str,
257 function_name: &str,
258 return_type: &str,
259 defined_types: Vec<&str>,
260 ) -> SimpleExtensions {
261 use crate::text::simple_extensions;
262
263 SimpleExtensions {
264 scalar_functions: vec![simple_extensions::ScalarFunction {
265 name: function_name.to_string(),
266 description: None,
267 impls: vec![simple_extensions::ScalarFunctionImplsItem {
268 args: None,
269 options: None,
270 variadic: None,
271 session_dependent: None,
272 deterministic: None,
273 nullability: None,
274 return_: simple_extensions::ReturnValue(simple_extensions::Type::String(
275 return_type.to_string(),
276 )),
277 implementation: None,
278 }],
279 }],
280 aggregate_functions: vec![],
281 window_functions: vec![],
282 dependencies: Default::default(),
283 type_variations: vec![],
284 types: defined_types
285 .into_iter()
286 .map(|name| SimpleExtensionsTypesItem {
287 name: name.to_string(),
288 description: None,
289 parameters: None,
290 structure: None,
291 variadic: None,
292 })
293 .collect(),
294 urn: urn.to_string(),
295 }
296 }
297
298 #[test]
299 fn test_custom_type_reference_valid() {
300 let extension = extension_with_custom_type_reference(
301 "extension:example.com:valid",
302 "get_point",
303 "u!point",
304 vec!["point"],
305 );
306
307 let result = ExtensionFile::create(extension);
308 assert!(
309 result.is_ok(),
310 "Should succeed when referenced type exists with u! prefix"
311 );
312 }
313
314 #[test]
315 fn test_custom_type_reference_missing() {
316 let extension = extension_with_custom_type_reference(
317 "extension:example.com:invalid",
318 "get_rectangle",
319 "u!rectangle",
320 vec![], );
322
323 let result = ExtensionFile::create(extension);
324 assert!(
325 result.is_err(),
326 "Should fail when referenced type doesn't exist"
327 );
328
329 match result {
330 Err(SimpleExtensionsError::UnresolvedTypeReference { type_name }) => {
331 assert_eq!(type_name, "rectangle");
332 }
333 other => panic!("Expected UnresolvedTypeReference error, got {:?}", other),
334 }
335 }
336
337 #[cfg(feature = "extensions")]
338 #[test]
339 fn test_scalar_function_parses_completely() {
340 use super::super::{
341 argument::ArgumentsItem,
342 scalar_functions::{Impl, NullabilityHandling, Options},
343 types::*,
344 };
345 use crate::parse::Parse;
346 use crate::text::simple_extensions;
347 use std::collections::HashMap;
348
349 let registry = Registry::from_core_extensions();
350 let functions_arithmetic_urn =
351 Urn::from_str("extension:io.substrait:functions_arithmetic").unwrap();
352
353 let add = registry
354 .get_scalar_function(&functions_arithmetic_urn, "add")
355 .expect("add function should exist");
356
357 assert_eq!(add.name, "add");
359 assert_eq!(add.description, Some("Add two values.".to_string()));
360 assert!(
361 !add.impls.is_empty(),
362 "add should have at least one implementation"
363 );
364
365 let mut ctx = super::super::extensions::TypeContext::default();
367 let expected_impl = Impl {
368 args: vec![
369 ArgumentsItem::ValueArgument(
370 simple_extensions::ValueArg {
371 name: Some("x".to_string()),
372 description: None,
373 value: simple_extensions::Type::String("i8".to_string()),
374 constant: None,
375 }
376 .parse(&mut ctx)
377 .unwrap(),
378 ),
379 ArgumentsItem::ValueArgument(
380 simple_extensions::ValueArg {
381 name: Some("y".to_string()),
382 description: None,
383 value: simple_extensions::Type::String("i8".to_string()),
384 constant: None,
385 }
386 .parse(&mut ctx)
387 .unwrap(),
388 ),
389 ],
390 options: Options({
391 let mut map = HashMap::new();
392 map.insert(
393 "overflow".to_string(),
394 vec![
395 "SILENT".to_string(),
396 "SATURATE".to_string(),
397 "ERROR".to_string(),
398 ],
399 );
400 map
401 }),
402 variadic: None,
403 session_dependent: false,
404 deterministic: true,
405 nullability: NullabilityHandling::Mirror,
406 return_type: ConcreteType {
407 kind: ConcreteTypeKind::Builtin(BasicBuiltinType::I8),
408 nullable: false,
409 },
410 implementation: HashMap::new(),
411 };
412
413 assert_eq!(&add.impls[0], &expected_impl);
414 }
415}