diff --git a/src/libtls/tls_hkdf.c b/src/libtls/tls_hkdf.c index b03efa0bd..002d481f1 100644 --- a/src/libtls/tls_hkdf.c +++ b/src/libtls/tls_hkdf.c @@ -22,6 +22,11 @@ typedef struct private_tls_hkdf_t private_tls_hkdf_t; +typedef struct cached_secrets_t { + chunk_t client; + chunk_t server; +} cached_secrets_t; + typedef enum hkdf_phase { HKDF_PHASE_0, HKDF_PHASE_1, @@ -67,11 +72,14 @@ struct private_tls_hkdf_t { chunk_t prk; /** - * Current implementation needs a copy of derived secrets to calculate the - * proper finished key. + * Handshake traffic secrets. */ - chunk_t client_traffic_secret; - chunk_t server_traffic_secret; + cached_secrets_t handshake_traffic_secrets; + + /** + * Current traffic secrets. + */ + cached_secrets_t traffic_secrets; }; static char *hkdf_labels[] = { @@ -418,11 +426,11 @@ METHOD(tls_hkdf_t, generate_secret, bool, if (label == TLS_HKDF_UPD_C_TRAFFIC || label == TLS_HKDF_UPD_S_TRAFFIC) { - chunk_t previous = this->client_traffic_secret; + chunk_t previous = this->traffic_secrets.client; if (label == TLS_HKDF_UPD_S_TRAFFIC) { - previous = this->server_traffic_secret; + previous = this->traffic_secrets.server; } if (!expand_label(this, previous, chunk_from_str("traffic upd"), @@ -446,16 +454,22 @@ METHOD(tls_hkdf_t, generate_secret, bool, switch (label) { case TLS_HKDF_C_HS_TRAFFIC: + chunk_clear(&this->handshake_traffic_secrets.client); + this->handshake_traffic_secrets.client = chunk_clone(okm); + /* fall-through */ case TLS_HKDF_C_AP_TRAFFIC: case TLS_HKDF_UPD_C_TRAFFIC: - chunk_clear(&this->client_traffic_secret); - this->client_traffic_secret = chunk_clone(okm); + chunk_clear(&this->traffic_secrets.client); + this->traffic_secrets.client = chunk_clone(okm); break; case TLS_HKDF_S_HS_TRAFFIC: + chunk_clear(&this->handshake_traffic_secrets.server); + this->handshake_traffic_secrets.server = chunk_clone(okm); + /* fall-through */ case TLS_HKDF_S_AP_TRAFFIC: case TLS_HKDF_UPD_S_TRAFFIC: - chunk_clear(&this->server_traffic_secret); - this->server_traffic_secret = chunk_clone(okm); + chunk_clear(&this->traffic_secrets.server); + this->traffic_secrets.server = chunk_clone(okm); break; default: break; @@ -476,12 +490,12 @@ METHOD(tls_hkdf_t, generate_secret, bool, * Derive keys/IVs from the current traffic secrets. */ static bool get_shared_label_keys(private_tls_hkdf_t *this, chunk_t label, + cached_secrets_t *secrets, bool server, size_t length, chunk_t *key) { chunk_t result = chunk_empty, secret; - secret = server ? this->server_traffic_secret - : this->client_traffic_secret; + secret = server ? secrets->server : secrets->client; if (!expand_label(this, secret, label, chunk_empty, length, &result)) { @@ -504,22 +518,22 @@ static bool get_shared_label_keys(private_tls_hkdf_t *this, chunk_t label, METHOD(tls_hkdf_t, derive_key, bool, private_tls_hkdf_t *this, bool is_server, size_t length, chunk_t *key) { - return get_shared_label_keys(this, chunk_from_str("key"), is_server, - length, key); + return get_shared_label_keys(this, chunk_from_str("key"), + &this->traffic_secrets, is_server, length, key); } METHOD(tls_hkdf_t, derive_iv, bool, private_tls_hkdf_t *this, bool is_server, size_t length, chunk_t *iv) { - return get_shared_label_keys(this, chunk_from_str("iv"), is_server, - length, iv); + return get_shared_label_keys(this, chunk_from_str("iv"), + &this->traffic_secrets, is_server, length, iv); } METHOD(tls_hkdf_t, derive_finished, bool, private_tls_hkdf_t *this, bool server, chunk_t *finished) { return get_shared_label_keys(this, chunk_from_str("finished"), - server, + &this->handshake_traffic_secrets, server, this->hasher->get_hash_size(this->hasher), finished); } @@ -580,14 +594,23 @@ METHOD(tls_hkdf_t, allocate_bytes, bool, this->prf->allocate_bytes(this->prf, seed, out); } +/** + * Clean up secrets + */ +static void destroy_secrets(cached_secrets_t *secrets) +{ + chunk_clear(&secrets->client); + chunk_clear(&secrets->server); +} + METHOD(tls_hkdf_t, destroy, void, private_tls_hkdf_t *this) { chunk_clear(&this->psk); chunk_clear(&this->prk); chunk_clear(&this->shared_secret); - chunk_clear(&this->client_traffic_secret); - chunk_clear(&this->server_traffic_secret); + destroy_secrets(&this->handshake_traffic_secrets); + destroy_secrets(&this->traffic_secrets); DESTROY_IF(this->prf); DESTROY_IF(this->hasher); free(this);