diff --git a/server.go b/server.go index b839ffb3..475cb3fd 100644 --- a/server.go +++ b/server.go @@ -401,10 +401,10 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro } var ( - token *Token - retrySrcConnectionID *protocol.ConnectionID + token *Token + retrySrcConnID *protocol.ConnectionID ) - origDestConnectionID := hdr.DestConnectionID + origDestConnID := hdr.DestConnectionID if len(hdr.Token) > 0 { c, err := s.tokenGenerator.DecodeToken(hdr.Token) if err == nil { @@ -414,8 +414,8 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro SentTime: c.SentTime, } if token.IsRetryToken { - origDestConnectionID = c.OriginalDestConnectionID - retrySrcConnectionID = &c.RetrySrcConnectionID + origDestConnID = c.OriginalDestConnectionID + retrySrcConnID = &c.RetrySrcConnectionID } } } @@ -451,68 +451,47 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro return err } s.logger.Debugf("Changing connection ID to %s.", connID) - sess := s.createNewSession( - p.remoteAddr, - origDestConnectionID, - retrySrcConnectionID, - hdr.DestConnectionID, - hdr.SrcConnectionID, - connID, - hdr.Version, - ) - if sess == nil { - p.buffer.Release() - return nil - } - sess.handlePacket(p) - s.zeroRTTQueue.DequeueToSession(hdr.DestConnectionID, sess) - return nil -} - -func (s *baseServer) createNewSession( - remoteAddr net.Addr, - origDestConnID protocol.ConnectionID, - retrySrcConnID *protocol.ConnectionID, - clientDestConnID protocol.ConnectionID, - destConnID protocol.ConnectionID, - srcConnID protocol.ConnectionID, - version protocol.VersionNumber, -) quicSession { var sess quicSession - if added := s.sessionHandler.AddWithConnID(clientDestConnID, srcConnID, func() packetHandler { + if added := s.sessionHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler { var tracer logging.ConnectionTracer if s.config.Tracer != nil { // Use the same connection ID that is passed to the client's GetLogWriter callback. - connID := clientDestConnID + connID := hdr.DestConnectionID if origDestConnID.Len() > 0 { connID = origDestConnID } tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID) } sess = s.newSession( - newSendConn(s.conn, remoteAddr), + newSendConn(s.conn, p.remoteAddr), s.sessionHandler, origDestConnID, retrySrcConnID, - clientDestConnID, - destConnID, - srcConnID, - s.sessionHandler.GetStatelessResetToken(srcConnID), + hdr.DestConnectionID, + hdr.SrcConnectionID, + connID, + s.sessionHandler.GetStatelessResetToken(connID), s.config, s.tlsConf, s.tokenGenerator, s.acceptEarlySessions, tracer, s.logger, - version, + hdr.Version, ) + sess.handlePacket(p) return sess }); !added { return nil } go sess.run() go s.handleNewSession(sess) - return sess + if sess == nil { + p.buffer.Release() + return nil + } + s.zeroRTTQueue.DequeueToSession(hdr.DestConnectionID, sess) + return nil } func (s *baseServer) handleNewSession(sess quicSession) { diff --git a/server_test.go b/server_test.go index 600c56ff..6de2ee71 100644 --- a/server_test.go +++ b/server_test.go @@ -958,6 +958,7 @@ var _ = Describe("Server", func() { }() ctx, cancel := context.WithCancel(context.Background()) // handshake context + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.newSession = func( _ sendConn, runner sessionRunner, @@ -975,6 +976,7 @@ var _ = Describe("Server", func() { _ utils.Logger, _ protocol.VersionNumber, ) quicSession { + sess.EXPECT().handlePacket(gomock.Any()) sess.EXPECT().HandshakeComplete().Return(ctx) sess.EXPECT().run().Do(func() {}) sess.EXPECT().Context().Return(context.Background()) @@ -986,7 +988,10 @@ var _ = Describe("Server", func() { return true }) tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()) - serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, nil, protocol.VersionWhatever) + serv.handleInitialImpl( + &receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, + ) Consistently(done).ShouldNot(BeClosed()) cancel() // complete the handshake Eventually(done).Should(BeClosed()) @@ -1026,6 +1031,7 @@ var _ = Describe("Server", func() { }() ready := make(chan struct{}) + serv.config.AcceptToken = func(_ net.Addr, _ *Token) bool { return true } serv.newSession = func( _ sendConn, runner sessionRunner, @@ -1044,6 +1050,7 @@ var _ = Describe("Server", func() { _ protocol.VersionNumber, ) quicSession { Expect(enable0RTT).To(BeTrue()) + sess.EXPECT().handlePacket(gomock.Any()) sess.EXPECT().run().Do(func() {}) sess.EXPECT().earlySessionReady().Return(ready) sess.EXPECT().Context().Return(context.Background()) @@ -1054,7 +1061,10 @@ var _ = Describe("Server", func() { fn() return true }) - serv.createNewSession(&net.UDPAddr{}, nil, nil, nil, nil, nil, protocol.VersionWhatever) + serv.handleInitialImpl( + &receivedPacket{buffer: getPacketBuffer()}, + &wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}}, + ) Consistently(done).ShouldNot(BeClosed()) close(ready) Eventually(done).Should(BeClosed())