diff --git a/client_test.go b/client_test.go index 77cf8cc6a..9a767429d 100644 --- a/client_test.go +++ b/client_test.go @@ -449,6 +449,7 @@ var _ = Describe("Client", func() { Context("quic.Config", func() { It("setups with the right values", func() { + srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'} tokenStore := NewLRUTokenStore(10, 4) config := &Config{ HandshakeIdleTimeout: 1337 * time.Minute, @@ -456,7 +457,7 @@ var _ = Describe("Client", func() { MaxIncomingStreams: 1234, MaxIncomingUniStreams: 4321, ConnectionIDLength: 13, - StatelessResetKey: []byte("foobar"), + StatelessResetKey: srk, TokenStore: tokenStore, EnableDatagrams: true, } @@ -466,7 +467,7 @@ var _ = Describe("Client", func() { Expect(c.MaxIncomingStreams).To(BeEquivalentTo(1234)) Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(4321)) Expect(c.ConnectionIDLength).To(Equal(13)) - Expect(c.StatelessResetKey).To(Equal([]byte("foobar"))) + Expect(c.StatelessResetKey).To(Equal(srk)) Expect(c.TokenStore).To(Equal(tokenStore)) Expect(c.EnableDatagrams).To(BeTrue()) }) diff --git a/config_test.go b/config_test.go index bc7b5a2ed..e2bc81530 100644 --- a/config_test.go +++ b/config_test.go @@ -76,7 +76,7 @@ var _ = Describe("Config", func() { case "MaxIncomingUniStreams": f.Set(reflect.ValueOf(int64(12))) case "StatelessResetKey": - f.Set(reflect.ValueOf([]byte{1, 2, 3, 4})) + f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4})) case "KeepAlivePeriod": f.Set(reflect.ValueOf(time.Second)) case "EnableDatagrams": diff --git a/integrationtests/self/stateless_reset_test.go b/integrationtests/self/stateless_reset_test.go index 82718b47b..9f0525ef3 100644 --- a/integrationtests/self/stateless_reset_test.go +++ b/integrationtests/self/stateless_reset_test.go @@ -23,9 +23,9 @@ var _ = Describe("Stateless Resets", func() { connIDLen := connIDLens[i] It(fmt.Sprintf("sends and recognizes stateless resets, for %d byte connection IDs", connIDLen), func() { - statelessResetKey := make([]byte, 32) - rand.Read(statelessResetKey) - serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: statelessResetKey}) + var statelessResetKey quic.StatelessResetKey + rand.Read(statelessResetKey[:]) + serverConfig := getQuicConfig(&quic.Config{StatelessResetKey: &statelessResetKey}) ln, err := quic.ListenAddr("localhost:0", getTLSConfig(), serverConfig) Expect(err).ToNot(HaveOccurred()) diff --git a/interface.go b/interface.go index 611c3f56a..214afcf11 100644 --- a/interface.go +++ b/interface.go @@ -201,6 +201,9 @@ type EarlyConnection interface { NextConnection() Connection } +// StatelessResetKey is a key used to derive stateless reset tokens. +type StatelessResetKey [32]byte + // A ConnectionID is a QUIC Connection ID, as defined in RFC 9000. // It is not able to handle QUIC Connection IDs longer than 20 bytes, // as they are allowed by RFC 8999. @@ -309,7 +312,7 @@ type Config struct { MaxIncomingUniStreams int64 // The StatelessResetKey is used to generate stateless reset tokens. // If no key is configured, sending of stateless resets is disabled. - StatelessResetKey []byte + StatelessResetKey *StatelessResetKey // KeepAlivePeriod defines whether this peer will periodically send a packet to keep the connection alive. // If set to 0, then no keep alive is sent. Otherwise, the keep alive is sent on that period (or at most // every half of MaxIdleTimeout, whichever is smaller). diff --git a/mock_multiplexer_test.go b/mock_multiplexer_test.go index ccc750c11..e589d3605 100644 --- a/mock_multiplexer_test.go +++ b/mock_multiplexer_test.go @@ -36,7 +36,7 @@ func (m *MockMultiplexer) EXPECT() *MockMultiplexerMockRecorder { } // AddConn mocks base method. -func (m *MockMultiplexer) AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) { +func (m *MockMultiplexer) AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "AddConn", c, connIDLen, statelessResetKey, tracer) ret0, _ := ret[0].(packetHandlerManager) diff --git a/multiplexer.go b/multiplexer.go index 2271b5517..d1005039f 100644 --- a/multiplexer.go +++ b/multiplexer.go @@ -1,7 +1,6 @@ package quic import ( - "bytes" "fmt" "net" "sync" @@ -20,13 +19,13 @@ type indexableConn interface { } type multiplexer interface { - AddConn(c net.PacketConn, connIDLen int, statelessResetKey []byte, tracer logging.Tracer) (packetHandlerManager, error) + AddConn(c net.PacketConn, connIDLen int, statelessResetKey *StatelessResetKey, tracer logging.Tracer) (packetHandlerManager, error) RemoveConn(indexableConn) error } type connManager struct { connIDLen int - statelessResetKey []byte + statelessResetKey *StatelessResetKey tracer logging.Tracer manager packetHandlerManager } @@ -37,7 +36,7 @@ type connMultiplexer struct { mutex sync.Mutex conns map[string] /* LocalAddr().String() */ connManager - newPacketHandlerManager func(net.PacketConn, int, []byte, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests + newPacketHandlerManager func(net.PacketConn, int, *StatelessResetKey, logging.Tracer, utils.Logger) (packetHandlerManager, error) // so it can be replaced in the tests logger utils.Logger } @@ -58,7 +57,7 @@ func getMultiplexer() multiplexer { func (m *connMultiplexer) AddConn( c net.PacketConn, connIDLen int, - statelessResetKey []byte, + statelessResetKey *StatelessResetKey, tracer logging.Tracer, ) (packetHandlerManager, error) { m.mutex.Lock() @@ -83,7 +82,7 @@ func (m *connMultiplexer) AddConn( if p.connIDLen != connIDLen { return nil, fmt.Errorf("cannot use %d byte connection IDs on a connection that is already using %d byte connction IDs", connIDLen, p.connIDLen) } - if statelessResetKey != nil && !bytes.Equal(p.statelessResetKey, statelessResetKey) { + if statelessResetKey != nil && p.statelessResetKey != statelessResetKey { return nil, fmt.Errorf("cannot use different stateless reset keys on the same packet conn") } if tracer != p.tracer { diff --git a/multiplexer_test.go b/multiplexer_test.go index 06ab65904..843613785 100644 --- a/multiplexer_test.go +++ b/multiplexer_test.go @@ -25,15 +25,16 @@ var _ = Describe("Multiplexer", func() { }) It("recognizes when the same connection is added twice", func() { + srk := &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'} pconn := NewMockPacketConn(mockCtrl) pconn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}).Times(2) pconn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) conn := testConn{PacketConn: pconn} tracer := mocklogging.NewMockTracer(mockCtrl) - _, err := getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) + _, err := getMultiplexer().AddConn(conn, 8, srk, tracer) Expect(err).ToNot(HaveOccurred()) conn.counter++ - _, err = getMultiplexer().AddConn(conn, 8, []byte("foobar"), tracer) + _, err = getMultiplexer().AddConn(conn, 8, srk, tracer) Expect(err).ToNot(HaveOccurred()) Expect(getMultiplexer().(*connMultiplexer).conns).To(HaveLen(1)) }) @@ -49,12 +50,14 @@ var _ = Describe("Multiplexer", func() { }) It("errors when adding an existing conn with a different stateless rest key", func() { + srk1 := &StatelessResetKey{'f', 'o', 'o'} + srk2 := &StatelessResetKey{'b', 'a', 'r'} conn := NewMockPacketConn(mockCtrl) conn.EXPECT().ReadFrom(gomock.Any()).Do(func([]byte) { <-(make(chan struct{})) }).MaxTimes(1) conn.EXPECT().LocalAddr().Return(&net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 1234}).Times(2) - _, err := getMultiplexer().AddConn(conn, 7, []byte("foobar"), nil) + _, err := getMultiplexer().AddConn(conn, 7, srk1, nil) Expect(err).ToNot(HaveOccurred()) - _, err = getMultiplexer().AddConn(conn, 7, []byte("raboof"), nil) + _, err = getMultiplexer().AddConn(conn, 7, srk2, nil) Expect(err).To(MatchError("cannot use different stateless reset keys on the same packet conn")) }) diff --git a/packet_handler_map.go b/packet_handler_map.go index 6018765a4..55f35b7cf 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -105,7 +105,7 @@ var receiveBufferWarningOnce sync.Once func newPacketHandlerMap( c net.PacketConn, connIDLen int, - statelessResetKey []byte, + statelessResetKey *StatelessResetKey, tracer logging.Tracer, logger utils.Logger, ) (packetHandlerManager, error) { @@ -132,11 +132,13 @@ func newPacketHandlerMap( deleteRetiredConnsAfter: protocol.RetiredConnectionIDDeleteTimeout, zeroRTTQueueDuration: protocol.Max0RTTQueueingDuration, closeQueue: make(chan closePacket, 4), - statelessResetEnabled: len(statelessResetKey) > 0, - statelessResetHasher: hmac.New(sha256.New, statelessResetKey), + statelessResetEnabled: statelessResetKey != nil, tracer: tracer, logger: logger, } + if m.statelessResetEnabled { + m.statelessResetHasher = hmac.New(sha256.New, statelessResetKey[:]) + } go m.listen() go m.runCloseQueue() diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 8ed9f15cc..aa5599838 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -33,7 +33,7 @@ var _ = Describe("Packet Handler Map", func() { packetChan chan packetToRead connIDLen int - statelessResetKey []byte + statelessResetKey *StatelessResetKey ) getPacketWithPacketType := func(connID protocol.ConnectionID, t protocol.PacketType, length protocol.ByteCount) []byte { @@ -440,9 +440,9 @@ var _ = Describe("Packet Handler Map", func() { Context("generating", func() { BeforeEach(func() { - key := make([]byte, 32) - rand.Read(key) - statelessResetKey = key + var key StatelessResetKey + rand.Read(key[:]) + statelessResetKey = &key }) It("generates stateless reset tokens", func() { diff --git a/server_test.go b/server_test.go index cf3d40753..0da7caab4 100644 --- a/server_test.go +++ b/server_test.go @@ -140,7 +140,7 @@ var _ = Describe("Server", func() { HandshakeIdleTimeout: 1337 * time.Hour, MaxIdleTimeout: 42 * time.Minute, KeepAlivePeriod: 5 * time.Second, - StatelessResetKey: []byte("foobar"), + StatelessResetKey: &StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'}, RequireAddressValidation: requireAddrVal, } ln, err := Listen(conn, tlsConf, &config) @@ -152,7 +152,7 @@ var _ = Describe("Server", func() { Expect(server.config.MaxIdleTimeout).To(Equal(42 * time.Minute)) Expect(reflect.ValueOf(server.config.RequireAddressValidation)).To(Equal(reflect.ValueOf(requireAddrVal))) Expect(server.config.KeepAlivePeriod).To(Equal(5 * time.Second)) - Expect(server.config.StatelessResetKey).To(Equal([]byte("foobar"))) + Expect(server.config.StatelessResetKey).To(Equal(&StatelessResetKey{'f', 'o', 'o', 'b', 'a', 'r'})) // stop the listener Expect(ln.Close()).To(Succeed()) })