substrait/parse/text/simple_extensions/
extensions.rs1use indexmap::IndexMap;
9use std::collections::{HashMap, HashSet};
10use std::str::FromStr;
11
12use super::{SimpleExtensionsError, types::CustomType};
13use crate::{
14 parse::{Context, Parse},
15 text::simple_extensions::SimpleExtensions as RawExtensions,
16 urn::Urn,
17};
18
19#[derive(Clone, Debug, Default)]
28pub struct SimpleExtensions {
29 types: HashMap<String, CustomType>,
31}
32
33impl SimpleExtensions {
34 pub fn add_type(&mut self, custom_type: &CustomType) -> Result<(), SimpleExtensionsError> {
36 if self.types.contains_key(&custom_type.name) {
37 return Err(SimpleExtensionsError::DuplicateTypeName {
38 name: custom_type.name.clone(),
39 });
40 }
41
42 self.types
43 .insert(custom_type.name.clone(), custom_type.clone());
44 Ok(())
45 }
46
47 pub fn has_type(&self, name: &str) -> bool {
49 self.types.contains_key(name)
50 }
51
52 pub fn get_type(&self, name: &str) -> Option<&CustomType> {
54 self.types.get(name)
55 }
56
57 pub fn types(&self) -> impl Iterator<Item = &CustomType> {
59 self.types.values()
60 }
61
62 pub(crate) fn into_types(self) -> HashMap<String, CustomType> {
64 self.types
65 }
66}
67
68#[derive(Debug, Default)]
70pub(crate) struct TypeContext {
71 known: HashSet<String>,
73 linked: HashSet<String>,
75}
76
77impl TypeContext {
78 pub fn found(&mut self, name: &str) {
80 self.linked.remove(name);
81 self.known.insert(name.to_string());
82 }
83
84 pub fn linked(&mut self, name: &str) {
87 if !self.known.contains(name) {
88 self.linked.insert(name.to_string());
89 }
90 }
91}
92
93impl Context for TypeContext {}
94
95impl Parse<TypeContext> for RawExtensions {
97 type Parsed = (Urn, SimpleExtensions);
98 type Error = super::SimpleExtensionsError;
99
100 fn parse(self, ctx: &mut TypeContext) -> Result<Self::Parsed, Self::Error> {
101 let RawExtensions { urn, types, .. } = self;
102 let urn = Urn::from_str(&urn)?;
103 let mut extension = SimpleExtensions::default();
104
105 for type_item in types {
106 let custom_type = Parse::parse(type_item, ctx)?;
107 extension.add_type(&custom_type)?;
108 }
109
110 if let Some(missing) = ctx.linked.iter().next() {
111 return Err(super::SimpleExtensionsError::UnresolvedTypeReference {
113 type_name: missing.clone(),
114 });
115 }
116
117 Ok((urn, extension))
118 }
119}
120
121impl From<(Urn, SimpleExtensions)> for RawExtensions {
122 fn from((urn, extension): (Urn, SimpleExtensions)) -> Self {
123 let types = extension
124 .into_types()
125 .into_values()
126 .map(Into::into)
127 .collect();
128
129 RawExtensions {
130 urn: urn.to_string(),
131 aggregate_functions: vec![],
132 dependencies: IndexMap::new(),
133 scalar_functions: vec![],
134 type_variations: vec![],
135 types,
136 window_functions: vec![],
137 }
138 }
139}