time_sensitive_split.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. """ Time sensitive split for Git repository data based on Tan et al.'s Online
  2. Change Classification """
  3. __author__ = "Kristian Berg"
  4. __copyright__ = "Copyright (c) 2018 Axis Communications AB"
  5. __license__ = "MIT"
  6. import subprocess
  7. from datetime import datetime, timedelta
  8. from utils import datetime_of_commit
  9. class GitTimeSensitiveSplit:
  10. """ Time sensitive split for Git repository data based on Tan et al.'s Online
  11. Change Classification """
  12. def __init__(self, path, sgap=timedelta(days=331), gap=timedelta(days=73),
  13. egap=timedelta(days=781), update=timedelta(days=200),
  14. traindur=timedelta(days=1700), testdur=timedelta(days=400),
  15. lastcommit=None, debug=False):
  16. self.path = path
  17. self.gap = gap
  18. self.update = update
  19. self.testdur = testdur
  20. self.traindur = traindur
  21. self.debug = debug
  22. # Determine date of oldest commit in repository
  23. command = ['git', 'log', '--reverse', '--date=iso']
  24. self.startdate = datetime_of_commit(path, command=command)
  25. # Determine date of newest commit in repository
  26. if lastcommit:
  27. self.enddate = datetime_of_commit(path, lastcommit)
  28. else:
  29. command = ['git', 'log', '--date=iso']
  30. self.enddate = datetime_of_commit(path, command=command)
  31. # Add start and end gaps
  32. self.startdate += sgap
  33. self.enddate -= egap
  34. if self.debug:
  35. print('Start: ' + str(self.startdate))
  36. print('End: ' + str(self.enddate))
  37. print('Duration: ' + str(self.enddate - self.startdate))
  38. # Build list of commit dates from oldest to newest
  39. command = ['git', 'rev-list', '--pretty=%ai', '--reverse', 'HEAD']
  40. res = subprocess.run(command, cwd=path, stdout=subprocess.PIPE)
  41. gitrevlist = res.stdout.decode('utf-8').split('\n')
  42. self.dates = [datetime.strptime(x, '%Y-%m-%d %H:%M:%S %z') for x in gitrevlist[1::2]]
  43. def split(self, X, y=None, group=None):
  44. """ Split method used by scikit-learn's cross_validate and cross_val_score
  45. methods """
  46. # Initiate loop variables
  47. trainset = []
  48. testset = []
  49. train_index = 0
  50. test_index = 0
  51. tsplit = self.startdate + self.traindur
  52. # Adjust start index to correspond to start date
  53. while self.dates[train_index] < self.startdate:
  54. train_index += 1
  55. n_pos = 0
  56. while tsplit + self.gap + self.testdur < self.enddate:
  57. # Set test index to correspond to appropriate date
  58. test_index = train_index
  59. while self.dates[test_index] < tsplit + self.gap:
  60. test_index += 1
  61. # Build training set
  62. while self.dates[train_index] < tsplit:
  63. trainset.append(train_index)
  64. train_index += 1
  65. # Build test set
  66. testset = []
  67. while self.dates[test_index] < tsplit + self.gap + self.testdur:
  68. testset.append(test_index)
  69. test_index += 1
  70. if y[test_index] == 1:
  71. n_pos += 1
  72. if self.debug:
  73. print(str(len(trainset)) + ' ' + str(len(testset)) + ' ' \
  74. + str(n_pos) + ' ' + str(self.dates[test_index]))
  75. n_pos = 0
  76. # Loop update
  77. tsplit += self.update
  78. yield trainset, testset