substrait/parse/text/simple_extensions/
extensions.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Validated simple extensions: [`SimpleExtensions`].
4//!
5//! Currently only type definitions are supported; function parsing will be
6//! added in a future update.
7
8use indexmap::IndexMap;
9use std::collections::{HashMap, HashSet};
10use std::str::FromStr;
11
12use super::{SimpleExtensionsError, types::CustomType};
13use crate::{
14    parse::{Context, Parse},
15    text::simple_extensions::SimpleExtensions as RawExtensions,
16    urn::Urn,
17};
18
19/// The contents (types) in an [`ExtensionFile`](super::file::ExtensionFile).
20///
21/// This structure stores and provides access to the individual objects defined
22/// in an [`ExtensionFile`](super::file::ExtensionFile); [`SimpleExtensions`]
23/// represents the contents of an extensions file.
24///
25/// Currently, only the [`CustomType`]s are included; any scalar, window, or
26/// aggregate functions are not yet included.
27#[derive(Clone, Debug, Default)]
28pub struct SimpleExtensions {
29    /// Types defined in this extension file
30    types: HashMap<String, CustomType>,
31}
32
33impl SimpleExtensions {
34    /// Add a type to the context. Name must be unique.
35    pub fn add_type(&mut self, custom_type: &CustomType) -> Result<(), SimpleExtensionsError> {
36        if self.types.contains_key(&custom_type.name) {
37            return Err(SimpleExtensionsError::DuplicateTypeName {
38                name: custom_type.name.clone(),
39            });
40        }
41
42        self.types
43            .insert(custom_type.name.clone(), custom_type.clone());
44        Ok(())
45    }
46
47    /// Check if a type with the given name exists in the context
48    pub fn has_type(&self, name: &str) -> bool {
49        self.types.contains_key(name)
50    }
51
52    /// Get a type by name from the context
53    pub fn get_type(&self, name: &str) -> Option<&CustomType> {
54        self.types.get(name)
55    }
56
57    /// Get an iterator over all types in the context
58    pub fn types(&self) -> impl Iterator<Item = &CustomType> {
59        self.types.values()
60    }
61
62    /// Consume the parsed extension and return its types.
63    pub(crate) fn into_types(self) -> HashMap<String, CustomType> {
64        self.types
65    }
66}
67
68/// resolved or unresolved.
69#[derive(Debug, Default)]
70pub(crate) struct TypeContext {
71    /// Types that have been seen so far, now resolved.
72    known: HashSet<String>,
73    /// Types that have been linked to, not yet resolved.
74    linked: HashSet<String>,
75}
76
77impl TypeContext {
78    /// Mark a type as found
79    pub fn found(&mut self, name: &str) {
80        self.linked.remove(name);
81        self.known.insert(name.to_string());
82    }
83
84    /// Mark a type as linked to - some other type or function references it,
85    /// but we haven't seen it.
86    pub fn linked(&mut self, name: &str) {
87        if !self.known.contains(name) {
88            self.linked.insert(name.to_string());
89        }
90    }
91}
92
93impl Context for TypeContext {}
94
95// Implement parsing for the raw text representation to produce an `ExtensionFile`.
96impl Parse<TypeContext> for RawExtensions {
97    type Parsed = (Urn, SimpleExtensions);
98    type Error = super::SimpleExtensionsError;
99
100    fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
101        let RawExtensions { urn, types, .. } = self;
102        let urn = Urn::from_str(&urn)?;
103        let mut extension = SimpleExtensions::default();
104
105        for type_item in types {
106            let custom_type = Parse::parse(type_item, ctx)?;
107            extension.add_type(&custom_type)?;
108        }
109
110        if let Some(missing) = ctx.linked.iter().next() {
111            // TODO: Track originating type(s) to improve this error message.
112            return Err(super::SimpleExtensionsError::UnresolvedTypeReference {
113                type_name: missing.clone(),
114            });
115        }
116
117        Ok((urn, extension))
118    }
119}
120
121impl From<(Urn, SimpleExtensions)> for RawExtensions {
122    fn from((urn, extension): (Urn, SimpleExtensions)) -> Self {
123        let types = extension
124            .into_types()
125            .into_values()
126            .map(Into::into)
127            .collect();
128
129        RawExtensions {
130            urn: urn.to_string(),
131            aggregate_functions: vec![],
132            dependencies: IndexMap::new(),
133            scalar_functions: vec![],
134            type_variations: vec![],
135            types,
136            window_functions: vec![],
137        }
138    }
139}