compress_zstd.c 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. /* Copyright (c) 2004, Roger Dingledine.
  2. * Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson.
  3. * Copyright (c) 2007-2017, The Tor Project, Inc. */
  4. /* See LICENSE for licensing information */
  5. /**
  6. * \file compress_zstd.c
  7. * \brief Compression backend for Zstandard.
  8. *
  9. * This module should never be invoked directly. Use the compress module
  10. * instead.
  11. **/
  12. #include "orconfig.h"
  13. #include "util.h"
  14. #include "torlog.h"
  15. #include "compress.h"
  16. #include "compress_zstd.h"
  17. #ifdef HAVE_ZSTD
  18. #include <zstd.h>
  19. #include <zstd_errors.h>
  20. #endif
  21. /** Total number of bytes allocated for Zstandard state. */
  22. static size_t total_zstd_allocation = 0;
  23. #ifdef HAVE_ZSTD
  24. /** Given <b>level</b> return the memory level. */
  25. static int
  26. memory_level(compression_level_t level)
  27. {
  28. switch (level) {
  29. default:
  30. case HIGH_COMPRESSION: return 9;
  31. case MEDIUM_COMPRESSION: return 8;
  32. case LOW_COMPRESSION: return 7;
  33. }
  34. }
  35. #endif // HAVE_ZSTD.
  36. /** Return 1 if Zstandard compression is supported; otherwise 0. */
  37. int
  38. tor_zstd_method_supported(void)
  39. {
  40. #ifdef HAVE_ZSTD
  41. return 1;
  42. #else
  43. return 0;
  44. #endif
  45. }
  46. /** Return a string representation of the version of the currently running
  47. * version of libzstd. Returns NULL if Zstandard is unsupported. */
  48. const char *
  49. tor_zstd_get_version_str(void)
  50. {
  51. #ifdef HAVE_ZSTD
  52. static char version_str[16];
  53. size_t version_number;
  54. version_number = ZSTD_versionNumber();
  55. tor_snprintf(version_str, sizeof(version_str),
  56. "%lu.%lu.%lu",
  57. version_number / 10000 % 100,
  58. version_number / 100 % 100,
  59. version_number % 100);
  60. return version_str;
  61. #else
  62. return NULL;
  63. #endif
  64. }
  65. /** Return a string representation of the version of the version of libzstd
  66. * used at compilation time. Returns NULL if Zstandard is unsupported. */
  67. const char *
  68. tor_zstd_get_header_version_str(void)
  69. {
  70. #ifdef HAVE_ZSTD
  71. return ZSTD_VERSION_STRING;
  72. #else
  73. return NULL;
  74. #endif
  75. }
  76. /** Given <b>in_len</b> bytes at <b>in</b>, compress them into a newly
  77. * allocated buffer, using the Zstandard method. Store the compressed string
  78. * in *<b>out</b>, and its length in *<b>out_len</b>. Return 0 on success, -1
  79. * on failure.
  80. */
  81. int
  82. tor_zstd_compress(char **out, size_t *out_len,
  83. const char *in, size_t in_len,
  84. compress_method_t method)
  85. {
  86. #ifdef HAVE_ZSTD
  87. ZSTD_CStream *stream = NULL;
  88. size_t out_size, old_size;
  89. size_t retval;
  90. tor_assert(out);
  91. tor_assert(out_len);
  92. tor_assert(in);
  93. tor_assert(in_len < UINT_MAX);
  94. tor_assert(method == ZSTD_METHOD);
  95. *out = NULL;
  96. stream = ZSTD_createCStream();
  97. if (stream == NULL) {
  98. // Zstandard does not give us any useful error message to why this
  99. // happened. See https://github.com/facebook/zstd/issues/398
  100. log_warn(LD_GENERAL, "Error while creating Zstandard stream");
  101. goto err;
  102. }
  103. retval = ZSTD_initCStream(stream,
  104. memory_level(HIGH_COMPRESSION));
  105. if (ZSTD_isError(retval)) {
  106. log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
  107. ZSTD_getErrorName(retval));
  108. goto err;
  109. }
  110. // Assume 50% compression and update our buffer in case we need to.
  111. out_size = in_len / 2;
  112. if (out_size < 1024)
  113. out_size = 1024;
  114. *out = tor_malloc(out_size);
  115. *out_len = 0;
  116. ZSTD_inBuffer input = { in, in_len, 0 };
  117. ZSTD_outBuffer output = { *out, out_size, 0 };
  118. while (input.pos < input.size) {
  119. retval = ZSTD_compressStream(stream, &output, &input);
  120. if (ZSTD_isError(retval)) {
  121. log_warn(LD_GENERAL, "Zstandard stream compression error: %s",
  122. ZSTD_getErrorName(retval));
  123. goto err;
  124. }
  125. if (input.pos < input.size && output.pos == output.size) {
  126. old_size = out_size;
  127. out_size *= 2;
  128. if (out_size < old_size) {
  129. log_warn(LD_GENERAL, "Size overflow in Zstandard compression.");
  130. goto err;
  131. }
  132. if (out_size - output.pos > UINT_MAX) {
  133. log_warn(LD_BUG, "Ran over unsigned int limit of Zstandard while "
  134. "compressing.");
  135. goto err;
  136. }
  137. output.dst = *out = tor_realloc(*out, out_size);
  138. output.size = out_size;
  139. }
  140. }
  141. while (1) {
  142. retval = ZSTD_endStream(stream, &output);
  143. if (retval == 0)
  144. break;
  145. if (ZSTD_isError(retval)) {
  146. log_warn(LD_GENERAL, "Zstandard stream error: %s",
  147. ZSTD_getErrorName(retval));
  148. goto err;
  149. }
  150. if (output.pos == output.size) {
  151. old_size = out_size;
  152. out_size *= 2;
  153. if (out_size < old_size) {
  154. log_warn(LD_GENERAL, "Size overflow in Zstandard compression.");
  155. goto err;
  156. }
  157. if (out_size - output.pos > UINT_MAX) {
  158. log_warn(LD_BUG, "Ran over unsigned int limit of Zstandard while "
  159. "compressing.");
  160. goto err;
  161. }
  162. output.dst = *out = tor_realloc(*out, out_size);
  163. output.size = out_size;
  164. }
  165. }
  166. *out_len = output.pos;
  167. if (tor_compress_is_compression_bomb(*out_len, in_len)) {
  168. log_warn(LD_BUG, "We compressed something and got an insanely high "
  169. "compression factor; other Tor instances would think "
  170. "this is a compression bomb.");
  171. goto err;
  172. }
  173. if (stream != NULL) {
  174. ZSTD_freeCStream(stream);
  175. }
  176. return 0;
  177. err:
  178. if (stream != NULL) {
  179. ZSTD_freeCStream(stream);
  180. }
  181. tor_free(*out);
  182. return -1;
  183. #else // HAVE_ZSTD.
  184. (void)out;
  185. (void)out_len;
  186. (void)in;
  187. (void)in_len;
  188. (void)method;
  189. return -1;
  190. #endif // HAVE_ZSTD.
  191. }
  192. /** Given a Zstandard compressed string of total length <b>in_len</b> bytes at
  193. * <b>in</b>, uncompress them into a newly allocated buffer. Store the
  194. * uncompressed string in *<b>out</b>, and its length in *<b>out_len</b>.
  195. * Return 0 on success, -1 on failure.
  196. *
  197. * If <b>complete_only</b> is true, we consider a truncated input as a failure;
  198. * otherwise we decompress as much as we can. Warn about truncated or corrupt
  199. * inputs at <b>protocol_warn_level</b>.
  200. */
  201. int
  202. tor_zstd_uncompress(char **out, size_t *out_len,
  203. const char *in, size_t in_len,
  204. compress_method_t method,
  205. int complete_only,
  206. int protocol_warn_level)
  207. {
  208. #ifdef HAVE_ZSTD
  209. ZSTD_DStream *stream = NULL;
  210. size_t retval;
  211. size_t out_size, old_size;
  212. tor_assert(out);
  213. tor_assert(out_len);
  214. tor_assert(in);
  215. tor_assert(in_len < UINT_MAX);
  216. tor_assert(method == ZSTD_METHOD);
  217. // FIXME(ahf): Handle this?
  218. (void)complete_only;
  219. (void)protocol_warn_level;
  220. *out = NULL;
  221. stream = ZSTD_createDStream();
  222. if (stream == NULL) {
  223. // Zstandard does not give us any useful error message to why this
  224. // happened. See https://github.com/facebook/zstd/issues/398
  225. log_warn(LD_GENERAL, "Error while creating Zstandard stream");
  226. goto err;
  227. }
  228. retval = ZSTD_initDStream(stream);
  229. if (ZSTD_isError(retval)) {
  230. log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
  231. ZSTD_getErrorName(retval));
  232. goto err;
  233. }
  234. out_size = in_len * 2;
  235. if (out_size < 1024)
  236. out_size = 1024;
  237. if (out_size >= SIZE_T_CEILING || out_size > UINT_MAX)
  238. goto err;
  239. *out = tor_malloc(out_size);
  240. *out_len = 0;
  241. ZSTD_inBuffer input = { in, in_len, 0 };
  242. ZSTD_outBuffer output = { *out, out_size, 0 };
  243. while (input.pos < input.size) {
  244. retval = ZSTD_decompressStream(stream, &output, &input);
  245. if (ZSTD_isError(retval)) {
  246. log_warn(LD_GENERAL, "Zstandard stream decompression error: %s",
  247. ZSTD_getErrorName(retval));
  248. goto err;
  249. }
  250. if (input.pos < input.size && output.pos == output.size) {
  251. old_size = out_size;
  252. out_size *= 2;
  253. if (out_size < old_size) {
  254. log_warn(LD_GENERAL, "Size overflow in Zstandard compression.");
  255. goto err;
  256. }
  257. if (tor_compress_is_compression_bomb(in_len, out_size)) {
  258. log_warn(LD_GENERAL, "Input looks like a possible Zstandard "
  259. "compression bomb. Not proceeding.");
  260. goto err;
  261. }
  262. if (out_size >= SIZE_T_CEILING) {
  263. log_warn(LD_BUG, "Hit SIZE_T_CEILING limit while uncompressing "
  264. "Zstandard data.");
  265. goto err;
  266. }
  267. if (out_size - output.pos > UINT_MAX) {
  268. log_warn(LD_BUG, "Ran over unsigned int limit of Zstandard while "
  269. "decompressing.");
  270. goto err;
  271. }
  272. output.dst = *out = tor_realloc(*out, out_size);
  273. output.size = out_size;
  274. }
  275. }
  276. *out_len = output.pos;
  277. if (stream != NULL) {
  278. ZSTD_freeDStream(stream);
  279. }
  280. // NUL-terminate our output.
  281. if (out_size == *out_len)
  282. *out = tor_realloc(*out, out_size + 1);
  283. (*out)[*out_len] = '\0';
  284. return 0;
  285. err:
  286. if (stream != NULL) {
  287. ZSTD_freeDStream(stream);
  288. }
  289. tor_free(*out);
  290. return -1;
  291. #else // HAVE_ZSTD.
  292. (void)out;
  293. (void)out_len;
  294. (void)in;
  295. (void)in_len;
  296. (void)method;
  297. (void)complete_only;
  298. (void)protocol_warn_level;
  299. return -1;
  300. #endif // HAVE_ZSTD.
  301. }
  302. /** Internal Zstandard state for incremental compression/decompression.
  303. * The body of this struct is not exposed. */
  304. struct tor_zstd_compress_state_t {
  305. #ifdef HAVE_ZSTD
  306. union {
  307. /** Compression stream. Used when <b>compress</b> is true. */
  308. ZSTD_CStream *compress_stream;
  309. /** Decompression stream. Used when <b>compress</b> is false. */
  310. ZSTD_DStream *decompress_stream;
  311. } u; /**< Zstandard stream objects. */
  312. #endif // HAVE_ZSTD.
  313. int compress; /**< True if we are compressing; false if we are inflating */
  314. /** Number of bytes read so far. Used to detect compression bombs. */
  315. size_t input_so_far;
  316. /** Number of bytes written so far. Used to detect compression bombs. */
  317. size_t output_so_far;
  318. /** Approximate number of bytes allocated for this object. */
  319. size_t allocation;
  320. };
  321. /** Construct and return a tor_zstd_compress_state_t object using
  322. * <b>method</b>. If <b>compress</b>, it's for compression; otherwise it's for
  323. * decompression. */
  324. tor_zstd_compress_state_t *
  325. tor_zstd_compress_new(int compress,
  326. compress_method_t method,
  327. compression_level_t compression_level)
  328. {
  329. tor_assert(method == ZSTD_METHOD);
  330. #ifdef HAVE_ZSTD
  331. tor_zstd_compress_state_t *result;
  332. size_t retval;
  333. result = tor_malloc_zero(sizeof(tor_zstd_compress_state_t));
  334. result->compress = compress;
  335. // FIXME(ahf): We should either try to do the pre-calculation that is done
  336. // with the zlib backend or use a custom allocator here where we pass our
  337. // tor_zstd_compress_state_t as the opaque value.
  338. result->allocation = 0;
  339. if (compress) {
  340. result->u.compress_stream = ZSTD_createCStream();
  341. if (result->u.compress_stream == NULL) {
  342. log_warn(LD_GENERAL, "Error while creating Zstandard stream");
  343. goto err;
  344. }
  345. retval = ZSTD_initCStream(result->u.compress_stream,
  346. memory_level(compression_level));
  347. if (ZSTD_isError(retval)) {
  348. log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
  349. ZSTD_getErrorName(retval));
  350. goto err;
  351. }
  352. } else {
  353. result->u.decompress_stream = ZSTD_createDStream();
  354. if (result->u.decompress_stream == NULL) {
  355. log_warn(LD_GENERAL, "Error while creating Zstandard stream");
  356. goto err;
  357. }
  358. retval = ZSTD_initDStream(result->u.decompress_stream);
  359. if (ZSTD_isError(retval)) {
  360. log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
  361. ZSTD_getErrorName(retval));
  362. goto err;
  363. }
  364. }
  365. return result;
  366. err:
  367. if (compress) {
  368. ZSTD_freeCStream(result->u.compress_stream);
  369. } else {
  370. ZSTD_freeDStream(result->u.decompress_stream);
  371. }
  372. tor_free(result);
  373. return NULL;
  374. #else // HAVE_ZSTD.
  375. (void)compress;
  376. (void)method;
  377. (void)compression_level;
  378. return NULL;
  379. #endif // HAVE_ZSTD.
  380. }
  381. /** Compress/decompress some bytes using <b>state</b>. Read up to
  382. * *<b>in_len</b> bytes from *<b>in</b>, and write up to *<b>out_len</b> bytes
  383. * to *<b>out</b>, adjusting the values as we go. If <b>finish</b> is true,
  384. * we've reached the end of the input.
  385. *
  386. * Return TOR_COMPRESS_DONE if we've finished the entire
  387. * compression/decompression.
  388. * Return TOR_COMPRESS_OK if we're processed everything from the input.
  389. * Return TOR_COMPRESS_BUFFER_FULL if we're out of space on <b>out</b>.
  390. * Return TOR_COMPRESS_ERROR if the stream is corrupt.
  391. */
  392. tor_compress_output_t
  393. tor_zstd_compress_process(tor_zstd_compress_state_t *state,
  394. char **out, size_t *out_len,
  395. const char **in, size_t *in_len,
  396. int finish)
  397. {
  398. #ifdef HAVE_ZSTD
  399. size_t retval;
  400. tor_assert(state != NULL);
  401. tor_assert(*in_len <= UINT_MAX);
  402. tor_assert(*out_len <= UINT_MAX);
  403. ZSTD_inBuffer input = { *in, *in_len, 0 };
  404. ZSTD_outBuffer output = { *out, *out_len, 0 };
  405. if (state->compress) {
  406. retval = ZSTD_compressStream(state->u.compress_stream,
  407. &output, &input);
  408. } else {
  409. retval = ZSTD_decompressStream(state->u.decompress_stream,
  410. &output, &input);
  411. }
  412. state->input_so_far += input.pos;
  413. state->output_so_far += output.pos;
  414. *out = (char *)output.dst + output.pos;
  415. *out_len = output.size - output.pos;
  416. *in = (char *)input.src + input.pos;
  417. *in_len = input.size - input.pos;
  418. if (! state->compress &&
  419. tor_compress_is_compression_bomb(state->input_so_far,
  420. state->output_so_far)) {
  421. log_warn(LD_DIR, "Possible compression bomb; abandoning stream.");
  422. return TOR_COMPRESS_ERROR;
  423. }
  424. if (ZSTD_isError(retval)) {
  425. log_warn(LD_GENERAL, "Zstandard %s didn't finish: %s.",
  426. state->compress ? "compression" : "decompression",
  427. ZSTD_getErrorName(retval));
  428. return TOR_COMPRESS_ERROR;
  429. }
  430. if (state->compress && !finish) {
  431. retval = ZSTD_flushStream(state->u.compress_stream, &output);
  432. *out = (char *)output.dst + output.pos;
  433. *out_len = output.size - output.pos;
  434. if (ZSTD_isError(retval)) {
  435. log_warn(LD_GENERAL, "Zstandard compression unable to flush: %s.",
  436. ZSTD_getErrorName(retval));
  437. return TOR_COMPRESS_ERROR;
  438. }
  439. if (retval > 0)
  440. return TOR_COMPRESS_BUFFER_FULL;
  441. }
  442. if (state->compress && finish) {
  443. retval = ZSTD_endStream(state->u.compress_stream, &output);
  444. *out = (char *)output.dst + output.pos;
  445. *out_len = output.size - output.pos;
  446. if (ZSTD_isError(retval)) {
  447. log_warn(LD_GENERAL, "Zstandard compression unable to write "
  448. "epilogue: %s.",
  449. ZSTD_getErrorName(retval));
  450. return TOR_COMPRESS_ERROR;
  451. }
  452. // endStream returns the number of bytes that is needed to write the
  453. // epilogue.
  454. if (retval > 0)
  455. return TOR_COMPRESS_BUFFER_FULL;
  456. }
  457. return finish ? TOR_COMPRESS_DONE : TOR_COMPRESS_OK;
  458. #else // HAVE_ZSTD.
  459. (void)state;
  460. (void)out;
  461. (void)out_len;
  462. (void)in;
  463. (void)in_len;
  464. (void)finish;
  465. return TOR_COMPRESS_ERROR;
  466. #endif // HAVE_ZSTD.
  467. }
  468. /** Deallocate <b>state</b>. */
  469. void
  470. tor_zstd_compress_free(tor_zstd_compress_state_t *state)
  471. {
  472. if (state == NULL)
  473. return;
  474. total_zstd_allocation -= state->allocation;
  475. #ifdef HAVE_ZSTD
  476. if (state->compress) {
  477. ZSTD_freeCStream(state->u.compress_stream);
  478. } else {
  479. ZSTD_freeDStream(state->u.decompress_stream);
  480. }
  481. #endif // HAVE_ZSTD.
  482. tor_free(state);
  483. }
  484. /** Return the approximate number of bytes allocated for <b>state</b>. */
  485. size_t
  486. tor_zstd_compress_state_size(const tor_zstd_compress_state_t *state)
  487. {
  488. tor_assert(state != NULL);
  489. return state->allocation;
  490. }
  491. /** Return the approximate number of bytes allocated for all Zstandard
  492. * states. */
  493. size_t
  494. tor_zstd_get_total_allocation(void)
  495. {
  496. return total_zstd_allocation;
  497. }