diff --git a/packet_packer.go b/packet_packer.go index 25e3cdf1..48292eed 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -25,10 +25,23 @@ type packer interface { } type packedPacket struct { - header *wire.ExtendedHeader - raw []byte - frames []wire.Frame - encryptionLevel protocol.EncryptionLevel + header *wire.ExtendedHeader + raw []byte + frames []wire.Frame +} + +func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel { + if !p.header.IsLongHeader { + return protocol.Encryption1RTT + } + switch p.header.Type { + case protocol.PacketTypeInitial: + return protocol.EncryptionInitial + case protocol.PacketTypeHandshake: + return protocol.EncryptionHandshake + default: + return protocol.EncryptionUnspecified + } } func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { @@ -37,7 +50,7 @@ func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet { PacketType: p.header.Type, Frames: p.frames, Length: protocol.ByteCount(len(p.raw)), - EncryptionLevel: p.encryptionLevel, + EncryptionLevel: p.EncryptionLevel(), SendTime: time.Now(), } } @@ -138,10 +151,9 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac header := p.getHeader(encLevel) raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: header, + raw: raw, + frames: frames, }, err } @@ -156,10 +168,9 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { frames := []wire.Frame{ack} raw, err := p.writeAndSealPacket(header, frames, sealer) return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: header, + raw: raw, + frames: frames, }, err } @@ -232,10 +243,9 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP return nil, err } packets = append(packets, &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: header, + raw: raw, + frames: frames, }) } return packets, nil @@ -286,10 +296,9 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { return nil, err } return &packedPacket{ - header: header, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: header, + raw: raw, + frames: frames, }, nil } @@ -325,10 +334,9 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { return nil, err } return &packedPacket{ - header: hdr, - raw: raw, - frames: frames, - encryptionLevel: encLevel, + header: hdr, + raw: raw, + frames: frames, }, nil } diff --git a/packet_packer_test.go b/packet_packer_test.go index f262f703..1971700e 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -253,7 +253,7 @@ var _ = Describe("Packet packer", func() { }) p, err := packer.PackPacket() Expect(err).ToNot(HaveOccurred()) - Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) }) It("packs a single ACK", func() { @@ -494,7 +494,7 @@ var _ = Describe("Packet packer", func() { Expect(err).ToNot(HaveOccurred()) Expect(packets).To(HaveLen(1)) p := packets[0] - Expect(p.encryptionLevel).To(Equal(protocol.Encryption1RTT)) + Expect(p.EncryptionLevel()).To(Equal(protocol.Encryption1RTT)) Expect(p.frames).To(Equal(frames)) }) @@ -846,7 +846,7 @@ var _ = Describe("Packet packer", func() { Expect(p).To(HaveLen(1)) Expect(p[0].header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p[0].frames).To(Equal([]wire.Frame{f})) - Expect(p[0].encryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p[0].EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) }) It("packs a retransmission for an Initial packet", func() { @@ -864,7 +864,7 @@ var _ = Describe("Packet packer", func() { Expect(packets).To(HaveLen(1)) p := packets[0] Expect(p.frames).To(Equal([]wire.Frame{sf})) - Expect(p.encryptionLevel).To(Equal(protocol.EncryptionInitial)) + Expect(p.EncryptionLevel()).To(Equal(protocol.EncryptionInitial)) Expect(p.header.Type).To(Equal(protocol.PacketTypeInitial)) Expect(p.header.Token).To(Equal(token)) Expect(p.raw).To(HaveLen(protocol.MinInitialPacketSize)) diff --git a/session.go b/session.go index 2cdd6d8c..b16083b7 100644 --- a/session.go +++ b/session.go @@ -976,7 +976,7 @@ func (s *session) logPacket(packet *packedPacket) { // We don't need to allocate the slices for calling the format functions return } - s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.encryptionLevel) + s.logger.Debugf("-> Sending packet 0x%x (%d bytes) for connection %s, %s", packet.header.PacketNumber, len(packet.raw), s.srcConnID, packet.EncryptionLevel()) packet.header.Log(s.logger) for _, frame := range packet.frames { wire.LogFrame(s.logger, frame, true)