substrait/parse/text/simple_extensions/
extensions.rs1use indexmap::IndexMap;
9use std::collections::{HashMap, HashSet};
10use std::str::FromStr;
11
12use super::{SimpleExtensionsError, scalar_functions::ScalarFunction, types::CustomType};
13use crate::{
14 parse::{Context, Parse},
15 text::simple_extensions::SimpleExtensions as RawExtensions,
16 urn::Urn,
17};
18
19#[derive(Clone, Debug, Default)]
25pub struct SimpleExtensions {
26 types: HashMap<String, CustomType>,
28 scalar_functions: HashMap<String, ScalarFunction>,
32}
33
34impl SimpleExtensions {
35 pub fn add_type(&mut self, custom_type: &CustomType) -> Result<(), SimpleExtensionsError> {
37 if self.types.contains_key(&custom_type.name) {
38 return Err(SimpleExtensionsError::DuplicateTypeName {
39 name: custom_type.name.clone(),
40 });
41 }
42
43 self.types
44 .insert(custom_type.name.clone(), custom_type.clone());
45 Ok(())
46 }
47
48 pub fn get_type(&self, name: &str) -> Option<&CustomType> {
50 self.types.get(name)
51 }
52
53 pub fn types(&self) -> impl Iterator<Item = &CustomType> {
55 self.types.values()
56 }
57
58 pub(crate) fn into_types(self) -> HashMap<String, CustomType> {
60 self.types
61 }
62
63 pub(super) fn add_scalar_function(&mut self, scalar_function: ScalarFunction) {
70 use std::collections::hash_map::Entry;
71 match self.scalar_functions.entry(scalar_function.name.clone()) {
72 Entry::Vacant(e) => {
73 e.insert(scalar_function);
74 }
75 Entry::Occupied(mut e) => {
76 Self::merge_scalar_function(e.get_mut(), scalar_function);
77 }
78 }
79 }
80
81 fn merge_scalar_function(existing: &mut ScalarFunction, new: ScalarFunction) {
87 existing.impls.extend(new.impls);
88 existing.description = existing.description.take().or(new.description);
89 }
90
91 pub fn get_scalar_function(&self, name: &str) -> Option<&ScalarFunction> {
93 self.scalar_functions.get(name)
94 }
95
96 pub fn scalar_functions(&self) -> impl Iterator<Item = &ScalarFunction> {
98 self.scalar_functions.values()
99 }
100}
101
102#[derive(Debug, Default)]
104pub(crate) struct TypeContext {
105 known: HashSet<String>,
107 linked: HashSet<String>,
109}
110
111impl TypeContext {
112 pub fn found(&mut self, name: &str) {
114 self.linked.remove(name);
115 self.known.insert(name.to_string());
116 }
117
118 pub fn linked(&mut self, name: &str) {
121 if !self.known.contains(name) {
122 self.linked.insert(name.to_string());
123 }
124 }
125}
126
127impl Context for TypeContext {}
128
129impl Parse<TypeContext> for RawExtensions {
131 type Parsed = (Urn, SimpleExtensions);
132 type Error = super::SimpleExtensionsError;
133
134 fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
135 let RawExtensions {
136 urn,
137 types,
138 scalar_functions,
139 ..
140 } = self;
141 let urn = Urn::from_str(&urn)?;
142 let mut extension = SimpleExtensions::default();
143
144 for type_item in types {
145 let custom_type = Parse::parse(type_item, ctx)?;
146 extension.add_type(&custom_type)?;
147 }
148
149 for scalar_fn in scalar_functions {
150 match ScalarFunction::from_raw(scalar_fn, ctx) {
151 Ok(parsed_fn) => {
152 extension.add_scalar_function(parsed_fn);
153 }
154 Err(super::scalar_functions::ScalarFunctionError::NotYetImplemented(_)) => {
155 continue;
158 }
159 Err(e) => return Err(e.into()),
160 }
161 }
162
163 if let Some(missing) = ctx.linked.iter().next() {
164 return Err(super::SimpleExtensionsError::UnresolvedTypeReference {
166 type_name: missing.clone(),
167 });
168 }
169
170 Ok((urn, extension))
171 }
172}
173
174impl From<(Urn, SimpleExtensions)> for RawExtensions {
175 fn from((urn, extension): (Urn, SimpleExtensions)) -> Self {
176 let types = extension
177 .into_types()
178 .into_values()
179 .map(Into::into)
180 .collect();
181
182 RawExtensions {
183 urn: urn.to_string(),
184 aggregate_functions: vec![],
185 dependencies: IndexMap::new(),
186 scalar_functions: vec![],
187 type_variations: vec![],
188 types,
189 window_functions: vec![],
190 }
191 }
192}