diff --git a/crates/bindings-macro/src/sats.rs b/crates/bindings-macro/src/sats.rs index 8bad39da837..1902592fcbd 100644 --- a/crates/bindings-macro/src/sats.rs +++ b/crates/bindings-macro/src/sats.rs @@ -347,7 +347,8 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { de_generics.params.insert(0, de_lt_param.into()); let (de_impl_generics, _, de_where_clause) = de_generics.split_for_impl(); - let (iter_n, iter_n2, iter_n3, iter_n4) = (0usize.., 0usize.., 0usize.., 0usize..); + let (iter_n, iter_n2, iter_n3, iter_n4, iter_n5, iter_n6, iter_n7) = + (0usize.., 0usize.., 0usize.., 0usize.., 0usize.., 0usize.., 0usize..); match &ty.data { SatsTypeData::Product(fields) => { @@ -382,8 +383,10 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { let field_names = fields.iter().map(|f| f.ident.unwrap()).collect::>(); let field_strings = fields.iter().map(|f| f.name.as_deref().unwrap()).collect::>(); - let field_types = fields.iter().map(|f| &f.ty); + let field_types = fields.iter().map(|f| f.ty); let field_types2 = field_types.clone(); + let field_types3 = field_types.clone(); + let field_types4 = field_types.clone(); quote! { #[allow(non_camel_case_types)] #[allow(clippy::all)] @@ -396,6 +399,12 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { _marker: std::marker::PhantomData:: #name #ty_generics>, }) } + + fn validate>(deserializer: D) -> Result<(), D::Error> { + deserializer.validate_product(__ProductVisitor { + _marker: std::marker::PhantomData:: #name #ty_generics>, + }) + } } struct __ProductVisitor #impl_generics #where_clause { @@ -419,6 +428,13 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { .ok_or_else(|| #spacetimedb_lib::de::Error::invalid_product_length(#iter_n, &self))?,)* }) } + fn validate_seq_product>(self, mut tup: A) -> Result<(), A::Error> { + #( + tup.validate_next_element::<#field_types2>()? + .ok_or_else(|| #spacetimedb_lib::de::Error::invalid_product_length(#iter_n2, &self))?; + )* + Ok(()) + } fn visit_named_product>(self, mut __prod: A) -> Result { #(let mut #field_names = None;)* while let Some(__field) = #spacetimedb_lib::de::NamedProductAccess::get_field_ident(&mut __prod, Self { @@ -427,17 +443,39 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { match __field { #(__ProductFieldIdent::#field_names => { if #field_names.is_some() { - return Err(#spacetimedb_lib::de::Error::duplicate_field(#iter_n2, Some(#field_strings), &self)) + return Err(#spacetimedb_lib::de::Error::duplicate_field(#iter_n3, Some(#field_strings), &self)) } - #field_names = Some(#spacetimedb_lib::de::NamedProductAccess::get_field_value::<#field_types2>(&mut __prod)?) + #field_names = Some(#spacetimedb_lib::de::NamedProductAccess::get_field_value::<#field_types3>(&mut __prod)?) })* } } Ok(#name { #(#field_names: - #field_names.ok_or_else(|| #spacetimedb_lib::de::Error::missing_field(#iter_n3, Some(#field_strings), &self))?,)* + #field_names.ok_or_else(|| #spacetimedb_lib::de::Error::missing_field(#iter_n4, Some(#field_strings), &self))?,)* }) } + fn validate_named_product>(self, mut __prod: A) -> Result<(), A::Error> { + #(let mut #field_names = false;)* + while let Some(__field) = #spacetimedb_lib::de::NamedProductAccess::get_field_ident(&mut __prod, Self { + _marker: std::marker::PhantomData, + })? { + match __field { + #(__ProductFieldIdent::#field_names => { + if #field_names { + return Err(#spacetimedb_lib::de::Error::duplicate_field(#iter_n5, Some(#field_strings), &self)) + } + #spacetimedb_lib::de::NamedProductAccess::validate_field_value::<#field_types4>(&mut __prod)?; + #field_names = true; + })* + } + } + #( + if !#field_names { + return Err(#spacetimedb_lib::de::Error::missing_field(#iter_n6, Some(#field_strings), &self)); + } + )* + Ok(()) + } } impl #de_impl_generics #spacetimedb_lib::de::FieldNameVisitor<'de> for __ProductVisitor #ty_generics #de_where_clause { @@ -456,7 +494,7 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { fn visit_seq(self, index: usize) -> Self::Output { match index { - #(#iter_n4 => __ProductFieldIdent::#field_names,)* + #(#iter_n7 => __ProductFieldIdent::#field_names,)* _ => core::unreachable!(), } } @@ -488,6 +526,18 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { } } }); + let arms_validate = variants.iter().map(|var| { + let ident = var.ident; + if let Some(ty) = var.ty { + quote! { + __Variant::#ident => #spacetimedb_lib::de::VariantAccess::validate::<#ty>(__access)?, + } + } else { + quote! { + __Variant::#ident => #spacetimedb_lib::de::VariantAccess::validate::<()>(__access)?, + } + } + }); quote! { #[allow(clippy::all)] const _: () = { @@ -497,6 +547,12 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { _marker: std::marker::PhantomData:: #name #ty_generics>, }) } + + fn validate>(deserializer: D) -> Result<(), D::Error> { + deserializer.validate_sum(__SumVisitor { + _marker: std::marker::PhantomData:: #name #ty_generics>, + }) + } } struct __SumVisitor #impl_generics #where_clause { @@ -516,6 +572,14 @@ pub(crate) fn derive_deserialize(ty: &SatsType<'_>) -> TokenStream { #(#arms)* } } + + fn validate_sum>(self, __data: A) -> Result<(), A::Error> { + let (__variant, __access) = __data.variant(self)?; + match __variant { + #(#arms_validate)* + } + Ok(()) + } } #[allow(non_camel_case_types)] diff --git a/crates/bindings/tests/ui/tables.stderr b/crates/bindings/tests/ui/tables.stderr index bbb002775dd..4fe52b766e7 100644 --- a/crates/bindings/tests/ui/tables.stderr +++ b/crates/bindings/tests/ui/tables.stderr @@ -63,6 +63,36 @@ note: required by a bound in `spacetimedb::spacetimedb_lib::de::SeqProductAccess | fn next_element>(&mut self) -> Result, Self::Error> { | ^^^^^^^^^^^^^^^^ required by this bound in `SeqProductAccess::next_element` +error[E0277]: the trait bound `Test: Deserialize<'de>` is not satisfied + --> tests/ui/tables.rs:5:8 + | +3 | #[spacetimedb::table(accessor = table)] + | --------------------------------------- required by a bound introduced by this call +4 | struct Table { +5 | x: Test, + | ^^^^ unsatisfied trait bound + | +help: the trait `Deserialize<'de>` is not implemented for `Test` + --> tests/ui/tables.rs:1:1 + | +1 | struct Test; + | ^^^^^^^^^^^ + = help: the following other types implement trait `Deserialize<'de>`: + &'de [u8] + &'de str + () + (T0, T1) + (T0, T1, T2) + (T0, T1, T2, T3) + (T0, T1, T2, T3, T4) + (T0, T1, T2, T3, T4, T5) + and $N others +note: required by a bound in `spacetimedb::spacetimedb_lib::de::SeqProductAccess::validate_next_element` + --> $WORKSPACE/crates/sats/src/de.rs + | + | fn validate_next_element>(&mut self) -> Result, Self::Error> { + | ^^^^^^^^^^^^^^^^ required by this bound in `SeqProductAccess::validate_next_element` + error[E0277]: the trait bound `Test: Deserialize<'_>` is not satisfied --> tests/ui/tables.rs:5:8 | @@ -90,6 +120,33 @@ note: required by a bound in `get_field_value` | fn get_field_value>(&mut self) -> Result { | ^^^^^^^^^^^^^^^^ required by this bound in `NamedProductAccess::get_field_value` +error[E0277]: the trait bound `Test: Deserialize<'_>` is not satisfied + --> tests/ui/tables.rs:5:8 + | +5 | x: Test, + | ^^^^ unsatisfied trait bound + | +help: the trait `Deserialize<'_>` is not implemented for `Test` + --> tests/ui/tables.rs:1:1 + | +1 | struct Test; + | ^^^^^^^^^^^ + = help: the following other types implement trait `Deserialize<'de>`: + &'de [u8] + &'de str + () + (T0, T1) + (T0, T1, T2) + (T0, T1, T2, T3) + (T0, T1, T2, T3, T4) + (T0, T1, T2, T3, T4, T5) + and $N others +note: required by a bound in `validate_field_value` + --> $WORKSPACE/crates/sats/src/de.rs + | + | fn validate_field_value>(&mut self) -> Result<(), Self::Error> { + | ^^^^^^^^^^^^^^^^ required by this bound in `NamedProductAccess::validate_field_value` + error[E0277]: the trait bound `Test: Serialize` is not satisfied --> tests/ui/tables.rs:5:8 | diff --git a/crates/core/src/host/v8/de.rs b/crates/core/src/host/v8/de.rs index a24987402a8..0445cdc9faf 100644 --- a/crates/core/src/host/v8/de.rs +++ b/crates/core/src/host/v8/de.rs @@ -152,6 +152,28 @@ impl<'de, 'this, 'scope: 'de> de::Deserializer<'de> for Deserializer<'this, 'sco }) } + fn validate_product>(self, visitor: V) -> Result<(), Self::Error> { + // In `ProductType.serializeValue()` in the TS SDK, null/undefined is accepted for the unit type. + if visitor.product_len() == 0 && self.input.is_null_or_undefined() { + return visitor.validate_seq_product(de::UnitAccess::new()); + } + + let object = cast!( + self.common.scope, + self.input, + Object, + "object for product type `{}`", + visitor.product_name().unwrap_or("") + )?; + + visitor.validate_named_product(ProductAccess { + common: self.common, + object, + next_value: None, + index: 0, + }) + } + fn deserialize_sum>(self, visitor: V) -> Result { let scope = &*self.common.scope; @@ -302,6 +324,17 @@ impl<'de, 'scope: 'de> de::NamedProductAccess<'de> for ProductAccess<'_, 'scope, // Deserialize the field's value. seed.deserialize(Deserializer { common, input }) } + + fn validate_field_value_seed>(&mut self, seed: T) -> Result<(), Self::Error> { + let common = self.common.reborrow(); + // Extract the field's value. + let input = self + .next_value + .take() + .expect("Call next_key_seed before next_value_seed"); + // Deserialize the field's value. + seed.validate(Deserializer { common, input }) + } } /// Used in `Deserializer::deserialize_sum` to translate a `tag` property of a JS object @@ -367,6 +400,23 @@ where index: 0, } } + + fn next_elem<'a>(&'a mut self) -> Option), Error<'scope>>> { + self.seeds.next().map(move |seed| { + // Extract the array element. + let input = self + .arr + .get_index(self.common.scope, self.index) + .ok_or_else(exception_already_thrown)?; + + // Make the deserializer. + let common = self.common.reborrow(); + let de = Deserializer { common, input }; + + self.index += 1; + Ok((seed, de)) + }) + } } impl<'de, 'scope: 'de, T: DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for ArrayAccess<'_, 'scope, '_, T> { @@ -374,24 +424,14 @@ impl<'de, 'scope: 'de, T: DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for type Error = Error<'scope>; fn next_element(&mut self) -> Result, Self::Error> { - self.seeds - .next() - .map(|seed| { - // Extract the array element. - let val = self - .arr - .get_index(self.common.scope, self.index) - .ok_or_else(exception_already_thrown)?; - - // Deserialize the element. - let val = seed.deserialize(Deserializer { - common: self.common.reborrow(), - input: val, - })?; - - self.index += 1; - Ok(val) - }) + self.next_elem() + .map(|res| res.and_then(|(seed, de)| seed.deserialize(de))) + .transpose() + } + + fn validate_next_element(&mut self) -> Result, Self::Error> { + self.next_elem() + .map(|res| res.and_then(|(seed, de)| seed.validate(de))) .transpose() } diff --git a/crates/sats/src/algebraic_value/de.rs b/crates/sats/src/algebraic_value/de.rs index cd2de61c5ad..dcfeed14b51 100644 --- a/crates/sats/src/algebraic_value/de.rs +++ b/crates/sats/src/algebraic_value/de.rs @@ -58,11 +58,21 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer { visitor.visit_seq_product(ProductAccess { vals }) } + fn validate_product>(self, visitor: V) -> Result<(), Self::Error> { + let vals = map_err(self.val.into_product())?.into_iter(); + visitor.validate_seq_product(ProductAccess { vals }) + } + fn deserialize_sum>(self, visitor: V) -> Result { let sum = map_err(self.val.into_sum())?; visitor.visit_sum(SumAccess { sum }) } + fn validate_sum>(self, visitor: V) -> Result<(), Self::Error> { + let sum = map_err(self.val.into_sum())?; + visitor.validate_sum(SumAccess { sum }) + } + fn deserialize_bool(self) -> Result { map_err(self.val.into_bool()) } @@ -139,6 +149,15 @@ impl<'de> de::Deserializer<'de> for ValueDeserializer { let iter = map_err(self.val.into_array())?.into_iter(); visitor.visit(ArrayAccess { iter, seed }) } + + fn validate_array_seed, T: de::DeserializeSeed<'de> + Clone>( + self, + visitor: V, + seed: T, + ) -> Result<(), Self::Error> { + let iter = map_err(self.val.into_array())?.into_iter(); + visitor.validate(ArrayAccess { iter, seed }) + } } /// Defines deserialization for [`ValueDeserializer`] where product elements are in the input. @@ -156,6 +175,13 @@ impl<'de> de::SeqProductAccess<'de> for ProductAccess { .map(|val| seed.deserialize(ValueDeserializer { val })) .transpose() } + + fn validate_next_element_seed>(&mut self, seed: T) -> Result, Self::Error> { + self.vals + .next() + .map(|val| seed.validate(ValueDeserializer { val })) + .transpose() + } } /// Defines deserialization for [`ValueDeserializer`] where a sum value is in the input. @@ -191,6 +217,10 @@ impl<'de> de::VariantAccess<'de> for ValueDeserializer { fn deserialize_seed>(self, seed: T) -> Result { seed.deserialize(self) } + + fn validate_seed>(self, seed: T) -> Result<(), Self::Error> { + seed.validate(self) + } } /// Defines deserialization for [`ValueDeserializer`] where an array value is in the input. @@ -212,6 +242,13 @@ impl<'de, T: de::DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for ArrayAcc .map(|val| self.seed.clone().deserialize(ValueDeserializer { val })) .transpose() } + + fn validate_next_element(&mut self) -> Result, Self::Error> { + self.iter + .next() + .map(|val| self.seed.clone().validate(ValueDeserializer { val })) + .transpose() + } } impl<'de> de::Deserializer<'de> for &'de ValueDeserializer { @@ -222,11 +259,21 @@ impl<'de> de::Deserializer<'de> for &'de ValueDeserializer { visitor.visit_seq_product(RefProductAccess { vals }) } + fn validate_product>(self, visitor: V) -> Result<(), Self::Error> { + let vals = ok_or(self.val.as_product())?.elements.iter(); + visitor.validate_seq_product(RefProductAccess { vals }) + } + fn deserialize_sum>(self, visitor: V) -> Result { let sum = ok_or(self.val.as_sum())?; visitor.visit_sum(SumAccess::from_ref(sum)) } + fn validate_sum>(self, visitor: V) -> Result<(), Self::Error> { + let sum = ok_or(self.val.as_sum())?; + visitor.validate_sum(SumAccess::from_ref(sum)) + } + fn deserialize_bool(self) -> Result { ok_or(self.val.as_bool().copied()) } @@ -289,6 +336,15 @@ impl<'de> de::Deserializer<'de> for &'de ValueDeserializer { let iter = ok_or(self.val.as_array())?.iter_cloned(); visitor.visit(RefArrayAccess { iter, seed }) } + + fn validate_array_seed, T: de::DeserializeSeed<'de> + Clone>( + self, + visitor: V, + seed: T, + ) -> Result<(), Self::Error> { + let iter = ok_or(self.val.as_array())?.iter_cloned(); + visitor.validate(RefArrayAccess { iter, seed }) + } } /// Defines deserialization for [`&'de ValueDeserializer`] where product elements are in the input. @@ -306,6 +362,13 @@ impl<'de> de::SeqProductAccess<'de> for RefProductAccess<'de> { .map(|val| seed.deserialize(ValueDeserializer::from_ref(val))) .transpose() } + + fn validate_next_element_seed>(&mut self, seed: T) -> Result, Self::Error> { + self.vals + .next() + .map(|val| seed.validate(ValueDeserializer::from_ref(val))) + .transpose() + } } impl<'de> de::SumAccess<'de> for &'de SumAccess { @@ -325,6 +388,10 @@ impl<'de> de::VariantAccess<'de> for &'de ValueDeserializer { fn deserialize_seed>(self, seed: T) -> Result { seed.deserialize(self) } + + fn validate_seed>(self, seed: T) -> Result<(), Self::Error> { + seed.validate(self) + } } /// Defines deserialization for [`&'de ValueDeserializer`] where an array value is in the input. @@ -347,4 +414,11 @@ impl<'de, T: de::DeserializeSeed<'de> + Clone> de::ArrayAccess<'de> for RefArray .map(|val| self.seed.clone().deserialize(ValueDeserializer { val })) .transpose() } + + fn validate_next_element(&mut self) -> Result, Self::Error> { + self.iter + .next() + .map(|val| self.seed.clone().validate(ValueDeserializer { val })) + .transpose() + } } diff --git a/crates/sats/src/bsatn.rs b/crates/sats/src/bsatn.rs index 8f33f5b2e42..7cc35d27aea 100644 --- a/crates/sats/src/bsatn.rs +++ b/crates/sats/src/bsatn.rs @@ -228,11 +228,12 @@ pub const fn assert_is_primitive_type() {} #[cfg(test)] mod tests { - use super::{to_vec, DecodeError}; - use crate::proptest::generate_typed_value; - use crate::{meta_type::MetaType, AlgebraicType, AlgebraicValue}; + use super::{to_vec, DecodeError, Deserializer}; + use crate::de::DeserializeSeed; + use crate::proptest::{generate_algebraic_type, generate_typed_value}; + use crate::{meta_type::MetaType, AlgebraicType, AlgebraicValue, WithTypespace}; use proptest::prelude::*; - use proptest::proptest; + use proptest::{collection::vec, proptest}; #[test] fn type_to_binary_equivalent() { @@ -248,14 +249,32 @@ mod tests { assert_eq!(direct, through_value); } + fn type_non_empty(ty: &AlgebraicType) -> bool { + match ty { + AlgebraicType::Ref(_) => unreachable!(), + AlgebraicType::Array(elem_ty) => type_non_empty(&elem_ty.elem_ty), + AlgebraicType::Product(elems) => elems.iter().any(|e| type_non_empty(&e.algebraic_type)), + AlgebraicType::Sum(vars) => !vars.is_empty(), + _ => true, + } + } + proptest! { #[test] fn bsatn_enc_de_roundtrips((ty, val) in generate_typed_value()) { let bytes = to_vec(&val).unwrap(); + prop_assert_eq!(WithTypespace::empty(&ty).validate(Deserializer::new(&mut &bytes[..])), Ok(())); let val_decoded = AlgebraicValue::decode(&ty, &mut &bytes[..]).unwrap(); prop_assert_eq!(val, val_decoded); } + #[test] + fn bsatn_invalid_wont_decode(ty in generate_algebraic_type(), bytes in vec(any::(), 0..4096)) { + prop_assume!(type_non_empty(&ty)); + prop_assume!(WithTypespace::empty(&ty).validate(Deserializer::new(&mut &bytes[..])).is_err()); + prop_assert!(AlgebraicValue::decode(&ty, &mut &bytes[..]).is_err()); + } + #[test] fn bsatn_non_zero_one_u8_aint_bool(val in 2u8..) { let bytes = [val]; diff --git a/crates/sats/src/bsatn/de.rs b/crates/sats/src/bsatn/de.rs index 4fdfad9950c..777e1041b39 100644 --- a/crates/sats/src/bsatn/de.rs +++ b/crates/sats/src/bsatn/de.rs @@ -57,10 +57,18 @@ impl<'de, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'_, R> { visitor.visit_seq_product(self) } + fn validate_product>(self, visitor: V) -> Result<(), Self::Error> { + visitor.validate_seq_product(self) + } + fn deserialize_sum>(self, visitor: V) -> Result { visitor.visit_sum(self) } + fn validate_sum>(self, visitor: V) -> Result<(), Self::Error> { + visitor.validate_sum(self) + } + fn deserialize_bool(self) -> Result { let byte = self.reader.get_u8()?; match byte { @@ -132,6 +140,16 @@ impl<'de, R: BufReader<'de>> de::Deserializer<'de> for Deserializer<'_, R> { let seeds = itertools::repeat_n(seed, len); visitor.visit(ArrayAccess { de: self, seeds }) } + + fn validate_array_seed, T: de::DeserializeSeed<'de> + Clone>( + mut self, + visitor: V, + seed: T, + ) -> Result<(), Self::Error> { + let len = self.reborrow().deserialize_len()?; + let seeds = itertools::repeat_n(seed, len); + visitor.validate(ArrayAccess { de: self, seeds }) + } } impl<'de, R: BufReader<'de>> SeqProductAccess<'de> for Deserializer<'_, R> { @@ -140,6 +158,10 @@ impl<'de, R: BufReader<'de>> SeqProductAccess<'de> for Deserializer<'_, R> { fn next_element_seed>(&mut self, seed: T) -> Result, DecodeError> { seed.deserialize(self.reborrow()).map(Some) } + + fn validate_next_element_seed>(&mut self, seed: T) -> Result, Self::Error> { + seed.validate(self.reborrow()).map(Some) + } } impl<'de, R: BufReader<'de>> SumAccess<'de> for Deserializer<'_, R> { @@ -157,6 +179,9 @@ impl<'de, R: BufReader<'de>> VariantAccess<'de> for Deserializer<'_, R> { fn deserialize_seed>(self, seed: T) -> Result { seed.deserialize(self) } + fn validate_seed>(self, seed: T) -> Result<(), Self::Error> { + seed.validate(self) + } } /// Deserializer for array elements. @@ -176,6 +201,13 @@ impl<'de, R: BufReader<'de>, T: de::DeserializeSeed<'de> + Clone> de::ArrayAcces .transpose() } + fn validate_next_element(&mut self) -> Result, Self::Error> { + self.seeds + .next() + .map(|seed| seed.validate(self.de.reborrow())) + .transpose() + } + fn size_hint(&self) -> Option { Some(self.seeds.len()) } diff --git a/crates/sats/src/buffer.rs b/crates/sats/src/buffer.rs index 17b58ec488c..2a60f4e1b76 100644 --- a/crates/sats/src/buffer.rs +++ b/crates/sats/src/buffer.rs @@ -2,9 +2,8 @@ //! without relying on types in third party libraries like `bytes::Bytes`, etc. //! Meant to be kept slim and trim for use across both native and WASM. -use bytes::{BufMut, BytesMut}; - use crate::{i256, u256}; +use bytes::{BufMut, BytesMut}; use core::cell::Cell; use core::fmt; use core::str::Utf8Error; @@ -30,6 +29,8 @@ pub enum DecodeError { Other(String), } +pub type DecodeResult = Result; + impl fmt::Display for DecodeError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { diff --git a/crates/sats/src/de.rs b/crates/sats/src/de.rs index 8ef760f5a13..e20315415fe 100644 --- a/crates/sats/src/de.rs +++ b/crates/sats/src/de.rs @@ -36,6 +36,11 @@ pub trait Deserializer<'de>: Sized { /// Deserializes a product value from the input. fn deserialize_product>(self, visitor: V) -> Result; + /// Validates a product value from the input. + fn validate_product>(self, visitor: V) -> Result<(), Self::Error> { + self.deserialize_product(visitor).map(|_| ()) + } + /// Deserializes a sum value from the input. /// /// The entire process of deserializing a sum, starting from `deserialize(args...)`, is roughly: @@ -69,6 +74,40 @@ pub trait Deserializer<'de>: Sized { /// that can deserialize the contents of the variant. fn deserialize_sum>(self, visitor: V) -> Result; + /// Validates a sum value from the input. + /// + /// The entire process of validating a sum, starting from `validate(args...)`, is roughly: + /// + /// - [`validate`][Deserialize::validate] calls this method, + /// [`validate_sum(sum_visitor)`](Deserializer::validate_sum), + /// providing us with a [`sum_visitor`](SumVisitor). + /// + /// - This method calls [`sum_visitor.validate_sum(sum_access)`](SumVisitor::validate_sum), + /// where [`sum_access`](SumAccess) deals with extracting the tag and the variant data, + /// with the latter provided as [`VariantAccess`]). + /// The `SumVisitor` will then assemble these into the representation of a sum value + /// that the [`Deserialize`] implementation wants. + /// + /// - [`validate_sum`](SumVisitor::validate_sum) then calls + /// [`sum_access.variant(variant_visitor)`](SumAccess::variant), + /// and uses the provided `variant_visitor` to translate extracted variant names / tags + /// into something that is meaningful for `validate_sum`, e.g., an index. + /// + /// The call to `variant` will also return [`variant_access`](VariantAccess) + /// that can validate the contents of the variant. + /// + /// - Finally, after `variant` returns, + /// `validate_sum` validates the variant data using + /// [`variant_access.validate_seed(seed)`](VariantAccess::validate_seed) + /// or [`variant_access.validate()`](VariantAccess::validate). + /// This part may require some conditional logic depending on the identified variant. + /// + /// The data format will also return an object ([`VariantAccess`]) + /// that can validate the contents of the variant. + fn validate_sum>(self, visitor: V) -> Result<(), Self::Error> { + self.deserialize_sum(visitor).map(|_| ()) + } + /// Deserializes a `bool` value from the input. fn deserialize_bool(self) -> Result; @@ -144,6 +183,17 @@ pub trait Deserializer<'de>: Sized { visitor: V, seed: T, ) -> Result; + + /// Validates an array value. + /// + /// The validation is provided with a `seed` value. + fn validate_array_seed, T: DeserializeSeed<'de> + Clone>( + self, + visitor: V, + seed: T, + ) -> Result<(), Self::Error> { + self.deserialize_array_seed(visitor, seed).map(|_| ()) + } } /// The `Error` trait allows [`Deserialize`] implementations to create descriptive error messages @@ -266,7 +316,7 @@ fn fmt_invalid_len<'de>( } /// A visitor walking through a [`Deserializer`] for products. -pub trait ProductVisitor<'de> { +pub trait ProductVisitor<'de>: Sized { /// The resulting product. type Output; @@ -286,6 +336,16 @@ pub trait ProductVisitor<'de> { /// The input contains a named product. fn visit_named_product>(self, prod: A) -> Result; + + /// The input contains an unnamed product. + fn validate_seq_product>(self, prod: A) -> Result<(), A::Error> { + self.visit_seq_product(prod).map(|_| ()) + } + + /// The input contains a named product. + fn validate_named_product>(self, prod: A) -> Result<(), A::Error> { + self.visit_named_product(prod).map(|_| ()) + } } /// What kind of product is this? @@ -315,6 +375,14 @@ pub trait SeqProductAccess<'de> { self.next_element_seed(PhantomData) } + /// Statefully validates `T::Output` from the input provided a `seed` value. + /// + /// Returns `Ok(Some(()))` for the next element in the unnamed product, + /// or `Ok(None)` if there are no more remaining items. + fn validate_next_element>(&mut self) -> Result, Self::Error> { + self.validate_next_element_seed(PhantomData::) + } + /// Statefully deserializes `T::Output` from the input provided a `seed` value. /// /// Returns `Ok(Some(value))` for the next element in the unnamed product, @@ -323,6 +391,14 @@ pub trait SeqProductAccess<'de> { /// [`Deserialize`] implementations should typically use /// [`next_element`](SeqProductAccess::next_element) instead. fn next_element_seed>(&mut self, seed: T) -> Result, Self::Error>; + + /// Statefully validates `T::Output` from the input provided a `seed` value. + /// + /// Returns `Ok(Some(()))` for the next element in the unnamed product, + /// or `Ok(None)` if there are no more remaining items. + fn validate_next_element_seed>(&mut self, seed: T) -> Result, Self::Error> { + self.next_element_seed(seed).map(|opt| opt.map(|_| ())) + } } /// Provides a [`ProductVisitor`] with access to each element of the named product in the input. @@ -344,11 +420,24 @@ pub trait NamedProductAccess<'de> { self.get_field_value_seed(PhantomData) } + /// Deserializes field value of type `T` from the input. + /// + /// This method exists as a convenience for [`Deserialize`] implementations. + /// [`NamedProductAccess`] implementations should not override the default behavior. + fn validate_field_value>(&mut self) -> Result<(), Self::Error> { + self.validate_field_value_seed(PhantomData::) + } + /// Statefully deserializes the field value `T::Output` from the input provided a `seed` value. /// /// [`Deserialize`] implementations should typically use /// [`next_element`](NamedProductAccess::get_field_value) instead. fn get_field_value_seed>(&mut self, seed: T) -> Result; + + /// Statefully validates the field value `T::Output` from the input provided a `seed` value. + fn validate_field_value_seed>(&mut self, seed: T) -> Result<(), Self::Error> { + self.get_field_value_seed(seed).map(|_| ()) + } } /// Visitor used to deserialize the name of a field. @@ -404,6 +493,17 @@ pub trait SumVisitor<'de> { /// The data format will also return an object ([`VariantAccess`]) /// that can deserialize the contents of the variant. fn visit_sum>(self, data: A) -> Result; + + /// Drives the validation of a sum value. + /// + /// This method will ask the data format ([`A: SumAccess`][SumAccess]) + /// which variant of the sum to select in terms of a variant name / tag. + /// `A` will use a [`VariantVisitor`], that `SumVisitor` has provided, + /// to translate into something that is meaningful for `visit_sum`, e.g., an index. + /// + /// The data format will also return an object ([`VariantAccess`]) + /// that can validate the contents of the variant. + fn validate_sum>(self, data: A) -> Result<(), A::Error>; } /// Provides a [`SumVisitor`] access to the data of a sum in the input. @@ -458,6 +558,18 @@ pub trait VariantAccess<'de>: Sized { /// Called when deserializing the contents of a sum variant, and provided with a `seed` value. fn deserialize_seed>(self, seed: T) -> Result; + + /// Called when validating the contents of a sum variant. + /// + /// This method exists as a convenience for [`Deserialize`] implementations. + fn validate>(self) -> Result<(), Self::Error> { + self.validate_seed(PhantomData::) + } + + /// Called when validating the contents of a sum variant, and provided with a `seed` value. + fn validate_seed>(self, seed: T) -> Result<(), Self::Error> { + self.deserialize_seed(seed).map(|_| ()) + } } /// A `SliceVisitor` is provided a slice `T` of some elements by a [`Deserializer`] @@ -484,12 +596,18 @@ pub trait SliceVisitor<'de, T: ToOwned + ?Sized>: Sized { } /// A visitor walking through a [`Deserializer`] for arrays. -pub trait ArrayVisitor<'de, T> { +pub trait ArrayVisitor<'de, T>: Sized { /// The output produced by this visitor. type Output; - /// The input contains an array. + /// The input contains an array, deserialize it. fn visit>(self, vec: A) -> Result; + + /// The input contains an array, but just validate it, don't deserialize. + fn validate>(self, vec: A) -> Result<(), A::Error> { + let _ = self.visit(vec)?; + Ok(()) + } } /// Provides an [`ArrayVisitor`] with access to each element of the array in the input. @@ -506,6 +624,13 @@ pub trait ArrayAccess<'de> { /// or `Ok(None)` if there are no more remaining elements. fn next_element(&mut self) -> Result, Self::Error>; + /// This returns `Ok(Some(()))` for the next element in the array, + /// or `Ok(None)` if there are no more remaining elements. + fn validate_next_element(&mut self) -> Result, Self::Error> { + let opt = self.next_element()?; + Ok(opt.map(|_| ())) + } + /// Returns the number of elements remaining in the array, if known. fn size_hint(&self) -> Option { None @@ -513,13 +638,23 @@ pub trait ArrayAccess<'de> { } /// `DeserializeSeed` is the stateful form of the [`Deserialize`] trait. -pub trait DeserializeSeed<'de> { +pub trait DeserializeSeed<'de>: Sized { /// The type produced by using this seed. type Output; /// Equivalent to the more common [`Deserialize::deserialize`] associated function, /// except with some initial piece of data (the seed `self`) passed in. fn deserialize>(self, deserializer: D) -> Result; + + /// Validate that the input is of the correct form for this seed. + /// + /// The default implementation simply deserializes the input and discards the result, + /// but implementations can override this to perform more efficient validation + /// without fully deserializing the input. + fn validate>(self, deserializer: D) -> Result<(), D::Error> { + let _ = self.deserialize(deserializer)?; + Ok(()) + } } use crate::de::impls::BorrowedSliceVisitor; @@ -563,6 +698,18 @@ pub trait Deserialize<'de>: Sized { fn __deserialize_array, const N: usize>(deserializer: D) -> Result<[Self; N], D::Error> { deserializer.deserialize_array(BasicArrayVisitor) } + + #[doc(hidden)] + #[inline(always)] + /// Validate that the input is of the correct form for this type. + /// + /// The default implementation simply deserializes the input and discards the result, + /// but implementations can override this to perform more efficient validation + /// without fully deserializing the input. + fn validate>(deserializer: D) -> Result<(), D::Error> { + let _ = Self::deserialize(deserializer)?; + Ok(()) + } } /// A data structure that can be deserialized in SATS @@ -614,6 +761,12 @@ pub fn array_visit<'de, A: ArrayAccess<'de>, V: GrowingVec>(mut acce Ok(v) } +/// A basic implementation of `ArrayVisitor::validate`. +pub fn array_validate<'de, A: ArrayAccess<'de>>(mut access: A) -> Result<(), A::Error> { + while access.next_element()?.is_some() {} + Ok(()) +} + /// An implementation of [`ArrayVisitor<'de, T>`] where the output is a `Vec`. pub struct BasicVecVisitor; @@ -623,6 +776,10 @@ impl<'de, T> ArrayVisitor<'de, T> for BasicVecVisitor { fn visit>(self, vec: A) -> Result { array_visit(vec) } + + fn validate>(self, vec: A) -> Result<(), A::Error> { + array_validate(vec) + } } /// An implementation of [`ArrayVisitor<'de, T>`] where the output is a `SmallVec<[T; N]>`. @@ -634,6 +791,10 @@ impl<'de, T, const N: usize> ArrayVisitor<'de, T> for BasicSmallVecVisitor { fn visit>(self, vec: A) -> Result { array_visit(vec) } + + fn validate>(self, vec: A) -> Result<(), A::Error> { + array_validate(vec) + } } /// An implementation of [`ArrayVisitor<'de, T>`] where the output is a `[T; N]`. @@ -650,6 +811,23 @@ impl<'de, T, const N: usize> ArrayVisitor<'de, T> for BasicArrayVisitor { } v.into_inner().map_err(|_| Error::custom("too few elements for array")) } + + fn validate>(self, mut vec: A) -> Result<(), A::Error> { + // Validate each element and count. + let mut count = 0; + while vec.next_element()?.is_some() { + count += 1; + } + // Don't do this in the loop, + // as we bias towards there not being any errors. + if count > N { + return Err(Error::custom("too many elements for array")); + } + if count < N { + return Err(Error::custom("too few elements for array")); + } + Ok(()) + } } /// Provided a list of names, @@ -745,6 +923,9 @@ impl<'de, D: Deserializer<'de>> VariantAccess<'de> for SomeAccess { fn deserialize_seed>(self, seed: T) -> Result { seed.deserialize(self.0) } + fn validate_seed>(self, seed: T) -> Result<(), Self::Error> { + seed.validate(self.0) + } } /// A `Deserializer` that represents a unit value. diff --git a/crates/sats/src/de/impls.rs b/crates/sats/src/de/impls.rs index e0b77001224..922bd1866f6 100644 --- a/crates/sats/src/de/impls.rs +++ b/crates/sats/src/de/impls.rs @@ -3,7 +3,7 @@ use super::{ ProductKind, ProductVisitor, SeqProductAccess, SliceVisitor, SumAccess, SumVisitor, VariantAccess, VariantVisitor, }; use crate::{ - de::{array_visit, ArrayAccess, ArrayVisitor, GrowingVec}, + de::{array_validate, array_visit, ArrayAccess, ArrayVisitor, BasicArrayVisitor, GrowingVec}, AlgebraicType, AlgebraicValue, ArrayType, ArrayValue, ProductType, ProductTypeElement, ProductValue, SumType, SumValue, WithTypespace, F32, F64, }; @@ -30,9 +30,16 @@ use std::{borrow::Cow, rc::Rc, sync::Arc}; /// ``` #[macro_export] macro_rules! impl_deserialize { - ([$($generics:tt)*] $(where [$($wc:tt)*])? $typ:ty, $de:ident => $body:expr) => { + ( + [$($generics:tt)*] $(where [$($wc:tt)*])? $typ:ty, + $de:ident => $body:expr + $(, $validate_de:ident => $validate:expr)? + ) => { impl<'de, $($generics)*> $crate::de::Deserialize<'de> for $typ { fn deserialize>($de: D) -> Result { $body } + $( + fn validate>($validate_de: D) -> Result<(), D::Error> { $validate } + )? } }; } @@ -106,6 +113,16 @@ macro_rules! impl_deserialize_tuple { Ok(($($ty_name,)*)) } + fn validate_seq_product>(self, mut _prod: A) -> Result<(), A::Error> { + $( + #[allow(non_snake_case)] + _prod + .validate_next_element_seed(PhantomData::<$ty_name>)? + .ok_or_else(|| Error::invalid_product_length($const_val, &self))?; + )* + + Ok(()) + } fn visit_named_product>(self, mut prod: A) -> Result { $( #[allow(non_snake_case)] @@ -128,6 +145,34 @@ macro_rules! impl_deserialize_tuple { $ty_name.ok_or_else(|| A::Error::missing_field($const_val, None, &self))?, )*)) } + fn validate_named_product>(self, mut prod: A) -> Result<(), A::Error> { + $( + #[allow(non_snake_case)] + let mut $ty_name = false; + )* + + let visit = TupleNameVisitorMax(self.product_len()); + while let Some(index) = prod.get_field_ident(visit)? { + match index { + $($const_val => { + if $ty_name { + return Err(A::Error::duplicate_field($const_val, None, &self)) + } + prod.validate_field_value::<$ty_name>()?; + $ty_name = true; + })* + index => return Err(Error::invalid_product_length(index, &self)), + } + } + + $( + if !$ty_name { + return Err(A::Error::missing_field($const_val, None, &self)) + } + )* + + Ok(()) + } } impl_deserialize!([$($ty_name: Deserialize<'de>),*] ($($ty_name,)*), de => { @@ -168,17 +213,51 @@ impl<'de> Deserialize<'de> for u8 { impl_deserialize!([] F32, de => f32::deserialize(de).map(Into::into)); impl_deserialize!([] F64, de => f64::deserialize(de).map(Into::into)); -impl_deserialize!([] String, de => de.deserialize_str(OwnedSliceVisitor)); -impl_deserialize!([] LeanString, de => >::deserialize(de).map(|s| (&*s).into())); -impl_deserialize!([T: Deserialize<'de>] Vec, de => T::__deserialize_vec(de)); -impl_deserialize!([T: Deserialize<'de>, const N: usize] SmallVec<[T; N]>, de => { - de.deserialize_array(BasicSmallVecVisitor) -}); -impl_deserialize!([T: Deserialize<'de>, const N: usize] [T; N], de => T::__deserialize_array(de)); -impl_deserialize!([] Box, de => String::deserialize(de).map(|s| s.into_boxed_str())); -impl_deserialize!([T: Deserialize<'de>] Box<[T]>, de => Vec::deserialize(de).map(|s| s.into_boxed_slice())); -impl_deserialize!([T: Deserialize<'de>] Rc<[T]>, de => Vec::deserialize(de).map(|s| s.into())); -impl_deserialize!([T: Deserialize<'de>] Arc<[T]>, de => Vec::deserialize(de).map(|s| s.into())); +impl_deserialize!( + [] String, + de => de.deserialize_str(OwnedSliceVisitor), + de => <&str>::validate(de) +); +impl_deserialize!( + [] LeanString, + de => >::deserialize(de).map(|s| (&*s).into()), + de => <&str>::validate(de) +); +impl_deserialize!( + [T: Deserialize<'de>] Vec, + de => T::__deserialize_vec(de), + de => de.validate_array_seed(BasicVecVisitor, PhantomData::) +); +impl_deserialize!( + [T: Deserialize<'de>, const N: usize] SmallVec<[T; N]>, + de => de.deserialize_array(BasicSmallVecVisitor), + de => de.validate_array_seed(BasicVecVisitor, PhantomData::) +); +impl_deserialize!( + [T: Deserialize<'de>, const N: usize] [T; N], + de => T::__deserialize_array(de), + de => de.validate_array_seed(BasicArrayVisitor::, PhantomData::) +); +impl_deserialize!( + [] Box, + de => String::deserialize(de).map(|s| s.into_boxed_str()), + de => String::validate(de) +); +impl_deserialize!( + [T: Deserialize<'de>] Box<[T]>, + de => Vec::deserialize(de).map(|s| s.into_boxed_slice()), + de => Vec::::validate(de) +); +impl_deserialize!( + [T: Deserialize<'de>] Rc<[T]>, + de => Vec::deserialize(de).map(|s| s.into()), + de => Vec::::validate(de) +); +impl_deserialize!( + [T: Deserialize<'de>] Arc<[T]>, + de => Vec::deserialize(de).map(|s| s.into()), + de => Vec::::validate(de) +); /// The visitor converts the slice to its owned version. struct OwnedSliceVisitor; @@ -232,8 +311,16 @@ impl<'de, T: ToOwned + ?Sized + 'de> SliceVisitor<'de, T> for BorrowedSliceVisit } } -impl_deserialize!([] Cow<'de, str>, de => de.deserialize_str(CowSliceVisitor)); -impl_deserialize!([] Cow<'de, [u8]>, de => de.deserialize_bytes(CowSliceVisitor)); +impl_deserialize!( + [] Cow<'de, str>, + de => de.deserialize_str(CowSliceVisitor), + de => <&str>::validate(de) +); +impl_deserialize!( + [] Cow<'de, [u8]>, + de => de.deserialize_bytes(CowSliceVisitor), + de => <&[u8]>::validate(de) +); /// The visitor works with either owned or borrowed versions to produce `Cow<'de, T>`. struct CowSliceVisitor; @@ -254,7 +341,11 @@ impl<'de, T: ToOwned + ?Sized + 'de> SliceVisitor<'de, T> for CowSliceVisitor { } } -impl_deserialize!([T: Deserialize<'de>] Box, de => T::deserialize(de).map(Box::new)); +impl_deserialize!( + [T: Deserialize<'de>] Box, + de => T::deserialize(de).map(Box::new), + de => T::validate(de) +); impl_deserialize!([T: Deserialize<'de>] Option, de => de.deserialize_sum(OptionVisitor(PhantomData))); /// The visitor deserializes an `Option`. @@ -283,6 +374,18 @@ impl<'de, T: Deserialize<'de>> SumVisitor<'de> for OptionVisitor { None }) } + + fn validate_sum>(self, data: A) -> Result<(), A::Error> { + // Determine the variant. + let (some, data) = data.variant(self)?; + + // Validate contents for it. + if some { + data.validate::() + } else { + data.validate::<()>() + } + } } impl<'de, T: Deserialize<'de>> VariantVisitor<'de> for OptionVisitor { @@ -340,6 +443,14 @@ impl<'de, T: Deserialize<'de>, E: Deserialize<'de>> SumVisitor<'de> for ResultVi ResultVariant::Err => Err(data.deserialize()?), }) } + + fn validate_sum>(self, data: A) -> Result<(), A::Error> { + let (variant, data) = data.variant(self)?; + match variant { + ResultVariant::Ok => data.validate::(), + ResultVariant::Err => data.validate::(), + } + } } impl<'de, T: Deserialize<'de>, U: Deserialize<'de>> VariantVisitor<'de> for ResultVisitor { @@ -407,6 +518,18 @@ impl<'de, S: Copy + DeserializeSeed<'de>> SumVisitor<'de> for BoundVisitor { BoundVariant::Unbounded => data.deserialize::<()>().map(|_| Bound::Unbounded), } } + + fn validate_sum>(self, data: A) -> Result<(), A::Error> { + // Determine the variant. + let this = self.0; + let (variant, data) = data.variant(self)?; + + // Validate contents for it. + match variant { + BoundVariant::Included | BoundVariant::Excluded => data.validate_seed(this), + BoundVariant::Unbounded => data.validate::<()>(), + } + } } impl<'de, T: Copy + DeserializeSeed<'de>> VariantVisitor<'de> for BoundVisitor { @@ -463,6 +586,31 @@ impl<'de> DeserializeSeed<'de> for WithTypespace<'_, AlgebraicType> { AlgebraicType::String => >::deserialize(de).map(Into::into), } } + + fn validate>(self, de: D) -> Result<(), D::Error> { + match self.ty() { + AlgebraicType::Ref(r) => self.resolve(*r).validate(de), + AlgebraicType::Sum(sum) => self.with(sum).validate(de), + AlgebraicType::Product(prod) => self.with(prod).validate(de), + AlgebraicType::Array(ty) => self.with(ty).validate(de), + AlgebraicType::Bool => bool::validate(de), + AlgebraicType::I8 => i8::validate(de), + AlgebraicType::U8 => u8::validate(de), + AlgebraicType::I16 => i16::validate(de), + AlgebraicType::U16 => u16::validate(de), + AlgebraicType::I32 => i32::validate(de), + AlgebraicType::U32 => u32::validate(de), + AlgebraicType::I64 => i64::validate(de), + AlgebraicType::U64 => u64::validate(de), + AlgebraicType::I128 => i128::validate(de), + AlgebraicType::U128 => u128::validate(de), + AlgebraicType::I256 => i256::validate(de), + AlgebraicType::U256 => u256::validate(de), + AlgebraicType::F32 => f32::validate(de), + AlgebraicType::F64 => f64::validate(de), + AlgebraicType::String => >::validate(de), + } + } } impl<'de> DeserializeSeed<'de> for WithTypespace<'_, SumType> { @@ -471,6 +619,10 @@ impl<'de> DeserializeSeed<'de> for WithTypespace<'_, SumType> { fn deserialize>(self, deserializer: D) -> Result { deserializer.deserialize_sum(self) } + + fn validate>(self, deserializer: D) -> Result<(), D::Error> { + deserializer.validate_sum(self) + } } impl<'de> SumVisitor<'de> for WithTypespace<'_, SumType> { @@ -492,6 +644,14 @@ impl<'de> SumVisitor<'de> for WithTypespace<'_, SumType> { let value = Box::new(data.deserialize_seed(variant_ty)?); Ok(SumValue { tag, value }) } + + fn validate_sum>(self, data: A) -> Result<(), A::Error> { + let (tag, data) = data.variant(self)?; + // Find the variant type by `tag`. + let variant_ty = self.map(|ty| &ty.variants[tag as usize].algebraic_type); + + data.validate_seed(variant_ty) + } } impl VariantVisitor<'_> for WithTypespace<'_, SumType> { @@ -529,6 +689,10 @@ impl<'de> DeserializeSeed<'de> for WithTypespace<'_, ProductType> { fn deserialize>(self, deserializer: D) -> Result { deserializer.deserialize_product(self.map(|pt| &*pt.elements)) } + + fn validate>(self, deserializer: D) -> Result<(), D::Error> { + deserializer.validate_product(self.map(|pt| &*pt.elements)) + } } impl<'de> DeserializeSeed<'de> for WithTypespace<'_, [ProductTypeElement]> { @@ -537,6 +701,10 @@ impl<'de> DeserializeSeed<'de> for WithTypespace<'_, [ProductTypeElement]> { fn deserialize>(self, deserializer: D) -> Result { deserializer.deserialize_product(self) } + + fn validate>(self, deserializer: D) -> Result<(), D::Error> { + deserializer.validate_product(self) + } } impl<'de> ProductVisitor<'de> for WithTypespace<'_, [ProductTypeElement]> { @@ -553,9 +721,17 @@ impl<'de> ProductVisitor<'de> for WithTypespace<'_, [ProductTypeElement]> { visit_seq_product(self, &self, tup) } + fn validate_seq_product>(self, prod: A) -> Result<(), A::Error> { + validate_seq_product(self, &self, prod) + } + fn visit_named_product>(self, tup: A) -> Result { visit_named_product(self, &self, tup) } + + fn validate_named_product>(self, prod: A) -> Result<(), A::Error> { + validate_named_product(self, &self, prod) + } } impl<'de> DeserializeSeed<'de> for WithTypespace<'_, ArrayType> { @@ -614,37 +790,46 @@ impl<'de> DeserializeSeed<'de> for WithTypespace<'_, ArrayType> { }; } } -} - -// impl<'de> DeserializeSeed<'de> for &ReducerDef { -// type Output = ProductValue; - -// fn deserialize>(self, deserializer: D) -> Result { -// deserializer.deserialize_product(self) -// } -// } - -// impl<'de> ProductVisitor<'de> for &ReducerDef { -// type Output = ProductValue; -// fn product_name(&self) -> Option<&str> { -// self.name.as_deref() -// } -// fn product_len(&self) -> usize { -// self.args.len() -// } -// fn product_kind(&self) -> ProductKind { -// ProductKind::ReducerArgs -// } + fn validate>(self, deserializer: D) -> Result<(), D::Error> { + /// Validate a vector for the appropriate `ArrayValue` variant. + fn val_array<'de, D: Deserializer<'de>, T: Deserialize<'de>>(de: D) -> Result<(), D::Error> { + de.validate_array_seed(BasicVecVisitor, PhantomData::) + } -// fn visit_seq_product>(self, tup: A) -> Result { -// visit_seq_product(&self.args, &self, tup) -// } + let mut ty = &*self.ty().elem_ty; -// fn visit_named_product>(self, tup: A) -> Result { -// visit_named_product(&self.args, &self, tup) -// } -// } + // Loop, resolving `Ref`s, until we reach a non-`Ref` type. + loop { + break match ty { + AlgebraicType::Ref(r) => { + // The only arm that will loop. + ty = self.resolve(*r).ty(); + continue; + } + AlgebraicType::Sum(ty) => deserializer.validate_array_seed(BasicVecVisitor, self.with(ty)), + AlgebraicType::Product(ty) => deserializer.validate_array_seed(BasicVecVisitor, self.with(ty)), + AlgebraicType::Array(ty) => deserializer.validate_array_seed(BasicVecVisitor, self.with(ty)), + &AlgebraicType::Bool => val_array::<_, bool>(deserializer), + &AlgebraicType::I8 => val_array::<_, i8>(deserializer), + &AlgebraicType::U8 => val_array::<_, u8>(deserializer), + &AlgebraicType::I16 => val_array::<_, i16>(deserializer), + &AlgebraicType::U16 => val_array::<_, u16>(deserializer), + &AlgebraicType::I32 => val_array::<_, i32>(deserializer), + &AlgebraicType::U32 => val_array::<_, u32>(deserializer), + &AlgebraicType::I64 => val_array::<_, i64>(deserializer), + &AlgebraicType::U64 => val_array::<_, u64>(deserializer), + &AlgebraicType::I128 => val_array::<_, i128>(deserializer), + &AlgebraicType::U128 => val_array::<_, u128>(deserializer), + &AlgebraicType::I256 => val_array::<_, i256>(deserializer), + &AlgebraicType::U256 => val_array::<_, u256>(deserializer), + &AlgebraicType::F32 => val_array::<_, f32>(deserializer), + &AlgebraicType::F64 => val_array::<_, f64>(deserializer), + &AlgebraicType::String => val_array::<_, String>(deserializer), + }; + } + } +} /// Deserialize, provided the fields' types, a product value with unnamed fields. pub fn visit_seq_product<'de, A: SeqProductAccess<'de>>( @@ -660,6 +845,19 @@ pub fn visit_seq_product<'de, A: SeqProductAccess<'de>>( Ok(ProductValue { elements }) } +/// Validate, provided the fields' types, a product value with unnamed fields. +pub fn validate_seq_product<'de, A: SeqProductAccess<'de>>( + elems: WithTypespace<[ProductTypeElement]>, + visitor: &impl ProductVisitor<'de>, + mut tup: A, +) -> Result<(), A::Error> { + for (i, el) in elems.ty().iter().enumerate() { + tup.validate_next_element_seed(elems.with(&el.algebraic_type))? + .ok_or_else(|| Error::invalid_product_length(i, visitor))?; + } + Ok(()) +} + /// Deserialize, provided the fields' types, a product value with named fields. pub fn visit_named_product<'de, A: super::NamedProductAccess<'de>>( elems_tys: WithTypespace<[ProductTypeElement]>, @@ -705,6 +903,46 @@ pub fn visit_named_product<'de, A: super::NamedProductAccess<'de>>( Ok(ProductValue { elements }) } +/// Validate, provided the fields' types, a product value with named fields. +pub fn validate_named_product<'de, A: super::NamedProductAccess<'de>>( + elems_tys: WithTypespace<[ProductTypeElement]>, + visitor: &impl ProductVisitor<'de>, + mut tup: A, +) -> Result<(), A::Error> { + let elems = elems_tys.ty(); + // TODO(perf): replace with bitset. + let mut elements = vec![false; elems.len()]; + let kind = visitor.product_kind(); + + // Deserialize a product value corresponding to each product type field. + // This is worst case quadratic in complexity + // as fields can be specified out of order (value side) compared to `elems` (type side). + for _ in 0..elems.len() { + // Deserialize a field name, match against the element types. + let index = tup.get_field_ident(TupleNameVisitor { elems, kind })?.ok_or_else(|| { + // Couldn't deserialize a field name. + // Find the first field name we haven't filled an element for. + let missing = elements.iter().position(|&field| !field).unwrap(); + let field_name = elems[missing].name().map(|n| &**n); + Error::missing_field(missing, field_name, visitor) + })?; + + let element = &elems[index]; + + // By index we can select which element to deserialize a value for. + let slot = &mut elements[index]; + if *slot { + return Err(Error::duplicate_field(index, element.name().map(|n| &**n), visitor)); + } + + // Deserialize the value for this field's type. + tup.validate_field_value_seed(elems_tys.with(&element.algebraic_type))?; + *slot = true; + } + + Ok(()) +} + /// A visitor for extracting indices of field names in the elements of a [`ProductType`]. struct TupleNameVisitor<'a> { /// The elements of a product type, in order. @@ -768,19 +1006,35 @@ impl_deserialize!([] spacetimedb_primitives::ColList, de => { fn visit>(self, vec: A) -> Result { array_visit(vec) } + + fn validate>(self, vec: A) -> Result<(), A::Error> { + array_validate(vec) + } } de.deserialize_array(ColListVisitor) }); -impl_deserialize!([] spacetimedb_primitives::ColSet, de => ColList::deserialize(de).map(Into::into)); +impl_deserialize!( + [] spacetimedb_primitives::ColSet, + de => ColList::deserialize(de).map(Into::into), + de => ColList::validate(de) +); #[cfg(feature = "blake3")] impl_deserialize!([] blake3::Hash, de => <[u8; blake3::OUT_LEN]>::deserialize(de).map(blake3::Hash::from_bytes)); // TODO(perf): integrate Bytes with Deserializer to reduce copying -impl_deserialize!([] bytes::Bytes, de => >::deserialize(de).map(Into::into)); +impl_deserialize!( + [] bytes::Bytes, + de => >::deserialize(de).map(Into::into), + de => <&[u8]>::validate(de) +); #[cfg(feature = "bytestring")] -impl_deserialize!([] bytestring::ByteString, de => ::deserialize(de).map(Into::into)); +impl_deserialize!( + [] bytestring::ByteString, + de => ::deserialize(de).map(Into::into), + de => <&str>::validate(de) +); #[cfg(test)] mod test { diff --git a/crates/sats/src/layout.rs b/crates/sats/src/layout.rs index ac123de363d..b9e8f919899 100644 --- a/crates/sats/src/layout.rs +++ b/crates/sats/src/layout.rs @@ -959,9 +959,22 @@ impl<'de> ProductVisitor<'de> for ProductTypeLayoutView<'_> { Ok(elems.into()) } + fn validate_seq_product>(self, mut tup: A) -> Result<(), A::Error> { + for (i, elem_ty) in self.elements.iter().enumerate() { + if tup.validate_next_element_seed(&elem_ty.ty)?.is_none() { + return Err(A::Error::invalid_product_length(i, &self)); + } + } + Ok(()) + } + fn visit_named_product>(self, _: A) -> Result { unreachable!() } + + fn validate_named_product>(self, _: A) -> Result<(), A::Error> { + unreachable!() + } } impl<'de> DeserializeSeed<'de> for &SumTypeLayout { @@ -1000,6 +1013,14 @@ impl<'de> SumVisitor<'de> for &SumTypeLayout { let value = data.deserialize_seed(variant_ty)?; Ok(SumValue::new(tag, value)) } + + fn validate_sum>(self, data: A) -> Result<(), A::Error> { + let (tag, data) = data.variant(self)?; + // Find the variant type by `tag`. + let variant_ty = &self.variants[tag as usize].ty; + + data.validate_seed(variant_ty) + } } impl VariantVisitor<'_> for &SumTypeLayout {