diff --git a/conda.recipe/meta.yaml b/conda.recipe/meta.yaml index a69167a..88a554f 100755 --- a/conda.recipe/meta.yaml +++ b/conda.recipe/meta.yaml @@ -5,13 +5,13 @@ package: requirements: build: - python - - "marshmallow>=3.0.0rc4" + - "marshmallow>=4.0.0" - "numpy>=1.13" - "python-dateutil>=2.8.0" run: - python - - "marshmallow>=3.0.0rc4" + - "marshmallow>=4.0.0" - "numpy>=1.13" - "python-dateutil>=2.8.0" diff --git a/environment.yml b/environment.yml index 7903a13..5c09981 100644 --- a/environment.yml +++ b/environment.yml @@ -2,7 +2,7 @@ name: paramtools-dev channels: - conda-forge dependencies: - - "marshmallow>=3.22.0" + - "marshmallow>=4.0.0" - "numpy>=2.1.0" - "python-dateutil>=2.8.0" - "pytest>=6.0.0" diff --git a/paramtools/parameters.py b/paramtools/parameters.py index fe3e908..0830049 100644 --- a/paramtools/parameters.py +++ b/paramtools/parameters.py @@ -101,7 +101,7 @@ def __init__( else: self._stateless_label_grid[name] = [] self.label_grid = copy.deepcopy(self._stateless_label_grid) - self._validator_schema.context["spec"] = self + self._validator_schema.pt_context["spec"] = self self._warnings = {} self._errors = {} self._defer_validation = False @@ -364,7 +364,7 @@ def _adjust( for param, value in parsed_params.items(): self._update_param(param, value) - self._validator_schema.context["spec"] = self + self._validator_schema.pt_context["spec"] = self has_errors = bool(self._errors.get("messages")) has_warnings = bool(self._warnings.get("messages")) @@ -525,7 +525,7 @@ def _delete( if self.label_to_extend is not None and extend_adj: self.extend() - self._validator_schema.context["spec"] = self + self._validator_schema.pt_context["spec"] = self has_errors = bool(self._errors.get("messages")) has_warnings = bool(self._warnings.get("messages")) @@ -1414,4 +1414,4 @@ def get_defaults(self): - `params`: String if URL or file path. Dict if this is the loaded params dict. """ - return utils.read_json(self.defaults) \ No newline at end of file + return utils.read_json(self.defaults) diff --git a/paramtools/schema.py b/paramtools/schema.py index f67b8f3..4d6c3f7 100644 --- a/paramtools/schema.py +++ b/paramtools/schema.py @@ -7,6 +7,7 @@ validates_schema, ValidationError as MarshmallowValidationError, decorators, + RAISE as RAISEUNKNOWNOPTION, ) from marshmallow.error_store import ErrorStore @@ -177,8 +178,9 @@ class BaseValidatorSchema(Schema): "when": "_get_when_validator", } - def __init__(self, data): - self.context = {} + def __init__(self, *args, **kwargs): + self.pt_context = {} + super().__init__(*args, **kwargs) def validate_only(self, data): """ @@ -186,7 +188,7 @@ def validate_only(self, data): from the marshmallow _do_load function: https://github.com/marshmallow-code/marshmallow/blob/3.5.2/src/marshmallow/schema.py#L807 """ - self.fields = data.keys() + # self.fields = data.keys() error_store = ErrorStore() # Run field-level validation self._invoke_field_validators( @@ -197,21 +199,23 @@ def validate_only(self, data): field_errors = bool(error_store.errors) self._invoke_schema_validators( error_store=error_store, - pass_many=True, + pass_collection=True, data=data, original_data=data, many=None, partial=None, field_errors=field_errors, + unknown=RAISEUNKNOWNOPTION, ) self._invoke_schema_validators( error_store=error_store, - pass_many=False, + pass_collection=False, data=data, original_data=data, many=None, partial=None, field_errors=field_errors, + unknown=RAISEUNKNOWNOPTION, ) errors = error_store.errors if errors: @@ -260,7 +264,7 @@ def validate_param(self, param_name, param_spec, raw_data): Do range validation for a parameter. """ validate_schema = not getattr( - self.context["spec"], "_defer_validation", False + self.pt_context["spec"], "_defer_validation", False ) validators = self.validators( param_name, param_spec, raw_data, validate_schema=validate_schema @@ -279,7 +283,7 @@ def validate_param(self, param_name, param_spec, raw_data): return warnings, errors def field_keyfunc(self, param_name): - data = self.context["spec"]._data[param_name] + data = self.pt_context["spec"]._data[param_name] field = get_type(data, self.validators(param_name)) try: return field.cmp_funcs()["key"] @@ -287,7 +291,7 @@ def field_keyfunc(self, param_name): return None def field(self, param_name): - data = self.context["spec"]._data[param_name] + data = self.pt_context["spec"]._data[param_name] return get_type(data, self.validators(param_name)) def validators( @@ -298,7 +302,7 @@ def validators( if raw_data is None: raw_data = {} - param_info = self.context["spec"]._data[param_name] + param_info = self.pt_context["spec"]._data[param_name] # sort keys to guarantee order. validator_spec = param_info.get("validators", {}) validators = [] @@ -336,7 +340,7 @@ def _get_when_validator( when_param = when_dict["param"] if ( - when_param not in self.context["spec"]._data.keys() + when_param not in self.pt_context["spec"]._data.keys() and when_param != "default" ): raise MarshmallowValidationError( @@ -371,8 +375,8 @@ def _get_when_validator( ) ) - _type = self.context["spec"]._data[oth_param]["type"] - number_dims = self.context["spec"]._data[oth_param]["number_dims"] + _type = self.pt_context["spec"]._data[oth_param]["type"] + number_dims = self.pt_context["spec"]._data[oth_param]["number_dims"] error_then = ( f"When {oth_param}{{when_labels}}{{ix}} is {{is_val}}, " @@ -458,9 +462,9 @@ def _get_range_validator( ) def _sort_by_label_to_extend(self, vos): - label_to_extend = self.context["spec"].label_to_extend + label_to_extend = self.pt_context["spec"].label_to_extend if label_to_extend is not None: - label_grid = self.context["spec"]._stateless_label_grid + label_grid = self.pt_context["spec"]._stateless_label_grid extend_vals = label_grid[label_to_extend] return sorted( vos, @@ -522,9 +526,9 @@ def _get_related_value( # If comparing against the "default" value then get the current # value of the parameter being updated. if oth_param_name == "default": - oth_param = self.context["spec"]._data[param_name] + oth_param = self.pt_context["spec"]._data[param_name] else: - oth_param = self.context["spec"]._data[oth_param_name] + oth_param = self.pt_context["spec"]._data[oth_param_name] vals = oth_param["value"] labs_to_check = {k for k in param_spec if k not in ("value", "_auto")} if labs_to_check: @@ -549,11 +553,11 @@ def _check_ndim_restriction( if other_param is None: continue if other_param == "default": - ndims = self.context["spec"]._data[param_name][ + ndims = self.pt_context["spec"]._data[param_name][ "number_dims" ] else: - ndims = self.context["spec"]._data[other_param][ + ndims = self.pt_context["spec"]._data[other_param][ "number_dims" ] if ndims > 0: diff --git a/setup.py b/setup.py index 4cbb390..468a191 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ url="https://github.com/hdoupe/ParamTools", packages=setuptools.find_packages(), install_requires=[ - "marshmallow>=3.0.0", + "marshmallow>=4.0.0", "numpy", "python-dateutil>=2.8.0", "fsspec",