Skip to content

[FEA] Attempt to convert string dtyped columns to floats under the hood in Random Forest fit #6267

Open
@beckernick

Description

For a good UX, scikit-learn will attempt to convert string dtyped columns of an input dataframe to floats by default during e.g., RandomForestClassifier.fit (and possibly other estimators). If it can't it will throw an error. We throw an error in both scenarios.

We should potentially do the same, pending performance impact.

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier

N = 1000
categories = ['0', '1', '2'] # automatically coerced, as it assumes there is an implied ordinality
# categories = ['a', 'b', 'c']  # errors, as it assumes cardinality

numeric_data = np.random.rand(N, 5)
string_data = np.random.choice(categories, size=(N,1))
y = np.random.choice([0, 1], size=N)

X = pd.DataFrame(numeric_data)
X["str_col"] = string_data
X.columns = [f"x{i}" for i in range(len(X.columns))]

clf = RandomForestClassifier()
clf.fit(X,y)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_3546391/1974482336.py in ?()
      1 clf = RandomForestClassifier()
----> 2 clf.fit(X,y)

[/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/base.py](http://10.176.1.125:8883/lab/tree/raid/nicholasb/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/base.py) in ?(estimator, *args, **kwargs)
   1385                 skip_parameter_validation=(
   1386                     prefer_skip_nested_validation or global_skip_validation
   1387                 )
   1388             ):
-> 1389                 return fit_method(estimator, *args, **kwargs)

[/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/ensemble/_forest.py](http://10.176.1.125:8883/lab/tree/raid/nicholasb/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/ensemble/_forest.py) in ?(self, X, y, sample_weight)
    356         # Validate or convert input data
    357         if issparse(y):
    358             raise ValueError("sparse multilabel-indicator for y is not supported.")
    359 
--> 360         X, y = validate_data(
    361             self,
    362             X,
    363             y,

[/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/utils/validation.py](http://10.176.1.125:8883/lab/tree/raid/nicholasb/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/utils/validation.py) in ?(_estimator, X, y, reset, validate_separately, skip_check_array, **check_params)
   2957             if "estimator" not in check_y_params:
   2958                 check_y_params = {**default_check_params, **check_y_params}
   2959             y = check_array(y, input_name="y", **check_y_params)
   2960         else:
-> 2961             X, y = check_X_y(X, y, **check_params)
   2962         out = X, y
   2963 
   2964     if not no_val_X and check_params.get("ensure_2d", True):

[/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/utils/validation.py](http://10.176.1.125:8883/lab/tree/raid/nicholasb/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/utils/validation.py) in ?(X, y, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_all_finite, ensure_2d, allow_nd, multi_output, ensure_min_samples, ensure_min_features, y_numeric, estimator)
   1366         )
   1367 
   1368     ensure_all_finite = _deprecate_force_all_finite(force_all_finite, ensure_all_finite)
   1369 
-> 1370     X = check_array(
   1371         X,
   1372         accept_sparse=accept_sparse,
   1373         accept_large_sparse=accept_large_sparse,

[/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/utils/validation.py](http://10.176.1.125:8883/lab/tree/raid/nicholasb/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/utils/validation.py) in ?(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_writeable, force_all_finite, ensure_all_finite, ensure_non_negative, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator, input_name)
   1052                         )
   1053                     array = xp.astype(array, dtype, copy=False)
   1054                 else:
   1055                     array = _asarray_with_order(array, order=order, dtype=dtype, xp=xp)
-> 1056             except ComplexWarning as complex_warning:
   1057                 raise ValueError(
   1058                     "Complex data not supported\n{}\n".format(array)
   1059                 ) from complex_warning

[/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/utils/_array_api.py](http://10.176.1.125:8883/lab/tree/raid/nicholasb/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/sklearn/utils/_array_api.py) in ?(array, dtype, order, copy, xp, device)
    835         # Use NumPy API to support order
    836         if copy is True:
    837             array = numpy.array(array, order=order, dtype=dtype)
    838         else:
--> 839             array = numpy.asarray(array, order=order, dtype=dtype)
    840 
    841         # At this point array is a NumPy ndarray. We convert it to an array
    842         # container that is consistent with the input's namespace.

[/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/pandas/core/generic.py](http://10.176.1.125:8883/lab/tree/raid/nicholasb/raid/nicholasb/miniforge3/envs/cuml-25.02/lib/python3.11/site-packages/pandas/core/generic.py) in ?(self, dtype, copy)
   2149     def __array__(
   2150         self, dtype: npt.DTypeLike | None = None, copy: bool_t | None = None
   2151     ) -> np.ndarray:
   2152         values = self._values
-> 2153         arr = np.asarray(values, dtype=dtype)
   2154         if (
   2155             astype_is_view(values.dtype, arr.dtype)
   2156             and using_copy_on_write()

ValueError: could not convert string to float: 'c'

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions