diff --git a/benchmark_test.go b/benchmark_test.go index 979c622f..835692d6 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -43,7 +43,7 @@ func newLinkedConnection(other *Session) *linkedConnection { Expect(err).NotTo(HaveOccurred()) } hdr.Raw = packet[:len(packet)-r.Len()] - conn.other.handlePacket(nil, hdr, packet[len(packet)-r.Len():]) + conn.other.handlePacket(&receivedPacket{publicHeader: hdr, data: packet[len(packet)-r.Len():]}) } }() return conn diff --git a/server.go b/server.go index d2aeb784..9b52288f 100644 --- a/server.go +++ b/server.go @@ -16,7 +16,7 @@ import ( // packetHandler handles packets type packetHandler interface { - handlePacket(addr interface{}, hdr *PublicHeader, data []byte) + handlePacket(*receivedPacket) run() Close(error) error } @@ -171,7 +171,11 @@ func (s *Server) handlePacket(conn *net.UDPConn, remoteAddr *net.UDPAddr, packet // Late packet for closed session return nil } - session.handlePacket(remoteAddr, hdr, packet[len(packet)-r.Len():]) + session.handlePacket(&receivedPacket{ + remoteAddr: remoteAddr, + publicHeader: hdr, + data: packet[len(packet)-r.Len():], + }) return nil } diff --git a/server_test.go b/server_test.go index 138af02d..88d0c85a 100644 --- a/server_test.go +++ b/server_test.go @@ -20,7 +20,7 @@ type mockSession struct { closed bool } -func (s *mockSession) handlePacket(addr interface{}, hdr *PublicHeader, data []byte) { +func (s *mockSession) handlePacket(*receivedPacket) { s.packetCount++ } diff --git a/session.go b/session.go index 2f380e96..b1c7089c 100644 --- a/session.go +++ b/session.go @@ -61,14 +61,14 @@ type Session struct { cryptoSetup *handshake.CryptoSetup - receivedPackets chan receivedPacket + receivedPackets chan *receivedPacket sendingScheduled chan struct{} // closeChan is used to notify the run loop that it should terminate. // If the value is not nil, the error is sent as a CONNECTION_CLOSE. closeChan chan *qerr.QuicError closed uint32 // atomic bool - undecryptablePackets []receivedPacket + undecryptablePackets []*receivedPacket aeadChanged chan struct{} delayedAckOriginTime time.Time @@ -107,11 +107,11 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol sentPacketHandler: sentPacketHandler, receivedPacketHandler: receivedPacketHandler, flowControlManager: flowControlManager, - receivedPackets: make(chan receivedPacket, protocol.MaxSessionUnprocessedPackets), + receivedPackets: make(chan *receivedPacket, protocol.MaxSessionUnprocessedPackets), closeChan: make(chan *qerr.QuicError, 1), sendingScheduled: make(chan struct{}, 1), connectionParametersManager: connectionParametersManager, - undecryptablePackets: make([]receivedPacket, 0, protocol.MaxUndecryptablePackets), + undecryptablePackets: make([]*receivedPacket, 0, protocol.MaxUndecryptablePackets), aeadChanged: make(chan struct{}, 1), timer: time.NewTimer(0), lastNetworkActivityTime: time.Now(), @@ -170,7 +170,7 @@ func (s *Session) run() { // We do all the interesting stuff after the switch statement, so // nothing to see here. case p := <-s.receivedPackets: - err = s.handlePacketImpl(p.remoteAddr, p.publicHeader, p.data) + err = s.handlePacketImpl(p) if qErr, ok := err.(*qerr.QuicError); ok && qErr.ErrorCode == qerr.DecryptionFailure { s.tryQueueingUndecryptablePacket(p) continue @@ -225,8 +225,10 @@ func (s *Session) maybeResetTimer() { s.currentDeadline = nextDeadline } -func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *PublicHeader, data []byte) error { +func (s *Session) handlePacketImpl(p *receivedPacket) error { s.lastNetworkActivityTime = time.Now() + hdr := p.publicHeader + data := p.data // Calculate packet number hdr.PacketNumber = protocol.InferPacketNumber( @@ -239,7 +241,7 @@ func (s *Session) handlePacketImpl(remoteAddr interface{}, hdr *PublicHeader, da } // TODO: Only do this after authenticating - s.conn.setCurrentRemoteAddr(remoteAddr) + s.conn.setCurrentRemoteAddr(p.remoteAddr) packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data) if err != nil { @@ -312,12 +314,12 @@ func (s *Session) handleFrames(fs []frames.Frame) error { return nil } -// handlePacket handles a packet -func (s *Session) handlePacket(remoteAddr interface{}, hdr *PublicHeader, data []byte) { +// handlePacket is called by the server with a new packet +func (s *Session) handlePacket(p *receivedPacket) { // Discard packets once the amount of queued packets is larger than // the channel size, protocol.MaxSessionUnprocessedPackets select { - case s.receivedPackets <- receivedPacket{remoteAddr: remoteAddr, publicHeader: hdr, data: data}: + case s.receivedPackets <- p: default: } } @@ -611,7 +613,7 @@ func (s *Session) scheduleSending() { } } -func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) { +func (s *Session) tryQueueingUndecryptablePacket(p *receivedPacket) { if s.cryptoSetup.HandshakeComplete() { return } @@ -624,7 +626,7 @@ func (s *Session) tryQueueingUndecryptablePacket(p receivedPacket) { func (s *Session) tryDecryptingQueuedPackets() { for _, p := range s.undecryptablePackets { - s.handlePacket(p.remoteAddr, p.publicHeader, p.data) + s.handlePacket(p) } s.undecryptablePackets = s.undecryptablePackets[:0] } diff --git a/session_test.go b/session_test.go index c13ff0c9..5f8963f6 100644 --- a/session_test.go +++ b/session_test.go @@ -459,7 +459,7 @@ var _ = Describe("Session", func() { It("sets the {last,largest}RcvdPacketNumber", func() { hdr.PacketNumber = 5 - err := session.handlePacketImpl(nil, hdr, nil) + err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) @@ -467,12 +467,12 @@ var _ = Describe("Session", func() { It("sets the {last,largest}RcvdPacketNumber, for an out-of-order packet", func() { hdr.PacketNumber = 5 - err := session.handlePacketImpl(nil, hdr, nil) + err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) hdr.PacketNumber = 3 - err = session.handlePacketImpl(nil, hdr, nil) + err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) Expect(session.lastRcvdPacketNumber).To(Equal(protocol.PacketNumber(3))) Expect(session.largestRcvdPacketNumber).To(Equal(protocol.PacketNumber(5))) @@ -480,9 +480,9 @@ var _ = Describe("Session", func() { It("ignores duplicate packets", func() { hdr.PacketNumber = 5 - err := session.handlePacketImpl(nil, hdr, nil) + err := session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) - err = session.handlePacketImpl(nil, hdr, nil) + err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) }) @@ -490,7 +490,7 @@ var _ = Describe("Session", func() { err := session.receivedPacketHandler.ReceivedStopWaiting(&frames.StopWaitingFrame{LeastUnacked: 10}) Expect(err).ToNot(HaveOccurred()) hdr.PacketNumber = 5 - err = session.handlePacketImpl(nil, hdr, nil) + err = session.handlePacketImpl(&receivedPacket{publicHeader: hdr}) Expect(err).ToNot(HaveOccurred()) }) }) @@ -717,7 +717,7 @@ var _ = Describe("Session", func() { hdr := &PublicHeader{ PacketNumber: protocol.PacketNumber(i + 1), } - session.handlePacket(nil, hdr, []byte("foobar")) + session.handlePacket(&receivedPacket{publicHeader: hdr, data: []byte("foobar")}) } session.run() @@ -731,7 +731,7 @@ var _ = Describe("Session", func() { hdr := &PublicHeader{ PacketNumber: protocol.PacketNumber(i + 1), } - session.handlePacket(nil, hdr, []byte("foobar")) + session.handlePacket(&receivedPacket{publicHeader: hdr, data: []byte("foobar")}) } go session.run() Consistently(session.undecryptablePackets).Should(HaveLen(0)) @@ -739,10 +739,8 @@ var _ = Describe("Session", func() { }) It("unqueues undecryptable packets for later decryption", func() { - session.undecryptablePackets = []receivedPacket{{ - nil, - &PublicHeader{PacketNumber: protocol.PacketNumber(42)}, - nil, + session.undecryptablePackets = []*receivedPacket{{ + publicHeader: &PublicHeader{PacketNumber: protocol.PacketNumber(42)}, }} Expect(session.receivedPackets).NotTo(Receive()) session.tryDecryptingQueuedPackets() @@ -775,7 +773,7 @@ var _ = Describe("Session", func() { It("stores up to MaxSessionUnprocessedPackets packets", func(done Done) { // Nothing here should block for i := protocol.PacketNumber(0); i < protocol.MaxSessionUnprocessedPackets+10; i++ { - session.handlePacket(nil, nil, nil) + session.handlePacket(&receivedPacket{}) } close(done) }, 0.5)