compress_zstd.c 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. /* Copyright (c) 2004, Roger Dingledine.
  2. * Copyright (c) 2004-2006, Roger Dingledine, Nick Mathewson.
  3. * Copyright (c) 2007-2017, The Tor Project, Inc. */
  4. /* See LICENSE for licensing information */
  5. /**
  6. * \file compress_zstd.c
  7. * \brief Compression backend for Zstandard.
  8. *
  9. * This module should never be invoked directly. Use the compress module
  10. * instead.
  11. **/
  12. #include "orconfig.h"
  13. #include "util.h"
  14. #include "torlog.h"
  15. #include "compress.h"
  16. #include "compress_zstd.h"
  17. #ifdef HAVE_ZSTD
  18. #include <zstd.h>
  19. #include <zstd_errors.h>
  20. #endif
  21. /** Total number of bytes allocated for Zstandard state. */
  22. static size_t total_zstd_allocation = 0;
  23. /** Return a string representation of the version of the currently running
  24. * version of libzstd. */
  25. const char *
  26. tor_zstd_get_version_str(void)
  27. {
  28. #ifdef HAVE_ZSTD
  29. static char version_str[16];
  30. size_t version_number;
  31. version_number = ZSTD_versionNumber();
  32. tor_snprintf(version_str, sizeof(version_str),
  33. "%lu.%lu.%lu",
  34. version_number / 10000 % 100,
  35. version_number / 100 % 100,
  36. version_number % 100);
  37. return version_str;
  38. #else
  39. return "N/A";
  40. #endif
  41. }
  42. /** Return a string representation of the version of the version of libzstd
  43. * used at compilation. */
  44. const char *
  45. tor_zstd_get_header_version_str(void)
  46. {
  47. #ifdef HAVE_ZSTD
  48. return ZSTD_VERSION_STRING;
  49. #else
  50. return "N/A";
  51. #endif
  52. }
  53. /** Given <b>in_len</b> bytes at <b>in</b>, compress them into a newly
  54. * allocated buffer, using the Zstandard method. Store the compressed string
  55. * in *<b>out</b>, and its length in *<b>out_len</b>. Return 0 on success, -1
  56. * on failure.
  57. */
  58. int
  59. tor_zstd_compress(char **out, size_t *out_len,
  60. const char *in, size_t in_len,
  61. compress_method_t method)
  62. {
  63. #ifdef HAVE_ZSTD
  64. ZSTD_CStream *stream = NULL;
  65. size_t out_size, old_size;
  66. size_t retval;
  67. tor_assert(out);
  68. tor_assert(out_len);
  69. tor_assert(in);
  70. tor_assert(in_len < UINT_MAX);
  71. tor_assert(method == ZSTD_METHOD);
  72. *out = NULL;
  73. stream = ZSTD_createCStream();
  74. if (stream == NULL) {
  75. // Zstandard does not give us any useful error message to why this
  76. // happened. See https://github.com/facebook/zstd/issues/398
  77. log_warn(LD_GENERAL, "Error while creating Zstandard stream");
  78. goto err;
  79. }
  80. retval = ZSTD_initCStream(stream,
  81. tor_compress_memory_level(HIGH_COMPRESSION));
  82. if (ZSTD_isError(retval)) {
  83. log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
  84. ZSTD_getErrorName(retval));
  85. goto err;
  86. }
  87. // Assume 50% compression and update our buffer in case we need to.
  88. out_size = in_len / 2;
  89. if (out_size < 1024)
  90. out_size = 1024;
  91. *out = tor_malloc(out_size);
  92. *out_len = 0;
  93. ZSTD_inBuffer input = { in, in_len, 0 };
  94. ZSTD_outBuffer output = { *out, out_size, 0 };
  95. while (input.pos < input.size) {
  96. retval = ZSTD_compressStream(stream, &output, &input);
  97. if (ZSTD_isError(retval)) {
  98. log_warn(LD_GENERAL, "Zstandard stream compression error: %s",
  99. ZSTD_getErrorName(retval));
  100. goto err;
  101. }
  102. if (input.pos < input.size && output.pos == output.size) {
  103. old_size = out_size;
  104. out_size *= 2;
  105. if (out_size < old_size) {
  106. log_warn(LD_GENERAL, "Size overflow in Zstandard compression.");
  107. goto err;
  108. }
  109. if (out_size - output.pos > UINT_MAX) {
  110. log_warn(LD_BUG, "Ran over unsigned int limit of Zstandard while "
  111. "compressing.");
  112. goto err;
  113. }
  114. output.dst = *out = tor_realloc(*out, out_size);
  115. output.size = out_size;
  116. }
  117. }
  118. while (1) {
  119. retval = ZSTD_endStream(stream, &output);
  120. if (retval == 0)
  121. break;
  122. if (ZSTD_isError(retval)) {
  123. log_warn(LD_GENERAL, "Zstandard stream error: %s",
  124. ZSTD_getErrorName(retval));
  125. goto err;
  126. }
  127. if (output.pos == output.size) {
  128. old_size = out_size;
  129. out_size *= 2;
  130. if (out_size < old_size) {
  131. log_warn(LD_GENERAL, "Size overflow in Zstandard compression.");
  132. goto err;
  133. }
  134. if (out_size - output.pos > UINT_MAX) {
  135. log_warn(LD_BUG, "Ran over unsigned int limit of Zstandard while "
  136. "compressing.");
  137. goto err;
  138. }
  139. output.dst = *out = tor_realloc(*out, out_size);
  140. output.size = out_size;
  141. }
  142. }
  143. *out_len = output.pos;
  144. if (tor_compress_is_compression_bomb(*out_len, in_len)) {
  145. log_warn(LD_BUG, "We compressed something and got an insanely high "
  146. "compression factor; other Tor instances would think "
  147. "this is a compression bomb.");
  148. goto err;
  149. }
  150. if (stream != NULL) {
  151. ZSTD_freeCStream(stream);
  152. }
  153. return 0;
  154. err:
  155. if (stream != NULL) {
  156. ZSTD_freeCStream(stream);
  157. }
  158. tor_free(*out);
  159. return -1;
  160. #else // HAVE_ZSTD.
  161. (void)out;
  162. (void)out_len;
  163. (void)in;
  164. (void)in_len;
  165. (void)method;
  166. return -1;
  167. #endif // HAVE_ZSTD.
  168. }
  169. /** Given a Zstandard compressed string of total length <b>in_len</b> bytes at
  170. * <b>in</b>, uncompress them into a newly allocated buffer. Store the
  171. * uncompressed string in *<b>out</b>, and its length in *<b>out_len</b>.
  172. * Return 0 on success, -1 on failure.
  173. *
  174. * If <b>complete_only</b> is true, we consider a truncated input as a failure;
  175. * otherwise we decompress as much as we can. Warn about truncated or corrupt
  176. * inputs at <b>protocol_warn_level</b>.
  177. */
  178. int
  179. tor_zstd_uncompress(char **out, size_t *out_len,
  180. const char *in, size_t in_len,
  181. compress_method_t method,
  182. int complete_only,
  183. int protocol_warn_level)
  184. {
  185. #ifdef HAVE_ZSTD
  186. ZSTD_DStream *stream = NULL;
  187. size_t retval;
  188. size_t out_size, old_size;
  189. tor_assert(out);
  190. tor_assert(out_len);
  191. tor_assert(in);
  192. tor_assert(in_len < UINT_MAX);
  193. tor_assert(method == ZSTD_METHOD);
  194. // FIXME(ahf): Handle this?
  195. (void)complete_only;
  196. (void)protocol_warn_level;
  197. *out = NULL;
  198. stream = ZSTD_createDStream();
  199. if (stream == NULL) {
  200. // Zstandard does not give us any useful error message to why this
  201. // happened. See https://github.com/facebook/zstd/issues/398
  202. log_warn(LD_GENERAL, "Error while creating Zstandard stream");
  203. goto err;
  204. }
  205. retval = ZSTD_initDStream(stream);
  206. if (ZSTD_isError(retval)) {
  207. log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
  208. ZSTD_getErrorName(retval));
  209. goto err;
  210. }
  211. out_size = in_len * 2;
  212. if (out_size < 1024)
  213. out_size = 1024;
  214. if (out_size >= SIZE_T_CEILING || out_size > UINT_MAX)
  215. goto err;
  216. *out = tor_malloc(out_size);
  217. *out_len = 0;
  218. ZSTD_inBuffer input = { in, in_len, 0 };
  219. ZSTD_outBuffer output = { *out, out_size, 0 };
  220. while (input.pos < input.size) {
  221. retval = ZSTD_decompressStream(stream, &output, &input);
  222. if (ZSTD_isError(retval)) {
  223. log_warn(LD_GENERAL, "Zstandard stream decompression error: %s",
  224. ZSTD_getErrorName(retval));
  225. goto err;
  226. }
  227. if (input.pos < input.size && output.pos == output.size) {
  228. old_size = out_size;
  229. out_size *= 2;
  230. if (out_size < old_size) {
  231. log_warn(LD_GENERAL, "Size overflow in Zstandard compression.");
  232. goto err;
  233. }
  234. if (tor_compress_is_compression_bomb(in_len, out_size)) {
  235. log_warn(LD_GENERAL, "Input looks like a possible Zstandard "
  236. "compression bomb. Not proceeding.");
  237. goto err;
  238. }
  239. if (out_size >= SIZE_T_CEILING) {
  240. log_warn(LD_BUG, "Hit SIZE_T_CEILING limit while uncompressing "
  241. "Zstandard data.");
  242. goto err;
  243. }
  244. if (out_size - output.pos > UINT_MAX) {
  245. log_warn(LD_BUG, "Ran over unsigned int limit of Zstandard while "
  246. "decompressing.");
  247. goto err;
  248. }
  249. output.dst = *out = tor_realloc(*out, out_size);
  250. output.size = out_size;
  251. }
  252. }
  253. *out_len = output.pos;
  254. if (stream != NULL) {
  255. ZSTD_freeDStream(stream);
  256. }
  257. // NUL-terminate our output.
  258. if (out_size == *out_len)
  259. *out = tor_realloc(*out, out_size + 1);
  260. (*out)[*out_len] = '\0';
  261. return 0;
  262. err:
  263. if (stream != NULL) {
  264. ZSTD_freeDStream(stream);
  265. }
  266. tor_free(*out);
  267. return -1;
  268. #else // HAVE_ZSTD.
  269. (void)out;
  270. (void)out_len;
  271. (void)in;
  272. (void)in_len;
  273. (void)method;
  274. (void)complete_only;
  275. (void)protocol_warn_level;
  276. return -1;
  277. #endif // HAVE_ZSTD.
  278. }
  279. /** Internal Zstandard state for incremental compression/decompression.
  280. * The body of this struct is not exposed. */
  281. struct tor_zstd_compress_state_t {
  282. #ifdef HAVE_ZSTD
  283. union {
  284. /** Compression stream. Used when <b>compress</b> is true. */
  285. ZSTD_CStream *compress_stream;
  286. /** Decompression stream. Used when <b>compress</b> is false. */
  287. ZSTD_DStream *decompress_stream;
  288. } u; /**< Zstandard stream objects. */
  289. #endif // HAVE_ZSTD.
  290. int compress; /**< True if we are compressing; false if we are inflating */
  291. /** Number of bytes read so far. Used to detect compression bombs. */
  292. size_t input_so_far;
  293. /** Number of bytes written so far. Used to detect compression bombs. */
  294. size_t output_so_far;
  295. /** Approximate number of bytes allocated for this object. */
  296. size_t allocation;
  297. };
  298. /** Construct and return a tor_zstd_compress_state_t object using
  299. * <b>method</b>. If <b>compress</b>, it's for compression; otherwise it's for
  300. * decompression. */
  301. tor_zstd_compress_state_t *
  302. tor_zstd_compress_new(int compress,
  303. compress_method_t method,
  304. compression_level_t compression_level)
  305. {
  306. tor_assert(method == ZSTD_METHOD);
  307. #ifdef HAVE_ZSTD
  308. tor_zstd_compress_state_t *result;
  309. size_t retval;
  310. result = tor_malloc_zero(sizeof(tor_zstd_compress_state_t));
  311. result->compress = compress;
  312. // FIXME(ahf): We should either try to do the pre-calculation that is done
  313. // with the zlib backend or use a custom allocator here where we pass our
  314. // tor_zstd_compress_state_t as the opaque value.
  315. result->allocation = 0;
  316. if (compress) {
  317. result->u.compress_stream = ZSTD_createCStream();
  318. if (result->u.compress_stream == NULL) {
  319. log_warn(LD_GENERAL, "Error while creating Zstandard stream");
  320. goto err;
  321. }
  322. retval = ZSTD_initCStream(result->u.compress_stream,
  323. tor_compress_memory_level(compression_level));
  324. if (ZSTD_isError(retval)) {
  325. log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
  326. ZSTD_getErrorName(retval));
  327. goto err;
  328. }
  329. } else {
  330. result->u.decompress_stream = ZSTD_createDStream();
  331. if (result->u.decompress_stream == NULL) {
  332. log_warn(LD_GENERAL, "Error while creating Zstandard stream");
  333. goto err;
  334. }
  335. retval = ZSTD_initDStream(result->u.decompress_stream);
  336. if (ZSTD_isError(retval)) {
  337. log_warn(LD_GENERAL, "Zstandard stream initialization error: %s",
  338. ZSTD_getErrorName(retval));
  339. goto err;
  340. }
  341. }
  342. return result;
  343. err:
  344. if (compress) {
  345. ZSTD_freeCStream(result->u.compress_stream);
  346. } else {
  347. ZSTD_freeDStream(result->u.decompress_stream);
  348. }
  349. tor_free(result);
  350. return NULL;
  351. #else // HAVE_ZSTD.
  352. (void)compress;
  353. (void)method;
  354. (void)compression_level;
  355. return NULL;
  356. #endif // HAVE_ZSTD.
  357. }
  358. /** Compress/decompress some bytes using <b>state</b>. Read up to
  359. * *<b>in_len</b> bytes from *<b>in</b>, and write up to *<b>out_len</b> bytes
  360. * to *<b>out</b>, adjusting the values as we go. If <b>finish</b> is true,
  361. * we've reached the end of the input.
  362. *
  363. * Return TOR_COMPRESS_DONE if we've finished the entire
  364. * compression/decompression.
  365. * Return TOR_COMPRESS_OK if we're processed everything from the input.
  366. * Return TOR_COMPRESS_BUFFER_FULL if we're out of space on <b>out</b>.
  367. * Return TOR_COMPRESS_ERROR if the stream is corrupt.
  368. */
  369. tor_compress_output_t
  370. tor_zstd_compress_process(tor_zstd_compress_state_t *state,
  371. char **out, size_t *out_len,
  372. const char **in, size_t *in_len,
  373. int finish)
  374. {
  375. #ifdef HAVE_ZSTD
  376. size_t retval;
  377. tor_assert(state != NULL);
  378. tor_assert(*in_len <= UINT_MAX);
  379. tor_assert(*out_len <= UINT_MAX);
  380. ZSTD_inBuffer input = { *in, *in_len, 0 };
  381. ZSTD_outBuffer output = { *out, *out_len, 0 };
  382. if (state->compress) {
  383. retval = ZSTD_compressStream(state->u.compress_stream,
  384. &output, &input);
  385. } else {
  386. retval = ZSTD_decompressStream(state->u.decompress_stream,
  387. &output, &input);
  388. }
  389. state->input_so_far += input.pos;
  390. state->output_so_far += output.pos;
  391. *out = (char *)output.dst + output.pos;
  392. *out_len = output.size - output.pos;
  393. *in = (char *)input.src + input.pos;
  394. *in_len = input.size - input.pos;
  395. if (! state->compress &&
  396. tor_compress_is_compression_bomb(state->input_so_far,
  397. state->output_so_far)) {
  398. log_warn(LD_DIR, "Possible compression bomb; abandoning stream.");
  399. return TOR_COMPRESS_ERROR;
  400. }
  401. if (ZSTD_isError(retval)) {
  402. log_warn(LD_GENERAL, "Zstandard %s didn't finish: %s.",
  403. state->compress ? "compression" : "decompression",
  404. ZSTD_getErrorName(retval));
  405. return TOR_COMPRESS_ERROR;
  406. }
  407. if (state->compress && !finish) {
  408. retval = ZSTD_flushStream(state->u.compress_stream, &output);
  409. *out = (char *)output.dst + output.pos;
  410. *out_len = output.size - output.pos;
  411. if (ZSTD_isError(retval)) {
  412. log_warn(LD_GENERAL, "Zstandard compression unable to flush: %s.",
  413. ZSTD_getErrorName(retval));
  414. return TOR_COMPRESS_ERROR;
  415. }
  416. if (retval > 0)
  417. return TOR_COMPRESS_BUFFER_FULL;
  418. }
  419. if (state->compress && finish) {
  420. retval = ZSTD_endStream(state->u.compress_stream, &output);
  421. *out = (char *)output.dst + output.pos;
  422. *out_len = output.size - output.pos;
  423. if (ZSTD_isError(retval)) {
  424. log_warn(LD_GENERAL, "Zstandard compression unable to write "
  425. "epilogue: %s.",
  426. ZSTD_getErrorName(retval));
  427. return TOR_COMPRESS_ERROR;
  428. }
  429. // endStream returns the number of bytes that is needed to write the
  430. // epilogue.
  431. if (retval > 0)
  432. return TOR_COMPRESS_BUFFER_FULL;
  433. }
  434. return finish ? TOR_COMPRESS_DONE : TOR_COMPRESS_OK;
  435. #else // HAVE_ZSTD.
  436. (void)state;
  437. (void)out;
  438. (void)out_len;
  439. (void)in;
  440. (void)in_len;
  441. (void)finish;
  442. return TOR_COMPRESS_ERROR;
  443. #endif // HAVE_ZSTD.
  444. }
  445. /** Deallocate <b>state</b>. */
  446. void
  447. tor_zstd_compress_free(tor_zstd_compress_state_t *state)
  448. {
  449. if (state == NULL)
  450. return;
  451. total_zstd_allocation -= state->allocation;
  452. #ifdef HAVE_ZSTD
  453. if (state->compress) {
  454. ZSTD_freeCStream(state->u.compress_stream);
  455. } else {
  456. ZSTD_freeDStream(state->u.decompress_stream);
  457. }
  458. #endif // HAVE_ZSTD.
  459. tor_free(state);
  460. }
  461. /** Return the approximate number of bytes allocated for <b>state</b>. */
  462. size_t
  463. tor_zstd_compress_state_size(const tor_zstd_compress_state_t *state)
  464. {
  465. tor_assert(state != NULL);
  466. return state->allocation;
  467. }
  468. /** Return the approximate number of bytes allocated for all Zstandard
  469. * states. */
  470. size_t
  471. tor_zstd_get_total_allocation(void)
  472. {
  473. return total_zstd_allocation;
  474. }