diff --git a/server.go b/server.go index cab88dc2..65c78a9a 100644 --- a/server.go +++ b/server.go @@ -11,13 +11,12 @@ import ( "sync/atomic" "time" - "github.com/lucas-clemente/quic-go/qlog" - "github.com/lucas-clemente/quic-go/internal/handshake" "github.com/lucas-clemente/quic-go/internal/protocol" "github.com/lucas-clemente/quic-go/internal/qerr" "github.com/lucas-clemente/quic-go/internal/utils" "github.com/lucas-clemente/quic-go/internal/wire" + "github.com/lucas-clemente/quic-go/qlog" ) // packetHandler handles packets @@ -75,7 +74,23 @@ type baseServer struct { receivedPackets chan *receivedPacket // set as a member, so they can be set in the tests - newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* client dest connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, [16]byte, *Config, *tls.Config, *handshake.TokenGenerator, bool /* enable 0-RTT */, qlog.Tracer, utils.Logger, protocol.VersionNumber) quicSession + newSession func( + connection, + sessionRunner, + protocol.ConnectionID, /* original dest connection ID */ + *protocol.ConnectionID, /* retry src connection ID */ + protocol.ConnectionID, /* client dest connection ID */ + protocol.ConnectionID, /* destination connection ID */ + protocol.ConnectionID, /* source connection ID */ + [16]byte, + *Config, + *tls.Config, + *handshake.TokenGenerator, + bool, /* enable 0-RTT */ + qlog.Tracer, + utils.Logger, + protocol.VersionNumber, + ) quicSession serverError error errorChan chan struct{} @@ -347,7 +362,10 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return errors.New("too short connection ID") } - var token *Token + var ( + token *Token + retrySrcConnectionID *protocol.ConnectionID + ) origDestConnectionID := hdr.DestConnectionID if len(hdr.Token) > 0 { c, err := s.tokenGenerator.DecodeToken(hdr.Token) @@ -359,6 +377,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } if token.IsRetryToken { origDestConnectionID = c.OriginalDestConnectionID + retrySrcConnectionID = &c.RetrySrcConnectionID } } } @@ -396,6 +415,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro sess := s.createNewSession( p.remoteAddr, origDestConnectionID, + retrySrcConnectionID, hdr.DestConnectionID, hdr.SrcConnectionID, connID, @@ -419,6 +439,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro func (s *baseServer) createNewSession( remoteAddr net.Addr, origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, @@ -441,6 +462,7 @@ func (s *baseServer) createNewSession( &conn{pconn: s.conn, currentAddr: remoteAddr}, s.sessionHandler, origDestConnID, + retrySrcConnID, clientDestConnID, destConnID, srcConnID, diff --git a/server_test.go b/server_test.go index b466078b..675f3256 100644 --- a/server_test.go +++ b/server_test.go @@ -291,7 +291,11 @@ var _ = Describe("Server", func() { It("creates a session when the token is accepted", func() { serv.config.AcceptToken = func(_ net.Addr, token *Token) bool { return true } - retryToken, err := serv.tokenGenerator.NewRetryToken(&net.UDPAddr{}, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, nil) + retryToken, err := serv.tokenGenerator.NewRetryToken( + &net.UDPAddr{}, + protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad}, + ) Expect(err).ToNot(HaveOccurred()) hdr := &wire.Header{ IsLongHeader: true, @@ -305,16 +309,23 @@ var _ = Describe("Server", func() { run := make(chan struct{}) var token [16]byte rand.Read(token[:]) + var newConnID protocol.ConnectionID - phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte { + phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { newConnID = c - return token + phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte { + newConnID = c + return token + }) + fn() + return true }) sess := NewMockQuicSession(mockCtrl) serv.newSession = func( _ connection, _ sessionRunner, origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, @@ -329,6 +340,7 @@ var _ = Describe("Server", func() { ) quicSession { Expect(enable0RTT).To(BeFalse()) Expect(origDestConnID).To(Equal(protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})) + Expect(retrySrcConnID).To(Equal(&protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad})) Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID)) // make sure we're using a server-generated connection ID @@ -343,12 +355,6 @@ var _ = Describe("Server", func() { return sess } - phm.EXPECT().Add(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, sess).Return(true) - phm.EXPECT().Add(gomock.Any(), sess).DoAndReturn(func(c protocol.ConnectionID, _ packetHandler) bool { - Expect(c).To(Equal(newConnID)) - return true - }) - done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -475,13 +481,24 @@ var _ = Describe("Server", func() { run := make(chan struct{}) var token [16]byte rand.Read(token[:]) + var newConnID protocol.ConnectionID + phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { + newConnID = c + phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte { + newConnID = c + return token + }) + fn() + return true + }) sess := NewMockQuicSession(mockCtrl) serv.newSession = func( _ connection, _ sessionRunner, origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, @@ -496,6 +513,7 @@ var _ = Describe("Server", func() { ) quicSession { Expect(enable0RTT).To(BeFalse()) Expect(origDestConnID).To(Equal(hdr.DestConnectionID)) + Expect(retrySrcConnID).To(BeNil()) Expect(clientDestConnID).To(Equal(hdr.DestConnectionID)) Expect(destConnID).To(Equal(hdr.SrcConnectionID)) // make sure we're using a server-generated connection ID @@ -510,16 +528,6 @@ var _ = Describe("Server", func() { return sess } - phm.EXPECT().AddWithConnID(protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, gomock.Any(), gomock.Any()).DoAndReturn(func(_, c protocol.ConnectionID, fn func() packetHandler) bool { - newConnID = c - phm.EXPECT().GetStatelessResetToken(gomock.Any()).DoAndReturn(func(c protocol.ConnectionID) [16]byte { - newConnID = c - return token - }) - fn() - return true - }) - done := make(chan struct{}) go func() { defer GinkgoRecover() @@ -557,6 +565,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -599,6 +608,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -646,6 +656,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -675,6 +686,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -739,6 +751,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -846,6 +859,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -868,7 +882,7 @@ var _ = Describe("Server", func() { fn() return true }) - serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) + serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, nil, protocol.VersionWhatever) Consistently(done).ShouldNot(BeClosed()) cancel() // complete the handshake Eventually(done).Should(BeClosed()) @@ -912,6 +926,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -935,7 +950,7 @@ var _ = Describe("Server", func() { fn() return true }) - serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, protocol.VersionWhatever) + serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, nil, protocol.VersionWhatever) Consistently(done).ShouldNot(BeClosed()) close(ready) Eventually(done).Should(BeClosed()) @@ -949,6 +964,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, @@ -1007,6 +1023,7 @@ var _ = Describe("Server", func() { _ connection, runner sessionRunner, _ protocol.ConnectionID, + _ *protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, _ protocol.ConnectionID, diff --git a/session.go b/session.go index 4e6b5ddd..6e7e35ce 100644 --- a/session.go +++ b/session.go @@ -113,7 +113,9 @@ type session struct { handshakeDestConnID protocol.ConnectionID // Set for the client. Destination connection ID used on the first Initial sent. origDestConnID protocol.ConnectionID - srcConnIDLen int + retrySrcConnID *protocol.ConnectionID // only set for the client (and if a Retry was performed) + + srcConnIDLen int perspective protocol.Perspective initialVersion protocol.VersionNumber // if version negotiation is performed, this is the version we initially tried @@ -201,6 +203,7 @@ var newSession = func( conn connection, runner sessionRunner, origDestConnID protocol.ConnectionID, + retrySrcConnID *protocol.ConnectionID, clientDestConnID protocol.ConnectionID, destConnID protocol.ConnectionID, srcConnID protocol.ConnectionID, @@ -275,6 +278,7 @@ var newSession = func( OriginalDestinationConnectionID: origDestConnID, ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs, InitialSourceConnectionID: srcConnID, + RetrySourceConnectionID: retrySrcConnID, } if s.qlogger != nil { s.qlogger.SentTransportParameters(params) @@ -874,7 +878,7 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t return false } - s.logger.Debugf("<- Received Retry") + s.logger.Debugf("<- Received Retry: %#v", hdr) s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID) if s.qlogger != nil { s.qlogger.ReceivedRetry(hdr) @@ -886,6 +890,7 @@ func (s *session) handleRetryPacket(hdr *wire.Header, data []byte) bool /* was t return false } s.handshakeDestConnID = newDestConnID + s.retrySrcConnID = &newDestConnID s.cryptoStreamHandler.ChangeConnectionID(newDestConnID) s.packer.SetToken(hdr.Token) s.connIDManager.ChangeInitialConnID(newDestConnID) @@ -1284,6 +1289,12 @@ func (s *session) restoreTransportParameters(params *wire.TransportParameters) { } func (s *session) processTransportParameters(params *wire.TransportParameters) { + if err := s.processTransportParametersImpl(params); err != nil { + s.closeLocal(err) + } +} + +func (s *session) processTransportParametersImpl(params *wire.TransportParameters) error { if s.logger.Debug() { s.logger.Debugf("Processed Transport Parameters: %s", params) } @@ -1293,14 +1304,24 @@ func (s *session) processTransportParameters(params *wire.TransportParameters) { // check the initial_source_connection_id if !params.InitialSourceConnectionID.Equal(s.handshakeDestConnID) { - s.closeLocal(qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID))) - return + return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID)) } - // check the original_destination_connection_id - if s.perspective == protocol.PerspectiveClient && !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { - s.closeLocal(qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID))) - return + if s.perspective == protocol.PerspectiveClient { + // check the original_destination_connection_id + if !params.OriginalDestinationConnectionID.Equal(s.origDestConnID) { + return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID)) + } + if s.retrySrcConnID != nil { // a Retry was performed + if params.RetrySourceConnectionID == nil { + return qerr.NewError(qerr.TransportParameterError, "missing retry_source_connection_id") + } + if !(*params.RetrySourceConnectionID).Equal(*s.retrySrcConnID) { + return qerr.NewError(qerr.TransportParameterError, fmt.Sprintf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID)) + } + } else if params.RetrySourceConnectionID != nil { + return qerr.NewError(qerr.TransportParameterError, "received retry_source_connection_id, although no Retry was performed") + } } s.peerParams = params @@ -1308,8 +1329,7 @@ func (s *session) processTransportParameters(params *wire.TransportParameters) { s.idleTimeout = utils.MinNonZeroDuration(s.config.MaxIdleTimeout, params.MaxIdleTimeout) s.keepAliveInterval = utils.MinDuration(s.idleTimeout/2, protocol.MaxKeepAliveInterval) if err := s.streamsMap.UpdateLimits(params); err != nil { - s.closeLocal(err) - return + return err } s.packer.HandleTransportParameters(params) s.frameParser.SetAckDelayExponent(params.AckDelayExponent) @@ -1330,6 +1350,7 @@ func (s *session) processTransportParameters(params *wire.TransportParameters) { if s.perspective == protocol.PerspectiveServer { close(s.earlySessionReadyChan) } + return nil } func (s *session) sendPackets() error { diff --git a/session_test.go b/session_test.go index ca94d37c..dcd262b2 100644 --- a/session_test.go +++ b/session_test.go @@ -94,6 +94,7 @@ var _ = Describe("Session", func() { mconn, sessionRunner, nil, + nil, clientDestConnID, destConnID, srcConnID, @@ -2280,7 +2281,47 @@ var _ = Describe("Client Session", func() { Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected initial_source_connection_id to equal 0xdeadbeef, is 0xdecafbad"))) }) - It("errors if the TransportParameters contain a wrong original_destination_connection_id", func() { + It("errors if the transport parameters don't contain the retry_source_connection_id, if a Retry was performed", func() { + sess.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, + StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose() + qlogger.EXPECT().ReceivedTransportParameters(params) + sess.processTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: missing retry_source_connection_id"))) + }) + + It("errors if the transport parameters contain the wrong retry_source_connection_id, if a Retry was performed", func() { + sess.retrySrcConnID = &protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose() + qlogger.EXPECT().ReceivedTransportParameters(params) + sess.processTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: expected retry_source_connection_id to equal 0xdeadbeef, is 0xdeadc0de"))) + }) + + It("errors if the transport parameters contain the retry_source_connection_id, if no Retry was performed", func() { + params := &wire.TransportParameters{ + OriginalDestinationConnectionID: destConnID, + InitialSourceConnectionID: destConnID, + RetrySourceConnectionID: &protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde}, + StatelessResetToken: &[16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + } + expectClose() + qlogger.EXPECT().ReceivedTransportParameters(params) + sess.processTransportParameters(params) + Eventually(errChan).Should(Receive(MatchError("TRANSPORT_PARAMETER_ERROR: received retry_source_connection_id, although no Retry was performed"))) + }) + + It("errors if the transport parameters contain a wrong original_destination_connection_id", func() { sess.origDestConnID = protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} params := &wire.TransportParameters{ OriginalDestinationConnectionID: protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad},