diff --git a/server.go b/server.go index f5b04549..4043c58c 100644 --- a/server.go +++ b/server.go @@ -98,10 +98,12 @@ type baseServer struct { protocol.VersionNumber, ) quicConn - serverError error - errorChan chan struct{} - closed bool - running chan struct{} // closed as soon as run() returns + serverError error + errorChan chan struct{} + closed bool + running chan struct{} // closed as soon as run() returns + versionNegotiationQueue chan *receivedPacket + invalidTokenQueue chan *receivedPacket connQueue chan quicConn connQueueLen int32 // to be used as an atomic @@ -229,26 +231,29 @@ func newServer( return nil, err } s := &baseServer{ - conn: conn, - tlsConf: tlsConf, - config: config, - tokenGenerator: tokenGenerator, - connIDGenerator: connIDGenerator, - connHandler: connHandler, - connQueue: make(chan quicConn), - errorChan: make(chan struct{}), - running: make(chan struct{}), - receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), - newConn: newConnection, - tracer: tracer, - logger: utils.DefaultLogger.WithPrefix("server"), - acceptEarlyConns: acceptEarly, - onClose: onClose, + conn: conn, + tlsConf: tlsConf, + config: config, + tokenGenerator: tokenGenerator, + connIDGenerator: connIDGenerator, + connHandler: connHandler, + connQueue: make(chan quicConn), + errorChan: make(chan struct{}), + running: make(chan struct{}), + receivedPackets: make(chan *receivedPacket, protocol.MaxServerUnprocessedPackets), + versionNegotiationQueue: make(chan *receivedPacket, 4), + invalidTokenQueue: make(chan *receivedPacket, 4), + newConn: newConnection, + tracer: tracer, + logger: utils.DefaultLogger.WithPrefix("server"), + acceptEarlyConns: acceptEarly, + onClose: onClose, } if acceptEarly { s.zeroRTTQueues = map[protocol.ConnectionID]*zeroRTTQueue{} } go s.run() + go s.runSendQueue() s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String()) return s, nil } @@ -272,6 +277,19 @@ func (s *baseServer) run() { } } +func (s *baseServer) runSendQueue() { + for { + select { + case <-s.running: + return + case p := <-s.versionNegotiationQueue: + s.maybeSendVersionNegotiationPacket(p) + case p := <-s.invalidTokenQueue: + s.maybeSendInvalidToken(p) + } + } +} + // Accept returns connections that already completed the handshake. // It is only valid if acceptEarlyConns is false. func (s *baseServer) Accept(ctx context.Context) (Connection, error) { @@ -353,28 +371,29 @@ func (s *baseServer) handlePacketImpl(p *receivedPacket) bool /* is the buffer s panic(fmt.Sprintf("misrouted packet: %#v", p.data)) } v, err := wire.ParseVersion(p.data) - // send a Version Negotiation Packet if the client is speaking a different protocol version - if err != nil || !protocol.IsSupportedVersion(s.config.Versions, v) { - if err != nil || p.Size() < protocol.MinUnknownVersionPacketSize { - s.logger.Debugf("Dropping a packet with an unknown version that is too small (%d bytes)", p.Size()) - if s.tracer != nil { - s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) - if err != nil { // should never happen - s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") - if s.tracer != nil { - s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) - } - return false - } - if !s.config.DisableVersionNegotiationPackets { - go s.sendVersionNegotiationPacket(p.remoteAddr, src, dest, p.info.OOB(), v) + // drop the packet if we failed to parse the protocol version + if err != nil { + s.logger.Debugf("Dropping a packet with an unknown version") + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) } return false } + // send a Version Negotiation Packet if the client is speaking a different protocol version + if !protocol.IsSupportedVersion(s.config.Versions, v) { + if s.config.DisableVersionNegotiationPackets { + return false + } + + if p.Size() < protocol.MinUnknownVersionPacketSize { + s.logger.Debugf("Dropping a packet with an unsupported version number %d that is too small (%d bytes)", v, p.Size()) + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return false + } + return s.enqueueVersionNegotiationPacket(p) + } if wire.Is0RTTPacket(p.data) { if !s.acceptEarlyConns { @@ -557,12 +576,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro // For Retry tokens, we send an INVALID_ERROR if // * the token is too old, or // * the token is invalid, in case of a retry token. - go func() { - defer p.buffer.Release() - if err := s.maybeSendInvalidToken(p, hdr); err != nil { - s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) - } - }() + s.enqueueInvalidToken(p) return nil } } @@ -730,7 +744,27 @@ func (s *baseServer) sendRetry(remoteAddr net.Addr, hdr *wire.Header, info *pack return err } -func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) error { +func (s *baseServer) enqueueInvalidToken(p *receivedPacket) { + select { + case s.invalidTokenQueue <- p: + default: + // it's fine to drop INVALID_TOKEN packets when we are busy + p.buffer.Release() + } +} + +func (s *baseServer) maybeSendInvalidToken(p *receivedPacket) { + defer p.buffer.Release() + + hdr, _, _, err := wire.ParsePacket(p.data) + if err != nil { + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropHeaderParseError) + } + s.logger.Debugf("Error parsing packet: %s", err) + return + } + // Only send INVALID_TOKEN if we can unprotect the packet. // This makes sure that we won't send it for packets that were corrupted. sealer, opener := handshake.NewInitialAEAD(hdr.DestConnectionID, protocol.PerspectiveServer, hdr.Version) @@ -740,21 +774,21 @@ func (s *baseServer) maybeSendInvalidToken(p *receivedPacket, hdr *wire.Header) if s.tracer != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropHeaderParseError) } - // don't return the error here. Just drop the packet. - return nil + return } hdrLen := extHdr.ParsedLen() if _, err := opener.Open(data[hdrLen:hdrLen], data[hdrLen:], extHdr.PacketNumber, data[:hdrLen]); err != nil { - // don't return the error here. Just drop the packet. if s.tracer != nil { s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeInitial, p.Size(), logging.PacketDropPayloadDecryptError) } - return nil + return } if s.logger.Debug() { s.logger.Debugf("Client sent an invalid retry token. Sending INVALID_TOKEN to %s.", p.remoteAddr) } - return s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info) + if err := s.sendError(p.remoteAddr, hdr, sealer, qerr.InvalidToken, p.info); err != nil { + s.logger.Debugf("Error sending INVALID_TOKEN error: %s", err) + } } func (s *baseServer) sendConnectionRefused(remoteAddr net.Addr, hdr *wire.Header, info *packetInfo) error { @@ -807,14 +841,41 @@ func (s *baseServer) sendError(remoteAddr net.Addr, hdr *wire.Header, sealer han return err } -func (s *baseServer) sendVersionNegotiationPacket(remote net.Addr, src, dest protocol.ArbitraryLenConnectionID, oob []byte, v protocol.VersionNumber) { +func (s *baseServer) enqueueVersionNegotiationPacket(p *receivedPacket) (bufferInUse bool) { + select { + case s.versionNegotiationQueue <- p: + return true + default: + // it's fine to not send version negotiation packets when we are busy + } + return false +} + +func (s *baseServer) maybeSendVersionNegotiationPacket(p *receivedPacket) { + defer p.buffer.Release() + + v, err := wire.ParseVersion(p.data) + if err != nil { + s.logger.Debugf("failed to parse version for sending version negotiation packet: %s", err) + return + } + + _, src, dest, err := wire.ParseArbitraryLenConnectionIDs(p.data) + if err != nil { // should never happen + s.logger.Debugf("Dropping a packet with an unknown version for which we failed to parse connection IDs") + if s.tracer != nil { + s.tracer.DroppedPacket(p.remoteAddr, logging.PacketTypeNotDetermined, p.Size(), logging.PacketDropUnexpectedPacket) + } + return + } + s.logger.Debugf("Client offered version %s, sending Version Negotiation", v) data := wire.ComposeVersionNegotiation(dest, src, s.config.Versions) if s.tracer != nil { - s.tracer.SentVersionNegotiationPacket(remote, src, dest, s.config.Versions) + s.tracer.SentVersionNegotiationPacket(p.remoteAddr, src, dest, s.config.Versions) } - if _, err := s.conn.WritePacket(data, remote, oob); err != nil { + if _, err := s.conn.WritePacket(data, p.remoteAddr, p.info.OOB()); err != nil { s.logger.Debugf("Error sending Version Negotiation: %s", err) } }