Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions conda.recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ package:
requirements:
build:
- python
- "marshmallow>=3.0.0rc4"
- "marshmallow>=4.0.0"
- "numpy>=1.13"
- "python-dateutil>=2.8.0"

run:
- python
- "marshmallow>=3.0.0rc4"
- "marshmallow>=4.0.0"
- "numpy>=1.13"
- "python-dateutil>=2.8.0"

Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: paramtools-dev
channels:
- conda-forge
dependencies:
- "marshmallow>=3.22.0"
- "marshmallow>=4.0.0"
- "numpy>=2.1.0"
- "python-dateutil>=2.8.0"
- "pytest>=6.0.0"
Expand Down
8 changes: 4 additions & 4 deletions paramtools/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
else:
self._stateless_label_grid[name] = []
self.label_grid = copy.deepcopy(self._stateless_label_grid)
self._validator_schema.context["spec"] = self
self._validator_schema.pt_context["spec"] = self
self._warnings = {}
self._errors = {}
self._defer_validation = False
Expand Down Expand Up @@ -364,7 +364,7 @@ def _adjust(
for param, value in parsed_params.items():
self._update_param(param, value)

self._validator_schema.context["spec"] = self
self._validator_schema.pt_context["spec"] = self

has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
Expand Down Expand Up @@ -525,7 +525,7 @@ def _delete(
if self.label_to_extend is not None and extend_adj:
self.extend()

self._validator_schema.context["spec"] = self
self._validator_schema.pt_context["spec"] = self

has_errors = bool(self._errors.get("messages"))
has_warnings = bool(self._warnings.get("messages"))
Expand Down Expand Up @@ -1414,4 +1414,4 @@ def get_defaults(self):
- `params`: String if URL or file path. Dict if this is the loaded params
dict.
"""
return utils.read_json(self.defaults)
return utils.read_json(self.defaults)
40 changes: 22 additions & 18 deletions paramtools/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
validates_schema,
ValidationError as MarshmallowValidationError,
decorators,
RAISE as RAISEUNKNOWNOPTION,
)
from marshmallow.error_store import ErrorStore

Expand Down Expand Up @@ -177,16 +178,17 @@ class BaseValidatorSchema(Schema):
"when": "_get_when_validator",
}

def __init__(self, data):
self.context = {}
def __init__(self, *args, **kwargs):
self.pt_context = {}
super().__init__(*args, **kwargs)

def validate_only(self, data):
"""
Bypass deserialization and just run field validators. This is taken
from the marshmallow _do_load function:
https://github.com/marshmallow-code/marshmallow/blob/3.5.2/src/marshmallow/schema.py#L807
"""
self.fields = data.keys()
# self.fields = data.keys()
error_store = ErrorStore()
# Run field-level validation
self._invoke_field_validators(
Expand All @@ -197,21 +199,23 @@ def validate_only(self, data):
field_errors = bool(error_store.errors)
self._invoke_schema_validators(
error_store=error_store,
pass_many=True,
pass_collection=True,
data=data,
original_data=data,
many=None,
partial=None,
field_errors=field_errors,
unknown=RAISEUNKNOWNOPTION,
)
self._invoke_schema_validators(
error_store=error_store,
pass_many=False,
pass_collection=False,
data=data,
original_data=data,
many=None,
partial=None,
field_errors=field_errors,
unknown=RAISEUNKNOWNOPTION,
)
errors = error_store.errors
if errors:
Expand Down Expand Up @@ -260,7 +264,7 @@ def validate_param(self, param_name, param_spec, raw_data):
Do range validation for a parameter.
"""
validate_schema = not getattr(
self.context["spec"], "_defer_validation", False
self.pt_context["spec"], "_defer_validation", False
)
validators = self.validators(
param_name, param_spec, raw_data, validate_schema=validate_schema
Expand All @@ -279,15 +283,15 @@ def validate_param(self, param_name, param_spec, raw_data):
return warnings, errors

def field_keyfunc(self, param_name):
data = self.context["spec"]._data[param_name]
data = self.pt_context["spec"]._data[param_name]
field = get_type(data, self.validators(param_name))
try:
return field.cmp_funcs()["key"]
except AttributeError:
return None

def field(self, param_name):
data = self.context["spec"]._data[param_name]
data = self.pt_context["spec"]._data[param_name]
return get_type(data, self.validators(param_name))

def validators(
Expand All @@ -298,7 +302,7 @@ def validators(
if raw_data is None:
raw_data = {}

param_info = self.context["spec"]._data[param_name]
param_info = self.pt_context["spec"]._data[param_name]
# sort keys to guarantee order.
validator_spec = param_info.get("validators", {})
validators = []
Expand Down Expand Up @@ -336,7 +340,7 @@ def _get_when_validator(
when_param = when_dict["param"]

if (
when_param not in self.context["spec"]._data.keys()
when_param not in self.pt_context["spec"]._data.keys()
and when_param != "default"
):
raise MarshmallowValidationError(
Expand Down Expand Up @@ -371,8 +375,8 @@ def _get_when_validator(
)
)

_type = self.context["spec"]._data[oth_param]["type"]
number_dims = self.context["spec"]._data[oth_param]["number_dims"]
_type = self.pt_context["spec"]._data[oth_param]["type"]
number_dims = self.pt_context["spec"]._data[oth_param]["number_dims"]

error_then = (
f"When {oth_param}{{when_labels}}{{ix}} is {{is_val}}, "
Expand Down Expand Up @@ -458,9 +462,9 @@ def _get_range_validator(
)

def _sort_by_label_to_extend(self, vos):
label_to_extend = self.context["spec"].label_to_extend
label_to_extend = self.pt_context["spec"].label_to_extend
if label_to_extend is not None:
label_grid = self.context["spec"]._stateless_label_grid
label_grid = self.pt_context["spec"]._stateless_label_grid
extend_vals = label_grid[label_to_extend]
return sorted(
vos,
Expand Down Expand Up @@ -522,9 +526,9 @@ def _get_related_value(
# If comparing against the "default" value then get the current
# value of the parameter being updated.
if oth_param_name == "default":
oth_param = self.context["spec"]._data[param_name]
oth_param = self.pt_context["spec"]._data[param_name]
else:
oth_param = self.context["spec"]._data[oth_param_name]
oth_param = self.pt_context["spec"]._data[oth_param_name]
vals = oth_param["value"]
labs_to_check = {k for k in param_spec if k not in ("value", "_auto")}
if labs_to_check:
Expand All @@ -549,11 +553,11 @@ def _check_ndim_restriction(
if other_param is None:
continue
if other_param == "default":
ndims = self.context["spec"]._data[param_name][
ndims = self.pt_context["spec"]._data[param_name][
"number_dims"
]
else:
ndims = self.context["spec"]._data[other_param][
ndims = self.pt_context["spec"]._data[other_param][
"number_dims"
]
if ndims > 0:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
url="https://github.com/hdoupe/ParamTools",
packages=setuptools.find_packages(),
install_requires=[
"marshmallow>=3.0.0",
"marshmallow>=4.0.0",
"numpy",
"python-dateutil>=2.8.0",
"fsspec",
Expand Down
Loading