From b2d3ef691efe052461201f96d0bc4ace2184de59 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Mon, 10 Jun 2019 16:21:45 +0800 Subject: [PATCH] use separate sealer interfaces for long and short header packets --- internal/handshake/aead.go | 9 +- internal/handshake/aead_test.go | 4 +- internal/handshake/crypto_setup.go | 12 +-- internal/handshake/initial_aead.go | 2 +- internal/handshake/interface.go | 16 ++- internal/mocks/crypto_setup.go | 12 +-- internal/mocks/mockgen.go | 2 +- internal/mocks/sealer.go | 75 ------------- internal/mocks/short_header_sealer.go | 89 ++++++++++++++++ mock_sealing_manager_test.go | 12 +-- packet_packer.go | 145 ++++++++++++++++---------- packet_packer_test.go | 27 ++--- 12 files changed, 233 insertions(+), 172 deletions(-) delete mode 100644 internal/mocks/sealer.go create mode 100644 internal/mocks/short_header_sealer.go diff --git a/internal/handshake/aead.go b/internal/handshake/aead.go index 07ce74f74..53cf09166 100644 --- a/internal/handshake/aead.go +++ b/internal/handshake/aead.go @@ -19,9 +19,10 @@ type sealer struct { is1RTT bool } -var _ Sealer = &sealer{} +var _ LongHeaderSealer = &sealer{} +var _ ShortHeaderSealer = &sealer{} -func newSealer(aead cipher.AEAD, hpEncrypter cipher.Block, is1RTT bool) Sealer { +func newSealer(aead cipher.AEAD, hpEncrypter cipher.Block, is1RTT bool) ShortHeaderSealer { return &sealer{ aead: aead, nonceBuf: make([]byte, aead.NonceSize()), @@ -57,6 +58,10 @@ func (s *sealer) Overhead() int { return s.aead.Overhead() } +func (s *sealer) KeyPhase() protocol.KeyPhase { + return protocol.KeyPhaseZero +} + type opener struct { aead cipher.AEAD pnDecrypter cipher.Block diff --git a/internal/handshake/aead_test.go b/internal/handshake/aead_test.go index 2f7bfbac1..8559b3c40 100644 --- a/internal/handshake/aead_test.go +++ b/internal/handshake/aead_test.go @@ -10,7 +10,7 @@ import ( ) var _ = Describe("AEAD", func() { - getSealerAndOpener := func(is1RTT bool) (Sealer, Opener) { + getSealerAndOpener := func(is1RTT bool) (ShortHeaderSealer, Opener) { key := make([]byte, 16) hpKey := make([]byte, 16) rand.Read(key) @@ -29,7 +29,7 @@ var _ = Describe("AEAD", func() { Context("message encryption", func() { var ( - sealer Sealer + sealer ShortHeaderSealer opener Opener ) diff --git a/internal/handshake/crypto_setup.go b/internal/handshake/crypto_setup.go index 31a6f199e..81f5baa23 100644 --- a/internal/handshake/crypto_setup.go +++ b/internal/handshake/crypto_setup.go @@ -106,15 +106,15 @@ type cryptoSetup struct { initialStream io.Writer initialOpener Opener - initialSealer Sealer + initialSealer LongHeaderSealer handshakeStream io.Writer handshakeOpener Opener - handshakeSealer Sealer + handshakeSealer LongHeaderSealer oneRTTStream io.Writer opener Opener - sealer Sealer + sealer ShortHeaderSealer } var _ qtls.RecordLayer = &cryptoSetup{} @@ -564,14 +564,14 @@ func (h *cryptoSetup) SendAlert(alert uint8) { h.alertChan <- alert } -func (h *cryptoSetup) GetInitialSealer() (Sealer, error) { +func (h *cryptoSetup) GetInitialSealer() (LongHeaderSealer, error) { h.mutex.Lock() defer h.mutex.Unlock() return h.initialSealer, nil } -func (h *cryptoSetup) GetHandshakeSealer() (Sealer, error) { +func (h *cryptoSetup) GetHandshakeSealer() (LongHeaderSealer, error) { h.mutex.Lock() defer h.mutex.Unlock() @@ -581,7 +581,7 @@ func (h *cryptoSetup) GetHandshakeSealer() (Sealer, error) { return h.handshakeSealer, nil } -func (h *cryptoSetup) Get1RTTSealer() (Sealer, error) { +func (h *cryptoSetup) Get1RTTSealer() (ShortHeaderSealer, error) { h.mutex.Lock() defer h.mutex.Unlock() diff --git a/internal/handshake/initial_aead.go b/internal/handshake/initial_aead.go index 5d78bbe2c..c561ae619 100644 --- a/internal/handshake/initial_aead.go +++ b/internal/handshake/initial_aead.go @@ -11,7 +11,7 @@ import ( var quicVersion1Salt = []byte{0xef, 0x4f, 0xb0, 0xab, 0xb4, 0x74, 0x70, 0xc4, 0x1b, 0xef, 0xcf, 0x80, 0x31, 0x33, 0x4f, 0xae, 0x48, 0x5e, 0x09, 0xa0} // NewInitialAEAD creates a new AEAD for Initial encryption / decryption. -func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (Sealer, Opener, error) { +func NewInitialAEAD(connID protocol.ConnectionID, pers protocol.Perspective) (LongHeaderSealer, Opener, error) { clientSecret, serverSecret := computeSecrets(connID) var mySecret, otherSecret []byte if pers == protocol.PerspectiveClient { diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 91a433911..c16539358 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -14,13 +14,19 @@ type Opener interface { DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) } -// Sealer seals a packet -type Sealer interface { +// LongHeaderSealer seals a long header packet +type LongHeaderSealer interface { Seal(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) []byte EncryptHeader(sample []byte, firstByte *byte, pnBytes []byte) Overhead() int } +// ShortHeaderSealer seals a short header packet +type ShortHeaderSealer interface { + LongHeaderSealer + KeyPhase() protocol.KeyPhase +} + // A tlsExtensionHandler sends and received the QUIC TLS extension. type tlsExtensionHandler interface { GetExtensions(msgType uint8) []qtls.Extension @@ -49,7 +55,7 @@ type CryptoSetup interface { GetHandshakeOpener() (Opener, error) Get1RTTOpener() (Opener, error) - GetInitialSealer() (Sealer, error) - GetHandshakeSealer() (Sealer, error) - Get1RTTSealer() (Sealer, error) + GetInitialSealer() (LongHeaderSealer, error) + GetHandshakeSealer() (LongHeaderSealer, error) + Get1RTTSealer() (ShortHeaderSealer, error) } diff --git a/internal/mocks/crypto_setup.go b/internal/mocks/crypto_setup.go index bac50456a..13f0a5bfe 100644 --- a/internal/mocks/crypto_setup.go +++ b/internal/mocks/crypto_setup.go @@ -94,10 +94,10 @@ func (mr *MockCryptoSetupMockRecorder) Get1RTTOpener() *gomock.Call { } // Get1RTTSealer mocks base method -func (m *MockCryptoSetup) Get1RTTSealer() (handshake.Sealer, error) { +func (m *MockCryptoSetup) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get1RTTSealer") - ret0, _ := ret[0].(handshake.Sealer) + ret0, _ := ret[0].(handshake.ShortHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -124,10 +124,10 @@ func (mr *MockCryptoSetupMockRecorder) GetHandshakeOpener() *gomock.Call { } // GetHandshakeSealer mocks base method -func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.Sealer, error) { +func (m *MockCryptoSetup) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetHandshakeSealer") - ret0, _ := ret[0].(handshake.Sealer) + ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -154,10 +154,10 @@ func (mr *MockCryptoSetupMockRecorder) GetInitialOpener() *gomock.Call { } // GetInitialSealer mocks base method -func (m *MockCryptoSetup) GetInitialSealer() (handshake.Sealer, error) { +func (m *MockCryptoSetup) GetInitialSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetInitialSealer") - ret0, _ := ret[0].(handshake.Sealer) + ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/internal/mocks/mockgen.go b/internal/mocks/mockgen.go index 4fa465b1e..91ec4cb87 100644 --- a/internal/mocks/mockgen.go +++ b/internal/mocks/mockgen.go @@ -2,7 +2,7 @@ package mocks //go:generate sh -c "mockgen -package mockquic -destination quic/stream.go github.com/lucas-clemente/quic-go Stream && goimports -w quic/stream.go" //go:generate sh -c "mockgen -package mockquic -destination quic/session.go github.com/lucas-clemente/quic-go Session && goimports -w quic/session.go" -//go:generate sh -c "../mockgen_internal.sh mocks sealer.go github.com/lucas-clemente/quic-go/internal/handshake Sealer" +//go:generate sh -c "../mockgen_internal.sh mocks short_header_sealer.go github.com/lucas-clemente/quic-go/internal/handshake ShortHeaderSealer" //go:generate sh -c "../mockgen_internal.sh mocks opener.go github.com/lucas-clemente/quic-go/internal/handshake Opener" //go:generate sh -c "../mockgen_internal.sh mocks crypto_setup.go github.com/lucas-clemente/quic-go/internal/handshake CryptoSetup" //go:generate sh -c "../mockgen_internal.sh mocks stream_flow_controller.go github.com/lucas-clemente/quic-go/internal/flowcontrol StreamFlowController" diff --git a/internal/mocks/sealer.go b/internal/mocks/sealer.go deleted file mode 100644 index 63a9ceb6f..000000000 --- a/internal/mocks/sealer.go +++ /dev/null @@ -1,75 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: Sealer) - -// Package mocks is a generated GoMock package. -package mocks - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" - protocol "github.com/lucas-clemente/quic-go/internal/protocol" -) - -// MockSealer is a mock of Sealer interface -type MockSealer struct { - ctrl *gomock.Controller - recorder *MockSealerMockRecorder -} - -// MockSealerMockRecorder is the mock recorder for MockSealer -type MockSealerMockRecorder struct { - mock *MockSealer -} - -// NewMockSealer creates a new mock instance -func NewMockSealer(ctrl *gomock.Controller) *MockSealer { - mock := &MockSealer{ctrl: ctrl} - mock.recorder = &MockSealerMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use -func (m *MockSealer) EXPECT() *MockSealerMockRecorder { - return m.recorder -} - -// EncryptHeader mocks base method -func (m *MockSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2) -} - -// EncryptHeader indicates an expected call of EncryptHeader -func (mr *MockSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockSealer)(nil).EncryptHeader), arg0, arg1, arg2) -} - -// Overhead mocks base method -func (m *MockSealer) Overhead() int { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Overhead") - ret0, _ := ret[0].(int) - return ret0 -} - -// Overhead indicates an expected call of Overhead -func (mr *MockSealerMockRecorder) Overhead() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockSealer)(nil).Overhead)) -} - -// Seal mocks base method -func (m *MockSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]byte) - return ret0 -} - -// Seal indicates an expected call of Seal -func (mr *MockSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockSealer)(nil).Seal), arg0, arg1, arg2, arg3) -} diff --git a/internal/mocks/short_header_sealer.go b/internal/mocks/short_header_sealer.go new file mode 100644 index 000000000..1f269b51d --- /dev/null +++ b/internal/mocks/short_header_sealer.go @@ -0,0 +1,89 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/lucas-clemente/quic-go/internal/handshake (interfaces: ShortHeaderSealer) + +// Package mocks is a generated GoMock package. +package mocks + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + protocol "github.com/lucas-clemente/quic-go/internal/protocol" +) + +// MockShortHeaderSealer is a mock of ShortHeaderSealer interface +type MockShortHeaderSealer struct { + ctrl *gomock.Controller + recorder *MockShortHeaderSealerMockRecorder +} + +// MockShortHeaderSealerMockRecorder is the mock recorder for MockShortHeaderSealer +type MockShortHeaderSealerMockRecorder struct { + mock *MockShortHeaderSealer +} + +// NewMockShortHeaderSealer creates a new mock instance +func NewMockShortHeaderSealer(ctrl *gomock.Controller) *MockShortHeaderSealer { + mock := &MockShortHeaderSealer{ctrl: ctrl} + mock.recorder = &MockShortHeaderSealerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockShortHeaderSealer) EXPECT() *MockShortHeaderSealerMockRecorder { + return m.recorder +} + +// EncryptHeader mocks base method +func (m *MockShortHeaderSealer) EncryptHeader(arg0 []byte, arg1 *byte, arg2 []byte) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "EncryptHeader", arg0, arg1, arg2) +} + +// EncryptHeader indicates an expected call of EncryptHeader +func (mr *MockShortHeaderSealerMockRecorder) EncryptHeader(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EncryptHeader", reflect.TypeOf((*MockShortHeaderSealer)(nil).EncryptHeader), arg0, arg1, arg2) +} + +// KeyPhase mocks base method +func (m *MockShortHeaderSealer) KeyPhase() protocol.KeyPhase { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "KeyPhase") + ret0, _ := ret[0].(protocol.KeyPhase) + return ret0 +} + +// KeyPhase indicates an expected call of KeyPhase +func (mr *MockShortHeaderSealerMockRecorder) KeyPhase() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyPhase", reflect.TypeOf((*MockShortHeaderSealer)(nil).KeyPhase)) +} + +// Overhead mocks base method +func (m *MockShortHeaderSealer) Overhead() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Overhead") + ret0, _ := ret[0].(int) + return ret0 +} + +// Overhead indicates an expected call of Overhead +func (mr *MockShortHeaderSealerMockRecorder) Overhead() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Overhead", reflect.TypeOf((*MockShortHeaderSealer)(nil).Overhead)) +} + +// Seal mocks base method +func (m *MockShortHeaderSealer) Seal(arg0, arg1 []byte, arg2 protocol.PacketNumber, arg3 []byte) []byte { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Seal", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]byte) + return ret0 +} + +// Seal indicates an expected call of Seal +func (mr *MockShortHeaderSealerMockRecorder) Seal(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Seal", reflect.TypeOf((*MockShortHeaderSealer)(nil).Seal), arg0, arg1, arg2, arg3) +} diff --git a/mock_sealing_manager_test.go b/mock_sealing_manager_test.go index eef954baa..f8c661e8d 100644 --- a/mock_sealing_manager_test.go +++ b/mock_sealing_manager_test.go @@ -35,10 +35,10 @@ func (m *MockSealingManager) EXPECT() *MockSealingManagerMockRecorder { } // Get1RTTSealer mocks base method -func (m *MockSealingManager) Get1RTTSealer() (handshake.Sealer, error) { +func (m *MockSealingManager) Get1RTTSealer() (handshake.ShortHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Get1RTTSealer") - ret0, _ := ret[0].(handshake.Sealer) + ret0, _ := ret[0].(handshake.ShortHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -50,10 +50,10 @@ func (mr *MockSealingManagerMockRecorder) Get1RTTSealer() *gomock.Call { } // GetHandshakeSealer mocks base method -func (m *MockSealingManager) GetHandshakeSealer() (handshake.Sealer, error) { +func (m *MockSealingManager) GetHandshakeSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetHandshakeSealer") - ret0, _ := ret[0].(handshake.Sealer) + ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } @@ -65,10 +65,10 @@ func (mr *MockSealingManagerMockRecorder) GetHandshakeSealer() *gomock.Call { } // GetInitialSealer mocks base method -func (m *MockSealingManager) GetInitialSealer() (handshake.Sealer, error) { +func (m *MockSealingManager) GetInitialSealer() (handshake.LongHeaderSealer, error) { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetInitialSealer") - ret0, _ := ret[0].(handshake.Sealer) + ret0, _ := ret[0].(handshake.LongHeaderSealer) ret1, _ := ret[1].(error) return ret0, ret1 } diff --git a/packet_packer.go b/packet_packer.go index 1c2f74616..846057a2c 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -25,6 +25,10 @@ type packer interface { ChangeDestConnectionID(protocol.ConnectionID) } +type sealer interface { + handshake.LongHeaderSealer +} + type payload struct { frames []wire.Frame ack *wire.AckFrame @@ -93,9 +97,9 @@ type packetNumberManager interface { } type sealingManager interface { - GetInitialSealer() (handshake.Sealer, error) - GetHandshakeSealer() (handshake.Sealer, error) - Get1RTTSealer() (handshake.Sealer, error) + GetInitialSealer() (handshake.LongHeaderSealer, error) + GetHandshakeSealer() (handshake.LongHeaderSealer, error) + Get1RTTSealer() (handshake.ShortHeaderSealer, error) } type frameSource interface { @@ -165,10 +169,11 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac length: ccf.Length(p.version), } // send the CONNECTION_CLOSE frame with the highest available encryption level - var sealer handshake.Sealer var err error + var hdr *wire.ExtendedHeader + var sealer sealer encLevel := protocol.Encryption1RTT - sealer, err = p.cryptoSetup.Get1RTTSealer() + s, err := p.cryptoSetup.Get1RTTSealer() if err != nil { encLevel = protocol.EncryptionHandshake sealer, err = p.cryptoSetup.GetHandshakeSealer() @@ -178,11 +183,16 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac if err != nil { return nil, err } + hdr = p.getLongHeader(protocol.EncryptionInitial) + } else { + hdr = p.getLongHeader(protocol.EncryptionHandshake) } + } else { + sealer = s + hdr = p.getShortHeader(s.KeyPhase()) } - header := p.getHeader(encLevel) - return p.writeAndSealPacket(header, payload, encLevel, sealer) + return p.writeAndSealPacket(hdr, payload, encLevel, sealer) } func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { @@ -199,7 +209,7 @@ func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) { if err != nil { return nil, err } - header := p.getHeader(protocol.Encryption1RTT) + header := p.getShortHeader(sealer.KeyPhase()) return p.writeAndSealPacket(header, payload, protocol.Encryption1RTT, sealer) } @@ -222,28 +232,42 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP } var packets []*packedPacket - var err error - var sealer handshake.Sealer - switch packet.EncryptionLevel { - case protocol.EncryptionInitial: - sealer, err = p.cryptoSetup.GetInitialSealer() - case protocol.EncryptionHandshake: - sealer, err = p.cryptoSetup.GetHandshakeSealer() - case protocol.Encryption1RTT: - sealer, err = p.cryptoSetup.Get1RTTSealer() - default: - return nil, fmt.Errorf("unexpected encryption level: %s", packet.EncryptionLevel) - } - if err != nil { - return nil, err - } for len(controlFrames) > 0 || len(streamFrames) > 0 { var frames []wire.Frame var length protocol.ByteCount - header := p.getHeader(packet.EncryptionLevel) - headerLen := header.GetLength(p.version) - maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen + var sealer sealer + var hdr *wire.ExtendedHeader + switch packet.EncryptionLevel { + case protocol.EncryptionInitial: + var err error + sealer, err = p.cryptoSetup.GetInitialSealer() + if err != nil { + return nil, err + } + hdr = p.getLongHeader(protocol.EncryptionInitial) + case protocol.EncryptionHandshake: + var err error + sealer, err = p.cryptoSetup.GetHandshakeSealer() + if err != nil { + return nil, err + } + hdr = p.getLongHeader(protocol.EncryptionHandshake) + case protocol.Encryption1RTT: + var s handshake.ShortHeaderSealer + var err error + s, err = p.cryptoSetup.Get1RTTSealer() + if err != nil { + return nil, err + } + sealer = s + hdr = p.getShortHeader(s.KeyPhase()) + default: + return nil, fmt.Errorf("unexpected encryption level: %s", packet.EncryptionLevel) + } + + hdrLen := hdr.GetLength(p.version) + maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - hdrLen for len(controlFrames) > 0 { frame := controlFrames[0] @@ -277,7 +301,7 @@ func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedP if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok { sf.DataLenPresent = false } - p, err := p.writeAndSealPacket(header, payload{frames: frames, length: length}, packet.EncryptionLevel, sealer) + p, err := p.writeAndSealPacket(hdr, payload{frames: frames, length: length}, packet.EncryptionLevel, sealer) if err != nil { return nil, err } @@ -302,7 +326,7 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) { // sealer not yet available return nil, nil } - header := p.getHeader(protocol.Encryption1RTT) + header := p.getShortHeader(sealer.KeyPhase()) headerLen := header.GetLength(p.version) if err != nil { return nil, err @@ -340,7 +364,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { hasData := p.initialStream.HasData() ack := p.acks.GetAckFrame(protocol.EncryptionInitial) - var sealer handshake.Sealer + var sealer handshake.LongHeaderSealer var err error if hasData || ack != nil { s = p.initialStream @@ -367,7 +391,7 @@ func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) { payload.ack = ack payload.length = ack.Length(p.version) } - hdr := p.getHeader(encLevel) + hdr := p.getLongHeader(encLevel) hdrLen := hdr.GetLength(p.version) if hasData { cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - payload.length) @@ -396,40 +420,49 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) (paylo return payload, nil } -func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { - pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) - header := &wire.ExtendedHeader{} - header.PacketNumber = pn - header.PacketNumberLen = pnLen - header.Version = p.version - header.DestConnectionID = p.destConnID +func (p *packetPacker) getShortHeader(kp protocol.KeyPhase) *wire.ExtendedHeader { + pn, pnLen := p.pnManager.PeekPacketNumber(protocol.Encryption1RTT) + hdr := &wire.ExtendedHeader{} + hdr.PacketNumber = pn + hdr.PacketNumberLen = pnLen + hdr.DestConnectionID = p.destConnID + hdr.KeyPhase = kp + return hdr +} - if encLevel != protocol.Encryption1RTT { - header.IsLongHeader = true - // Always send Initial and Handshake packets with the maximum packet number length. - // This simplifies retransmissions: Since the header can't get any larger, - // we don't need to split CRYPTO frames. - header.PacketNumberLen = protocol.PacketNumberLen4 - header.SrcConnectionID = p.srcConnID - // Set the length to the maximum packet size. - // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns. - header.Length = p.maxPacketSize - switch encLevel { - case protocol.EncryptionInitial: - header.Type = protocol.PacketTypeInitial - case protocol.EncryptionHandshake: - header.Type = protocol.PacketTypeHandshake - } +func (p *packetPacker) getLongHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader { + pn, pnLen := p.pnManager.PeekPacketNumber(encLevel) + hdr := &wire.ExtendedHeader{} + hdr.PacketNumber = pn + hdr.PacketNumberLen = pnLen + hdr.DestConnectionID = p.destConnID + + switch encLevel { + case protocol.EncryptionInitial: + hdr.Type = protocol.PacketTypeInitial + case protocol.EncryptionHandshake: + hdr.Type = protocol.PacketTypeHandshake } - return header + hdr.Version = p.version + hdr.IsLongHeader = true + // Always send Initial and Handshake packets with the maximum packet number length. + // This simplifies retransmissions: Since the header can't get any larger, + // we don't need to split CRYPTO frames. + hdr.PacketNumberLen = protocol.PacketNumberLen4 + hdr.SrcConnectionID = p.srcConnID + // Set the length to the maximum packet size. + // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns. + hdr.Length = p.maxPacketSize + + return hdr } func (p *packetPacker) writeAndSealPacket( header *wire.ExtendedHeader, payload payload, encLevel protocol.EncryptionLevel, - sealer handshake.Sealer, + sealer sealer, ) (*packedPacket, error) { var paddingLen protocol.ByteCount pnLen := protocol.ByteCount(header.PacketNumberLen) @@ -454,7 +487,7 @@ func (p *packetPacker) writeAndSealPacketWithPadding( payload payload, paddingLen protocol.ByteCount, encLevel protocol.EncryptionLevel, - sealer handshake.Sealer, + sealer sealer, ) (*packedPacket, error) { packetBuffer := getPacketBuffer() buffer := bytes.NewBuffer(packetBuffer.Slice[:0]) diff --git a/packet_packer_test.go b/packet_packer_test.go index cffc2d6b5..3904201f0 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -106,7 +106,7 @@ var _ = Describe("Packet packer", func() { Context("generating a packet header", func() { It("uses the Long Header format", func() { pnManager.EXPECT().PeekPacketNumber(protocol.EncryptionHandshake).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2) - h := packer.getHeader(protocol.EncryptionHandshake) + h := packer.getLongHeader(protocol.EncryptionHandshake) Expect(h.IsLongHeader).To(BeTrue()) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x42))) // long headers always use 4 byte packet numbers, no matter what the packet number generator says @@ -120,7 +120,7 @@ var _ = Describe("Packet packer", func() { destConnID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} packer.srcConnID = srcConnID packer.destConnID = destConnID - h := packer.getHeader(protocol.EncryptionHandshake) + h := packer.getLongHeader(protocol.EncryptionHandshake) Expect(h.SrcConnectionID).To(Equal(srcConnID)) Expect(h.DestConnectionID).To(Equal(destConnID)) }) @@ -132,21 +132,22 @@ var _ = Describe("Packet packer", func() { dest1 := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8} dest2 := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1} packer.ChangeDestConnectionID(dest1) - h := packer.getHeader(protocol.EncryptionInitial) + h := packer.getLongHeader(protocol.EncryptionInitial) Expect(h.SrcConnectionID).To(Equal(srcConnID)) Expect(h.DestConnectionID).To(Equal(dest1)) packer.ChangeDestConnectionID(dest2) - h = packer.getHeader(protocol.EncryptionInitial) + h = packer.getLongHeader(protocol.EncryptionInitial) Expect(h.SrcConnectionID).To(Equal(srcConnID)) Expect(h.DestConnectionID).To(Equal(dest2)) }) - It("uses the Short Header format for 1-RTT packets", func() { + It("gets a short header", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen4) - h := packer.getHeader(protocol.Encryption1RTT) + h := packer.getShortHeader(protocol.KeyPhaseOne) Expect(h.IsLongHeader).To(BeFalse()) Expect(h.PacketNumber).To(Equal(protocol.PacketNumber(0x1337))) Expect(h.PacketNumberLen).To(Equal(protocol.PacketNumberLen4)) + Expect(h.KeyPhase).To(Equal(protocol.KeyPhaseOne)) }) }) @@ -156,10 +157,11 @@ var _ = Describe("Packet packer", func() { handshakeStream.EXPECT().HasData() pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337), protocol.PacketNumberLen2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x1337)) - sealer := mocks.NewMockSealer(mockCtrl) + sealer := mocks.NewMockShortHeaderSealer(mockCtrl) sealer.EXPECT().Overhead().Return(4).AnyTimes() var hdrRaw []byte gomock.InOrder( + sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne), sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), protocol.PacketNumber(0x1337), gomock.Any()).DoAndReturn(func(_, src []byte, _ protocol.PacketNumber, aad []byte) []byte { hdrRaw = append([]byte{}, aad...) return append(src, []byte{0xde, 0xca, 0xfb, 0xad}...) @@ -195,10 +197,11 @@ var _ = Describe("Packet packer", func() { }) Context("packing packets", func() { - var sealer *mocks.MockSealer + var sealer *mocks.MockShortHeaderSealer BeforeEach(func() { - sealer = mocks.NewMockSealer(mockCtrl) + sealer = mocks.NewMockShortHeaderSealer(mockCtrl) + sealer.EXPECT().KeyPhase().Return(protocol.KeyPhaseOne).AnyTimes() sealer.EXPECT().Overhead().Return(7).AnyTimes() sealer.EXPECT().EncryptHeader(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() sealer.EXPECT().Seal(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(func(dst, src []byte, pn protocol.PacketNumber, associatedData []byte) []byte { @@ -505,7 +508,7 @@ var _ = Describe("Packet packer", func() { It("packs two packets for retransmission if the original packet contained many control frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) var frames []wire.Frame var totalLen protocol.ByteCount // pack a bunch of control frames, such that the packet is way bigger than a single packet @@ -533,7 +536,7 @@ var _ = Describe("Packet packer", func() { It("splits a STREAM frame that doesn't fit", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) packets, err := packer.PackRetransmission(&ackhandler.Packet{ EncryptionLevel: protocol.Encryption1RTT, Frames: []wire.Frame{&wire.StreamFrame{ @@ -598,7 +601,7 @@ var _ = Describe("Packet packer", func() { It("packs two packets for retransmission if the original packet contained many STREAM frames", func() { pnManager.EXPECT().PeekPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42), protocol.PacketNumberLen2).Times(2) pnManager.EXPECT().PopPacketNumber(protocol.Encryption1RTT).Return(protocol.PacketNumber(0x42)).Times(2) - sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil) + sealingManager.EXPECT().Get1RTTSealer().Return(sealer, nil).Times(2) var frames []wire.Frame var totalLen protocol.ByteCount // pack a bunch of control frames, such that the packet is way bigger than a single packet