substrait/parse/text/simple_extensions/
file.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use super::{CustomType, SimpleExtensions, SimpleExtensionsError};
4use crate::parse::Parse;
5use crate::parse::text::simple_extensions::extensions::TypeContext;
6use crate::text::simple_extensions::SimpleExtensions as RawExtensions;
7use crate::urn::Urn;
8use std::io::Read;
9
10/// A parsed and validated [`RawExtensions`]: a simple extensions file.
11///
12/// An [`ExtensionFile`] has a canonical [`Urn`] and a parsed set of
13/// [`SimpleExtensions`] data. It represents the extensions file as a whole.
14#[derive(Debug)]
15pub struct ExtensionFile {
16    /// The URN this extension was loaded from
17    pub(crate) urn: Urn,
18    /// The extension data containing types and eventually functions
19    pub(crate) extension: SimpleExtensions,
20}
21
22impl ExtensionFile {
23    /// Create a new, empty [`ExtensionFile`] with an empty set of [`SimpleExtensions`].
24    pub fn empty(urn: Urn) -> Self {
25        let extension = SimpleExtensions::default();
26        Self { urn, extension }
27    }
28
29    /// Create an [`ExtensionFile`] from raw simple extension data.
30    pub fn create(extensions: RawExtensions) -> Result<Self, SimpleExtensionsError> {
31        // Parse all types (may contain unresolved Extension(String) references)
32        let mut ctx = TypeContext::default();
33        let (urn, extension) = Parse::parse(extensions, &mut ctx)?;
34
35        // TODO: Use ctx.known/ctx.linked to validate unresolved references and cross-file links.
36
37        Ok(Self { urn, extension })
38    }
39
40    /// Get a type by name
41    pub fn get_type(&self, name: &str) -> Option<&CustomType> {
42        self.extension.get_type(name)
43    }
44
45    /// Get an iterator over all types in this extension
46    pub fn types(&self) -> impl Iterator<Item = &CustomType> {
47        self.extension.types()
48    }
49
50    /// Returns the [`Urn`]` for this extension file.
51    pub fn urn(&self) -> &Urn {
52        &self.urn
53    }
54
55    /// Get a reference to the underlying [`SimpleExtensions`].
56    pub fn extension(&self) -> &SimpleExtensions {
57        &self.extension
58    }
59
60    /// Convert the parsed extension file back into the raw text representation
61    /// by value.
62    pub fn into_raw(self) -> RawExtensions {
63        let ExtensionFile { urn, extension } = self;
64        RawExtensions::from((urn, extension))
65    }
66
67    /// Convert the parsed extension file back into the raw text representation
68    /// by reference.
69    pub fn to_raw(&self) -> RawExtensions {
70        RawExtensions::from((self.urn.clone(), self.extension.clone()))
71    }
72
73    /// Read an extension file from a reader.
74    /// - `reader`: any [`Read`] instance with the YAML content
75    ///
76    /// Returns a parsed and validated [`ExtensionFile`] or an error.
77    pub fn read<R: Read>(reader: R) -> Result<Self, SimpleExtensionsError> {
78        let raw: RawExtensions = serde_yaml::from_reader(reader)?;
79        Self::create(raw)
80    }
81
82    /// Read an extension file from a string slice.
83    pub fn read_from_str<S: AsRef<str>>(s: S) -> Result<Self, SimpleExtensionsError> {
84        let raw: RawExtensions = serde_yaml::from_str(s.as_ref())?;
85        Self::create(raw)
86    }
87}
88
89// Parsing and conversion implementations are defined on `SimpleExtensions` in `extensions.rs`.
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94    use crate::parse::text::simple_extensions::types::ParameterConstraint as RawParameterType;
95
96    const YAML_PARAM_TEST: &str = r#"
97%YAML 1.2
98---
99urn: extension:example.com:param_test
100types:
101  - name: "ParamTest"
102    parameters:
103      - name: "K"
104        type: integer
105        min: 1
106        max: 10
107"#;
108
109    const YAML_UNRESOLVED_TYPE: &str = r#"
110%YAML 1.2
111---
112urn: extension:example.com:unresolved
113types:
114  - name: "Alias"
115    structure: List<Map<string, MissingType>>
116"#;
117
118    #[test]
119    fn yaml_round_trip_integer_param_bounds() {
120        let deserialized: RawExtensions = serde_yaml::from_str(YAML_PARAM_TEST).expect("parse ok");
121        let ext = ExtensionFile::create(deserialized.clone()).expect("create ok");
122        assert_eq!(ext.urn().to_string(), "extension:example.com:param_test");
123
124        let ty = ext.get_type("ParamTest").expect("type exists");
125        match &ty.parameters[..] {
126            [param] => match &param.param_type {
127                RawParameterType::Integer {
128                    min: actual_min,
129                    max: actual_max,
130                } => {
131                    assert_eq!(actual_min, &Some(1));
132                    assert_eq!(actual_max, &Some(10));
133                }
134                other => panic!("unexpected param type: {other:?}"),
135            },
136            other => panic!("unexpected parameters: {other:?}"),
137        }
138
139        let back = ext.to_raw();
140        assert_eq!(deserialized, back);
141    }
142
143    #[test]
144    fn unresolved_type_reference_errors() {
145        let err = ExtensionFile::read_from_str(YAML_UNRESOLVED_TYPE)
146            .expect_err("expected unresolved type reference error");
147
148        match err {
149            SimpleExtensionsError::UnresolvedTypeReference { type_name } => {
150                assert_eq!(type_name, "MissingType");
151            }
152            other => panic!("unexpected error type: {other:?}"),
153        }
154    }
155}