diff --git a/connection.go b/connection.go index fbaddf835..38d152749 100644 --- a/connection.go +++ b/connection.go @@ -1839,8 +1839,8 @@ func (s *connection) maybeSendAckOnlyPacket() error { } return err } - s.sendPackedShortHeaderPacket(buffer, p.Packet, now) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) + s.sendPackedShortHeaderPacket(buffer, p.Packet, now) return nil } diff --git a/internal/wire/extended_header.go b/internal/wire/extended_header.go index 9d9edab25..74f2fc5f4 100644 --- a/internal/wire/extended_header.go +++ b/internal/wire/extended_header.go @@ -177,7 +177,7 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.Versi b.Write(h.Token) } quicvarint.WriteWithLen(b, uint64(h.Length), 2) - return h.writePacketNumber(b) + return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen) } func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { @@ -188,23 +188,7 @@ func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNum b.WriteByte(typeByte) b.Write(h.DestConnectionID.Bytes()) - return h.writePacketNumber(b) -} - -func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error { - switch h.PacketNumberLen { - case protocol.PacketNumberLen1: - b.WriteByte(uint8(h.PacketNumber)) - case protocol.PacketNumberLen2: - utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber)) - case protocol.PacketNumberLen3: - utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber)) - case protocol.PacketNumberLen4: - utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber)) - default: - return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen) - } - return nil + return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen) } // ParsedLen returns the number of bytes that were consumed when parsing the header @@ -247,3 +231,19 @@ func (h *ExtendedHeader) Log(logger utils.Logger) { logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) } } + +func writePacketNumber(b *bytes.Buffer, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) error { + switch pnLen { + case protocol.PacketNumberLen1: + b.WriteByte(uint8(pn)) + case protocol.PacketNumberLen2: + utils.BigEndian.WriteUint16(b, uint16(pn)) + case protocol.PacketNumberLen3: + utils.BigEndian.WriteUint24(b, uint32(pn)) + case protocol.PacketNumberLen4: + utils.BigEndian.WriteUint32(b, uint32(pn)) + default: + return fmt.Errorf("invalid packet number length: %d", pnLen) + } + return nil +} diff --git a/internal/wire/short_header.go b/internal/wire/short_header.go index 1ddb9be9c..57913aafc 100644 --- a/internal/wire/short_header.go +++ b/internal/wire/short_header.go @@ -1,6 +1,7 @@ package wire import ( + "bytes" "errors" "fmt" "io" @@ -50,6 +51,17 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err } +// WriteShortHeader writes a short header. +func WriteShortHeader(b *bytes.Buffer, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) error { + typeByte := 0x40 | uint8(pnLen-1) + if kp == protocol.KeyPhaseOne { + typeByte |= byte(1 << 2) + } + b.WriteByte(typeByte) + b.Write(connID.Bytes()) + return writePacketNumber(b, pn, pnLen) +} + func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount { return 1 + protocol.ByteCount(dest.Len()) + protocol.ByteCount(pnLen) } diff --git a/internal/wire/short_header_test.go b/internal/wire/short_header_test.go index a475a288b..74dd55dab 100644 --- a/internal/wire/short_header_test.go +++ b/internal/wire/short_header_test.go @@ -75,6 +75,20 @@ var _ = Describe("Short Header", func() { Expect(ShortHeaderLen(protocol.ParseConnectionID([]byte{}), protocol.PacketNumberLen1)).To(BeEquivalentTo(2)) }) + Context("writing", func() { + It("writes a short header packet", func() { + b := &bytes.Buffer{} + connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4}) + Expect(WriteShortHeader(b, connID, 1337, 4, protocol.KeyPhaseOne)).To(Succeed()) + l, pn, pnLen, kp, err := ParseShortHeader(b.Bytes(), 4) + Expect(err).ToNot(HaveOccurred()) + Expect(pn).To(Equal(protocol.PacketNumber(1337))) + Expect(pnLen).To(Equal(protocol.PacketNumberLen4)) + Expect(kp).To(Equal(protocol.KeyPhaseOne)) + Expect(l).To(Equal(b.Len())) + }) + }) + Context("logging", func() { var ( buf *bytes.Buffer