diff --git a/tests/test_sql.py b/tests/test_sql.py index b4e2bfd..66abe41 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,6 +1,7 @@ """SQL functionality tests for xarray-sql using pytest.""" import numpy as np +import pandas as pd import pytest import xarray as xr @@ -146,3 +147,44 @@ def test_string_coordinates(): assert "student" in result.columns assert "subject" in result.columns assert "score" in result.columns + + +class TestNanAsNull: + """NaN in float columns should become Arrow nulls so SQL aggregates work.""" + + @pytest.fixture + def nan_ds(self): + data = np.array([[[1.0, 2.0], [np.nan, 4.0]], [[5.0, np.nan], [7.0, 8.0]]]) + return xr.Dataset( + {"temp": (["time", "x", "y"], data)}, + coords={ + "time": pd.date_range("2020-01-01", periods=2), + "x": [0, 1], + "y": [0, 1], + }, + ).chunk({"time": 1}) + + def test_nan_aggregates(self, nan_ds): + ctx = XarrayContext() + ctx.from_dataset("data", nan_ds) + + # Test multiple aggregates at once: + # MAX/MIN/AVG should ignore NaN, COUNT(col) should exclude NaN, + # and WHERE col IS NULL should match NaN. + query = """ + SELECT + MAX(temp) AS mx, + MIN(temp) AS mn, + AVG(temp) AS avg, + COUNT(temp) AS cnt, + COUNT(*) FILTER (WHERE temp IS NULL) AS null_cnt + FROM data + """ + result = ctx.sql(query).to_pandas().iloc[0] + + assert result["mx"] == 8.0 + assert result["mn"] == 1.0 + expected_avg = np.nanmean([1.0, 2.0, 4.0, 5.0, 7.0, 8.0]) + assert abs(result["avg"] - expected_avg) < 1e-6 + assert result["cnt"] == 6 + assert result["null_cnt"] == 2 diff --git a/xarray_sql/df.py b/xarray_sql/df.py index d08207d..353a3db 100644 --- a/xarray_sql/df.py +++ b/xarray_sql/df.py @@ -207,7 +207,11 @@ def dataset_to_record_batch( arrays.append(pa.array(arr, type=field.type)) else: # Data variable: ravel to 1-D (zero-copy for C-contiguous arrays). - arrays.append(pa.array(ds[name].values.ravel(), type=field.type)) + # from_pandas=True maps NaN → Arrow null inside the C++ copy kernel, + # so SQL aggregates (MAX, MIN, AVG) skip missing values correctly. + arrays.append( + pa.array(ds[name].values.ravel(), type=field.type, from_pandas=True) + ) return pa.RecordBatch.from_arrays(arrays, schema=schema) @@ -282,7 +286,11 @@ def iter_record_batches( arrays.append(pa.array(coord_values[name][coord_idx], type=field.type)) else: arrays.append( - pa.array(data_arrays[name][row_start:row_end], type=field.type) + pa.array( + data_arrays[name][row_start:row_end], + type=field.type, + from_pandas=True, + ) ) yield pa.RecordBatch.from_arrays(arrays, schema=schema)