substrait/parse/text/simple_extensions/
extensions.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Validated simple extensions: [`SimpleExtensions`].
4//!
5//! Both type definitions and scalar function definitions are supported.
6//! Aggregate functions (see #447) and window functions (see #446) are not yet supported.
7
8use indexmap::IndexMap;
9use std::collections::{HashMap, HashSet};
10use std::str::FromStr;
11
12use super::{SimpleExtensionsError, scalar_functions::ScalarFunction, types::CustomType};
13use crate::{
14    parse::{Context, Parse},
15    text::simple_extensions::SimpleExtensions as RawExtensions,
16    urn::Urn,
17};
18
19/// The contents (types and functions) in an [`ExtensionFile`](super::file::ExtensionFile).
20///
21/// This structure stores and provides access to the individual objects defined
22/// in an [`ExtensionFile`](super::file::ExtensionFile); [`SimpleExtensions`]
23/// represents the contents of an extensions file.
24#[derive(Clone, Debug, Default)]
25pub struct SimpleExtensions {
26    /// Types defined in this extension file
27    types: HashMap<String, CustomType>,
28    /// Scalar functions defined in this extension file
29    ///
30    /// TODO: Add support for window functions (issue #446) and aggregate functions (issue #447)
31    scalar_functions: HashMap<String, ScalarFunction>,
32}
33
34impl SimpleExtensions {
35    /// Add a type to the context. Name must be unique.
36    pub fn add_type(&mut self, custom_type: &CustomType) -> Result<(), SimpleExtensionsError> {
37        if self.types.contains_key(&custom_type.name) {
38            return Err(SimpleExtensionsError::DuplicateTypeName {
39                name: custom_type.name.clone(),
40            });
41        }
42
43        self.types
44            .insert(custom_type.name.clone(), custom_type.clone());
45        Ok(())
46    }
47
48    /// Get a type by name from the context
49    pub fn get_type(&self, name: &str) -> Option<&CustomType> {
50        self.types.get(name)
51    }
52
53    /// Get an iterator over all types in the context
54    pub fn types(&self) -> impl Iterator<Item = &CustomType> {
55        self.types.values()
56    }
57
58    /// Consume the parsed extension and return its types.
59    pub(crate) fn into_types(self) -> HashMap<String, CustomType> {
60        self.types
61    }
62
63    /// Add a scalar function to the context, merging with existing functions of the same name.
64    ///
65    /// When duplicate function names are encountered, implementations are merged (unioned).
66    /// The existing description is kept if present, otherwise the new description is used.
67    ///
68    /// See: https://github.com/substrait-io/substrait/issues/931
69    pub(super) fn add_scalar_function(&mut self, scalar_function: ScalarFunction) {
70        use std::collections::hash_map::Entry;
71        match self.scalar_functions.entry(scalar_function.name.clone()) {
72            Entry::Vacant(e) => {
73                e.insert(scalar_function);
74            }
75            Entry::Occupied(mut e) => {
76                Self::merge_scalar_function(e.get_mut(), scalar_function);
77            }
78        }
79    }
80
81    /// Merge a new scalar function into an existing one.
82    ///
83    /// Unions the implementations. Keeps the existing description if present,
84    /// otherwise uses the new description.
85    // TODO: Reject conflicting implementations instead of blindly merging.
86    fn merge_scalar_function(existing: &mut ScalarFunction, new: ScalarFunction) {
87        existing.impls.extend(new.impls);
88        existing.description = existing.description.take().or(new.description);
89    }
90
91    /// Get a scalar function by name
92    pub fn get_scalar_function(&self, name: &str) -> Option<&ScalarFunction> {
93        self.scalar_functions.get(name)
94    }
95
96    /// Get an iterator over all scalar functions
97    pub fn scalar_functions(&self) -> impl Iterator<Item = &ScalarFunction> {
98        self.scalar_functions.values()
99    }
100}
101
102/// resolved or unresolved.
103#[derive(Debug, Default)]
104pub(crate) struct TypeContext {
105    /// Types that have been seen so far, now resolved.
106    known: HashSet<String>,
107    /// Types that have been linked to, not yet resolved.
108    linked: HashSet<String>,
109}
110
111impl TypeContext {
112    /// Mark a type as found
113    pub fn found(&mut self, name: &str) {
114        self.linked.remove(name);
115        self.known.insert(name.to_string());
116    }
117
118    /// Mark a type as linked to - some other type or function references it,
119    /// but we haven't seen it.
120    pub fn linked(&mut self, name: &str) {
121        if !self.known.contains(name) {
122            self.linked.insert(name.to_string());
123        }
124    }
125}
126
127impl Context for TypeContext {}
128
129// Implement parsing for the raw text representation to produce an `ExtensionFile`.
130impl Parse<TypeContext> for RawExtensions {
131    type Parsed = (Urn, SimpleExtensions);
132    type Error = super::SimpleExtensionsError;
133
134    fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
135        let RawExtensions {
136            urn,
137            types,
138            scalar_functions,
139            ..
140        } = self;
141        let urn = Urn::from_str(&urn)?;
142        let mut extension = SimpleExtensions::default();
143
144        for type_item in types {
145            let custom_type = Parse::parse(type_item, ctx)?;
146            extension.add_type(&custom_type)?;
147        }
148
149        for scalar_fn in scalar_functions {
150            match ScalarFunction::from_raw(scalar_fn, ctx) {
151                Ok(parsed_fn) => {
152                    extension.add_scalar_function(parsed_fn);
153                }
154                Err(super::scalar_functions::ScalarFunctionError::NotYetImplemented(_)) => {
155                    // Skip functions with unimplemented features (e.g., type derivations)
156                    // These will be supported in a future update
157                    continue;
158                }
159                Err(e) => return Err(e.into()),
160            }
161        }
162
163        if let Some(missing) = ctx.linked.iter().next() {
164            // TODO: Track originating type(s) to improve this error message.
165            return Err(super::SimpleExtensionsError::UnresolvedTypeReference {
166                type_name: missing.clone(),
167            });
168        }
169
170        Ok((urn, extension))
171    }
172}
173
174impl From<(Urn, SimpleExtensions)> for RawExtensions {
175    fn from((urn, extension): (Urn, SimpleExtensions)) -> Self {
176        let types = extension
177            .into_types()
178            .into_values()
179            .map(Into::into)
180            .collect();
181
182        RawExtensions {
183            urn: urn.to_string(),
184            aggregate_functions: vec![],
185            dependencies: IndexMap::new(),
186            scalar_functions: vec![],
187            type_variations: vec![],
188            types,
189            window_functions: vec![],
190        }
191    }
192}