add a function to write short headers

This commit is contained in:
Marten Seemann
2022-12-26 10:13:12 +13:00
parent 87f51d33dc
commit 124e597cf9
4 changed files with 45 additions and 19 deletions

View File

@@ -1839,8 +1839,8 @@ func (s *connection) maybeSendAckOnlyPacket() error {
} }
return err return err
} }
s.sendPackedShortHeaderPacket(buffer, p.Packet, now)
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false) s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, buffer.Len(), false)
s.sendPackedShortHeaderPacket(buffer, p.Packet, now)
return nil return nil
} }

View File

@@ -177,7 +177,7 @@ func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.Versi
b.Write(h.Token) b.Write(h.Token)
} }
quicvarint.WriteWithLen(b, uint64(h.Length), 2) quicvarint.WriteWithLen(b, uint64(h.Length), 2)
return h.writePacketNumber(b) return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen)
} }
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error { func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
@@ -188,23 +188,7 @@ func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNum
b.WriteByte(typeByte) b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes()) b.Write(h.DestConnectionID.Bytes())
return h.writePacketNumber(b) return writePacketNumber(b, h.PacketNumber, h.PacketNumberLen)
}
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
return nil
} }
// ParsedLen returns the number of bytes that were consumed when parsing the header // ParsedLen returns the number of bytes that were consumed when parsing the header
@@ -247,3 +231,19 @@ func (h *ExtendedHeader) Log(logger utils.Logger) {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase) logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
} }
} }
func writePacketNumber(b *bytes.Buffer, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen) error {
switch pnLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(pn))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(pn))
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(pn))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(pn))
default:
return fmt.Errorf("invalid packet number length: %d", pnLen)
}
return nil
}

View File

@@ -1,6 +1,7 @@
package wire package wire
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -50,6 +51,17 @@ func ParseShortHeader(data []byte, connIDLen int) (length int, _ protocol.Packet
return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err return 1 + connIDLen + int(pnLen), pn, pnLen, kp, err
} }
// WriteShortHeader writes a short header.
func WriteShortHeader(b *bytes.Buffer, connID protocol.ConnectionID, pn protocol.PacketNumber, pnLen protocol.PacketNumberLen, kp protocol.KeyPhaseBit) error {
typeByte := 0x40 | uint8(pnLen-1)
if kp == protocol.KeyPhaseOne {
typeByte |= byte(1 << 2)
}
b.WriteByte(typeByte)
b.Write(connID.Bytes())
return writePacketNumber(b, pn, pnLen)
}
func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount { func ShortHeaderLen(dest protocol.ConnectionID, pnLen protocol.PacketNumberLen) protocol.ByteCount {
return 1 + protocol.ByteCount(dest.Len()) + protocol.ByteCount(pnLen) return 1 + protocol.ByteCount(dest.Len()) + protocol.ByteCount(pnLen)
} }

View File

@@ -75,6 +75,20 @@ var _ = Describe("Short Header", func() {
Expect(ShortHeaderLen(protocol.ParseConnectionID([]byte{}), protocol.PacketNumberLen1)).To(BeEquivalentTo(2)) Expect(ShortHeaderLen(protocol.ParseConnectionID([]byte{}), protocol.PacketNumberLen1)).To(BeEquivalentTo(2))
}) })
Context("writing", func() {
It("writes a short header packet", func() {
b := &bytes.Buffer{}
connID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
Expect(WriteShortHeader(b, connID, 1337, 4, protocol.KeyPhaseOne)).To(Succeed())
l, pn, pnLen, kp, err := ParseShortHeader(b.Bytes(), 4)
Expect(err).ToNot(HaveOccurred())
Expect(pn).To(Equal(protocol.PacketNumber(1337)))
Expect(pnLen).To(Equal(protocol.PacketNumberLen4))
Expect(kp).To(Equal(protocol.KeyPhaseOne))
Expect(l).To(Equal(b.Len()))
})
})
Context("logging", func() { Context("logging", func() {
var ( var (
buf *bytes.Buffer buf *bytes.Buffer