diff --git a/benchmark_test.go b/benchmark_test.go index c4ce9485a..8c3962a86 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "math/rand" + "net" "time" "github.com/lucas-clemente/quic-go/protocol" @@ -27,33 +28,32 @@ var _ = Describe("Benchmarks", func() { Measure("transferring a file", func(b Benchmarker) { rand.Read(data) // no need to check for an error. math.Rand.Read never errors - // start the server - sconf := &Config{ - TLSConfig: testdata.GetTLSConfig(), - ConnState: func(sess Session, cs ConnState) { - if cs != ConnStateForwardSecure { - return - } + var ln Listener - defer GinkgoRecover() - str, err := sess.OpenStream() - Expect(err).ToNot(HaveOccurred()) - _, err = str.Write(data) - Expect(err).ToNot(HaveOccurred()) - err = str.Close() - Expect(err).ToNot(HaveOccurred()) - }, - } - ln, err := ListenAddr("localhost:0", sconf) - Expect(err).ToNot(HaveOccurred()) - // Serve will error as soon as ln is closed. Ignore all errors here - go ln.Serve() + serverAddr := make(chan net.Addr) + // start the server + go func() { + defer GinkgoRecover() + var err error + ln, err = ListenAddr("localhost:0", &Config{TLSConfig: testdata.GetTLSConfig()}) + Expect(err).ToNot(HaveOccurred()) + serverAddr <- ln.Addr() + sess, err := ln.Accept() + Expect(err).ToNot(HaveOccurred()) + str, err := sess.OpenStream() + Expect(err).ToNot(HaveOccurred()) + _, err = str.Write(data) + Expect(err).ToNot(HaveOccurred()) + err = str.Close() + Expect(err).ToNot(HaveOccurred()) + }() // start the client - cconf := &Config{ + conf := &Config{ TLSConfig: &tls.Config{InsecureSkipVerify: true}, } - sess, err := DialAddr(ln.Addr().String(), cconf) + addr := <-serverAddr + sess, err := DialAddr(addr.String(), conf) Expect(err).ToNot(HaveOccurred()) str, err := sess.AcceptStream() Expect(err).ToNot(HaveOccurred()) diff --git a/example/echo/echo.go b/example/echo/echo.go index d8215ad5e..520130dd9 100644 --- a/example/echo/echo.go +++ b/example/echo/echo.go @@ -33,26 +33,22 @@ func main() { func echoServer() error { cfgServer := &quic.Config{ TLSConfig: generateTLSConfig(), - ConnState: func(sess quic.Session, cs quic.ConnState) { - // Ignore unless the handshake is finished - if cs != quic.ConnStateForwardSecure { - return - } - go func() { - stream, err := sess.AcceptStream() - if err != nil { - panic(err) - } - // Echo through the loggingWriter - go io.Copy(loggingWriter{stream}, stream) - }() - }, } listener, err := quic.ListenAddr(addr, cfgServer) if err != nil { return err } - return listener.Serve() + sess, err := listener.Accept() + if err != nil { + return err + } + stream, err := sess.AcceptStream() + if err != nil { + panic(err) + } + // Echo through the loggingWriter + _, err = io.Copy(loggingWriter{stream}, stream) + return err } func clientMain() error { diff --git a/h2quic/server.go b/h2quic/server.go index a2e0f1970..f5b1e1b60 100644 --- a/h2quic/server.go +++ b/h2quic/server.go @@ -85,13 +85,7 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { config := quic.Config{ TLSConfig: tlsConfig, - ConnState: func(session quic.Session, connState quic.ConnState) { - sess := session.(streamCreator) - if connState == quic.ConnStateVersionNegotiated { - s.handleHeaderStream(sess) - } - }, - Versions: protocol.SupportedVersions, + Versions: protocol.SupportedVersions, } var ln quic.Listener @@ -107,7 +101,14 @@ func (s *Server) serveImpl(tlsConfig *tls.Config, conn *net.UDPConn) error { } s.listener = ln s.listenerMutex.Unlock() - return ln.Serve() + + for { + sess, err := ln.Accept() + if err != nil { + return err + } + go s.handleHeaderStream(sess.(streamCreator)) + } } func (s *Server) handleHeaderStream(session streamCreator) { diff --git a/interface.go b/interface.go index a1d351d17..92dd8c921 100644 --- a/interface.go +++ b/interface.go @@ -79,6 +79,6 @@ type Listener interface { Close() error // Addr returns the local network addr that the server is listening on. Addr() net.Addr - // Serve starts the main server loop, and blocks until a network error occurs or the server is closed. - Serve() error + // Accept returns new sessions. It should be called in a loop. + Accept() (Session, error) } diff --git a/server.go b/server.go index aa4ee2c77..0c2410dfd 100644 --- a/server.go +++ b/server.go @@ -34,6 +34,10 @@ type server struct { sessionsMutex sync.RWMutex deleteClosedSessionsAfter time.Duration + serverError error + sessionQueue chan Session + errorChan chan struct{} + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, config *Config) (packetHandler, error) } @@ -66,7 +70,7 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) { return nil, err } - return &server{ + s := &server{ conn: conn, config: populateServerConfig(config), certChain: certChain, @@ -74,7 +78,11 @@ func Listen(conn net.PacketConn, config *Config) (Listener, error) { sessions: map[protocol.ConnectionID]packetHandler{}, newSession: newSession, deleteClosedSessionsAfter: protocol.ClosedSessionDeleteTimeout, - }, nil + sessionQueue: make(chan Session, 5), + errorChan: make(chan struct{}), + } + go s.serve() + return s, nil } func populateServerConfig(config *Config) *Config { @@ -85,13 +93,12 @@ func populateServerConfig(config *Config) *Config { return &Config{ TLSConfig: config.TLSConfig, - ConnState: config.ConnState, Versions: versions, } } -// Listen listens on an existing PacketConn -func (s *server) Serve() error { +// serve listens on an existing PacketConn +func (s *server) serve() { for { data := getPacketBuffer() data = data[:protocol.MaxReceivePacketSize] @@ -99,14 +106,27 @@ func (s *server) Serve() error { // If it does, we only read a truncated packet, which will then end up undecryptable n, remoteAddr, err := s.conn.ReadFrom(data) if err != nil { + s.serverError = err + close(s.errorChan) _ = s.Close() - return err + return } data = data[:n] if err := s.handlePacket(s.conn, remoteAddr, data); err != nil { utils.Errorf("error handling packet: %s", err.Error()) } } + } + +// Accept returns newly openend sessions +func (s *server) Accept() (Session, error) { + var sess Session + select { + case sess = <-s.sessionQueue: + return sess, nil + case <-s.errorChan: + return nil, s.serverError + } } // Close the server @@ -212,10 +232,6 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet s.removeConnection(hdr.ConnectionID) }() - - if s.config.ConnState != nil { - go s.config.ConnState(session, ConnStateVersionNegotiated) - } } if session == nil { // Late packet for closed session @@ -231,14 +247,8 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet } func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) { - var state ConnState if isForwardSecure { - state = ConnStateForwardSecure - } else { - state = ConnStateSecure - } - if s.config.ConnState != nil { - go s.config.ConnState(session, state) + s.sessionQueue <- session } } diff --git a/server_test.go b/server_test.go index 0ef386735..4afec7c40 100644 --- a/server_test.go +++ b/server_test.go @@ -87,10 +87,12 @@ var _ = Describe("Server", func() { BeforeEach(func() { serv = &server{ - sessions: make(map[protocol.ConnectionID]packetHandler), - newSession: newMockSession, - conn: conn, - config: config, + sessions: make(map[protocol.ConnectionID]packetHandler), + newSession: newMockSession, + conn: conn, + config: config, + sessionQueue: make(chan Session, 5), + errorChan: make(chan struct{}), } b := &bytes.Buffer{} utils.WriteUint32(b, protocol.VersionNumberToTag(protocol.SupportedVersions[0])) @@ -115,56 +117,29 @@ var _ = Describe("Server", func() { }) It("creates new sessions", func() { - var connStateCalled bool - var connStateStatus ConnState - var connStateSession Session - config.ConnState = func(s Session, state ConnState) { - connStateStatus = state - connStateSession = s - connStateCalled = true - } err := serv.handlePacket(nil, nil, firstPacket) Expect(err).ToNot(HaveOccurred()) Expect(serv.sessions).To(HaveLen(1)) sess := serv.sessions[connID].(*mockSession) Expect(sess.connectionID).To(Equal(connID)) Expect(sess.packetCount).To(Equal(1)) - Eventually(func() bool { return connStateCalled }).Should(BeTrue()) - Expect(connStateSession).To(Equal(sess)) - Expect(connStateStatus).To(Equal(ConnStateVersionNegotiated)) }) - It("calls the ConnState callback when the connection is secure", func() { - var connStateCalled bool - var connStateStatus ConnState - var connStateSession Session - config.ConnState = func(s Session, state ConnState) { - connStateStatus = state - connStateSession = s - connStateCalled = true - } - sess := &mockSession{} - serv.cryptoChangeCallback(sess, false) - Eventually(func() bool { return connStateCalled }).Should(BeTrue()) - Expect(connStateSession).To(Equal(sess)) - Expect(connStateStatus).To(Equal(ConnStateSecure)) - }) - - It("calls the ConnState callback when the connection is forward-secure", func() { - var connStateCalled bool - var connStateStatus ConnState - var connStateSession Session - config.ConnState = func(s Session, state ConnState) { - connStateStatus = state - connStateSession = s - connStateCalled = true - } + It("accepts a session once the connection it is forward secure", func(done Done) { + var acceptedSess Session + go func() { + defer GinkgoRecover() + var err error + acceptedSess, err = serv.Accept() + Expect(err).ToNot(HaveOccurred()) + }() sess := &mockSession{} + // serv.cryptoChangeCallback(sess, false) + // Consistently(func() Session { return acceptedSess }).Should(BeNil()) serv.cryptoChangeCallback(sess, true) - Eventually(func() bool { return connStateCalled }).Should(BeTrue()) - Expect(connStateStatus).To(Equal(ConnStateForwardSecure)) - Expect(connStateSession).To(Equal(sess)) - }) + Eventually(func() Session { return acceptedSess }).Should(Equal(sess)) + close(done) + }, 0.5) It("assigns packets to existing sessions", func() { err := serv.handlePacket(nil, nil, firstPacket) @@ -231,7 +206,7 @@ var _ = Describe("Server", func() { var returned bool go func() { defer GinkgoRecover() - err := ln.Serve() + _, err := ln.Accept() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("use of closed network connection")) returned = true @@ -240,12 +215,15 @@ var _ = Describe("Server", func() { Eventually(func() bool { return returned }).Should(BeTrue()) }) - It("errors when encountering a connection error", func() { + It("errors when encountering a connection error", func(done Done) { testErr := errors.New("connection error") conn.readErr = testErr - err := serv.Serve() + go serv.serve() + _, err := serv.Accept() Expect(err).To(MatchError(testErr)) - }) + Expect(serv.Close()).To(Succeed()) + close(done) + }, 0.5) It("closes all sessions when encountering a connection error", func() { err := serv.handlePacket(nil, nil, firstPacket) @@ -254,8 +232,9 @@ var _ = Describe("Server", func() { Expect(serv.sessions[connID].(*mockSession).closed).To(BeFalse()) testErr := errors.New("connection error") conn.readErr = testErr - _ = serv.Serve() - Expect(serv.sessions[connID].(*mockSession).closed).To(BeTrue()) + go serv.serve() + Eventually(func() bool { return serv.sessions[connID].(*mockSession).closed }).Should(BeTrue()) + Expect(serv.Close()).To(Succeed()) }) It("ignores delayed packets with mismatching versions", func() { @@ -324,20 +303,17 @@ var _ = Describe("Server", func() { }) It("setups with the right values", func() { - var connStateCallback ConnStateCallback = func(_ Session, _ ConnState) {} supportedVersions := []protocol.VersionNumber{1, 3, 5} config := Config{ TLSConfig: &tls.Config{}, - ConnState: connStateCallback, Versions: supportedVersions, } ln, err := Listen(conn, &config) - Expect(err).ToNot(HaveOccurred()) server := ln.(*server) + Expect(err).ToNot(HaveOccurred()) Expect(server.deleteClosedSessionsAfter).To(Equal(protocol.ClosedSessionDeleteTimeout)) Expect(server.sessions).ToNot(BeNil()) Expect(server.scfg).ToNot(BeNil()) - Expect(server.config.ConnState).ToNot(BeNil()) Expect(server.config.Versions).To(Equal(supportedVersions)) }) @@ -387,7 +363,7 @@ var _ = Describe("Server", func() { var returned bool go func() { - ln.Serve() + ln.Accept() returned = true }() @@ -400,7 +376,7 @@ var _ = Describe("Server", func() { b.Bytes()..., ) Expect(conn.dataWritten.Bytes()).To(Equal(expected)) - Expect(returned).To(BeFalse()) + Consistently(func() bool { return returned }).Should(BeFalse()) }) It("sends a PublicReset for new connections that don't have the VersionFlag set", func() { @@ -410,7 +386,7 @@ var _ = Describe("Server", func() { Expect(err).ToNot(HaveOccurred()) go func() { defer GinkgoRecover() - err := ln.Serve() + _, err := ln.Accept() Expect(err).ToNot(HaveOccurred()) }()