random_forest_wrapper.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. """ A wrapper to enable RandomForestClassifier to be used in conjunction
  2. with imblearn sampling methods when using cross_validate or cross_val_score
  3. methods from scikit-learn """
  4. __author__ = "Kristian Berg"
  5. __copyright__ = "Copyright (c) 2018 Axis Communications AB"
  6. __license__ = "MIT"
  7. from sklearn.ensemble import RandomForestClassifier
  8. class RandomForestWrapper(RandomForestClassifier):
  9. """ A wrapper to enable RandomForestClassifier to be used in conjunction
  10. with imblearn sampling methods when using cross_validate or cross_val_score
  11. methods from scikit-learn """
  12. # pylint: disable = too-many-ancestors
  13. def __init__(self,
  14. sampler=None,
  15. n_estimators=10,
  16. criterion="gini",
  17. max_depth=None,
  18. min_samples_split=2,
  19. min_samples_leaf=1,
  20. min_weight_fraction_leaf=0.,
  21. max_features="auto",
  22. max_leaf_nodes=None,
  23. min_impurity_decrease=0.,
  24. min_impurity_split=None,
  25. bootstrap=True,
  26. oob_score=False,
  27. n_jobs=1,
  28. random_state=None,
  29. verbose=0,
  30. warm_start=False,
  31. class_weight=None):
  32. # pylint: disable = too-many-arguments, too-many-locals
  33. super().__init__(n_estimators,
  34. criterion,
  35. max_depth,
  36. min_samples_split,
  37. min_samples_leaf,
  38. min_weight_fraction_leaf,
  39. max_features,
  40. max_leaf_nodes,
  41. min_impurity_decrease,
  42. min_impurity_split,
  43. bootstrap,
  44. oob_score,
  45. n_jobs,
  46. random_state,
  47. verbose,
  48. warm_start,
  49. class_weight)
  50. self.sampler = sampler
  51. def fit(self, X, y, sample_weight=None):
  52. if self.sampler:
  53. X, y = self.sampler.fit_sample(X, y)
  54. return super().fit(X, y, sample_weight=None)