浏览代码

Initial ML model commit with supporting scripts

Kristian Berg 5 年之前
父节点
当前提交
df682274f6
共有 6 个文件被更改,包括 380 次插入0 次删除
  1. 2 0
      code/.gitignore
  2. 29 0
      code/model/model.ini
  3. 167 0
      code/model/model.py
  4. 58 0
      code/model/random_forest_wrapper.py
  5. 94 0
      code/model/time_sensitive_split.py
  6. 30 0
      code/model/utils.py

+ 2 - 0
code/.gitignore

@@ -28,3 +28,5 @@ szz/results
 szz/libs
 szz/gradle.properties
 .gradle
+
+model/data

+ 29 - 0
code/model/model.ini

@@ -0,0 +1,29 @@
+[DEFAULT]
+n_estimators = 200
+split = kfold
+nfolds = 10
+seed = None
+lastcommit = None
+sampler = None
+
+[args]
+# Overwrite DEFAULT values here. Use # to comment and uncomment
+
+# The value of 'split' can be either 'kfold' for stratified k-fold cross validation
+# or 'occ' for Online Change Classification. DEFAULT value is 'kfold'
+#split = occ
+
+# The value of 'sampler can be either 'smote', 'cluster' or 'smotetomek'
+#sampler = smote
+#sampler = cluster
+#sampler = smotetomek
+
+[occ]
+# These are the parameters used when applying Online Change Classification
+sgap = 331
+gap = 73
+egap = 781
+update = 200
+traindur = 1700
+testdur = 400
+lastcommit = None

+ 167 - 0
code/model/model.py

@@ -0,0 +1,167 @@
+""" A collection of scripts for training and evaluating a RandomForestClassifier
+on a bug prediction dataset at commit level """
+__author__ = "Kristian Berg"
+__copyright__ = "Copyright (c) 2018 Axis Communications AB"
+__license__ = "MIT"
+
+import argparse
+import configparser
+from sklearn.model_selection import cross_validate
+from sklearn.externals import joblib
+from imblearn.over_sampling import SMOTE
+from imblearn.under_sampling import ClusterCentroids
+from imblearn.combine import SMOTETomek
+from treeinterpreter import treeinterpreter as ti
+import numpy as np
+from random_forest_wrapper import RandomForestWrapper
+from time_sensitive_split import GitTimeSensitiveSplit
+
+def evaluate(path, datapath, lastcommit, config, debug):
+    """ Evaluate model performance """
+
+    data, labels, _, _ = load_data(datapath)
+    args = config['args']
+
+    if args['seed'] != 'None':
+        np.random.seed(args.getint('seed'))
+
+    sampler = get_sampler(args['sampler'])
+
+    if args['split'] == 'kfold':
+        split = int(args['nfolds'])
+    elif args['split'] == 'occ':
+        split = GitTimeSensitiveSplit(path=path, lastcommit=lastcommit, debug=debug)
+
+    scoring = {'p': 'precision',
+               'r': 'recall',
+               'f1': 'f1',
+              }
+
+    data = data[::-1]
+    labels = labels[::-1]
+    wrap = RandomForestWrapper(sampler, n_estimators=args.getint('n_estimators'))
+    scores = cross_validate(wrap, data, labels, scoring=scoring, cv=split, return_train_score=False)
+    for key in sorted(scores.keys()):
+        print(key + ': ' + str(scores[key]))
+        print(key + ': ' + str(np.average(scores[key])) + ' ± ' +
+              str(np.std(scores[key])))
+
+def train(datapath, sampler_arg=None, printfeats=False):
+    """ Train model and save in pkl file """
+    data, labels, _, names = load_data(datapath)
+    sampler = get_sampler(sampler_arg)
+    clf = RandomForestWrapper(sampler, n_estimators=200)
+    clf.fit(data, labels)
+
+    if printfeats:
+        feats = zip(names[1:], clf.feature_importances_)
+        feats = sorted(feats, key=lambda yo: yo[1])
+        for pair in feats:
+            print(pair)
+
+    joblib.dump(clf, 'model.pkl')
+
+def classify(datapath, commithash=None, index=None):
+    """ Load model and classify single data point. Also determines
+    most significant feature """
+    # pylint: disable = too-many-locals
+    clf = joblib.load('model.pkl')
+    data, _, hashes, names = load_data(datapath)
+
+    if commithash:
+        temp, = np.where(hashes == commithash)
+        sample = temp[0]
+    elif index:
+        sample = index
+    else:
+        sample = 1
+
+    prediction, _, contributions = ti.predict(clf, data[[sample]])
+    label1 = np.array(contributions)[0, :, 0]
+    label2 = np.array(contributions)[0, :, 1]
+
+    if prediction[0][0] > prediction[0][1]:
+        res = label1
+        labeltext = 'clean'
+    else:
+        res = label2
+        labeltext = 'buggy'
+
+    top = max(res)
+    index, = np.where(res == top)
+    feature = names[index[0] + 1]
+
+    print('Predicted result: ' + labeltext)
+    print('Top factor: ' + feature)
+
+def get_sampler(arg):
+    """ Return sampler based on string argument """
+    if arg == 'smote':
+        # Oversampling
+        return SMOTE()
+    elif arg == 'cluster':
+        # Undersampling
+        return ClusterCentroids()
+    elif arg == 'smotetomek':
+        # Mixed over- and undersampling
+        return SMOTETomek()
+    return None
+
+def load_data(datapath):
+    """ Load data from label and feature .csv files """
+
+    with open('data/features.csv') as feats:
+        names = feats.readline().split(',')
+        num_cols = len(names)
+
+    data = np.genfromtxt(datapath + '/features.csv', delimiter=',', skip_header=1,
+                         usecols=tuple(range(1, num_cols)))
+    labels = np.genfromtxt(datapath + '/labels.csv', delimiter=',', dtype='int',
+                           skip_header=1, usecols=(1))
+    hashes = np.genfromtxt(datapath + '/features.csv', delimiter=',', dtype='str',
+                           skip_header=1, usecols=0)
+
+    return data, labels, hashes, names
+
+def main():
+    """ Main method """
+    parser = argparse.ArgumentParser(description='Train or evaluate model for '
+                                     + 'defect prediction')
+    parser.add_argument('method', metavar='m', type=str,
+                        help='method to be executed, either "train", ' +
+                        '"classify" or "evaluate"')
+    parser.add_argument('config', metavar='c', type=str,
+                        help='specify .ini config file')
+    parser.add_argument('datapath', metavar='d', type=str,
+                        help='filepath of features.csv and label.csv files')
+    parser.add_argument('--hash', type=str, default=None,
+                        help='when method is "classify", specify data point' +
+                        ' by hash')
+    parser.add_argument('--index', type=int, default=None,
+                        help='when method is "classify", specify data point' +
+                        ' by index')
+    parser.add_argument('--path', type=str, default=None,
+                        help='when method is "evaluate", specify path to git' +
+                        ' repository')
+    parser.add_argument('--lastcommit', type=str, default=None,
+                        help='when method is "evaluate", specify last commit' +
+                        ' to include')
+    parser.add_argument('--significance', type=bool, default=False,
+                        help='when method is "train", if True prints feature ' +
+                        'significances')
+    parser.add_argument('--debug', type=bool, default=False,
+                        help='enables debug print output')
+    args = parser.parse_args()
+
+    config = configparser.ConfigParser()
+    config.read(args.config)
+
+    if args.method == 'evaluate':
+        evaluate(args.path, args.datapath, args.lastcommit, config, args.debug)
+    elif args.method == 'train':
+        train(args.datapath, args.significance)
+    elif args.method == 'classify':
+        classify(args.datapath, args.hash, args.index)
+
+if __name__ == '__main__':
+    main()

