diff --git a/http3/server.go b/http3/server.go index 4b7aad64a..0a6d12ffc 100644 --- a/http3/server.go +++ b/http3/server.go @@ -39,9 +39,9 @@ type Server struct { port uint32 // used atomically - listenerMutex sync.Mutex - listener quic.Listener - closed bool + mutex sync.Mutex + listeners map[*quic.Listener]struct{} + closed utils.AtomicBool supportedVersionsAsString string @@ -80,19 +80,13 @@ func (s *Server) Serve(conn net.PacketConn) error { } func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { + if s.closed.Get() { + return http.ErrServerClosed + } if s.Server == nil { return errors.New("use of http3.Server without http.Server") } s.logger = utils.DefaultLogger.WithPrefix("server") - s.listenerMutex.Lock() - if s.closed { - s.listenerMutex.Unlock() - return errors.New("Server is already closed") - } - if s.listener != nil { - s.listenerMutex.Unlock() - return errors.New("ListenAndServe may only be called once") - } if tlsConf == nil { tlsConf = &tls.Config{} @@ -121,11 +115,10 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { ln, err = quicListen(conn, tlsConf, s.QuicConfig) } if err != nil { - s.listenerMutex.Unlock() return err } - s.listener = ln - s.listenerMutex.Unlock() + s.addListener(&ln) + defer s.removeListener(&ln) for { sess, err := ln.Accept(context.Background()) @@ -136,6 +129,24 @@ func (s *Server) serveImpl(tlsConf *tls.Config, conn net.PacketConn) error { } } +// We store a pointer to interface in the map set. This is safe because we only +// call trackListener via Serve and can track+defer untrack the same pointer to +// local variable there. We never need to compare a Listener from another caller. +func (s *Server) addListener(l *quic.Listener) { + s.mutex.Lock() + if s.listeners == nil { + s.listeners = make(map[*quic.Listener]struct{}) + } + s.listeners[l] = struct{}{} + s.mutex.Unlock() +} + +func (s *Server) removeListener(l *quic.Listener) { + s.mutex.Lock() + delete(s.listeners, l) + s.mutex.Unlock() +} + func (s *Server) handleConn(sess quic.Session) { // TODO: accept control streams decoder := qpack.NewDecoder(nil) @@ -256,15 +267,18 @@ func (s *Server) handleRequest(str quic.Stream, decoder *qpack.Decoder) error { // Close the server immediately, aborting requests and sending CONNECTION_CLOSE frames to connected clients. // Close in combination with ListenAndServe() (instead of Serve()) may race if it is called before a UDP socket is established. func (s *Server) Close() error { - s.listenerMutex.Lock() - defer s.listenerMutex.Unlock() - s.closed = true - if s.listener != nil { - err := s.listener.Close() - s.listener = nil - return err + s.closed.Set(true) + + s.mutex.Lock() + defer s.mutex.Unlock() + + var err error + for ln := range s.listeners { + if cerr := (*ln).Close(); cerr != nil && err == nil { + err = cerr + } } - return nil + return err } // CloseGracefully shuts down the server gracefully. The server sends a GOAWAY frame first, then waits for either timeout to trigger, or for all running requests to complete. diff --git a/http3/server_test.go b/http3/server_test.go index 8ef958f59..3270ea337 100644 --- a/http3/server_test.go +++ b/http3/server_test.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net" "net/http" "strings" "time" @@ -25,8 +26,7 @@ import ( var _ = Describe("Server", func() { var ( - s *Server - // session *mockquic.MockSession + s *Server origQuicListenAddr = quicListenAddr ) @@ -342,10 +342,99 @@ var _ = Describe("Server", func() { Expect((&Server{}).Close()).To(Succeed()) }) - It("errors when ListenAndServer is called after Close", func() { + It("errors when ListenAndServe is called after Close", func() { serv := &Server{Server: &http.Server{}} Expect(serv.Close()).To(Succeed()) - Expect(serv.ListenAndServe()).To(MatchError("Server is already closed")) + Expect(serv.ListenAndServe()).To(MatchError(http.ErrServerClosed)) + }) + + Context("Serve", func() { + origQuicListen := quicListen + + AfterEach(func() { + quicListen = origQuicListen + }) + + It("serves a packet conn", func() { + ln := mockquic.NewMockListener(mockCtrl) + conn := &net.UDPConn{} + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) { + Expect(c).To(Equal(conn)) + return ln, nil + } + + s := &Server{Server: &http.Server{}} + s.TLSConfig = &tls.Config{} + + stopAccept := make(chan struct{}) + ln.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept + return nil, errors.New("closed") + }) + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done) + s.Serve(conn) + }() + + Consistently(done).ShouldNot(BeClosed()) + ln.EXPECT().Close().Do(func() { close(stopAccept) }) + Expect(s.Close()).To(Succeed()) + Eventually(done).Should(BeClosed()) + }) + + It("serves two packet conns", func() { + ln1 := mockquic.NewMockListener(mockCtrl) + ln2 := mockquic.NewMockListener(mockCtrl) + lns := []quic.Listener{ln1, ln2} + conn1 := &net.UDPConn{} + conn2 := &net.UDPConn{} + conns := []net.PacketConn{conn1, conn2} + quicListen = func(c net.PacketConn, tlsConf *tls.Config, config *quic.Config) (quic.Listener, error) { + conn := conns[0] + conns = conns[1:] + ln := lns[0] + lns = lns[1:] + Expect(c).To(Equal(conn)) + return ln, nil + } + + s := &Server{Server: &http.Server{}} + s.TLSConfig = &tls.Config{} + + stopAccept1 := make(chan struct{}) + ln1.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept1 + return nil, errors.New("closed") + }) + stopAccept2 := make(chan struct{}) + ln2.EXPECT().Accept(gomock.Any()).DoAndReturn(func(context.Context) (quic.Session, error) { + <-stopAccept2 + return nil, errors.New("closed") + }) + + done1 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done1) + s.Serve(conn1) + }() + done2 := make(chan struct{}) + go func() { + defer GinkgoRecover() + defer close(done2) + s.Serve(conn2) + }() + + Consistently(done1).ShouldNot(BeClosed()) + Expect(done2).ToNot(BeClosed()) + ln1.EXPECT().Close().Do(func() { close(stopAccept1) }) + ln2.EXPECT().Close().Do(func() { close(stopAccept2) }) + Expect(s.Close()).To(Succeed()) + Eventually(done1).Should(BeClosed()) + Eventually(done2).Should(BeClosed()) + }) }) Context("ListenAndServe", func() { @@ -357,20 +446,6 @@ var _ = Describe("Server", func() { Expect(s.Close()).To(Succeed()) }) - It("may only be called once", func() { - cErr := make(chan error) - for i := 0; i < 2; i++ { - go func() { - defer GinkgoRecover() - if err := s.ListenAndServe(); err != nil { - cErr <- err - } - }() - } - Eventually(cErr).Should(Receive(MatchError("ListenAndServe may only be called once"))) - Expect(s.Close()).To(Succeed()) - }) - It("uses the quic.Config to start the QUIC server", func() { conf := &quic.Config{HandshakeTimeout: time.Nanosecond} var receivedConf *quic.Config @@ -432,30 +507,6 @@ var _ = Describe("Server", func() { }) }) - Context("ListenAndServeTLS", func() { - BeforeEach(func() { - s.Server.Addr = "localhost:0" - }) - - AfterEach(func() { - Expect(s.Close()).To(Succeed()) - }) - - It("may only be called once", func() { - cErr := make(chan error) - for i := 0; i < 2; i++ { - go func() { - defer GinkgoRecover() - if err := s.ListenAndServeTLS(testdata.GetCertificatePaths()); err != nil { - cErr <- err - } - }() - } - Eventually(cErr).Should(Receive(MatchError("ListenAndServe may only be called once"))) - Expect(s.Close()).To(Succeed()) - }) - }) - It("closes gracefully", func() { Expect(s.CloseGracefully(0)).To(Succeed()) }) diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 8b443f293..cfc1863e3 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -2,6 +2,7 @@ package mocks //go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go" //go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go" +//go:generate sh -c "mockgen -package mockquic -destination quic/listener.go github.com/lucas-clemente/quic-go Listener && goimports -w quic/listener.go" //go:generate sh -c "../mockgen_internal.sh mocks short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer" //go:generate sh -c "../mockgen_internal.sh mocks short_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderOpener" //go:generate sh -c "../mockgen_internal.sh mocks long_header_opener.go github.com/lucas-clemente/quic-go/internal/handshake LongHeaderOpener" diff --git a/internal/mocks/quic/listener.go b/internal/mocks/quic/listener.go new file mode 100644 index 000000000..555f4764e --- /dev/null +++ b/internal/mocks/quic/listener.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go (interfaces: Listener) + +// Package mockquic is a generated GoMock package. +package mockquic + +import ( + context "context" + net "net" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + quic_go "github.com/lucas-clemente/quic-go" +) + +// MockListener is a mock of Listener interface +type MockListener struct { + ctrl *gomock.Controller + recorder *MockListenerMockRecorder +} + +// MockListenerMockRecorder is the mock recorder for MockListener +type MockListenerMockRecorder struct { + mock *MockListener +} + +// NewMockListener creates a new mock instance +func NewMockListener(ctrl *gomock.Controller) *MockListener { + mock := &MockListener{ctrl: ctrl} + mock.recorder = &MockListenerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockListener) EXPECT() *MockListenerMockRecorder { + return m.recorder +} + +// Accept mocks base method +func (m *MockListener) Accept(arg0 context.Context) (quic_go.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept", arg0) + ret0, _ := ret[0].(quic_go.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept +func (mr *MockListenerMockRecorder) Accept(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockListener)(nil).Accept), arg0) +} + +// Addr mocks base method +func (m *MockListener) Addr() net.Addr { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(net.Addr) + return ret0 +} + +// Addr indicates an expected call of Addr +func (mr *MockListenerMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockListener)(nil).Addr)) +} + +// Close mocks base method +func (m *MockListener) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close +func (mr *MockListenerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockListener)(nil).Close)) +}