123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209 |
- #include <linux/crypto.h>
- #include <linux/init.h>
- #include <linux/interrupt.h>
- #include <linux/mm.h>
- #include <linux/module.h>
- #include <linux/net.h>
- #include <linux/vmalloc.h>
- #include <linux/zstd.h>
- #define ZSTD_DEF_LEVEL 3
- struct zstd_ctx {
- ZSTD_CCtx *cctx;
- ZSTD_DCtx *dctx;
- void *cwksp;
- void *dwksp;
- };
- static ZSTD_parameters zstd_params(void)
- {
- return ZSTD_getParams(ZSTD_DEF_LEVEL, 0, 0);
- }
- static int zstd_comp_init(struct zstd_ctx *ctx)
- {
- int ret = 0;
- const ZSTD_parameters params = zstd_params();
- const size_t wksp_size = ZSTD_CCtxWorkspaceBound(params.cParams);
- ctx->cwksp = vzalloc(wksp_size);
- if (!ctx->cwksp) {
- ret = -ENOMEM;
- goto out;
- }
- ctx->cctx = ZSTD_initCCtx(ctx->cwksp, wksp_size);
- if (!ctx->cctx) {
- ret = -EINVAL;
- goto out_free;
- }
- out:
- return ret;
- out_free:
- vfree(ctx->cwksp);
- goto out;
- }
- static int zstd_decomp_init(struct zstd_ctx *ctx)
- {
- int ret = 0;
- const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
- ctx->dwksp = vzalloc(wksp_size);
- if (!ctx->dwksp) {
- ret = -ENOMEM;
- goto out;
- }
- ctx->dctx = ZSTD_initDCtx(ctx->dwksp, wksp_size);
- if (!ctx->dctx) {
- ret = -EINVAL;
- goto out_free;
- }
- out:
- return ret;
- out_free:
- vfree(ctx->dwksp);
- goto out;
- }
- static void zstd_comp_exit(struct zstd_ctx *ctx)
- {
- vfree(ctx->cwksp);
- ctx->cwksp = NULL;
- ctx->cctx = NULL;
- }
- static void zstd_decomp_exit(struct zstd_ctx *ctx)
- {
- vfree(ctx->dwksp);
- ctx->dwksp = NULL;
- ctx->dctx = NULL;
- }
- static int __zstd_init(void *ctx)
- {
- int ret;
- ret = zstd_comp_init(ctx);
- if (ret)
- return ret;
- ret = zstd_decomp_init(ctx);
- if (ret)
- zstd_comp_exit(ctx);
- return ret;
- }
- static int zstd_init(struct crypto_tfm *tfm)
- {
- struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
- return __zstd_init(ctx);
- }
- static void __zstd_exit(void *ctx)
- {
- zstd_comp_exit(ctx);
- zstd_decomp_exit(ctx);
- }
- static void zstd_exit(struct crypto_tfm *tfm)
- {
- struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
- __zstd_exit(ctx);
- }
- static int __zstd_compress(const u8 *src, unsigned int slen,
- u8 *dst, unsigned int *dlen, void *ctx)
- {
- size_t out_len;
- struct zstd_ctx *zctx = ctx;
- const ZSTD_parameters params = zstd_params();
- out_len = ZSTD_compressCCtx(zctx->cctx, dst, *dlen, src, slen, params);
- if (ZSTD_isError(out_len))
- return -EINVAL;
- *dlen = out_len;
- return 0;
- }
- static int zstd_compress(struct crypto_tfm *tfm, const u8 *src,
- unsigned int slen, u8 *dst, unsigned int *dlen)
- {
- struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
- return __zstd_compress(src, slen, dst, dlen, ctx);
- }
- static int __zstd_decompress(const u8 *src, unsigned int slen,
- u8 *dst, unsigned int *dlen, void *ctx)
- {
- size_t out_len;
- struct zstd_ctx *zctx = ctx;
- out_len = ZSTD_decompressDCtx(zctx->dctx, dst, *dlen, src, slen);
- if (ZSTD_isError(out_len))
- return -EINVAL;
- *dlen = out_len;
- return 0;
- }
- static int zstd_decompress(struct crypto_tfm *tfm, const u8 *src,
- unsigned int slen, u8 *dst, unsigned int *dlen)
- {
- struct zstd_ctx *ctx = crypto_tfm_ctx(tfm);
- return __zstd_decompress(src, slen, dst, dlen, ctx);
- }
- static struct crypto_alg alg = {
- .cra_name = "zstd",
- .cra_flags = CRYPTO_ALG_TYPE_COMPRESS,
- .cra_ctxsize = sizeof(struct zstd_ctx),
- .cra_module = THIS_MODULE,
- .cra_init = zstd_init,
- .cra_exit = zstd_exit,
- .cra_u = { .compress = {
- .coa_compress = zstd_compress,
- .coa_decompress = zstd_decompress } }
- };
- static int __init zstd_mod_init(void)
- {
- int ret;
- ret = crypto_register_alg(&alg);
- if (ret)
- return ret;
- return ret;
- }
- static void __exit zstd_mod_fini(void)
- {
- crypto_unregister_alg(&alg);
- }
- module_init(zstd_mod_init);
- module_exit(zstd_mod_fini);
- MODULE_LICENSE("GPL");
- MODULE_DESCRIPTION("Zstd Compression Algorithm");
- MODULE_ALIAS_CRYPTO("zstd");
|