regression.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import contextlib
  2. import os
  3. import pathlib
  4. import signal
  5. import subprocess
  6. import unittest
  7. HAS_SGX = os.environ.get('SGX') == '1'
  8. def expectedFailureIf(predicate):
  9. if predicate:
  10. return unittest.expectedFailure
  11. return lambda func: func
  12. class RegressionTestCase(unittest.TestCase):
  13. LOADER_ENV = 'PAL_LOADER'
  14. DEFAULT_TIMEOUT = (20 if HAS_SGX else 10)
  15. def get_manifest(self, filename):
  16. return filename + '.manifest' + ('.sgx' if HAS_SGX else '')
  17. def run_binary(self, args, *, timeout=None, **kwds):
  18. timeout = (max(self.DEFAULT_TIMEOUT, timeout) if timeout is not None
  19. else self.DEFAULT_TIMEOUT)
  20. try:
  21. loader = os.environ[self.LOADER_ENV]
  22. except KeyError:
  23. self.skipTest(
  24. 'environment variable {} unset'.format(self.LOADER_ENV))
  25. if not pathlib.Path(loader).exists():
  26. self.skipTest('loader ({}) not found'.format(loader))
  27. with subprocess.Popen([loader, *args],
  28. stdout=subprocess.PIPE, stderr=subprocess.PIPE,
  29. preexec_fn=os.setpgrp,
  30. **kwds) as process:
  31. try:
  32. stdout, stderr = process.communicate(timeout=timeout)
  33. except subprocess.TimeoutExpired:
  34. os.killpg(process.pid, signal.SIGKILL)
  35. self.fail('timeout ({} s) expired'.format(timeout))
  36. if process.returncode:
  37. raise subprocess.CalledProcessError(
  38. process.returncode, args, stdout, stderr)
  39. return stdout.decode(), stderr.decode()
  40. @contextlib.contextmanager
  41. def expect_returncode(self, returncode):
  42. if returncode == 0:
  43. raise ValueError('expected returncode should be nonzero')
  44. try:
  45. yield
  46. self.fail('did not fail (expected {})'.format(returncode))
  47. except subprocess.CalledProcessError as e:
  48. self.assertEqual(e.returncode, returncode,
  49. 'failed with returncode {} (expected {})'.format(
  50. e.returncode, returncode))