/*
 * This demo combines CryptoDev and OpenSSL encryption
 * to verify eachother.
 *
 * Author: Michal Ludvig <michal@logix.cz>
 *         http://www.logix.cz/michal
 *
 */
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>

#include <sys/ioctl.h>
#include <linux/cryptodev.h>

#include <openssl/evp.h>

#define DATA_SIZE       14096
#define BLOCK_SIZE      16
#define KEY_SIZE        16
#define ALG_NAME        "aes"
#define MODE            CRYPTO_FLAG_CBC

void
dump_mem(const char *msg, const unsigned char *buf, size_t len)
{
        int i;

        printf("%s\n", msg);
        for (i = 0; i < len; i++)
                printf("%02x%s", buf[i], (i%8==7) ? "\n" : "");
}

int
compare_buffers(const char *msg,
                const char *buf1_label, const unsigned char *buf1,
                const char *buf2_label, const unsigned char *buf2,
                size_t len)
{
        if (memcmp(buf1, buf2, len) != 0) {
                size_t i, dump;
                for (i = 0; i < len; i++)
                        if (buf1[i] != buf2[i])
                                break;
                fprintf(stderr, "FAIL: %s differ at offset %zu\n", msg, i);
                dump = len - i > 64 ? 64 : len - i;
                dump_mem(buf1_label, &buf1[i], dump);
                dump_mem(buf2_label, &buf2[i], dump);

                return 1;
        }
        return 0;
}

static int
test_crypto(int cfd)
{
        struct {
                char    in[DATA_SIZE],
                        encrypted[DATA_SIZE],
                        decrypted[DATA_SIZE],
                        iv[BLOCK_SIZE],
                        key[KEY_SIZE];
        } data, data_openssl;

        struct session_op sess;
        struct crypt_op cryp;
        char algbuf[100];
        EVP_CIPHER_CTX ctx;
        const EVP_CIPHER *c;

        /* ==== Prepare the input data ==== */

        /* Use the garbage that is on the stack :-) */
        memset(&data, 0, sizeof(data));

        memcpy (&data_openssl, &data, sizeof(data));

        /* ==== First the CryptoDev pass ==== */
        memset(&sess, 0, sizeof(sess));
        memset(&cryp, 0, sizeof(cryp));

        /* Get crypto session for AES128 */
        sess.cipher = CRYPTO_CIPHER_NAME | MODE;
        sess.alg_name = ALG_NAME;
        sess.alg_namelen = strlen(sess.alg_name);
        sess.keylen = KEY_SIZE;
        sess.key = data.key;
        if (ioctl(cfd, CIOCGSESSION, &sess)) {
                perror("ioctl(CIOCGSESSION)");
                return 1;
        }

        /* Encrypt data.in to data.encrypted */
        cryp.ses = sess.ses;
        cryp.len = sizeof(data.in);
        cryp.src = data.in;
        cryp.dst = data.encrypted;
        cryp.iv = data.iv;
        cryp.op = COP_ENCRYPT;
        if (ioctl(cfd, CIOCCRYPT, &cryp)) {
                perror("ioctl(CIOCCRYPT)");
                return 1;
        }

        /* Decrypt data.encrypted to data.decrypted */
        cryp.src = data.encrypted;
        cryp.dst = data.decrypted;
        cryp.op = COP_DECRYPT;
        if (ioctl(cfd, CIOCCRYPT, &cryp)) {
                perror("ioctl(CIOCCRYPT)");
                return 1;
        }

        /* Finish crypto session */
        if (ioctl(cfd, CIOCFSESSION, &sess.ses)) {
                perror("ioctl(CIOCFSESSION)");
                return 1;
        }

        /* ==== Now the OpenSSL pass ==== */

        OpenSSL_add_all_algorithms();
        EVP_CIPHER_CTX_init(&ctx);

        snprintf(algbuf, sizeof(algbuf), "%s-%d-%s", ALG_NAME, KEY_SIZE * 8,
                 MODE & CRYPTO_FLAG_CBC ? "cbc" : "ecb");

        c = EVP_get_cipherbyname(algbuf);
        if (!c) {
                perror("EVP_get_cipherbyname()");
                return 1;
        }

        if (!EVP_CipherInit(&ctx, c, data_openssl.key, data_openssl.iv, 1)) {
                perror("EVP_CipherInit()");
                return 1;
        }

        if (!EVP_Cipher(&ctx, data_openssl.encrypted, data_openssl.in, sizeof(data_openssl.in))) {
                perror("EVP_Cipher()");
                return 1;
        }

        EVP_CIPHER_CTX_cleanup(&ctx);
        EVP_CIPHER_CTX_init(&ctx);

        if (!EVP_CipherInit(&ctx, c, data_openssl.key, data_openssl.iv, 0)) {
                perror("EVP_CipherInit()");
                return 1;
        }

        if (!EVP_Cipher(&ctx, data_openssl.decrypted, data_openssl.encrypted, sizeof(data_openssl.in))) {
                perror("EVP_Cipher()");
                return 1;
        }

        /* ==== Verify the results ==== */

        if (compare_buffers("Encrypted data", "CryptoDev:", data.encrypted,
                            "OpenSSL:", data_openssl.encrypted, sizeof(data.encrypted)))
                return 1;

        if (compare_buffers("Decrypted data", "CryptoDev:", data.decrypted,
                            "OpenSSL:", data_openssl.decrypted, sizeof(data.decrypted)))
                return 1;

        printf("Test passed\n");

        return 0;
}

int
main()
{
        int fd = -1, cfd = -1;

        /* Open the crypto device */
        fd = open("/dev/crypto", O_RDWR, 0);
        if (fd < 0) {
                perror("open(/dev/crypto)");
                return 1;
        }

        /* Clone file descriptor */
        if (ioctl(fd, CRIOGET, &cfd)) {
                perror("ioctl(CRIOGET)");
                return 1;
        }

        /* Set close-on-exec (not really neede here) */
        if (fcntl(cfd, F_SETFD, 1) == -1) {
                perror("fcntl(F_SETFD)");
                return 1;
        }

        /* Run the test itself */
        if (test_crypto(cfd))
                return 1;

        /* Close cloned descriptor */
        if (close(cfd)) {
                perror("close(cfd)");
                return 1;
        }

        /* Close the original descriptor */
        if (close(fd)) {
                perror("close(fd)");
                return 1;
        }

        return 0;
}