diff --git a/mock_session_handler_test.go b/mock_session_handler_test.go new file mode 100644 index 00000000..522cee2a --- /dev/null +++ b/mock_session_handler_test.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: SessionHandler) + +// Package quic is a generated GoMock package. +package quic + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockSessionHandler is a mock of SessionHandler interface +type MockSessionHandler struct { + ctrl *gomock.Controller + recorder *MockSessionHandlerMockRecorder +} + +// MockSessionHandlerMockRecorder is the mock recorder for MockSessionHandler +type MockSessionHandlerMockRecorder struct { + mock *MockSessionHandler +} + +// NewMockSessionHandler creates a new mock instance +func NewMockSessionHandler(ctrl *gomock.Controller) *MockSessionHandler { + mock := &MockSessionHandler{ctrl: ctrl} + mock.recorder = &MockSessionHandlerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockSessionHandler) EXPECT() *MockSessionHandlerMockRecorder { + return m.recorder +} + +// Add mocks base method +func (m *MockSessionHandler) Add(arg0 protocol.ConnectionID, arg1 packetHandler) { + m.ctrl.Call(m, "Add", arg0, arg1) +} + +// Add indicates an expected call of Add +func (mr *MockSessionHandlerMockRecorder) Add(arg0, arg1 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockSessionHandler)(nil).Add), arg0, arg1) +} + +// Close mocks base method +func (m *MockSessionHandler) Close() { + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close +func (mr *MockSessionHandlerMockRecorder) Close() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSessionHandler)(nil).Close)) +} + +// Get mocks base method +func (m *MockSessionHandler) Get(arg0 protocol.ConnectionID) (packetHandler, bool) { + ret := m.ctrl.Call(m, "Get", arg0) + ret0, _ := ret[0].(packetHandler) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Get indicates an expected call of Get +func (mr *MockSessionHandlerMockRecorder) Get(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSessionHandler)(nil).Get), arg0) +} + +// Remove mocks base method +func (m *MockSessionHandler) Remove(arg0 protocol.ConnectionID) { + m.ctrl.Call(m, "Remove", arg0) +} + +// Remove indicates an expected call of Remove +func (mr *MockSessionHandlerMockRecorder) Remove(arg0 interface{}) *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockSessionHandler)(nil).Remove), arg0) +} diff --git a/mockgen.go b/mockgen.go index f2d83575..2ebef9cc 100644 --- a/mockgen.go +++ b/mockgen.go @@ -13,5 +13,6 @@ package quic //go:generate sh -c "./mockgen_private.sh quic mock_gquic_aead_test.go github.com/lucas-clemente/quic-go gQUICAEAD GQUICAEAD" //go:generate sh -c "./mockgen_private.sh quic mock_session_runner_test.go github.com/lucas-clemente/quic-go sessionRunner SessionRunner" //go:generate sh -c "./mockgen_private.sh quic mock_packet_handler_test.go github.com/lucas-clemente/quic-go packetHandler PacketHandler" +//go:generate sh -c "./mockgen_private.sh quic mock_session_handler_test.go github.com/lucas-clemente/quic-go sessionHandler SessionHandler" //go:generate sh -c "find . -type f -name 'mock_*_test.go' | xargs sed -i '' 's/quic_go.//g'" //go:generate sh -c "goimports -w mock*_test.go" diff --git a/server.go b/server.go index 105ce9ff..963801b7 100644 --- a/server.go +++ b/server.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "net" - "sync" "time" "github.com/lucas-clemente/quic-go/internal/crypto" @@ -42,6 +41,13 @@ func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectio var _ sessionRunner = &runner{} +type sessionHandler interface { + Add(protocol.ConnectionID, packetHandler) + Get(protocol.ConnectionID) (packetHandler, bool) + Remove(protocol.ConnectionID) + Close() +} + // A Listener of QUIC type server struct { tlsConf *tls.Config @@ -55,9 +61,7 @@ type server struct { certChain crypto.CertChain scfg *handshake.ServerConfig - sessionsMutex sync.RWMutex - sessions map[string] /* string(ConnectionID)*/ packetHandler - closed bool + sessionHandler sessionHandler serverError error @@ -65,9 +69,8 @@ type server struct { errorChan chan struct{} sessionRunner sessionRunner - // set as members, so they can be set in the tests - newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error) - deleteClosedSessionsAfter time.Duration + // set as a member, so they can be set in the tests + newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (packetHandler, error) logger utils.Logger } @@ -115,18 +118,17 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, } s := &server{ - conn: conn, - tlsConf: tlsConf, - config: config, - certChain: certChain, - scfg: scfg, - sessions: map[string]packetHandler{}, - newSession: newSession, - deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, - sessionQueue: make(chan Session, 5), - errorChan: make(chan struct{}), - supportsTLS: supportsTLS, - logger: utils.DefaultLogger.WithPrefix("server"), + conn: conn, + tlsConf: tlsConf, + config: config, + certChain: certChain, + scfg: scfg, + newSession: newSession, + sessionHandler: newSessionMap(), + sessionQueue: make(chan Session, 5), + errorChan: make(chan struct{}), + supportsTLS: supportsTLS, + logger: utils.DefaultLogger.WithPrefix("server"), } s.setup() if supportsTLS { @@ -142,7 +144,7 @@ func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, func (s *server) setup() { s.sessionRunner = &runner{ onHandshakeCompleteImpl: func(sess packetHandler) { s.sessionQueue <- sess }, - removeConnectionIDImpl: s.removeConnection, + removeConnectionIDImpl: s.sessionHandler.Remove, } } @@ -165,13 +167,13 @@ func (s *server) setupTLS() error { case tlsSession := <-sessionChan: connID := tlsSession.connID sess := tlsSession.sess - s.sessionsMutex.Lock() - if _, ok := s.sessions[string(connID)]; ok { // drop this session if it already exists - s.sessionsMutex.Unlock() + if _, ok := s.sessionHandler.Get(connID); ok { // drop this session if it already exists continue } - s.sessions[string(connID)] = sess - s.sessionsMutex.Unlock() + // TODO(#1003): There's a race condition here. + // If another connection with the same conn ID is added between Get() and Add(), it would be overwritten. + // We can avoid this be using server-chosen connection IDs. + s.sessionHandler.Add(connID, sess) go sess.run() } } @@ -288,27 +290,7 @@ func (s *server) Accept() (Session, error) { // Close the server func (s *server) Close() error { - s.sessionsMutex.Lock() - if s.closed { - s.sessionsMutex.Unlock() - return nil - } - s.closed = true - - var wg sync.WaitGroup - for _, session := range s.sessions { - if session != nil { - wg.Add(1) - go func(sess packetHandler) { - // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped - _ = sess.Close(nil) - wg.Done() - }(session) - } - } - s.sessionsMutex.Unlock() - wg.Wait() - + s.sessionHandler.Close() err := s.conn.Close() <-s.errorChan // wait for serve() to return return err @@ -359,10 +341,7 @@ func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot } } - s.sessionsMutex.RLock() - session, sessionKnown := s.sessions[string(hdr.DestConnectionID)] - s.sessionsMutex.RUnlock() - + session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID) if sessionKnown && session == nil { // Late packet for closed session return nil @@ -382,21 +361,18 @@ func (s *server) handleIETFQUICPacket(hdr *wire.Header, packetData []byte, remot } func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAddr net.Addr, rcvTime time.Time) error { - s.sessionsMutex.RLock() - session, sessionKnown := s.sessions[string(hdr.DestConnectionID)] - s.sessionsMutex.RUnlock() - - if sessionKnown && session == nil { - // Late packet for closed session - return nil - } - // ignore all Public Reset packets if hdr.ResetFlag { s.logger.Infof("Received unexpected Public Reset for connection %s.", hdr.DestConnectionID) return nil } + session, sessionKnown := s.sessionHandler.Get(hdr.DestConnectionID) + if sessionKnown && session == nil { + // Late packet for closed session + return nil + } + // If we don't have a session for this connection, and this packet cannot open a new connection, send a Public Reset // This should only happen after a server restart, when we still receive packets for connections that we lost the state for. if !sessionKnown && !hdr.VersionFlag { @@ -450,9 +426,7 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd if err != nil { return err } - s.sessionsMutex.Lock() - s.sessions[string(hdr.DestConnectionID)] = session - s.sessionsMutex.Unlock() + s.sessionHandler.Add(hdr.DestConnectionID, session) go session.run() } @@ -465,15 +439,3 @@ func (s *server) handleGQUICPacket(hdr *wire.Header, packetData []byte, remoteAd }) return nil } - -func (s *server) removeConnection(id protocol.ConnectionID) { - s.sessionsMutex.Lock() - s.sessions[string(id)] = nil - s.sessionsMutex.Unlock() - - time.AfterFunc(s.deleteClosedSessionsAfter, func() { - s.sessionsMutex.Lock() - delete(s.sessions, string(id)) - s.sessionsMutex.Unlock() - }) -} diff --git a/server_test.go b/server_test.go index 7b83ffa8..8e2df4e4 100644 --- a/server_test.go +++ b/server_test.go @@ -80,13 +80,15 @@ var _ = Describe("Server", func() { Context("with mock session", func() { var ( - serv *server - firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID) - connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} - sessions = make([]*MockPacketHandler, 0) + serv *server + firstPacket []byte // a valid first packet for a new connection with connectionID 0x4cfa9f9b668619f6 (= connID) + connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} + sessions = make([]*MockPacketHandler, 0) + sessionHandler *MockSessionHandler ) BeforeEach(func() { + sessionHandler = NewMockSessionHandler(mockCtrl) newMockSession := func( _ connection, runner sessionRunner, @@ -105,13 +107,13 @@ var _ = Describe("Server", func() { return s, nil } serv = &server{ - sessions: make(map[string]packetHandler), - newSession: newMockSession, - conn: conn, - config: config, - sessionQueue: make(chan Session, 5), - errorChan: make(chan struct{}), - logger: utils.DefaultLogger, + sessionHandler: sessionHandler, + newSession: newMockSession, + conn: conn, + config: config, + sessionQueue: make(chan Session, 5), + errorChan: make(chan struct{}), + logger: utils.DefaultLogger, } serv.setup() b := &bytes.Buffer{} @@ -139,11 +141,13 @@ var _ = Describe("Server", func() { run := make(chan struct{}) s.EXPECT().run().Do(func() { close(run) }) sessions = append(sessions, s) + + sessionHandler.EXPECT().Get(connID) + sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) { + Expect(sess.(*mockSession).connID).To(Equal(connID)) + }) err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) - sess := serv.sessions[string(connID)].(*mockSession) - Expect(sess.connID).To(Equal(connID)) Eventually(run).Should(BeClosed()) }) @@ -154,51 +158,38 @@ var _ = Describe("Server", func() { sess.EXPECT().run().Do(func() { close(run) }) err := serv.setupTLS() Expect(err).ToNot(HaveOccurred()) + sessionHandler.EXPECT().Get(connID) + sessionHandler.EXPECT().Add(connID, sess) serv.serverTLS.sessionChan <- tlsSession{ connID: connID, sess: sess, } - Eventually(func() packetHandler { - serv.sessionsMutex.Lock() - defer serv.sessionsMutex.Unlock() - return serv.sessions[string(connID)] - }).Should(Equal(sess)) Eventually(run).Should(BeClosed()) }) It("only accepts one new TLS sessions for one connection ID", func() { connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - run := make(chan struct{}) sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().run().Do(func() { close(run) }) - sess2 := NewMockPacketHandler(mockCtrl) err := serv.setupTLS() Expect(err).ToNot(HaveOccurred()) + + done := make(chan struct{}) + sessionHandler.EXPECT().Get(connID).Return(NewMockPacketHandler(mockCtrl), true).Do(func(protocol.ConnectionID) { + close(done) + }) + // don't EXPECT any calls to sessionHandler.Add serv.serverTLS.sessionChan <- tlsSession{ connID: connID, sess: sess, } - Eventually(func() packetHandler { - serv.sessionsMutex.Lock() - defer serv.sessionsMutex.Unlock() - return serv.sessions[string(connID)] - }).Should(Equal(sess)) - serv.serverTLS.sessionChan <- tlsSession{ - connID: connID, - sess: sess2, - } - Consistently(func() packetHandler { - serv.sessionsMutex.Lock() - defer serv.sessionsMutex.Unlock() - return serv.sessions[string(connID)] - }).Should(Equal(sess)) - Eventually(run).Should(BeClosed()) + Eventually(done).Should(BeClosed()) }) It("accepts a session once the connection it is forward secure", func() { s := NewMockPacketHandler(mockCtrl) s.EXPECT().handlePacket(gomock.Any()) - s.EXPECT().run() + run := make(chan struct{}) + s.EXPECT().run().Do(func() { close(run) }) sessions = append(sessions, s) done := make(chan struct{}) go func() { @@ -208,17 +199,19 @@ var _ = Describe("Server", func() { Expect(sess.(*mockSession).connID).To(Equal(connID)) close(done) }() + sessionHandler.EXPECT().Get(connID) + sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) { + Consistently(done).ShouldNot(BeClosed()) + sess.(*mockSession).runner.onHandshakeComplete(sess) + }) err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) - Consistently(done).ShouldNot(BeClosed()) - sess := serv.sessions[string(connID)].(*mockSession) - sess.runner.onHandshakeComplete(sess) Eventually(done).Should(BeClosed()) + Eventually(run).Should(BeClosed()) }) It("doesn't accept sessions that error during the handshake", func() { - run := make(chan error) + run := make(chan error, 1) sess := NewMockPacketHandler(mockCtrl) sess.EXPECT().handlePacket(gomock.Any()) sess.EXPECT().run().DoAndReturn(func() error { return <-run }) @@ -229,79 +222,44 @@ var _ = Describe("Server", func() { serv.Accept() close(done) }() + sessionHandler.EXPECT().Get(connID) + sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) { + run <- errors.New("handshake error") + }) err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) - run <- errors.New("handshake error") - serv.sessions[string(connID)].(*mockSession).runner.removeConnectionID(connID) Consistently(done).ShouldNot(BeClosed()) // make the go routine return + sessionHandler.EXPECT().Close() close(serv.errorChan) serv.Close() Eventually(done).Should(BeClosed()) }) It("assigns packets to existing sessions", func() { - run := make(chan struct{}) sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()).Times(2) - sess.EXPECT().run().Do(func() { close(run) }) - sessions = append(sessions, sess) - - err := serv.handlePacket(nil, firstPacket) - Expect(err).ToNot(HaveOccurred()) - err = serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}) - Expect(err).ToNot(HaveOccurred()) - Eventually(run).Should(BeClosed()) - }) - - It("deletes sessions", func() { - serv.deleteClosedSessionsAfter = time.Second // make sure that the nil value for the closed session doesn't get deleted in this test - serv.sessions[string(connID)] = &mockSession{} - serv.removeConnection(connID) - // The server should now have closed the session, leaving a nil value in the sessions map - Consistently(func() map[string]packetHandler { return serv.sessions }).Should(HaveLen(1)) - Expect(serv.sessions[string(connID)]).To(BeNil()) - }) - - It("deletes nil session entries after a wait time", func() { - serv.deleteClosedSessionsAfter = 25 * time.Millisecond - serv.sessions[string(connID)] = &mockSession{} - // make session.run() return - serv.removeConnection(connID) - Eventually(func() bool { - serv.sessionsMutex.Lock() - _, ok := serv.sessions[string(connID)] - serv.sessionsMutex.Unlock() - return ok - }).Should(BeFalse()) - }) - - It("closes sessions and the connection when Close is called", func() { - run := make(chan struct{}) - sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().Close(nil) sess.EXPECT().handlePacket(gomock.Any()) - sess.EXPECT().run().Do(func() { close(run) }) - sessions = append(sessions, sess) + + sessionHandler.EXPECT().Get(connID).Return(sess, true) + err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}) + Expect(err).ToNot(HaveOccurred()) + }) + + It("closes the sessionHandler and the connection when Close is called", func() { go func() { defer GinkgoRecover() serv.serve() }() - err := serv.handlePacket(nil, firstPacket) - Expect(err).ToNot(HaveOccurred()) - Eventually(run).Should(BeClosed()) // close the server + sessionHandler.EXPECT().Close().AnyTimes() Expect(serv.Close()).To(Succeed()) Expect(conn.closed).To(BeTrue()) }) It("ignores packets for closed sessions", func() { - serv.sessions[string(connID)] = nil - err := serv.handlePacket(nil, []byte{0x08, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6, 0x01}) + sessionHandler.EXPECT().Get(connID).Return(nil, true) + err := serv.handlePacket(nil, firstPacket) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) - Expect(serv.sessions[string(connID)]).To(BeNil()) }) It("works if no quic.Config is given", func(done Done) { @@ -327,49 +285,32 @@ var _ = Describe("Server", func() { Eventually(func() bool { return returned }).Should(BeTrue()) }) - It("errors when encountering a connection error", func(done Done) { + It("errors when encountering a connection error", func() { testErr := errors.New("connection error") conn.readErr = testErr - go serv.serve() - _, err := serv.Accept() - Expect(err).To(MatchError(testErr)) - Expect(serv.Close()).To(Succeed()) - close(done) - }, 0.5) - - It("closes all sessions when encountering a connection error", func() { - sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().Close(nil) - serv.sessions[string(connID)] = sess - - conn.readErr = errors.New("connection error") + sessionHandler.EXPECT().Close() done := make(chan struct{}) go func() { defer GinkgoRecover() serv.serve() close(done) }() - Expect(serv.Close()).To(Succeed()) + _, err := serv.Accept() + Expect(err).To(MatchError(testErr)) Eventually(done).Should(BeClosed()) }) It("ignores delayed packets with mismatching versions", func() { - run := make(chan struct{}) sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) // only called once - sess.EXPECT().run().Do(func() { close(run) }) - sessions = append(sessions, sess) - - err := serv.handlePacket(nil, firstPacket) - Expect(err).ToNot(HaveOccurred()) - Eventually(run).Should(BeClosed()) + // don't EXPECT any handlePacket() calls to this session + sessionHandler.EXPECT().Get(connID).Return(sess, true) b := &bytes.Buffer{} // add an unsupported version data := []byte{0x09, 0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6} utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]+1)) data = append(append(data, b.Bytes()...), 0x01) - err = serv.handlePacket(nil, data) + err := serv.handlePacket(nil, data) Expect(err).ToNot(HaveOccurred()) // if we didn't ignore the packet, the server would try to send a version negotiation packet, which would make the test panic because it doesn't have a udpConn Expect(conn.dataWritten.Bytes()).To(BeEmpty()) @@ -397,21 +338,12 @@ var _ = Describe("Server", func() { }) It("cuts packets at the payload length", func() { - run := make(chan struct{}) sess := NewMockPacketHandler(mockCtrl) - gomock.InOrder( - sess.EXPECT().handlePacket(gomock.Any()), // first packet - sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) { - Expect(packet.data).To(HaveLen(123)) - }), - ) - sess.EXPECT().run().Do(func() { close(run) }) - sessions = append(sessions, sess) + sess.EXPECT().handlePacket(gomock.Any()).Do(func(packet *receivedPacket) { + Expect(packet.data).To(HaveLen(123)) + }) serv.supportsTLS = true - err := serv.handlePacket(nil, firstPacket) - Expect(err).ToNot(HaveOccurred()) - Eventually(run).Should(BeClosed()) b := &bytes.Buffer{} hdr := &wire.Header{ IsLongHeader: true, @@ -422,7 +354,8 @@ var _ = Describe("Server", func() { Version: versionIETFFrames, } Expect(hdr.Write(b, protocol.PerspectiveClient, versionIETFFrames)).To(Succeed()) - err = serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...)) + sessionHandler.EXPECT().Get(connID).Return(sess, true) + err := serv.handlePacket(nil, append(b.Bytes(), make([]byte, 456)...)) Expect(err).ToNot(HaveOccurred()) }) @@ -443,18 +376,8 @@ var _ = Describe("Server", func() { }) It("ignores Public Resets", func() { - run := make(chan struct{}) - sess := NewMockPacketHandler(mockCtrl) - sess.EXPECT().handlePacket(gomock.Any()) // called only once - sess.EXPECT().run().Do(func() { close(run) }) - sessions = append(sessions, sess) - err := serv.handlePacket(nil, firstPacket) + err := serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337)) Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) - Eventually(run).Should(BeClosed()) - err = serv.handlePacket(nil, wire.WritePublicReset(connID, 1, 1337)) - Expect(err).ToNot(HaveOccurred()) - Expect(serv.sessions).To(HaveLen(1)) }) It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() { @@ -470,6 +393,7 @@ var _ = Describe("Server", func() { hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize)) // add a fake CHLO serv.conn = conn + sessionHandler.EXPECT().Get(connID) err := serv.handlePacket(nil, b.Bytes()) Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty()) Expect(err).ToNot(HaveOccurred()) @@ -487,6 +411,7 @@ var _ = Describe("Server", func() { hdr.Write(b, protocol.PerspectiveClient, 13 /* not a valid QUIC version */) b.Write(bytes.Repeat([]byte{0}, protocol.MinClientHelloSize-1)) // this packet is 1 byte too small serv.conn = conn + sessionHandler.EXPECT().Get(connID) err := serv.handlePacket(udpAddr, b.Bytes()) Expect(err).To(MatchError("dropping small packet with unknown version")) Expect(conn.dataWritten.Len()).Should(BeZero()) @@ -506,8 +431,7 @@ var _ = Describe("Server", func() { ln, err := Listen(conn, &tls.Config{}, &config) Expect(err).ToNot(HaveOccurred()) server := ln.(*server) - Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout)) - Expect(server.sessions).ToNot(BeNil()) + Expect(server.sessionHandler).ToNot(BeNil()) Expect(server.scfg).ToNot(BeNil()) Expect(server.config.Versions).To(Equal(supportedVersions)) Expect(server.config.HandshakeTimeout).To(Equal(1337 * time.Hour)) @@ -692,7 +616,6 @@ var _ = Describe("Server", func() { Eventually(func() int { return conn.dataWritten.Len() }).ShouldNot(BeZero()) Expect(conn.dataWrittenTo).To(Equal(udpAddr)) Expect(conn.dataWritten.Bytes()[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set - Expect(ln.(*server).sessions).To(BeEmpty()) }) }) diff --git a/session_map.go b/session_map.go new file mode 100644 index 00000000..561630b1 --- /dev/null +++ b/session_map.go @@ -0,0 +1,74 @@ +package quic + +import ( + "sync" + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" +) + +type sessionMap struct { + mutex sync.RWMutex + + sessions map[string] /* string(ConnectionID)*/ packetHandler + closed bool + + deleteClosedSessionsAfter time.Duration +} + +var _ sessionHandler = &sessionMap{} + +func newSessionMap() sessionHandler { + return &sessionMap{ + sessions: make(map[string]packetHandler), + deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, + } +} + +func (h *sessionMap) Get(id protocol.ConnectionID) (packetHandler, bool) { + h.mutex.RLock() + sess, ok := h.sessions[string(id)] + h.mutex.RUnlock() + return sess, ok +} + +func (h *sessionMap) Add(id protocol.ConnectionID, sess packetHandler) { + h.mutex.Lock() + h.sessions[string(id)] = sess + h.mutex.Unlock() +} + +func (h *sessionMap) Remove(id protocol.ConnectionID) { + h.mutex.Lock() + h.sessions[string(id)] = nil + h.mutex.Unlock() + + time.AfterFunc(h.deleteClosedSessionsAfter, func() { + h.mutex.Lock() + delete(h.sessions, string(id)) + h.mutex.Unlock() + }) +} + +func (h *sessionMap) Close() { + h.mutex.Lock() + if h.closed { + h.mutex.Unlock() + return + } + h.closed = true + + var wg sync.WaitGroup + for _, session := range h.sessions { + if session != nil { + wg.Add(1) + go func(sess packetHandler) { + // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped + _ = sess.Close(nil) + wg.Done() + }(session) + } + } + h.mutex.Unlock() + wg.Wait() +} diff --git a/session_map_test.go b/session_map_test.go new file mode 100644 index 00000000..ca931503 --- /dev/null +++ b/session_map_test.go @@ -0,0 +1,56 @@ +package quic + +import ( + "time" + + "github.com/lucas-clemente/quic-go/internal/protocol" + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Session Handler", func() { + var handler *sessionMap + + BeforeEach(func() { + handler = newSessionMap().(*sessionMap) + }) + + It("adds and gets", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5} + sess := &mockSession{} + handler.Add(connID, sess) + session, ok := handler.Get(connID) + Expect(ok).To(BeTrue()) + Expect(session).To(Equal(sess)) + }) + + It("deletes", func() { + connID := protocol.ConnectionID{1, 2, 3, 4, 5} + handler.Add(connID, &mockSession{}) + handler.Remove(connID) + session, ok := handler.Get(connID) + Expect(ok).To(BeTrue()) + Expect(session).To(BeNil()) + }) + + It("deletes nil session entries after a wait time", func() { + handler.deleteClosedSessionsAfter = 25 * time.Millisecond + connID := protocol.ConnectionID{1, 2, 3, 4, 5} + handler.Add(connID, &mockSession{}) + handler.Remove(connID) + Eventually(func() bool { + _, ok := handler.Get(connID) + return ok + }).Should(BeFalse()) + }) + + It("closes", func() { + sess1 := NewMockPacketHandler(mockCtrl) + sess1.EXPECT().Close(nil) + sess2 := NewMockPacketHandler(mockCtrl) + sess2.EXPECT().Close(nil) + handler.Add(protocol.ConnectionID{1, 1, 1, 1}, sess1) + handler.Add(protocol.ConnectionID{2, 2, 2, 2}, sess2) + handler.Close() + }) +})