diff --git a/connection_test.go b/connection_test.go index 7f9468d95..dd3f93b09 100644 --- a/connection_test.go +++ b/connection_test.go @@ -54,14 +54,22 @@ var _ = Describe("Connection", func() { destConnID := protocol.ParseConnectionID([]byte{8, 7, 6, 5, 4, 3, 2, 1}) clientDestConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - getCoalescedPacket := func(pn protocol.PacketNumber, isLongHeader bool) *coalescedPacket { + getCoalescedPacket := func(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) *coalescedPacket { buffer := getPacketBuffer() buffer.Data = append(buffer.Data, []byte("foobar")...) packet := &coalescedPacket{buffer: buffer} - if isLongHeader { + if encLevel != protocol.Encryption1RTT { + var typ protocol.PacketType + //nolint:exhaustive + switch encLevel { + case protocol.EncryptionInitial: + typ = protocol.PacketTypeInitial + case protocol.EncryptionHandshake: + typ = protocol.PacketTypeHandshake + } packet.longHdrPackets = []*longHeaderPacket{{ header: &wire.ExtendedHeader{ - Header: wire.Header{}, + Header: wire.Header{Type: typ}, PacketNumber: pn, }, length: 6, // foobar @@ -1326,14 +1334,14 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().QueueProbePacket(encLevel) sph.EXPECT().ECNMode(gomock.Any()) - p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) + p := getCoalescedPacket(123, encLevel) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) conn.sentPacketHandler = sph runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) - if enc == protocol.Encryption1RTT { + if encLevel == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, gomock.Any(), gomock.Any(), gomock.Any()) } else { tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, gomock.Any(), gomock.Any(), gomock.Any()) @@ -1349,13 +1357,13 @@ var _ = Describe("Connection", func() { sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone) sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT0) sph.EXPECT().QueueProbePacket(encLevel).Return(false) - p := getCoalescedPacket(123, enc != protocol.Encryption1RTT) + p := getCoalescedPacket(123, encLevel) packer.EXPECT().MaybePackProbePacket(encLevel, gomock.Any(), conn.version).Return(p, nil) sph.EXPECT().SentPacket(gomock.Any(), protocol.PacketNumber(123), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) runConn() sent := make(chan struct{}) sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(*packetBuffer, uint16, protocol.ECN) { close(sent) }) - if enc == protocol.Encryption1RTT { + if encLevel == protocol.Encryption1RTT { tracer.EXPECT().SentShortHeaderPacket(gomock.Any(), p.shortHdrPacket.Length, logging.ECT0, gomock.Any(), gomock.Any()) } else { tracer.EXPECT().SentLongHeaderPacket(gomock.Any(), p.longHdrPackets[0].length, logging.ECT0, gomock.Any(), gomock.Any())