+ 58 - 0
code/model/random_forest_wrapper.py

@@ -0,0 +1,58 @@
+""" A wrapper to enable RandomForestClassifier to be used in conjunction
+with imblearn sampling methods when using cross_validate or cross_val_score
+methods from scikit-learn """
+__author__ = "Kristian Berg"
+__copyright__ = "Copyright (c) 2018 Axis Communications AB"
+__license__ = "MIT"
+
+from sklearn.ensemble import RandomForestClassifier
+
+class RandomForestWrapper(RandomForestClassifier):
+    """ A wrapper to enable RandomForestClassifier to be used in conjunction
+    with imblearn sampling methods when using cross_validate or cross_val_score
+    methods from scikit-learn """
+    # pylint: disable = too-many-ancestors
+
+    def __init__(self,
+                 sampler=None,
+                 n_estimators=10,
+                 criterion="gini",
+                 max_depth=None,
+                 min_samples_split=2,
+                 min_samples_leaf=1,
+                 min_weight_fraction_leaf=0.,
+                 max_features="auto",
+                 max_leaf_nodes=None,
+                 min_impurity_decrease=0.,
+                 min_impurity_split=None,
+                 bootstrap=True,
+                 oob_score=False,
+                 n_jobs=1,
+                 random_state=None,
+                 verbose=0,
+                 warm_start=False,
+                 class_weight=None):
+        # pylint: disable = too-many-arguments, too-many-locals
+        super().__init__(n_estimators,
+                         criterion,
+                         max_depth,
+                         min_samples_split,
+                         min_samples_leaf,
+                         min_weight_fraction_leaf,
+                         max_features,
+                         max_leaf_nodes,
+                         min_impurity_decrease,
+                         min_impurity_split,
+                         bootstrap,
+                         oob_score,
+                         n_jobs,
+                         random_state,
+                         verbose,
+                         warm_start,
+                         class_weight)
+        self.sampler = sampler
+
+    def fit(self, X, y, sample_weight=None):
+        if self.sampler:
+            X, y = self.sampler.fit_sample(X, y)
+        return super().fit(X, y, sample_weight=None)

