training_test_sets.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. """ Generate train and test set. """
  2. __author__ = "Kristian Berg"
  3. __copyright__ = "Copyright (c) 2018 Axis Communications AB"
  4. __license__ = "MIT"
  5. import subprocess
  6. import re
  7. import json
  8. from datetime import datetime, timedelta
  9. # TODO: give update parameter as fraction
  10. def build_sets(path, sgap=timedelta(days=200), gap=timedelta(days=150),
  11. egap=timedelta(days=150), update=timedelta(days=400),
  12. testdur=timedelta(days=70), traindur=timedelta(days=2000)):
  13. # Determine date of oldest commit in repository
  14. command = ['git', 'log', '--reverse', '--date=iso']
  15. startdate = datetime_of_commit(path, command=command)
  16. # Determine date of newest commit in repository
  17. command = ['git', 'log', '--date=iso']
  18. enddate = datetime_of_commit(path, command=command)
  19. # Add start and end gaps
  20. startdate += sgap
  21. enddate -= egap
  22. # Print stuff
  23. print('Start: ' + str(startdate))
  24. print('End: ' + str(enddate))
  25. print('Duration: ' + str(enddate - startdate))
  26. print('len(training) len(testing)')
  27. # Build list of commit hashes from oldest to newest
  28. command = ['git', 'rev-list', '--reverse', 'HEAD']
  29. res = subprocess.run(command, cwd=path, stdout=subprocess.PIPE)
  30. gitrevlist = res.stdout.decode('utf-8')
  31. hashes = gitrevlist.split()
  32. # Initiate loop variables
  33. trainsets = []
  34. testsets = []
  35. training = []
  36. testing = []
  37. train_index = 0
  38. test_index = 0
  39. tsplit = startdate + traindur
  40. # Adjust start index to correspond to start date
  41. commitdate = datetime_of_commit(path, hash=hashes[train_index])
  42. while commitdate < startdate:
  43. train_index += 1
  44. commitdate = datetime_of_commit(path, hash=hashes[train_index])
  45. # TODO: Last few commits are not used
  46. while tsplit + gap + testdur < enddate:
  47. # Set test index to correspond to appropriate date
  48. test_index = train_index
  49. commitdate = datetime_of_commit(path, hash=hashes[test_index])
  50. while commitdate < tsplit + gap:
  51. test_index += 1
  52. commitdate = datetime_of_commit(path, hash=hashes[test_index])
  53. # Build training set
  54. commitdate = datetime_of_commit(path, hash=hashes[train_index])
  55. while commitdate < tsplit:
  56. training.append(hashes[train_index])
  57. train_index += 1
  58. commitdate = datetime_of_commit(path, hash=hashes[train_index])
  59. trainsets.append(list(training))
  60. # Build test set
  61. testing = []
  62. commitdate = datetime_of_commit(path, hash=hashes[test_index])
  63. while commitdate < tsplit + gap + testdur:
  64. testing.append(hashes[test_index])
  65. test_index += 1
  66. commitdate = datetime_of_commit(path, hash=hashes[test_index])
  67. testsets.append(list(testing))
  68. # Print stuff
  69. print(str(len(training)) + ' ' + str(len(testing)))
  70. # Loop update
  71. tsplit += update
  72. # Write results to file
  73. with open('trainsets.json', 'w') as f:
  74. f.write(json.dumps(trainsets))
  75. with open('testsets.json', 'w') as f:
  76. f.write(json.dumps(testsets))
  77. # Returns date of specific commit given a hash
  78. # OR date of first commit result given a command
  79. def datetime_of_commit(path, hash=None, command=None):
  80. # Check that either hash or command parameter has a value
  81. if hash:
  82. command = ['git', 'show', '--quiet', '--date=iso', hash]
  83. elif command:
  84. if command[0] != 'git':
  85. raise ValueError('Not a git command')
  86. elif '--date=iso' not in command:
  87. raise ValueError('Command needs to specify --date=iso')
  88. else:
  89. raise ValueError('Either hash or command parameter is needed')
  90. # Get date of commit
  91. res = subprocess.run(command, cwd=path, stdout=subprocess.PIPE)
  92. gitlog = res.stdout.decode('utf-8', errors='ignore')
  93. match = re.search('(?<=\nDate: )[0-9-+: ]+(?=\n)', gitlog).group(0)
  94. date = datetime.strptime(match, '%Y-%m-%d %H:%M:%S %z')
  95. return date
  96. if __name__ == '__main__':
  97. build_sets('/home/kristiab/Git/jenkins')