Browse Source

[tls] Support stateful session resumption

Record the session ID (if any) provided by the server and attempt to
reuse it for any concurrent connections to the same server.

If multiple connections are initiated concurrently (e.g. when using
PeerDist) then defer sending the ClientHello for all but the first
connection, to allow time for the first connection to potentially
obtain a session ID (and thereby speed up the negotiation for all
remaining connections).

Signed-off-by: Michael Brown <mcb30@ipxe.org>
tags/v1.20.1
Michael Brown 5 years ago
parent
commit
272fe32529
2 changed files with 220 additions and 10 deletions
  1. 29
    2
      src/include/ipxe/tls.h
  2. 191
    8
      src/net/tls.c

+ 29
- 2
src/include/ipxe/tls.h View File

@@ -242,13 +242,40 @@ struct md5_sha1_digest {
242 242
 /** MD5+SHA1 digest size */
243 243
 #define MD5_SHA1_DIGEST_SIZE sizeof ( struct md5_sha1_digest )
244 244
 
245
-/** A TLS connection */
246
-struct tls_connection {
245
+/** A TLS session */
246
+struct tls_session {
247 247
 	/** Reference counter */
248 248
 	struct refcnt refcnt;
249
+	/** List of sessions */
250
+	struct list_head list;
249 251
 
250 252
 	/** Server name */
251 253
 	const char *name;
254
+	/** Session ID */
255
+	uint8_t id[32];
256
+	/** Length of session ID */
257
+	size_t id_len;
258
+	/** Master secret */
259
+	uint8_t master_secret[48];
260
+
261
+	/** List of connections */
262
+	struct list_head conn;
263
+};
264
+
265
+/** A TLS connection */
266
+struct tls_connection {
267
+	/** Reference counter */
268
+	struct refcnt refcnt;
269
+
270
+	/** Session */
271
+	struct tls_session *session;
272
+	/** List of connections within the same session */
273
+	struct list_head list;
274
+	/** Session ID */
275
+	uint8_t session_id[32];
276
+	/** Length of session ID */
277
+	size_t session_id_len;
278
+
252 279
 	/** Plaintext stream */
253 280
 	struct interface plainstream;
254 281
 	/** Ciphertext stream */

+ 191
- 8
src/net/tls.c View File

@@ -175,6 +175,10 @@ FILE_LICENCE ( GPL2_OR_LATER );
175 175
 	__einfo_uniqify ( EINFO_EPROTO, 0x01,				\
176 176
 			  "Illegal protocol version upgrade" )
177 177
 
178
+/** List of TLS session */
179
+static LIST_HEAD ( tls_sessions );
180
+
181
+static void tls_tx_resume_all ( struct tls_session *session );
178 182
 static int tls_send_plaintext ( struct tls_connection *tls, unsigned int type,
179 183
 				const void *data, size_t len );
180 184
 static void tls_clear_cipher ( struct tls_connection *tls,
@@ -307,6 +311,25 @@ struct rsa_digestinfo_prefix rsa_md5_sha1_prefix __rsa_digestinfo_prefix = {
307 311
  ******************************************************************************
308 312
  */
309 313
 
314
+/**
315
+ * Free TLS session
316
+ *
317
+ * @v refcnt		Reference counter
318
+ */
319
+static void free_tls_session ( struct refcnt *refcnt ) {
320
+	struct tls_session *session =
321
+		container_of ( refcnt, struct tls_session, refcnt );
322
+
323
+	/* Sanity check */
324
+	assert ( list_empty ( &session->conn ) );
325
+
326
+	/* Remove from list of sessions */
327
+	list_del ( &session->list );
328
+
329
+	/* Free session */
330
+	free ( session );
331
+}
332
+
310 333
 /**
311 334
  * Free TLS connection
312 335
  *
@@ -315,6 +338,7 @@ struct rsa_digestinfo_prefix rsa_md5_sha1_prefix __rsa_digestinfo_prefix = {
315 338
 static void free_tls ( struct refcnt *refcnt ) {
316 339
 	struct tls_connection *tls =
317 340
 		container_of ( refcnt, struct tls_connection, refcnt );
341
+	struct tls_session *session = tls->session;
318 342
 	struct io_buffer *iobuf;
319 343
 	struct io_buffer *tmp;
320 344
 
@@ -330,8 +354,12 @@ static void free_tls ( struct refcnt *refcnt ) {
330 354
 	x509_put ( tls->cert );
331 355
 	x509_chain_put ( tls->chain );
332 356
 
357
+	/* Drop reference to session */
358
+	assert ( list_empty ( &tls->list ) );
359
+	ref_put ( &session->refcnt );
360
+
333 361
 	/* Free TLS structure itself */
334
-	free ( tls );	
362
+	free ( tls );
335 363
 }
336 364
 
337 365
 /**
@@ -353,6 +381,13 @@ static void tls_close ( struct tls_connection *tls, int rc ) {
353 381
 	intf_shutdown ( &tls->cipherstream, rc );
354 382
 	intf_shutdown ( &tls->plainstream, rc );
355 383
 	intf_shutdown ( &tls->validator, rc );
384
+
385
+	/* Remove from session */
386
+	list_del ( &tls->list );
387
+	INIT_LIST_HEAD ( &tls->list );
388
+
389
+	/* Resume all other connections, in case we were the lead connection */
390
+	tls_tx_resume_all ( tls->session );
356 391
 }
357 392
 
358 393
 /******************************************************************************
@@ -928,6 +963,18 @@ static void tls_tx_resume ( struct tls_connection *tls ) {
928 963
 	process_add ( &tls->process );
929 964
 }
930 965
 
966
+/**
967
+ * Resume TX state machine for all connections within a session
968
+ *
969
+ * @v session		TLS session
970
+ */
971
+static void tls_tx_resume_all ( struct tls_session *session ) {
972
+	struct tls_connection *tls;
973
+
974
+	list_for_each_entry ( tls, &session->conn, list )
975
+		tls_tx_resume ( tls );
976
+}
977
+
931 978
 /**
932 979
  * Transmit Handshake record
933 980
  *
@@ -953,11 +1000,14 @@ static int tls_send_handshake ( struct tls_connection *tls,
953 1000
  * @ret rc		Return status code
954 1001
  */
955 1002
 static int tls_send_client_hello ( struct tls_connection *tls ) {
1003
+	struct tls_session *session = tls->session;
1004
+	size_t name_len = strlen ( session->name );
956 1005
 	struct {
957 1006
 		uint32_t type_length;
958 1007
 		uint16_t version;
959 1008
 		uint8_t random[32];
960 1009
 		uint8_t session_id_len;
1010
+		uint8_t session_id[session->id_len];
961 1011
 		uint16_t cipher_suite_len;
962 1012
 		uint16_t cipher_suites[TLS_NUM_CIPHER_SUITES];
963 1013
 		uint8_t compression_methods_len;
@@ -971,7 +1021,7 @@ static int tls_send_client_hello ( struct tls_connection *tls ) {
971 1021
 				struct {
972 1022
 					uint8_t type;
973 1023
 					uint16_t len;
974
-					uint8_t name[ strlen ( tls->name ) ];
1024
+					uint8_t name[name_len];
975 1025
 				} __attribute__ (( packed )) list[1];
976 1026
 			} __attribute__ (( packed )) server_name;
977 1027
 			uint16_t max_fragment_length_type;
@@ -999,12 +1049,22 @@ static int tls_send_client_hello ( struct tls_connection *tls ) {
999 1049
 	struct tls_signature_hash_algorithm *sighash;
1000 1050
 	unsigned int i;
1001 1051
 
1052
+	/* Record requested session ID and associated master secret */
1053
+	memcpy ( tls->session_id, session->id, sizeof ( tls->session_id ) );
1054
+	tls->session_id_len = session->id_len;
1055
+	memcpy ( tls->master_secret, session->master_secret,
1056
+		 sizeof ( tls->master_secret ) );
1057
+
1058
+	/* Construct record */
1002 1059
 	memset ( &hello, 0, sizeof ( hello ) );
1003 1060
 	hello.type_length = ( cpu_to_le32 ( TLS_CLIENT_HELLO ) |
1004 1061
 			      htonl ( sizeof ( hello ) -
1005 1062
 				      sizeof ( hello.type_length ) ) );
1006 1063
 	hello.version = htons ( tls->version );
1007 1064
 	memcpy ( &hello.random, &tls->client_random, sizeof ( hello.random ) );
1065
+	hello.session_id_len = tls->session_id_len;
1066
+	memcpy ( hello.session_id, tls->session_id,
1067
+		 sizeof ( hello.session_id ) );
1008 1068
 	hello.cipher_suite_len = htons ( sizeof ( hello.cipher_suites ) );
1009 1069
 	i = 0 ; for_each_table_entry ( suite, TLS_CIPHER_SUITES )
1010 1070
 		hello.cipher_suites[i++] = suite->code;
@@ -1018,7 +1078,7 @@ static int tls_send_client_hello ( struct tls_connection *tls ) {
1018 1078
 	hello.extensions.server_name.list[0].type = TLS_SERVER_NAME_HOST_NAME;
1019 1079
 	hello.extensions.server_name.list[0].len
1020 1080
 		= htons ( sizeof ( hello.extensions.server_name.list[0].name ));
1021
-	memcpy ( hello.extensions.server_name.list[0].name, tls->name,
1081
+	memcpy ( hello.extensions.server_name.list[0].name, session->name,
1022 1082
 		 sizeof ( hello.extensions.server_name.list[0].name ) );
1023 1083
 	hello.extensions.max_fragment_length_type
1024 1084
 		= htons ( TLS_MAX_FRAGMENT_LENGTH );
@@ -1513,8 +1573,34 @@ static int tls_new_server_hello ( struct tls_connection *tls,
1513 1573
 	if ( ( rc = tls_select_cipher ( tls, hello_b->cipher_suite ) ) != 0 )
1514 1574
 		return rc;
1515 1575
 
1516
-	/* Generate secrets */
1517
-	tls_generate_master_secret ( tls );
1576
+	/* Reuse or generate master secret */
1577
+	if ( hello_a->session_id_len &&
1578
+	     ( hello_a->session_id_len == tls->session_id_len ) &&
1579
+	     ( memcmp ( session_id, tls->session_id,
1580
+			tls->session_id_len ) == 0 ) ) {
1581
+
1582
+		/* Session ID match: reuse master secret */
1583
+		DBGC ( tls, "TLS %p resuming session ID:\n", tls );
1584
+		DBGC_HDA ( tls, 0, tls->session_id, tls->session_id_len );
1585
+
1586
+	} else {
1587
+
1588
+		/* Generate new master secret */
1589
+		tls_generate_master_secret ( tls );
1590
+
1591
+		/* Record new session ID, if present */
1592
+		if ( hello_a->session_id_len &&
1593
+		     ( hello_a->session_id_len <= sizeof ( tls->session_id ))){
1594
+			tls->session_id_len = hello_a->session_id_len;
1595
+			memcpy ( tls->session_id, session_id,
1596
+				 tls->session_id_len );
1597
+			DBGC ( tls, "TLS %p new session ID:\n", tls );
1598
+			DBGC_HDA ( tls, 0, tls->session_id,
1599
+				   tls->session_id_len );
1600
+		}
1601
+	}
1602
+
1603
+	/* Generate keys */
1518 1604
 	if ( ( rc = tls_generate_keys ( tls ) ) != 0 )
1519 1605
 		return rc;
1520 1606
 
@@ -1739,6 +1825,7 @@ static int tls_new_server_hello_done ( struct tls_connection *tls,
1739 1825
  */
1740 1826
 static int tls_new_finished ( struct tls_connection *tls,
1741 1827
 			      const void *data, size_t len ) {
1828
+	struct tls_session *session = tls->session;
1742 1829
 	struct digest_algorithm *digest = tls->handshake_digest;
1743 1830
 	const struct {
1744 1831
 		uint8_t verify_data[ sizeof ( tls->verify.server ) ];
@@ -1767,6 +1854,30 @@ static int tls_new_finished ( struct tls_connection *tls,
1767 1854
 	/* Mark server as finished */
1768 1855
 	pending_put ( &tls->server_negotiation );
1769 1856
 
1857
+	/* If we are resuming a session (i.e. if the server Finished
1858
+	 * arrives before the client Finished is sent), then schedule
1859
+	 * transmission of Change Cipher and Finished.
1860
+	 */
1861
+	if ( is_pending ( &tls->client_negotiation ) ) {
1862
+		tls->tx_pending |= ( TLS_TX_CHANGE_CIPHER | TLS_TX_FINISHED );
1863
+		tls_tx_resume ( tls );
1864
+	}
1865
+
1866
+	/* Record session ID and master secret, if applicable */
1867
+	if ( tls->session_id_len ) {
1868
+		session->id_len = tls->session_id_len;
1869
+		memcpy ( session->id, tls->session_id, sizeof ( session->id ) );
1870
+		memcpy ( session->master_secret, tls->master_secret,
1871
+			 sizeof ( session->master_secret ) );
1872
+	}
1873
+
1874
+	/* Move to end of session's connection list and allow other
1875
+	 * connections to start making progress.
1876
+	 */
1877
+	list_del ( &tls->list );
1878
+	list_add_tail ( &tls->list, &session->conn );
1879
+	tls_tx_resume_all ( session );
1880
+
1770 1881
 	/* Send notification of a window change */
1771 1882
 	xfer_window_changed ( &tls->plainstream );
1772 1883
 
@@ -2608,6 +2719,7 @@ static struct interface_descriptor tls_cipherstream_desc =
2608 2719
  * @v rc		Reason for completion
2609 2720
  */
2610 2721
 static void tls_validator_done ( struct tls_connection *tls, int rc ) {
2722
+	struct tls_session *session = tls->session;
2611 2723
 	struct tls_cipherspec *cipherspec = &tls->tx_cipherspec_pending;
2612 2724
 	struct pubkey_algorithm *pubkey = cipherspec->suite->pubkey;
2613 2725
 	struct x509_certificate *cert;
@@ -2628,9 +2740,9 @@ static void tls_validator_done ( struct tls_connection *tls, int rc ) {
2628 2740
 	assert ( cert != NULL );
2629 2741
 
2630 2742
 	/* Verify server name */
2631
-	if ( ( rc = x509_check_name ( cert, tls->name ) ) != 0 ) {
2743
+	if ( ( rc = x509_check_name ( cert, session->name ) ) != 0 ) {
2632 2744
 		DBGC ( tls, "TLS %p server certificate does not match %s: %s\n",
2633
-		       tls, tls->name, strerror ( rc ) );
2745
+		       tls, session->name, strerror ( rc ) );
2634 2746
 		goto err;
2635 2747
 	}
2636 2748
 
@@ -2682,6 +2794,8 @@ static struct interface_descriptor tls_validator_desc =
2682 2794
  * @v tls		TLS connection
2683 2795
  */
2684 2796
 static void tls_tx_step ( struct tls_connection *tls ) {
2797
+	struct tls_session *session = tls->session;
2798
+	struct tls_connection *conn;
2685 2799
 	int rc;
2686 2800
 
2687 2801
 	/* Wait for cipherstream to become ready */
@@ -2690,6 +2804,17 @@ static void tls_tx_step ( struct tls_connection *tls ) {
2690 2804
 
2691 2805
 	/* Send first pending transmission */
2692 2806
 	if ( tls->tx_pending & TLS_TX_CLIENT_HELLO ) {
2807
+		/* Wait for session ID to become available unless we
2808
+		 * are the lead connection within the session.
2809
+		 */
2810
+		if ( session->id_len == 0 ) {
2811
+			list_for_each_entry ( conn, &session->conn, list ) {
2812
+				if ( conn == tls )
2813
+					break;
2814
+				if ( is_pending ( &conn->server_negotiation ) )
2815
+					return;
2816
+			}
2817
+		}
2693 2818
 		/* Send Client Hello */
2694 2819
 		if ( ( rc = tls_send_client_hello ( tls ) ) != 0 ) {
2695 2820
 			DBGC ( tls, "TLS %p could not send Client Hello: %s\n",
@@ -2766,6 +2891,60 @@ static void tls_tx_step ( struct tls_connection *tls ) {
2766 2891
 static struct process_descriptor tls_process_desc =
2767 2892
 	PROC_DESC_ONCE ( struct tls_connection, process, tls_tx_step );
2768 2893
 
2894
+/******************************************************************************
2895
+ *
2896
+ * Session management
2897
+ *
2898
+ ******************************************************************************
2899
+ */
2900
+
2901
+/**
2902
+ * Find or create session for TLS connection
2903
+ *
2904
+ * @v tls		TLS connection
2905
+ * @v name		Server name
2906
+ * @ret rc		Return status code
2907
+ */
2908
+static int tls_session ( struct tls_connection *tls, const char *name ) {
2909
+	struct tls_session *session;
2910
+	char *name_copy;
2911
+	int rc;
2912
+
2913
+	/* Find existing matching session, if any */
2914
+	list_for_each_entry ( session, &tls_sessions, list ) {
2915
+		if ( strcmp ( name, session->name ) == 0 ) {
2916
+			ref_get ( &session->refcnt );
2917
+			tls->session = session;
2918
+			DBGC ( tls, "TLS %p joining session %s\n", tls, name );
2919
+			return 0;
2920
+		}
2921
+	}
2922
+
2923
+	/* Create new session */
2924
+	session = zalloc ( sizeof ( *session ) + strlen ( name )
2925
+			   + 1 /* NUL */ );
2926
+	if ( ! session ) {
2927
+		rc = -ENOMEM;
2928
+		goto err_alloc;
2929
+	}
2930
+	ref_init ( &session->refcnt, free_tls_session );
2931
+	name_copy = ( ( ( void * ) session ) + sizeof ( *session ) );
2932
+	strcpy ( name_copy, name );
2933
+	session->name = name_copy;
2934
+	INIT_LIST_HEAD ( &session->conn );
2935
+	list_add ( &session->list, &tls_sessions );
2936
+
2937
+	/* Record session */
2938
+	tls->session = session;
2939
+
2940
+	DBGC ( tls, "TLS %p created session %s\n", tls, name );
2941
+	return 0;
2942
+
2943
+	ref_put ( &session->refcnt );
2944
+ err_alloc:
2945
+	return rc;
2946
+}
2947
+
2769 2948
 /******************************************************************************
2770 2949
  *
2771 2950
  * Instantiator
@@ -2786,7 +2965,7 @@ int add_tls ( struct interface *xfer, const char *name,
2786 2965
 	}
2787 2966
 	memset ( tls, 0, sizeof ( *tls ) );
2788 2967
 	ref_init ( &tls->refcnt, free_tls );
2789
-	tls->name = name;
2968
+	INIT_LIST_HEAD ( &tls->list );
2790 2969
 	intf_init ( &tls->plainstream, &tls_plainstream_desc, &tls->refcnt );
2791 2970
 	intf_init ( &tls->cipherstream, &tls_cipherstream_desc, &tls->refcnt );
2792 2971
 	intf_init ( &tls->validator, &tls_validator_desc, &tls->refcnt );
@@ -2809,6 +2988,9 @@ int add_tls ( struct interface *xfer, const char *name,
2809 2988
 		      ( sizeof ( tls->pre_master_secret.random ) ) ) ) != 0 ) {
2810 2989
 		goto err_random;
2811 2990
 	}
2991
+	if ( ( rc = tls_session ( tls, name ) ) != 0 )
2992
+		goto err_session;
2993
+	list_add_tail ( &tls->list, &tls->session->conn );
2812 2994
 
2813 2995
 	/* Start negotiation */
2814 2996
 	tls_restart ( tls );
@@ -2819,6 +3001,7 @@ int add_tls ( struct interface *xfer, const char *name,
2819 3001
 	ref_put ( &tls->refcnt );
2820 3002
 	return 0;
2821 3003
 
3004
+ err_session:
2822 3005
  err_random:
2823 3006
 	ref_put ( &tls->refcnt );
2824 3007
  err_alloc:

Loading…
Cancel
Save