diff --git a/src/analyzer/protocol/ssl/SSL.cc b/src/analyzer/protocol/ssl/SSL.cc index d46c92b120..2cc072f0cd 100644 --- a/src/analyzer/protocol/ssl/SSL.cc +++ b/src/analyzer/protocol/ssl/SSL.cc @@ -20,20 +20,6 @@ #include #endif -static void print_hex(std::string name, u_char* data, int len) - { - int i = 0; - printf("%s (%d): ", name.c_str(), len); - if ( len > 0 ) - printf("0x%02x", data[0]); - - for ( i = 1; i < len; i++ ) - { - printf(" 0x%02x", data[i]); - } - printf("\n"); - } - namespace zeek::analyzer::ssl { @@ -149,21 +135,21 @@ void SSL_Analyzer::SetSecret(size_t len, const u_char* data) secret.append((const char*)data, len); } -void SSL_Analyzer::SetKeys(zeek::StringVal* keys) - { - SetKeys(keys->Len(), keys->Bytes()); - } - -void SSL_Analyzer::SetKeys(size_t len, const u_char* data) +void SSL_Analyzer::SetKeys(zeek::StringVal* nkeys) { keys.clear(); - keys.reserve(len); - std::copy(data, data + len, std::back_inserter(keys)); + keys.reserve(nkeys->Len()); + std::copy(nkeys->Bytes(), nkeys->Bytes() + nkeys->Len(), std::back_inserter(keys)); } -bool SSL_Analyzer::TLS12_PRF(const std::string& secret, const std::string& label, const char* rnd1, - size_t rnd1_len, const char* rnd2, size_t rnd2_len, u_char* out, - size_t out_len) +void SSL_Analyzer::SetKeys(const std::vector newkeys) + { + keys = newkeys; + } + +std::optional> +SSL_Analyzer::TLS12_PRF(const std::string& secret, const std::string& label, + const std::string& rnd1, const std::string& rnd2, size_t requested_len) { #ifdef OPENSSL_HAVE_KDF_H #if defined(OPENSSL_VERSION_MAJOR) && (OPENSSL_VERSION_MAJOR >= 3) @@ -178,13 +164,12 @@ bool SSL_Analyzer::TLS12_PRF(const std::string& secret, const std::string& label #endif /* OSSL 3 */ // prepare seed: seed = label + rnd1 + rnd2 - size_t seed_len = label.size() + rnd1_len + rnd2_len; std::string seed{}; - seed.reserve(seed_len); + seed.reserve(label.size() + rnd1.size() + rnd2.size()); seed.append(label); - seed.append(rnd1, rnd1_len); - seed.append(rnd2, rnd2_len); + seed.append(rnd1); + seed.append(rnd2); #if defined(OPENSSL_VERSION_MAJOR) && (OPENSSL_VERSION_MAJOR >= 3) // setup OSSL_PARAM array: digest, secret, seed @@ -196,20 +181,23 @@ bool SSL_Analyzer::TLS12_PRF(const std::string& secret, const std::string& label *p++ = OSSL_PARAM_construct_octet_string(OSSL_KDF_PARAM_SEED, (void*)seed.data(), seed.size()); *p = OSSL_PARAM_construct_end(); + auto keybuf = std::vector(requested_len); + // set OSSL params if ( EVP_KDF_CTX_set_params(kctx, params) <= 0 ) goto abort; // derive key material - if ( EVP_KDF_derive(kctx, out, out_len, NULL) <= 0 ) + if ( EVP_KDF_derive(kctx, keybuf.data(), requested_len, nullptr) <= 0 ) goto abort; EVP_KDF_CTX_free(kctx); - return true; + return keybuf; abort: EVP_KDF_CTX_free(kctx); - return false; + return {}; #else /* OSSL 3 */ + auto keybuf = std::vector(requested_len); if ( EVP_PKEY_derive_init(pctx) <= 0 ) goto abort; /* Error */ // setup PKEY params: digest, secret, seed @@ -220,18 +208,18 @@ abort: goto abort; /* Error */ if ( EVP_PKEY_CTX_add1_tls1_prf_seed(pctx, seed.data(), seed.size()) <= 0 ) goto abort; /* Error */ - if ( EVP_PKEY_derive(pctx, out, &out_len) <= 0 ) + if ( EVP_PKEY_derive(pctx, keybuf.data(), &requested_len) <= 0 ) goto abort; /* Error */ EVP_PKEY_CTX_free(pctx); - return true; + return keubuf; abort: EVP_PKEY_CTX_free(pctx); #endif /* OSSL 3 */ #endif /* HAVE_KDF */ - return false; + return {}; } bool SSL_Analyzer::TryDecryptApplicationData(int len, const u_char* data, bool is_orig, @@ -260,19 +248,19 @@ bool SSL_Analyzer::TryDecryptApplicationData(int len, const u_char* data, bool i if ( secret.size() != 0 && keys.size() == 0 ) { #ifdef OPENSSL_HAVE_KDF_H - DBG_LOG(DBG_ANALYZER, "Deriving TLS keys for connection foo"); + DBG_LOG(DBG_ANALYZER, "Deriving TLS keys for connection"); uint32_t ts = htonl((uint32_t)handshake_interp->gmt_unix_time()); - char crand[32] = {0x00}; - u_char keybuf[72]; - auto c_rnd = handshake_interp->client_random(); auto s_rnd = handshake_interp->server_random(); - memcpy(crand, &(ts), 4); - memcpy(crand + 4, c_rnd.data(), c_rnd.length()); - auto res = TLS12_PRF(secret, "key expansion", (char*)s_rnd.data(), s_rnd.length(), crand, - sizeof(crand), keybuf, sizeof(keybuf)); + std::string crand; + crand.append(reinterpret_cast(&(ts)), 4); + crand.append(reinterpret_cast(c_rnd.data()), c_rnd.length()); + std::string srand(reinterpret_cast(s_rnd.data()), s_rnd.length()); + + // fixme - 72 should not be hardcoded + auto res = TLS12_PRF(secret, "key expansion", srand, crand, 72); if ( ! res ) { DBG_LOG(DBG_ANALYZER, "TLS PRF failed. Aborting.\n"); @@ -280,10 +268,11 @@ bool SSL_Analyzer::TryDecryptApplicationData(int len, const u_char* data, bool i } // save derived keys - SetKeys(sizeof(keybuf), keybuf); + SetKeys(res.value()); #else DBG_LOG(DBG_ANALYZER, "Cannot derive TLS keys as Zeek was compiled without "); + return false; #endif } diff --git a/src/analyzer/protocol/ssl/SSL.h b/src/analyzer/protocol/ssl/SSL.h index b96172fc1b..b19bd0b6c0 100644 --- a/src/analyzer/protocol/ssl/SSL.h +++ b/src/analyzer/protocol/ssl/SSL.h @@ -82,7 +82,7 @@ public: * * @param data Pointer to the key buffer as derived via TLS PRF */ - void SetKeys(size_t len, const u_char* data); + void SetKeys(const std::vector newkeys); /** * Try to decrypt TLS application data from a packet. Requires secret or keys to be set prior @@ -122,8 +122,9 @@ public: * * @return True, if the operation completed successfully, false otherwise */ - bool TLS12_PRF(const std::string& secret, const std::string& label, const char* rnd1, - size_t rnd1_len, const char* rnd2, size_t rnd2_len, u_char* out, size_t out_len); + std::optional> TLS12_PRF(const std::string& secret, + const std::string& label, const std::string& rnd1, + const std::string& rnd2, size_t requested_len); /** * Forward decrypted TLS application data to child analyzers