+ 94 - 0
code/model/time_sensitive_split.py

@@ -0,0 +1,94 @@
+""" Time sensitive split for Git repository data based on Tan et al.'s Online
+Change Classification """
+__author__ = "Kristian Berg"
+__copyright__ = "Copyright (c) 2018 Axis Communications AB"
+__license__ = "MIT"
+
+import subprocess
+from datetime import datetime, timedelta
+from utils import datetime_of_commit
+
+class GitTimeSensitiveSplit:
+    """ Time sensitive split for Git repository data based on Tan et al.'s Online
+    Change Classification """
+    def __init__(self, path, sgap=timedelta(days=331), gap=timedelta(days=73),
+                 egap=timedelta(days=781), update=timedelta(days=200),
+                 traindur=timedelta(days=1700), testdur=timedelta(days=400),
+                 lastcommit=None, debug=False):
+        self.path = path
+        self.gap = gap
+        self.update = update
+        self.testdur = testdur
+        self.traindur = traindur
+        self.debug = debug
+
+        # Determine date of oldest commit in repository
+        command = ['git', 'log', '--reverse', '--date=iso']
+        self.startdate = datetime_of_commit(path, command=command)
+
+        # Determine date of newest commit in repository
+        if lastcommit:
+            self.enddate = datetime_of_commit(path, lastcommit)
+        else:
+            command = ['git', 'log', '--date=iso']
+            self.enddate = datetime_of_commit(path, command=command)
+
+        # Add start and end gaps
+        self.startdate += sgap
+        self.enddate -= egap
+
+        if self.debug:
+            print('Start: ' + str(self.startdate))
+            print('End: ' + str(self.enddate))
+            print('Duration: ' + str(self.enddate - self.startdate))
+
+        # Build list of commit dates from oldest to newest
+        command = ['git', 'rev-list', '--pretty=%ai', '--reverse', 'HEAD']
+        res = subprocess.run(command, cwd=path, stdout=subprocess.PIPE)
+        gitrevlist = res.stdout.decode('utf-8').split('\n')
+        self.dates = [datetime.strptime(x, '%Y-%m-%d %H:%M:%S %z') for x in gitrevlist[1::2]]
+
+    def split(self, X, y=None, group=None):
+        """ Split method used by scikit-learn's cross_validate and cross_val_score
+        methods """
+
+        # Initiate loop variables
+        trainset = []
+        testset = []
+        train_index = 0
+        test_index = 0
+        tsplit = self.startdate + self.traindur
+
+        # Adjust start index to correspond to start date
+        while self.dates[train_index] < self.startdate:
+            train_index += 1
+
+        n_pos = 0
+        while tsplit + self.gap + self.testdur < self.enddate:
+            # Set test index to correspond to appropriate date
+            test_index = train_index
+            while self.dates[test_index] < tsplit + self.gap:
+                test_index += 1
+
+            # Build training set
+            while self.dates[train_index] < tsplit:
+                trainset.append(train_index)
+                train_index += 1
+
+            # Build test set
+            testset = []
+            while self.dates[test_index] < tsplit + self.gap + self.testdur:
+                testset.append(test_index)
+                test_index += 1
+                if y[test_index] == 1:
+                    n_pos += 1
+
+            if self.debug:
+                print(str(len(trainset)) + ' ' + str(len(testset)) + ' ' \
+                      + str(n_pos) + ' ' + str(self.dates[test_index]))
+            n_pos = 0
+
+            # Loop update
+            tsplit += self.update
+
+            yield trainset, testset

+ 30 - 0
code/model/utils.py

@@ -0,0 +1,30 @@
+"""Returns date of specific commit given a hash
+OR date of first commit result given a command"""
+__author__ = "Kristian Berg"
+__copyright__ = "Copyright (c) 2018 Axis Communications AB"
+__license__ = "MIT"
+
+from datetime import datetime
+import subprocess
+import re
+
+def datetime_of_commit(path, hashval=None, command=None):
+    """Returns date of specific commit given a hash
+    OR date of first commit result given a command"""
+    # Check that either hash or command parameter has a value
+    if hashval:
+        command = ['git', 'show', '--quiet', '--date=iso', hashval]
+    elif command:
+        if command[0] != 'git':
+            raise ValueError('Not a git command')
+        elif '--date=iso' not in command:
+            raise ValueError('Command needs to specify --date=iso')
+    else:
+        raise ValueError('Either hash or command parameter is needed')
+
+    # Get date of commit
+    res = subprocess.run(command, cwd=path, stdout=subprocess.PIPE)
+    gitlog = res.stdout.decode('utf-8', errors='ignore')
+    match = re.search('(?<=\nDate:   )[0-9-+: ]+(?=\n)', gitlog).group(0)
+    date = datetime.strptime(match, '%Y-%m-%d %H:%M:%S %z')
+    return date