model.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. """ A collection of scripts for training and evaluating a RandomForestClassifier
  2. on a bug prediction dataset at commit level """
  3. __author__ = "Kristian Berg"
  4. __copyright__ = "Copyright (c) 2018 Axis Communications AB"
  5. __license__ = "MIT"
  6. import argparse
  7. import configparser
  8. from sklearn.model_selection import cross_validate
  9. from sklearn.externals import joblib
  10. from imblearn.over_sampling import SMOTE
  11. from imblearn.under_sampling import ClusterCentroids
  12. from imblearn.combine import SMOTETomek
  13. from treeinterpreter import treeinterpreter as ti
  14. import numpy as np
  15. from random_forest_wrapper import RandomForestWrapper
  16. from time_sensitive_split import GitTimeSensitiveSplit
  17. def evaluate(path, datapath, lastcommit, config, debug):
  18. """ Evaluate model performance """
  19. data, labels, _, _ = load_data(datapath)
  20. args = config['args']
  21. if args['seed'] != 'None':
  22. np.random.seed(args.getint('seed'))
  23. sampler = get_sampler(args['sampler'])
  24. if args['split'] == 'kfold':
  25. split = int(args['nfolds'])
  26. elif args['split'] == 'occ':
  27. split = GitTimeSensitiveSplit(path=path, lastcommit=lastcommit, debug=debug)
  28. scoring = {'p': 'precision',
  29. 'r': 'recall',
  30. 'f1': 'f1',
  31. }
  32. data = data[::-1]
  33. labels = labels[::-1]
  34. wrap = RandomForestWrapper(sampler, n_estimators=args.getint('n_estimators'))
  35. scores = cross_validate(wrap, data, labels, scoring=scoring, cv=split, return_train_score=False)
  36. for key in sorted(scores.keys()):
  37. print(key + ': ' + str(scores[key]))
  38. print(key + ': ' + str(np.average(scores[key])) + ' ± ' +
  39. str(np.std(scores[key])))
  40. def train(datapath, sampler_arg=None, printfeats=False):
  41. """ Train model and save in pkl file """
  42. data, labels, _, names = load_data(datapath)
  43. sampler = get_sampler(sampler_arg)
  44. clf = RandomForestWrapper(sampler, n_estimators=200)
  45. clf.fit(data, labels)
  46. if printfeats:
  47. feats = zip(names[1:], clf.feature_importances_)
  48. feats = sorted(feats, key=lambda yo: yo[1])
  49. for pair in feats:
  50. print(pair)
  51. joblib.dump(clf, 'model.pkl')
  52. def classify(datapath, commithash=None, index=None):
  53. """ Load model and classify single data point. Also determines
  54. most significant feature """
  55. # pylint: disable = too-many-locals
  56. clf = joblib.load('model.pkl')
  57. data, _, hashes, names = load_data(datapath)
  58. if commithash:
  59. temp, = np.where(hashes == commithash)
  60. sample = temp[0]
  61. elif index:
  62. sample = index
  63. else:
  64. sample = 1
  65. prediction, _, contributions = ti.predict(clf, data[[sample]])
  66. label1 = np.array(contributions)[0, :, 0]
  67. label2 = np.array(contributions)[0, :, 1]
  68. if prediction[0][0] > prediction[0][1]:
  69. res = label1
  70. labeltext = 'clean'
  71. else:
  72. res = label2
  73. labeltext = 'buggy'
  74. top = max(res)
  75. index, = np.where(res == top)
  76. feature = names[index[0] + 1]
  77. print('Predicted result: ' + labeltext)
  78. print('Top factor: ' + feature)
  79. def get_sampler(arg):
  80. """ Return sampler based on string argument """
  81. if arg == 'smote':
  82. # Oversampling
  83. return SMOTE()
  84. elif arg == 'cluster':
  85. # Undersampling
  86. return ClusterCentroids()
  87. elif arg == 'smotetomek':
  88. # Mixed over- and undersampling
  89. return SMOTETomek()
  90. return None
  91. def load_data(datapath):
  92. """ Load data from label and feature .csv files """
  93. with open('data/features.csv') as feats:
  94. names = feats.readline().split(',')
  95. num_cols = len(names)
  96. data = np.genfromtxt(datapath + '/features.csv', delimiter=',', skip_header=1,
  97. usecols=tuple(range(1, num_cols)))
  98. labels = np.genfromtxt(datapath + '/labels.csv', delimiter=',', dtype='int',
  99. skip_header=1, usecols=(1))
  100. hashes = np.genfromtxt(datapath + '/features.csv', delimiter=',', dtype='str',
  101. skip_header=1, usecols=0)
  102. return data, labels, hashes, names
  103. def main():
  104. """ Main method """
  105. parser = argparse.ArgumentParser(description='Train or evaluate model for '
  106. + 'defect prediction')
  107. parser.add_argument('method', metavar='m', type=str,
  108. help='method to be executed, either "train", ' +
  109. '"classify" or "evaluate"')
  110. parser.add_argument('config', metavar='c', type=str,
  111. help='specify .ini config file')
  112. parser.add_argument('datapath', metavar='d', type=str,
  113. help='filepath of features.csv and label.csv files')
  114. parser.add_argument('--hash', type=str, default=None,
  115. help='when method is "classify", specify data point' +
  116. ' by hash')
  117. parser.add_argument('--index', type=int, default=None,
  118. help='when method is "classify", specify data point' +
  119. ' by index')
  120. parser.add_argument('--path', type=str, default=None,
  121. help='when method is "evaluate", specify path to git' +
  122. ' repository')
  123. parser.add_argument('--lastcommit', type=str, default=None,
  124. help='when method is "evaluate", specify last commit' +
  125. ' to include')
  126. parser.add_argument('--significance', type=bool, default=False,
  127. help='when method is "train", if True prints feature ' +
  128. 'significances')
  129. parser.add_argument('--debug', type=bool, default=False,
  130. help='enables debug print output')
  131. args = parser.parse_args()
  132. config = configparser.ConfigParser()
  133. config.read(args.config)
  134. if args.method == 'evaluate':
  135. evaluate(args.path, args.datapath, args.lastcommit, config, args.debug)
  136. elif args.method == 'train':
  137. train(args.datapath, args.significance)
  138. elif args.method == 'classify':
  139. classify(args.datapath, args.hash, args.index)
  140. if __name__ == '__main__':
  141. main()