diff --git a/feature_engine/wrappers/wrappers.py b/feature_engine/wrappers/wrappers.py index 37e593a1b..28d336e81 100644 --- a/feature_engine/wrappers/wrappers.py +++ b/feature_engine/wrappers/wrappers.py @@ -215,6 +215,7 @@ def fit(self, X: pd.DataFrame, y: Optional[str] = None): "OneHotEncoder", "OrdinalEncoder", "SimpleImputer", + "FunctionTransformer", ]: self.variables_ = _find_all_variables(X, self.variables) diff --git a/tests/test_wrappers/test_sklearn_wrapper.py b/tests/test_wrappers/test_sklearn_wrapper.py index 30b884f7f..b0db6d9a4 100644 --- a/tests/test_wrappers/test_sklearn_wrapper.py +++ b/tests/test_wrappers/test_sklearn_wrapper.py @@ -549,3 +549,31 @@ def test_get_feature_names_out_ohe(varlist, df_vartypes): ] assert output_feat == transformer.get_feature_names_out(varlist) + + +def test_function_transformer_works_with_categoricals(): + X = pd.DataFrame({"col1": ["1", "2", "3"], "col2": ["a", "b", "c"]}) + + X_expected = pd.DataFrame({"col1": [1.0, 2.0, 3.0], "col2": ["a", "b", "c"]}) + + transformer = SklearnTransformerWrapper( + FunctionTransformer(lambda x: x.astype(np.float64)), variables=["col1"] + ) + + X_tf = transformer.fit_transform(X) + + pd.testing.assert_frame_equal(X_expected, X_tf) + + +def test_function_transformer_works_with_numericals(): + X = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) + + X_expected = pd.DataFrame({"col1": [2, 3, 4], "col2": ["a", "b", "c"]}) + + transformer = SklearnTransformerWrapper( + FunctionTransformer(lambda x: x+1), variables=["col1"] + ) + + X_tf = transformer.fit_transform(X) + + pd.testing.assert_frame_equal(X_expected, X_tf)