From 1b65d8e4d4a1945621e4200745eaf22fa4cc7030 Mon Sep 17 00:00:00 2001 From: Soledad Galli Date: Mon, 13 Jun 2022 14:59:10 +0200 Subject: [PATCH] add check for duplicated variable names --- feature_engine/variable_manipulation.py | 7 ++- tests/test_variable_manipulation.py | 68 +++++++++++-------------- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/feature_engine/variable_manipulation.py b/feature_engine/variable_manipulation.py index 160846944..25ac9f114 100644 --- a/feature_engine/variable_manipulation.py +++ b/feature_engine/variable_manipulation.py @@ -27,11 +27,14 @@ def _check_input_parameter_variables(variables: Variables) -> Any: """ msg = "variables should be a string, an int or a list of strings or integers." + msg_dupes = "the list contains duplicated variable names" if variables: if isinstance(variables, list): if not all(isinstance(i, (str, int)) for i in variables): raise ValueError(msg) + if len(variables) != len(set(variables)): + raise ValueError(msg_dupes) else: if not isinstance(variables, (str, int)): raise ValueError(msg) @@ -250,7 +253,9 @@ def _find_or_check_datetime_variables( def _find_all_variables( - X: pd.DataFrame, variables: Variables = None, exclude_datetime: bool = False, + X: pd.DataFrame, + variables: Variables = None, + exclude_datetime: bool = False, ) -> List[Union[str, int]]: """ If variables are None, captures all variables in the dataframe in a list. diff --git a/tests/test_variable_manipulation.py b/tests/test_variable_manipulation.py index 98606b63e..c2294e2e2 100644 --- a/tests/test_variable_manipulation.py +++ b/tests/test_variable_manipulation.py @@ -12,30 +12,30 @@ ) -def test_check_input_parameter_variables(): - vars_ls = ["var1", "var2", "var1"] - vars_int_ls = [0, 1, 2, 3] - vars_none = None - vars_str = "var1" - vars_int = 0 - vars_tuple = ("var1", "var2") - vars_set = {"var1", "var2"} - vars_dict = {"var1": 1, "var2": 2} - - assert _check_input_parameter_variables(vars_ls) == ["var1", "var2", "var1"] - assert _check_input_parameter_variables(vars_int_ls) == [0, 1, 2, 3] - assert _check_input_parameter_variables(vars_none) is None - assert _check_input_parameter_variables(vars_str) == "var1" - assert _check_input_parameter_variables(vars_int) == 0 - +@pytest.mark.parametrize( + "_input_vars", + [ + ("var1", "var2"), + {"var1": 1, "var2": 2}, + ["var1", "var2", "var2", "var3"], + [0, 1, 1, 2], + ], +) +def test_check_input_parameter_variables_raises_errors(_input_vars): with pytest.raises(ValueError): - assert _check_input_parameter_variables(vars_tuple) + assert _check_input_parameter_variables(_input_vars) - with pytest.raises(ValueError): - assert _check_input_parameter_variables(vars_set) - with pytest.raises(ValueError): - assert _check_input_parameter_variables(vars_dict) +@pytest.mark.parametrize( + "_input_vars", + [["var1", "var2", "var3"], [0, 1, 2, 3], "var1", ["var1"], 0, [0]], +) +def test_check_input_parameter_variables(_input_vars): + assert _check_input_parameter_variables(_input_vars) == _input_vars + + +def test_check_input_parameter_variables_is_none(): + assert _check_input_parameter_variables(None) is None def test_find_or_check_numerical_variables(df_vartypes, df_numeric_columns): @@ -206,15 +206,12 @@ def test_find_or_check_datetime_variables(df_datetime): _find_or_check_datetime_variables(df_datetime, variables=None) == vars_convertible_to_dt ) - assert ( - _find_or_check_datetime_variables( - df_datetime[vars_convertible_to_dt].reindex( - columns=["date_obj1", "datetime_range", "date_obj2"] - ), - variables=None, - ) - == ["date_obj1", "datetime_range", "date_obj2"] - ) + assert _find_or_check_datetime_variables( + df_datetime[vars_convertible_to_dt].reindex( + columns=["date_obj1", "datetime_range", "date_obj2"] + ), + variables=None, + ) == ["date_obj1", "datetime_range", "date_obj2"] # when variables are specified assert _find_or_check_datetime_variables(df_datetime, var_dt_str) == [var_dt_str] @@ -226,13 +223,10 @@ def test_find_or_check_datetime_variables(df_datetime): _find_or_check_datetime_variables(df_datetime, variables=vars_convertible_to_dt) == vars_convertible_to_dt ) - assert ( - _find_or_check_datetime_variables( - df_datetime.join(tz_time), - variables=None, - ) - == vars_convertible_to_dt + ["time_objTZ"] - ) + assert _find_or_check_datetime_variables( + df_datetime.join(tz_time), + variables=None, + ) == vars_convertible_to_dt + ["time_objTZ"] # datetime var cast as categorical df_datetime["date_obj1"] = df_datetime["date_obj1"].astype("category")