get_w.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. from scipy.stats import poisson
  4. from hmmlearn import hmm
  5. import sys
  6. def get_states(counts, lens):
  7. if len(counts) == 0 or len(lens) == 0:
  8. return
  9. counts = np.array([c for c in counts])
  10. lens = np.array([l for l in lens])
  11. scores = list()
  12. models = list()
  13. for idx in range(10): # ten different random starting states
  14. # define our hidden Markov model
  15. # (because we always prepend an hour of 0 messages,
  16. # and because it helps to ensure what the first state represents,
  17. # we set the probability of starting in the first state to 1,
  18. # and don't include start probability as a parameter to update)
  19. model = hmm.PoissonHMM(n_components=2, random_state=idx,
  20. n_iter=10, params='tl', init_params='tl',
  21. startprob_prior=np.array([1.0,0.0]),
  22. lambdas_prior=np.array([[0.01], [0.1]]))
  23. model.startprob_ = np.array([1.0,0.0])
  24. model.fit(counts[:, None], lens)
  25. models.append(model)
  26. try:
  27. scores.append(model.score(counts[:, None], lens))
  28. except:
  29. print("igoring failed model scoring")
  30. # get the best model
  31. model = models[np.argmax(scores)]
  32. try:
  33. states = model.predict(counts[:, None], lens)
  34. except:
  35. print("failed to predict")
  36. return None, None
  37. if model.lambdas_[0] > model.lambdas_[1]:
  38. states = [int(not(s)) for s in states]
  39. return ','.join([str(s) for s in states]), ','.join([str(l) for l in model.lambdas_])
  40. target_dir = sys.argv[1]
  41. for i in range(2, len(sys.argv)):
  42. file_path = sys.argv[i]
  43. with open(file_path) as f:
  44. lines = f.readlines()
  45. counts = [int(n) for n in lines[0].strip().split(',')]
  46. lens = [int(n) for n in lines[1].strip().split(',')]
  47. states, lambdas = get_states(counts, lens)
  48. if states is None:
  49. continue
  50. file_out = target_dir + '/' + file_path.split('/')[-1]
  51. with open(file_out, 'w') as f:
  52. print(lines[0].strip(), file=f)
  53. print(lines[1].strip(), file=f)
  54. print(lines[2].strip(), file=f)
  55. print(lines[3].strip(), file=f)
  56. print(states, file=f)
  57. print(lambdas, file=f)