12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- """ A wrapper to enable RandomForestClassifier to be used in conjunction
- with imblearn sampling methods when using cross_validate or cross_val_score
- methods from scikit-learn """
- __author__ = "Kristian Berg"
- __copyright__ = "Copyright (c) 2018 Axis Communications AB"
- __license__ = "MIT"
- from sklearn.ensemble import RandomForestClassifier
- class RandomForestWrapper(RandomForestClassifier):
- """ A wrapper to enable RandomForestClassifier to be used in conjunction
- with imblearn sampling methods when using cross_validate or cross_val_score
- methods from scikit-learn """
- # pylint: disable = too-many-ancestors
- def __init__(self,
- sampler=None,
- n_estimators=10,
- criterion="gini",
- max_depth=None,
- min_samples_split=2,
- min_samples_leaf=1,
- min_weight_fraction_leaf=0.,
- max_features="auto",
- max_leaf_nodes=None,
- min_impurity_decrease=0.,
- min_impurity_split=None,
- bootstrap=True,
- oob_score=False,
- n_jobs=1,
- random_state=None,
- verbose=0,
- warm_start=False,
- class_weight=None):
- # pylint: disable = too-many-arguments, too-many-locals
- super().__init__(n_estimators,
- criterion,
- max_depth,
- min_samples_split,
- min_samples_leaf,
- min_weight_fraction_leaf,
- max_features,
- max_leaf_nodes,
- min_impurity_decrease,
- min_impurity_split,
- bootstrap,
- oob_score,
- n_jobs,
- random_state,
- verbose,
- warm_start,
- class_weight)
- self.sampler = sampler
- def fit(self, X, y, sample_weight=None):
- if self.sampler:
- X, y = self.sampler.fit_sample(X, y)
- return super().fit(X, y, sample_weight=None)
|