diff --git a/client.go b/client.go index 84ad19ac..54a2231e 100644 --- a/client.go +++ b/client.go @@ -288,7 +288,7 @@ func (c *client) establishSecureConnection(ctx context.Context) error { func (c *client) handlePacket(p *receivedPacket) { if wire.IsVersionNegotiationPacket(p.data) { - go c.handleVersionNegotiationPacket(p.hdr) + go c.handleVersionNegotiationPacket(p) return } @@ -301,10 +301,16 @@ func (c *client) handlePacket(p *receivedPacket) { c.session.handlePacket(p) } -func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) { +func (c *client) handleVersionNegotiationPacket(p *receivedPacket) { c.mutex.Lock() defer c.mutex.Unlock() + hdr, _, _, err := wire.ParsePacket(p.data, 0) + if err != nil { + c.logger.Debugf("Error parsing Version Negotiation packet: %s", err) + return + } + // ignore delayed / duplicated version negotiation packets if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() { c.logger.Debugf("Received a delayed Version Negotiation packet.") @@ -403,6 +409,6 @@ func (c *client) GetVersion() protocol.VersionNumber { return v } -func (c *client) GetPerspective() protocol.Perspective { +func (c *client) getPerspective() protocol.Perspective { return protocol.PerspectiveClient } diff --git a/client_test.go b/client_test.go index b3581753..89358c85 100644 --- a/client_test.go +++ b/client_test.go @@ -58,12 +58,9 @@ var _ = Describe("Client", func() { composeVersionNegotiationPacket := func(connID protocol.ConnectionID, versions []protocol.VersionNumber) *receivedPacket { data, err := wire.ComposeVersionNegotiation(connID, nil, versions) Expect(err).ToNot(HaveOccurred()) - hdr, _, _, err := wire.ParsePacket(data, 0) - Expect(err).ToNot(HaveOccurred()) Expect(wire.IsVersionNegotiationPacket(data)).To(BeTrue()) return &receivedPacket{ rcvTime: time.Now(), - hdr: hdr, data: data, } } @@ -543,19 +540,22 @@ var _ = Describe("Client", func() { Expect(err).To(MatchError(testErr)) }) - It("recognizes that a non version negotiation packet means that the server accepted the suggested version", func() { + It("recognizes that a non Version Negotiation packet means that the server accepted the suggested version", func() { sess := NewMockQuicSession(mockCtrl) sess.EXPECT().handlePacket(gomock.Any()) cl.session = sess cl.config = &Config{} - cl.handlePacket(&receivedPacket{ - hdr: &wire.Header{ + buf := &bytes.Buffer{} + Expect((&wire.ExtendedHeader{ + Header: wire.Header{ DestConnectionID: connID, SrcConnectionID: connID, Version: cl.version, }, - }) - Eventually(cl.versionNegotiated.Get()).Should(BeTrue()) + PacketNumberLen: protocol.PacketNumberLen3, + }).Write(buf, protocol.VersionTLS)).To(Succeed()) + cl.handlePacket(&receivedPacket{data: buf.Bytes()}) + Eventually(cl.versionNegotiated.Get).Should(BeTrue()) }) It("errors if no matching version is found", func() { diff --git a/mock_packet_handler_test.go b/mock_packet_handler_test.go index 4425bcb9..812e7dcc 100644 --- a/mock_packet_handler_test.go +++ b/mock_packet_handler_test.go @@ -46,18 +46,6 @@ func (mr *MockPacketHandlerMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPacketHandler)(nil).Close)) } -// GetPerspective mocks base method -func (m *MockPacketHandler) GetPerspective() protocol.Perspective { - ret := m.ctrl.Call(m, "GetPerspective") - ret0, _ := ret[0].(protocol.Perspective) - return ret0 -} - -// GetPerspective indicates an expected call of GetPerspective -func (mr *MockPacketHandlerMockRecorder) GetPerspective() *gomock.Call { - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPerspective", reflect.TypeOf((*MockPacketHandler)(nil).GetPerspective)) -} - // destroy mocks base method func (m *MockPacketHandler) destroy(arg0 error) { m.ctrl.Call(m, "destroy", arg0) @@ -68,6 +56,18 @@ func (mr *MockPacketHandlerMockRecorder) destroy(arg0 interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockPacketHandler)(nil).destroy), arg0) } +// getPerspective mocks base method +func (m *MockPacketHandler) getPerspective() protocol.Perspective { + ret := m.ctrl.Call(m, "getPerspective") + ret0, _ := ret[0].(protocol.Perspective) + return ret0 +} + +// getPerspective indicates an expected call of getPerspective +func (mr *MockPacketHandlerMockRecorder) getPerspective() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockPacketHandler)(nil).getPerspective)) +} + // handlePacket mocks base method func (m *MockPacketHandler) handlePacket(arg0 *receivedPacket) { m.ctrl.Call(m, "handlePacket", arg0) diff --git a/mock_quic_session_test.go b/mock_quic_session_test.go index 226d3beb..201b1495 100644 --- a/mock_quic_session_test.go +++ b/mock_quic_session_test.go @@ -231,6 +231,18 @@ func (mr *MockQuicSessionMockRecorder) destroy(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "destroy", reflect.TypeOf((*MockQuicSession)(nil).destroy), arg0) } +// getPerspective mocks base method +func (m *MockQuicSession) getPerspective() protocol.Perspective { + ret := m.ctrl.Call(m, "getPerspective") + ret0, _ := ret[0].(protocol.Perspective) + return ret0 +} + +// getPerspective indicates an expected call of getPerspective +func (mr *MockQuicSessionMockRecorder) getPerspective() *gomock.Call { + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "getPerspective", reflect.TypeOf((*MockQuicSession)(nil).getPerspective)) +} + // handlePacket mocks base method func (m *MockQuicSession) handlePacket(arg0 *receivedPacket) { m.ctrl.Call(m, "handlePacket", arg0) diff --git a/packet_handler_map.go b/packet_handler_map.go index eb79316c..f3e1d019 100644 --- a/packet_handler_map.go +++ b/packet_handler_map.go @@ -2,7 +2,6 @@ package quic import ( "errors" - "fmt" "net" "sync" "time" @@ -105,7 +104,7 @@ func (h *packetHandlerMap) CloseServer() { var wg sync.WaitGroup for id, handlerEntry := range h.handlers { handler := handlerEntry.handler - if handler.GetPerspective() == protocol.PerspectiveServer { + if handler.getPerspective() == protocol.PerspectiveServer { wg.Add(1) go func(id string, handler packetHandler) { // session.Close() blocks until the CONNECTION_CLOSE has been sent and the run-loop has stopped @@ -174,93 +173,46 @@ func (h *packetHandlerMap) handlePacket( buffer *packetBuffer, data []byte, ) { - packets, err := h.parsePacket(addr, buffer, data) + connID, err := wire.ParseConnectionID(data, h.connIDLen) if err != nil { - h.logger.Debugf("error parsing packets from %s: %s", addr, err) - // This is just the error from parsing the last packet. - // We still need to process the packets that were successfully parsed before. - } - if len(packets) == 0 { - buffer.Release() + h.logger.Debugf("error parsing connection ID on packet from %s: %s", addr, err) return } - h.handleParsedPackets(packets) -} - -func (h *packetHandlerMap) parsePacket( - addr net.Addr, - buffer *packetBuffer, - data []byte, -) ([]*receivedPacket, error) { rcvTime := time.Now() - packets := make([]*receivedPacket, 0, 1) - var counter int - var lastConnID protocol.ConnectionID - for len(data) > 0 { - hdr, packetData, rest, err := wire.ParsePacket(data, h.connIDLen) - if err != nil { - return packets, fmt.Errorf("error parsing packet: %s", err) - } - - if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { - return packets, fmt.Errorf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) - } - lastConnID = hdr.DestConnectionID - - if counter > 0 { - buffer.Split() - } - counter++ - packets = append(packets, &receivedPacket{ - remoteAddr: addr, - hdr: hdr, - rcvTime: rcvTime, - data: packetData, - buffer: buffer, - }) - - // only log if this actually a coalesced packet - if h.logger.Debug() && (counter > 1 || len(rest) > 0) { - h.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packets[counter-1].data), len(rest)) - } - - data = rest - } - return packets, nil -} - -func (h *packetHandlerMap) handleParsedPackets(packets []*receivedPacket) { h.mutex.RLock() defer h.mutex.RUnlock() - // coalesced packets all have the same destination connection ID - handlerEntry, handlerFound := h.handlers[string(packets[0].hdr.DestConnectionID)] + handlerEntry, handlerFound := h.handlers[string(connID)] - for _, p := range packets { - if handlerFound { // existing session - handlerEntry.handler.handlePacket(p) - continue - } - // No session found. - // This might be a stateless reset. - if !p.hdr.IsLongHeader { - if len(p.data) >= protocol.MinStatelessResetSize { - var token [16]byte - copy(token[:], p.data[len(p.data)-16:]) - if sess, ok := h.resetTokens[token]; ok { - sess.destroy(errors.New("received a stateless reset")) - continue - } - } - // TODO(#943): send a stateless reset - h.logger.Debugf("received a short header packet with an unexpected connection ID %s", p.hdr.DestConnectionID) - break // a short header packet is always the last in a coalesced packet - } - if h.server == nil { // no server set - h.logger.Debugf("received a packet with an unexpected connection ID %s", p.hdr.DestConnectionID) - continue - } - h.server.handlePacket(p) + p := &receivedPacket{ + remoteAddr: addr, + rcvTime: rcvTime, + buffer: buffer, + data: data, } + if handlerFound { // existing session + handlerEntry.handler.handlePacket(p) + return + } + // No session found. + // This might be a stateless reset. + if data[0]&0x80 == 0 { // stateless resets are always short header packets + if len(p.data) >= protocol.MinStatelessResetSize { + var token [16]byte + copy(token[:], p.data[len(p.data)-16:]) + if sess, ok := h.resetTokens[token]; ok { + sess.destroy(errors.New("received a stateless reset")) + return + } + } + // TODO(#943): send a stateless reset + h.logger.Debugf("received a short header packet with an unexpected connection ID %s", connID) + return + } + if h.server == nil { // no server set + h.logger.Debugf("received a packet with an unexpected connection ID %s", connID) + return + } + h.server.handlePacket(p) } diff --git a/packet_handler_map_test.go b/packet_handler_map_test.go index 8f0b4e8c..2f1e03a8 100644 --- a/packet_handler_map_test.go +++ b/packet_handler_map_test.go @@ -3,7 +3,6 @@ package quic import ( "bytes" "errors" - "net" "time" "github.com/golang/mock/gomock" @@ -88,11 +87,15 @@ var _ = Describe("Packet Handler Map", func() { handledPacket1 := make(chan struct{}) handledPacket2 := make(chan struct{}) packetHandler1.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.hdr.DestConnectionID).To(Equal(connID1)) + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID1)) close(handledPacket1) }) packetHandler2.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.hdr.DestConnectionID).To(Equal(connID2)) + connID, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(connID).To(Equal(connID2)) close(handledPacket2) }) handler.Add(connID1, packetHandler1) @@ -105,12 +108,10 @@ var _ = Describe("Packet Handler Map", func() { }) It("drops unparseable packets", func() { - _, err := handler.parsePacket(nil, nil, []byte{0, 1, 2, 3}) - Expect(err).To(HaveOccurred()) - Expect(err.Error()).To(ContainSubstring("error parsing packet:")) + handler.handlePacket(nil, nil, []byte{0, 1, 2, 3}) }) - It("deletes removed session immediately", func() { + It("deletes removed sessions immediately", func() { handler.deleteRetiredSessionsAfter = time.Hour connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} handler.Add(connID, NewMockPacketHandler(mockCtrl)) @@ -159,64 +160,6 @@ var _ = Describe("Packet Handler Map", func() { conn.Close() Eventually(done).Should(BeClosed()) }) - - Context("coalesced packets", func() { - It("cuts packets to the right length", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - data := append(getPacketWithLength(connID, 456), make([]byte, 1000)...) - packetHandler := NewMockPacketHandler(mockCtrl) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.data).To(HaveLen(456 + int(p.hdr.ParsedLen()))) - }) - handler.Add(connID, packetHandler) - handler.handlePacket(nil, nil, data) - }) - - It("handles coalesced packets", func() { - connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - packetHandler := NewMockPacketHandler(mockCtrl) - handledPackets := make(chan *receivedPacket, 3) - packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - handledPackets <- p - }).Times(3) - handler.Add(connID, packetHandler) - - buffer := getPacketBuffer() - packet := buffer.Slice[:0] - packet = append(packet, append(getPacketWithLength(connID, 10), make([]byte, 10-2 /* packet number len */)...)...) - packet = append(packet, append(getPacketWithLength(connID, 20), make([]byte, 20-2 /* packet number len */)...)...) - packet = append(packet, append(getPacketWithLength(connID, 30), make([]byte, 30-2 /* packet number len */)...)...) - conn.dataToRead <- packet - - now := time.Now() - for i := 1; i <= 3; i++ { - var p *receivedPacket - Eventually(handledPackets).Should(Receive(&p)) - Expect(p.hdr.DestConnectionID).To(Equal(connID)) - Expect(p.hdr.Length).To(BeEquivalentTo(10 * i)) - Expect(p.data).To(HaveLen(int(p.hdr.ParsedLen() + p.hdr.Length))) - Expect(p.rcvTime).To(BeTemporally("~", now, scaleDuration(20*time.Millisecond))) - Expect(p.buffer.refCount).To(Equal(3)) - } - }) - - It("ignores coalesced packet parts if the connection IDs don't match", func() { - connID1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} - connID2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} - - buffer := getPacketBuffer() - packet := buffer.Slice[:0] - // var packet []byte - packet = append(packet, getPacket(connID1)...) - packet = append(packet, getPacket(connID2)...) - - packets, err := handler.parsePacket(&net.UDPAddr{}, buffer, packet) - Expect(err).To(MatchError("coalesced packet has different destination connection ID: 0x0807060504030201, expected 0x0102030405060708")) - Expect(packets).To(HaveLen(1)) - Expect(packets[0].hdr.DestConnectionID).To(Equal(connID1)) - Expect(packets[0].buffer.refCount).To(Equal(1)) - }) - }) }) Context("stateless reset handling", func() { @@ -228,7 +171,9 @@ var _ = Describe("Packet Handler Map", func() { // first send a normal packet handledPacket := make(chan struct{}) packetHandler.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.hdr.DestConnectionID).To(Equal(connID)) + cid, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(cid).To(Equal(connID)) close(handledPacket) }) conn.dataToRead <- getPacket(connID) @@ -250,24 +195,6 @@ var _ = Describe("Packet Handler Map", func() { Eventually(destroyed).Should(BeClosed()) }) - It("detects a stateless that is coalesced with another packet", func() { - packetHandler := NewMockPacketHandler(mockCtrl) - connID := protocol.ConnectionID{0xde, 0xca, 0xfb, 0xad} - token := [16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} - handler.AddWithResetToken(connID, packetHandler, token) - fakeConnID := protocol.ConnectionID{1, 2, 3, 4, 5} - packet := getPacket(fakeConnID) - reset := append([]byte{0x40} /* short header packet */, fakeConnID...) - reset = append(reset, make([]byte, 50)...) // add some "random" data - reset = append(reset, token[:]...) - destroyed := make(chan struct{}) - packetHandler.EXPECT().destroy(errors.New("received a stateless reset")).Do(func(error) { - close(destroyed) - }) - conn.dataToRead <- append(packet, reset...) - Eventually(destroyed).Should(BeClosed()) - }) - It("deletes reset tokens when the session is retired", func() { handler.deleteRetiredSessionsAfter = scaleDuration(10 * time.Millisecond) connID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef, 0x42} @@ -291,7 +218,9 @@ var _ = Describe("Packet Handler Map", func() { p := getPacket(connID) server := NewMockUnknownPacketHandler(mockCtrl) server.EXPECT().handlePacket(gomock.Any()).Do(func(p *receivedPacket) { - Expect(p.hdr.DestConnectionID).To(Equal(connID)) + cid, err := wire.ParseConnectionID(p.data, 0) + Expect(err).ToNot(HaveOccurred()) + Expect(cid).To(Equal(connID)) }) handler.SetServer(server) handler.handlePacket(nil, nil, p) @@ -299,9 +228,9 @@ var _ = Describe("Packet Handler Map", func() { It("closes all server sessions", func() { clientSess := NewMockPacketHandler(mockCtrl) - clientSess.EXPECT().GetPerspective().Return(protocol.PerspectiveClient) + clientSess.EXPECT().getPerspective().Return(protocol.PerspectiveClient) serverSess := NewMockPacketHandler(mockCtrl) - serverSess.EXPECT().GetPerspective().Return(protocol.PerspectiveServer) + serverSess.EXPECT().getPerspective().Return(protocol.PerspectiveServer) serverSess.EXPECT().Close() handler.Add(protocol.ConnectionID{1, 1, 1, 1}, clientSess) diff --git a/server.go b/server.go index de6fc86b..eba17ac1 100644 --- a/server.go +++ b/server.go @@ -23,7 +23,7 @@ type packetHandler interface { handlePacket(*receivedPacket) io.Closer destroy(error) - GetPerspective() protocol.Perspective + getPerspective() protocol.Perspective } type unknownPacketHandler interface { @@ -44,6 +44,7 @@ type quicSession interface { Session handlePacket(*receivedPacket) GetVersion() protocol.VersionNumber + getPerspective() protocol.Perspective run() error destroy(error) closeForRecreating() protocol.PacketNumber @@ -324,53 +325,60 @@ func (s *server) Addr() net.Addr { } func (s *server) handlePacket(p *receivedPacket) { - hdr := p.hdr - - // send a Version Negotiation Packet if the client is speaking a different protocol version - if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { - go s.sendVersionNegotiationPacket(p) - return - } - if hdr.Type == protocol.PacketTypeInitial { - go s.handleInitial(p) - return - } - - defer p.buffer.Release() - // Drop long header packets. - // There's litte point in sending a Stateless Reset, since the client - // might not have received the token yet. - if hdr.IsLongHeader { - return - } - + go func() { + if shouldReleaseBuffer := s.handlePacketImpl(p); !shouldReleaseBuffer { + p.buffer.Release() + } + }() } -func (s *server) handleInitial(p *receivedPacket) { - s.logger.Debugf("<- Received Initial packet.") - sess, connID, err := s.handleInitialImpl(p) +func (s *server) handlePacketImpl(p *receivedPacket) bool /* was the packet passed on to a session */ { + if len(p.data) < protocol.MinInitialPacketSize { + s.logger.Debugf("Dropping a packet that is too small to be a valid Initial (%d bytes)", len(p.data)) + return false + } + // If we're creating a new session, the packet will be passed to the session. + // The header will then be parsed again. + hdr, _, _, err := wire.ParsePacket(p.data, s.config.ConnectionIDLength) + if err != nil { + s.logger.Debugf("Error parsing packet: %s", err) + return false + } + if !hdr.IsLongHeader { + // TODO: send a stateless reset + return false + } + // send a Version Negotiation Packet if the client is speaking a different protocol version + if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) { + s.sendVersionNegotiationPacket(p, hdr) + return false + } + if hdr.IsLongHeader && hdr.Type != protocol.PacketTypeInitial { + // Drop long header packets. + // There's litte point in sending a Stateless Reset, since the client + // might not have received the token yet. + return false + } + + s.logger.Debugf("<- Received Initial packet.") + + sess, connID, err := s.handleInitialImpl(p, hdr) if err != nil { - p.buffer.Release() s.logger.Errorf("Error occurred handling initial packet: %s", err) - return + return false } if sess == nil { // a retry was done, or the connection attempt was rejected - p.buffer.Release() - return + return false } // Don't put the packet buffer back if a new session was created. // The session will handle the packet and take of that. - serverSession := newServerSession(sess, s.config, s.logger) - s.sessionHandler.Add(connID, serverSession) + s.sessionHandler.Add(connID, sess) + return true } -func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) { - hdr := p.hdr +func (s *server) handleInitialImpl(p *receivedPacket, hdr *wire.Header) (quicSession, protocol.ConnectionID, error) { if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial { - return nil, nil, errors.New("dropping Initial packet with too short connection ID") - } - if len(p.data) < protocol.MinInitialPacketSize { - return nil, nil, errors.New("dropping too small Initial packet") + return nil, nil, errors.New("too short connection ID") } var cookie *Cookie @@ -388,7 +396,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con if !s.config.AcceptCookie(p.remoteAddr, cookie) { // Log the Initial packet now. // If no Retry is sent, the packet will be logged by the session. - (&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger) + (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) return nil, nil, s.sendRetry(p.remoteAddr, hdr) } @@ -535,9 +543,7 @@ func (s *server) sendServerBusy(remoteAddr net.Addr, hdr *wire.Header) error { return nil } -func (s *server) sendVersionNegotiationPacket(p *receivedPacket) { - defer p.buffer.Release() - hdr := p.hdr +func (s *server) sendVersionNegotiationPacket(p *receivedPacket, hdr *wire.Header) { s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version) data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions) if err != nil { diff --git a/server_session.go b/server_session.go deleted file mode 100644 index d1ab73a4..00000000 --- a/server_session.go +++ /dev/null @@ -1,59 +0,0 @@ -package quic - -import ( - "fmt" - - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" -) - -type serverSession struct { - quicSession - - config *Config - - logger utils.Logger -} - -var _ packetHandler = &serverSession{} - -func newServerSession(sess quicSession, config *Config, logger utils.Logger) packetHandler { - return &serverSession{ - quicSession: sess, - config: config, - logger: logger, - } -} - -func (s *serverSession) handlePacket(p *receivedPacket) { - if err := s.handlePacketImpl(p); err != nil { - s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err) - } -} - -func (s *serverSession) handlePacketImpl(p *receivedPacket) error { - hdr := p.hdr - - // Probably an old packet that was sent by the client before the version was negotiated. - // It is safe to drop it. - if hdr.IsLongHeader && hdr.Version != s.quicSession.GetVersion() { - return nil - } - - if hdr.IsLongHeader { - switch hdr.Type { - case protocol.PacketTypeInitial, protocol.PacketTypeHandshake: - // nothing to do here. Packet will be passed to the session. - default: - // Note that this also drops 0-RTT packets. - return fmt.Errorf("Received unsupported packet type: %s", hdr.Type) - } - } - - s.quicSession.handlePacket(p) - return nil -} - -func (s *serverSession) GetPerspective() protocol.Perspective { - return protocol.PerspectiveServer -} diff --git a/server_session_test.go b/server_session_test.go deleted file mode 100644 index b350eb51..00000000 --- a/server_session_test.go +++ /dev/null @@ -1,78 +0,0 @@ -package quic - -import ( - "github.com/lucas-clemente/quic-go/internal/protocol" - "github.com/lucas-clemente/quic-go/internal/utils" - "github.com/lucas-clemente/quic-go/internal/wire" - - . "github.com/onsi/ginkgo" - . "github.com/onsi/gomega" -) - -var _ = Describe("Server Session", func() { - var ( - qsess *MockQuicSession - sess *serverSession - ) - - BeforeEach(func() { - qsess = NewMockQuicSession(mockCtrl) - sess = newServerSession(qsess, &Config{}, utils.DefaultLogger).(*serverSession) - }) - - It("handles packets", func() { - p := &receivedPacket{ - hdr: &wire.Header{ - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5}, - }, - } - qsess.EXPECT().handlePacket(p) - sess.handlePacket(p) - }) - - It("ignores delayed packets with mismatching versions", func() { - qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) - // don't EXPECT any calls to handlePacket() - p := &receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Version: protocol.VersionNumber(123), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }, - } - err := sess.handlePacketImpl(p) - Expect(err).ToNot(HaveOccurred()) - }) - - It("ignores packets with the wrong Long Header type", func() { - qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) - p := &receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - Version: protocol.VersionNumber(100), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }, - } - err := sess.handlePacketImpl(p) - Expect(err).To(MatchError("Received unsupported packet type: Retry")) - }) - - It("passes on Handshake packets", func() { - p := &receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeHandshake, - Version: protocol.VersionNumber(100), - DestConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - }, - } - qsess.EXPECT().GetVersion().Return(protocol.VersionNumber(100)) - qsess.EXPECT().handlePacket(p) - Expect(sess.handlePacketImpl(p)).To(Succeed()) - }) - - It("has the right perspective", func() { - Expect(sess.GetPerspective()).To(Equal(protocol.PerspectiveServer)) - }) -}) diff --git a/server_test.go b/server_test.go index b5525fff..ca8a8801 100644 --- a/server_test.go +++ b/server_test.go @@ -26,6 +26,18 @@ var _ = Describe("Server", func() { tlsConf *tls.Config ) + getPacket := func(hdr *wire.Header, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect((&wire.ExtendedHeader{ + Header: *hdr, + PacketNumberLen: protocol.PacketNumberLen3, + }).Write(buf, protocol.VersionTLS)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } + } + BeforeEach(func() { conn = newMockPacketConn() conn.addr = &net.UDPAddr{} @@ -124,53 +136,45 @@ var _ = Describe("Server", func() { } It("drops Initial packets with a too short connection ID", func() { - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Version: serv.config.Versions[0], - }, - })) + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Version: serv.config.Versions[0], + }, nil)) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops too small Initial", func() { - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - Version: serv.config.Versions[0], - }, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize-100), - })) + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize-100), + )) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops packets with a too short connection ID", func() { - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Version: serv.config.Versions[0], - }, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - })) + serv.handlePacket(getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize))) Consistently(conn.dataWritten).ShouldNot(Receive()) }) It("drops non-Initial packets", func() { - serv.logger.SetLogLevel(utils.LogLevelDebug) - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ + serv.handlePacket(getPacket( + &wire.Header{ Type: protocol.PacketTypeHandshake, Version: serv.config.Versions[0], }, - data: []byte("invalid"), - })) + []byte("invalid"), + )) }) It("decodes the cookie from the Token field", func() { @@ -187,15 +191,14 @@ var _ = Describe("Server", func() { } token, err := serv.cookieGenerator.NewToken(raddr, nil) Expect(err).ToNot(HaveOccurred()) - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - remoteAddr: raddr, - hdr: &wire.Header{ - Type: protocol.PacketTypeInitial, - Token: token, - Version: serv.config.Versions[0], - }, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - })) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: token, + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) @@ -211,31 +214,29 @@ var _ = Describe("Server", func() { close(done) return false } - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - remoteAddr: raddr, - hdr: &wire.Header{ - Type: protocol.PacketTypeInitial, - Token: []byte("foobar"), - Version: serv.config.Versions[0], - }, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - })) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + Token: []byte("foobar"), + Version: serv.config.Versions[0], + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = raddr + serv.handlePacket(packet) Eventually(done).Should(BeClosed()) }) It("sends a Version Negotiation Packet for unsupported versions", func() { srcConnID := protocol.ConnectionID{1, 2, 3, 4, 5} destConnID := protocol.ConnectionID{1, 2, 3, 4, 5, 6} - serv.handlePacket(insertPacketBuffer(&receivedPacket{ - remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, - hdr: &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeInitial, - SrcConnectionID: srcConnID, - DestConnectionID: destConnID, - Version: 0x42, - }, - })) + packet := getPacket(&wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeInitial, + SrcConnectionID: srcConnID, + DestConnectionID: destConnID, + Version: 0x42, + }, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + serv.handlePacket(packet) var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) @@ -249,16 +250,15 @@ var _ = Describe("Server", func() { It("replies with a Retry packet, if a Cookie is required", func() { serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return false } hdr := &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - serv.handleInitial(insertPacketBuffer(&receivedPacket{ - remoteAddr: &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337}, - hdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - })) + packet := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + packet.remoteAddr = &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1337} + serv.handlePacket(packet) var write mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&write)) Expect(write.to.String()).To(Equal("127.0.0.1:1337")) @@ -273,15 +273,13 @@ var _ = Describe("Server", func() { It("creates a session, if no Cookie is required", func() { serv.config.AcceptCookie = func(_ net.Addr, _ *Cookie) bool { return true } hdr := &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - p := &receivedPacket{ - hdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) run := make(chan struct{}) serv.newSession = func( _ connection, @@ -309,7 +307,7 @@ var _ = Describe("Server", func() { done := make(chan struct{}) go func() { defer GinkgoRecover() - serv.handlePacket(insertPacketBuffer(p)) + serv.handlePacket(p) // the Handshake packet is written by the session Consistently(conn.dataWritten).ShouldNot(Receive()) close(done) @@ -324,16 +322,14 @@ var _ = Describe("Server", func() { senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} hdr := &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - p := &receivedPacket{ - remoteAddr: senderAddr, - hdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.remoteAddr = senderAddr serv.newSession = func( _ connection, runner sessionRunner, @@ -360,12 +356,12 @@ var _ = Describe("Server", func() { go func() { defer GinkgoRecover() defer wg.Done() - serv.handlePacket(insertPacketBuffer(p)) + serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) }() } wg.Wait() - serv.handlePacket(insertPacketBuffer(p)) + serv.handlePacket(p) var reject mockPacketConnWrite Eventually(conn.dataWritten).Should(Receive(&reject)) Expect(reject.to).To(Equal(senderAddr)) @@ -381,16 +377,14 @@ var _ = Describe("Server", func() { senderAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 42} hdr := &wire.Header{ + IsLongHeader: true, Type: protocol.PacketTypeInitial, SrcConnectionID: protocol.ConnectionID{5, 4, 3, 2, 1}, DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, Version: protocol.VersionTLS, } - p := &receivedPacket{ - remoteAddr: senderAddr, - hdr: hdr, - data: bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize), - } + p := getPacket(hdr, make([]byte, protocol.MinInitialPacketSize)) + p.remoteAddr = senderAddr ctx, cancel := context.WithCancel(context.Background()) sessionCreated := make(chan struct{}) sess := NewMockQuicSession(mockCtrl) @@ -414,7 +408,7 @@ var _ = Describe("Server", func() { return sess, nil } - serv.handlePacket(insertPacketBuffer(p)) + serv.handlePacket(p) Consistently(conn.dataWritten).ShouldNot(Receive()) Eventually(sessionCreated).Should(BeClosed()) cancel() @@ -429,7 +423,7 @@ var _ = Describe("Server", func() { Consistently(done).ShouldNot(BeClosed()) // make the go routine return - sess.EXPECT().Close() + sess.EXPECT().getPerspective() Expect(serv.Close()).To(Succeed()) Eventually(done).Should(BeClosed()) }) diff --git a/session.go b/session.go index 80448a6f..feaadc29 100644 --- a/session.go +++ b/session.go @@ -55,7 +55,6 @@ type cryptoStreamHandler interface { type receivedPacket struct { remoteAddr net.Addr - hdr *wire.Header rcvTime time.Time data []byte @@ -483,7 +482,43 @@ func (s *session) handleHandshakeComplete() { } } -func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet successfully processed */ { +func (s *session) handlePacketImpl(p *receivedPacket) bool { + var counter uint8 + var lastConnID protocol.ConnectionID + var processed bool + for len(p.data) > 0 { + hdr, packetData, rest, err := wire.ParsePacket(p.data, s.srcConnID.Len()) + if err != nil { + s.logger.Debugf("error parsing packet: %s", err) + break + } + + if counter > 0 && !hdr.DestConnectionID.Equal(lastConnID) { + s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", hdr.DestConnectionID, lastConnID) + break + } + lastConnID = hdr.DestConnectionID + + if counter > 0 { + p.buffer.Split() + } + counter++ + + // only log if this actually a coalesced packet + if s.logger.Debug() && (counter > 1 || len(rest) > 0) { + s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest)) + } + p.data = packetData + pr := s.handleSinglePacket(p, hdr) + if pr { + processed = pr + } + p.data = rest + } + return processed +} + +func (s *session) handleSinglePacket(p *receivedPacket, hdr *wire.Header) bool /* was the packet successfully processed */ { var wasQueued bool defer func() { @@ -493,22 +528,22 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc } }() - if p.hdr.Type == protocol.PacketTypeRetry { - return s.handleRetryPacket(p) + if hdr.Type == protocol.PacketTypeRetry { + return s.handleRetryPacket(p, hdr) } // The server can change the source connection ID with the first Handshake packet. // After this, all packets with a different source connection have to be ignored. - if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) { - s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID) + if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) { + s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", hdr.SrcConnectionID, s.destConnID) return false } // drop 0-RTT packets - if p.hdr.Type == protocol.PacketType0RTT { + if hdr.Type == protocol.PacketType0RTT { return false } - packet, err := s.unpacker.Unpack(p.hdr, p.data) + packet, err := s.unpacker.Unpack(hdr, p.data) if err != nil { if err == handshake.ErrOpenerNotYetAvailable { // Sealer for this encryption level not yet available. @@ -524,7 +559,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc } if s.logger.Debug() { - s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), p.hdr.DestConnectionID, packet.encryptionLevel) + s.logger.Debugf("<- Reading packet %#x (%d bytes) for connection %s, %s", packet.packetNumber, len(p.data), hdr.DestConnectionID, packet.encryptionLevel) packet.hdr.Log(s.logger) } @@ -535,7 +570,7 @@ func (s *session) handlePacketImpl(p *receivedPacket) bool /* was the packet suc return true } -func (s *session) handleRetryPacket(p *receivedPacket) bool /* was this a valid Retry */ { +func (s *session) handleRetryPacket(p *receivedPacket, hdr *wire.Header) bool /* was this a valid Retry */ { if s.perspective == protocol.PerspectiveServer { s.logger.Debugf("Ignoring Retry.") return false @@ -544,7 +579,6 @@ func (s *session) handleRetryPacket(p *receivedPacket) bool /* was this a valid s.logger.Debugf("Ignoring Retry, since we already received a packet.") return false } - hdr := p.hdr (&wire.ExtendedHeader{Header: *hdr}).Log(s.logger) if !hdr.OrigDestConnectionID.Equal(s.destConnID) { s.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, s.destConnID) @@ -1246,6 +1280,10 @@ func (s *session) RemoteAddr() net.Addr { return s.conn.RemoteAddr() } +func (s *session) getPerspective() protocol.Perspective { + return s.perspective +} + func (s *session) GetVersion() protocol.VersionNumber { return s.version } diff --git a/session_test.go b/session_test.go index 0a78cca9..ac659f41 100644 --- a/session_test.go +++ b/session_test.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "context" + "crypto/rand" "errors" "net" "runtime/pprof" @@ -354,20 +355,6 @@ var _ = Describe("Session", func() { Expect(str).To(Equal(mstr)) }) - It("drops Retry packets", func() { - hdr := wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - } - buf := &bytes.Buffer{} - (&wire.ExtendedHeader{Header: hdr}).Write(buf, sess.version) - Expect(sess.handlePacketImpl(&receivedPacket{ - hdr: &hdr, - data: buf.Bytes(), - buffer: getPacketBuffer(), - })).To(BeFalse()) - }) - Context("closing", func() { var ( runErr error @@ -492,18 +479,26 @@ var _ = Describe("Session", func() { sess.unpacker = unpacker }) - getData := func(extHdr *wire.ExtendedHeader) []byte { + getPacket := func(extHdr *wire.ExtendedHeader, data []byte) *receivedPacket { buf := &bytes.Buffer{} Expect(extHdr.Write(buf, sess.version)).To(Succeed()) - // need to set extHdr.Header, since the wire.Header contains the parsed length - hdr, _, _, err := wire.ParsePacket(buf.Bytes(), 0) - Expect(err).ToNot(HaveOccurred()) - extHdr.Header = *hdr - return buf.Bytes() + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } } + It("drops Retry packets", func() { + hdr := wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + } + Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{Header: hdr}, nil))).To(BeFalse()) + }) + It("informs the ReceivedPacketHandler about non-retransmittable packets", func() { hdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumber: 0x37, PacketNumberLen: protocol.PacketNumberLen1, } @@ -517,15 +512,14 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionInitial, rcvTime, false) sess.receivedPacketHandler = rph - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - rcvTime: rcvTime, - hdr: &hdr.Header, - data: getData(hdr), - }))).To(BeTrue()) + packet := getPacket(hdr, nil) + packet.rcvTime = rcvTime + Expect(sess.handlePacketImpl(packet)).To(BeTrue()) }) It("informs the ReceivedPacketHandler about retransmittable packets", func() { hdr := &wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, PacketNumber: 0x37, PacketNumberLen: protocol.PacketNumberLen1, } @@ -541,11 +535,9 @@ var _ = Describe("Session", func() { rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl) rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.EncryptionHandshake, rcvTime, true) sess.receivedPacketHandler = rph - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - rcvTime: rcvTime, - hdr: &hdr.Header, - data: getData(hdr), - }))).To(BeTrue()) + packet := getPacket(hdr, nil) + packet.rcvTime = rcvTime + Expect(sess.handlePacketImpl(packet)).To(BeTrue()) }) It("drops a packet when unpacking fails", func() { @@ -559,10 +551,10 @@ var _ = Describe("Session", func() { sess.run() }() sessionRunner.EXPECT().retireConnectionID(gomock.Any()) - sess.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{}, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - })) + sess.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil)) Consistently(sess.Context().Done()).ShouldNot(BeClosed()) // make the go routine return sess.closeLocal(errors.New("close")) @@ -586,65 +578,61 @@ var _ = Describe("Session", func() { close(done) }() sessionRunner.EXPECT().retireConnectionID(gomock.Any()) - sess.handlePacket(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{}, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - })) + sess.handlePacket(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil)) Eventually(done).Should(BeClosed()) }) - It("handles duplicate packets", func() { - hdr := &wire.ExtendedHeader{ - PacketNumber: 5, - PacketNumberLen: protocol.PacketNumberLen1, - } - unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ - encryptionLevel: protocol.Encryption1RTT, - hdr: hdr, - data: []byte{0}, // one PADDING frame - }, nil).Times(2) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{hdr: &hdr.Header, data: getData(hdr)}))).To(BeTrue()) - }) - It("ignores 0-RTT packets", func() { - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketType0RTT, DestConnectionID: sess.srcConnID, }, - }))).To(BeFalse()) + PacketNumberLen: protocol.PacketNumberLen2, + } + Expect(sess.handlePacketImpl(getPacket(hdr, nil))).To(BeFalse()) }) It("ignores packets with a different source connection ID", func() { - hdr := &wire.Header{ - IsLongHeader: true, - DestConnectionID: sess.destConnID, - SrcConnectionID: sess.srcConnID, - Length: 1, + hdr1 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: sess.destConnID, + SrcConnectionID: sess.srcConnID, + Length: 1, + Version: sess.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 1, } + hdr2 := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: sess.destConnID, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + Length: 1, + Version: sess.version, + }, + PacketNumberLen: protocol.PacketNumberLen1, + PacketNumber: 2, + } + Expect(sess.srcConnID).ToNot(Equal(hdr2.SrcConnectionID)) // Send one packet, which might change the connection ID. // only EXPECT one call to the unpacker unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).Return(&unpackedPacket{ encryptionLevel: protocol.Encryption1RTT, - hdr: &wire.ExtendedHeader{Header: *hdr}, + hdr: hdr1, data: []byte{0}, // one PADDING frame }, nil) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - hdr: hdr, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - }))).To(BeTrue()) + Expect(sess.handlePacketImpl(getPacket(hdr1, nil))).To(BeTrue()) // The next packet has to be ignored, since the source connection ID doesn't match. - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ - IsLongHeader: true, - DestConnectionID: sess.destConnID, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - Length: 1, - }, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - }))).To(BeFalse()) + Expect(sess.handlePacketImpl(getPacket(hdr2, nil))).To(BeFalse()) }) Context("updating the remote address", func() { @@ -657,14 +645,86 @@ var _ = Describe("Session", func() { origAddr := sess.conn.(*mockConnection).remoteAddr remoteIP := &net.IPAddr{IP: net.IPv4(192, 168, 0, 100)} Expect(origAddr).ToNot(Equal(remoteIP)) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - remoteAddr: remoteIP, - hdr: &wire.Header{}, - data: getData(&wire.ExtendedHeader{PacketNumberLen: protocol.PacketNumberLen1}), - }))).To(BeTrue()) + packet := getPacket(&wire.ExtendedHeader{ + Header: wire.Header{DestConnectionID: sess.srcConnID}, + PacketNumberLen: protocol.PacketNumberLen1, + }, nil) + packet.remoteAddr = remoteIP + Expect(sess.handlePacketImpl(packet)).To(BeTrue()) Expect(sess.conn.(*mockConnection).remoteAddr).To(Equal(origAddr)) }) }) + + Context("coalesced packets", func() { + getPacketWithLength := func(connID protocol.ConnectionID, length protocol.ByteCount) (int /* header length */, *receivedPacket) { + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeHandshake, + DestConnectionID: connID, + SrcConnectionID: sess.destConnID, + Version: protocol.VersionTLS, + Length: length, + }, + PacketNumberLen: protocol.PacketNumberLen3, + } + hdrLen := hdr.GetLength(sess.version) + b := make([]byte, 1) + rand.Read(b) + packet := getPacket(hdr, bytes.Repeat(b, int(length)-3)) + return int(hdrLen), packet + } + + It("cuts packets to the right length", func() { + hdrLen, packet := getPacketWithLength(sess.srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(int(hdrLen + 456 - 3))) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }) + Expect(sess.handlePacketImpl(packet)).To(BeTrue()) + }) + + It("handles coalesced packets", func() { + hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(int(hdrLen1 + 456 - 3))) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }) + hdrLen2, packet2 := getPacketWithLength(sess.srcConnID, 123) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(int(hdrLen2 + 123 - 3))) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }) + packet1.data = append(packet1.data, packet2.data...) + Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) + }) + + It("ignores coalesced packet parts if the destination connection IDs don't match", func() { + wrongConnID := protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef} + Expect(sess.srcConnID).ToNot(Equal(wrongConnID)) + hdrLen1, packet1 := getPacketWithLength(sess.srcConnID, 456) + unpacker.EXPECT().Unpack(gomock.Any(), gomock.Any()).DoAndReturn(func(_ *wire.Header, data []byte) (*unpackedPacket, error) { + Expect(data).To(HaveLen(int(hdrLen1 + 456 - 3))) + return &unpackedPacket{ + encryptionLevel: protocol.EncryptionHandshake, + data: []byte{0}, + }, nil + }) + _, packet2 := getPacketWithLength(wrongConnID, 123) + // don't EXPECT any calls to unpacker.Unpack() + packet1.data = append(packet1.data, packet2.data...) + Expect(sess.handlePacketImpl(packet1)).To(BeTrue()) + }) + }) }) Context("sending packets", func() { @@ -1436,6 +1496,15 @@ var _ = Describe("Client Session", func() { cryptoSetup *mocks.MockCryptoSetup ) + getPacket := func(hdr *wire.ExtendedHeader, data []byte) *receivedPacket { + buf := &bytes.Buffer{} + Expect(hdr.Write(buf, sess.version)).To(Succeed()) + return &receivedPacket{ + data: append(buf.Bytes(), data...), + buffer: getPacketBuffer(), + } + } + BeforeEach(func() { Eventually(areSessionsRunning).Should(BeFalse()) @@ -1450,9 +1519,9 @@ var _ = Describe("Client Session", func() { nil, // tls.Config 42, // initial packet number &handshake.TransportParameters{}, - protocol.VersionWhatever, + protocol.VersionTLS, utils.DefaultLogger, - protocol.VersionWhatever, + protocol.VersionTLS, ) sess = sessP.(*session) Expect(err).ToNot(HaveOccurred()) @@ -1479,16 +1548,16 @@ var _ = Describe("Client Session", func() { }() newConnID := protocol.ConnectionID{1, 3, 3, 7, 1, 3, 3, 7} packer.EXPECT().ChangeDestConnectionID(newConnID) - Expect(sess.handlePacketImpl(insertPacketBuffer(&receivedPacket{ - hdr: &wire.Header{ + Expect(sess.handlePacketImpl(getPacket(&wire.ExtendedHeader{ + Header: wire.Header{ IsLongHeader: true, Type: protocol.PacketTypeHandshake, SrcConnectionID: newConnID, DestConnectionID: sess.srcConnID, Length: 1, }, - data: []byte{0}, - }))).To(BeTrue()) + PacketNumberLen: protocol.PacketNumberLen2, + }, []byte{0}))).To(BeTrue()) // make sure the go routine returns packer.EXPECT().PackConnectionClose(gomock.Any()).Return(&packedPacket{}, nil) sessionRunner.EXPECT().retireConnectionID(gomock.Any()) @@ -1498,56 +1567,52 @@ var _ = Describe("Client Session", func() { }) Context("handling Retry", func() { - var validRetryHdr *wire.Header + var validRetryHdr *wire.ExtendedHeader BeforeEach(func() { - validRetryHdr = &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - OrigDestConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, - Token: []byte("foobar"), + validRetryHdr = &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + OrigDestConnectionID: protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}, + Token: []byte("foobar"), + Version: sess.version, + }, } }) - getPacket := func(hdr *wire.Header) *receivedPacket { - buf := &bytes.Buffer{} - (&wire.ExtendedHeader{Header: *hdr}).Write(buf, sess.version) - return &receivedPacket{ - hdr: hdr, - data: buf.Bytes(), - buffer: getPacketBuffer(), - } - } - It("handles Retry packets", func() { cryptoSetup.EXPECT().ChangeConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) packer.EXPECT().SetToken([]byte("foobar")) packer.EXPECT().ChangeDestConnectionID(protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}) - Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeTrue()) + Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeTrue()) }) It("ignores Retry packets after receiving a regular packet", func() { sess.receivedFirstPacket = true - Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeFalse()) + Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse()) }) It("ignores Retry packets if the server didn't change the connection ID", func() { validRetryHdr.SrcConnectionID = sess.destConnID - Expect(sess.handlePacketImpl(getPacket(validRetryHdr))).To(BeFalse()) + Expect(sess.handlePacketImpl(getPacket(validRetryHdr, nil))).To(BeFalse()) }) It("ignores Retry packets with the wrong original destination connection ID", func() { - hdr := &wire.Header{ - IsLongHeader: true, - Type: protocol.PacketTypeRetry, - SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, - DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, - OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, - Token: []byte("foobar"), + hdr := &wire.ExtendedHeader{ + Header: wire.Header{ + IsLongHeader: true, + Type: protocol.PacketTypeRetry, + SrcConnectionID: protocol.ConnectionID{0xde, 0xad, 0xbe, 0xef}, + DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}, + OrigDestConnectionID: protocol.ConnectionID{1, 2, 3, 4}, + Token: []byte("foobar"), + }, + PacketNumberLen: protocol.PacketNumberLen3, } - Expect(sess.handlePacketImpl(getPacket(hdr))).To(BeFalse()) + Expect(sess.handlePacketImpl(getPacket(hdr, nil))).To(BeFalse()) }) })