From 3f62ea8673d89b0a21ab9f2a8c5142ab494f14c0 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Fri, 27 Oct 2017 08:39:06 +0700 Subject: [PATCH] set the Long Header packet type based on the state of the handshake --- internal/handshake/crypto_setup_client.go | 4 ++ internal/handshake/crypto_setup_server.go | 4 ++ internal/handshake/crypto_setup_tls.go | 56 ++++++++++++++++---- internal/handshake/crypto_setup_tls_test.go | 57 +++++++++++++++++++++ internal/handshake/interface.go | 5 +- internal/handshake/mint_utils.go | 25 +++++++-- internal/handshake/mint_utils_test.go | 23 ++++++++- internal/mocks/handshake/mint_tls.go | 12 +++++ internal/protocol/protocol.go | 18 +++++++ internal/wire/header.go | 2 +- internal/wire/ietf_header.go | 2 +- internal/wire/ietf_header_test.go | 3 +- packet_packer.go | 7 ++- packet_packer_test.go | 9 ++++ 14 files changed, 205 insertions(+), 22 deletions(-) diff --git a/internal/handshake/crypto_setup_client.go b/internal/handshake/crypto_setup_client.go index 6cf114f6..01330933 100644 --- a/internal/handshake/crypto_setup_client.go +++ b/internal/handshake/crypto_setup_client.go @@ -381,6 +381,10 @@ func (h *cryptoSetupClient) SetDiversificationNonce(data []byte) { h.divNonceChan <- data } +func (h *cryptoSetupClient) GetNextPacketType() protocol.PacketType { + panic("not needed for cryptoSetupServer") +} + func (h *cryptoSetupClient) sendCHLO() error { h.clientHelloCounter++ if h.clientHelloCounter > protocol.MaxClientHellos { diff --git a/internal/handshake/crypto_setup_server.go b/internal/handshake/crypto_setup_server.go index e8085a22..50e26183 100644 --- a/internal/handshake/crypto_setup_server.go +++ b/internal/handshake/crypto_setup_server.go @@ -458,6 +458,10 @@ func (h *cryptoSetupServer) SetDiversificationNonce(data []byte) { panic("not needed for cryptoSetupServer") } +func (h *cryptoSetupServer) GetNextPacketType() protocol.PacketType { + panic("not needed for cryptoSetupServer") +} + func (h *cryptoSetupServer) validateClientNonce(nonce []byte) error { if len(nonce) != 32 { return qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length") diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index d6b15ada..294f466d 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -19,13 +19,14 @@ type cryptoSetupTLS struct { perspective protocol.Perspective - keyDerivation KeyDerivationFunction - tls mintTLS conn *fakeConn - nullAEAD crypto.AEAD - aead crypto.AEAD + nextPacketType protocol.PacketType + + keyDerivation KeyDerivationFunction + nullAEAD crypto.AEAD + aead crypto.AEAD aeadChanged chan<- protocol.EncryptionLevel } @@ -98,12 +99,13 @@ func NewCryptoSetupTLSClient( } return &cryptoSetupTLS{ - conn: conn, - perspective: protocol.PerspectiveClient, - tls: &mintController{mintConn}, - nullAEAD: nullAEAD, - keyDerivation: crypto.DeriveAESKeys, - aeadChanged: aeadChanged, + conn: conn, + perspective: protocol.PerspectiveClient, + tls: &mintController{mintConn}, + nullAEAD: nullAEAD, + keyDerivation: crypto.DeriveAESKeys, + aeadChanged: aeadChanged, + nextPacketType: protocol.PacketTypeClientInitial, }, nil } @@ -114,7 +116,10 @@ handshakeLoop: case mint.AlertNoAlert: // handshake complete break handshakeLoop case mint.AlertWouldBlock: - h.conn.UnblockRead() + h.determineNextPacketType() + if err := h.conn.Continue(); err != nil { + return err + } default: return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) } @@ -184,6 +189,35 @@ func (h *cryptoSetupTLS) GetSealerForCryptoStream() (protocol.EncryptionLevel, S return protocol.EncryptionUnencrypted, h.nullAEAD } +func (h *cryptoSetupTLS) determineNextPacketType() error { + h.mutex.Lock() + defer h.mutex.Unlock() + state := h.tls.State().HandshakeState + if h.perspective == protocol.PerspectiveServer { + switch state { + case "ServerStateStart": // if we're still at ServerStateStart when writing the first packet, that means we've come back to that state by sending a HelloRetryRequest + h.nextPacketType = protocol.PacketTypeServerStatelessRetry + case "ServerStateWaitFinished": + h.nextPacketType = protocol.PacketTypeServerCleartext + default: + // TODO: accept 0-RTT data + return fmt.Errorf("Unexpected handshake state: %s", state) + } + return nil + } + // client + if state != "ClientStateWaitSH" { + h.nextPacketType = protocol.PacketTypeClientCleartext + } + return nil +} + +func (h *cryptoSetupTLS) GetNextPacketType() protocol.PacketType { + h.mutex.RLock() + defer h.mutex.RUnlock() + return h.nextPacketType +} + func (h *cryptoSetupTLS) DiversificationNonce() []byte { panic("diversification nonce not needed for TLS") } diff --git a/internal/handshake/crypto_setup_tls_test.go b/internal/handshake/crypto_setup_tls_test.go index 92f4c8c0..42933a0b 100644 --- a/internal/handshake/crypto_setup_tls_test.go +++ b/internal/handshake/crypto_setup_tls_test.go @@ -1,6 +1,7 @@ package handshake import ( + "bytes" "errors" "fmt" @@ -53,8 +54,10 @@ var _ = Describe("TLS Crypto Setup", func() { }) It("continues shaking hands when mint says that it would block", func() { + cs.conn.stream = &bytes.Buffer{} cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertWouldBlock) + cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{}) cs.tls.(*mockhandshake.MockmintTLS).EXPECT().Handshake().Return(mint.AlertNoAlert) cs.keyDerivation = mockKeyDerivation err := cs.HandleCryptoStream() @@ -71,6 +74,60 @@ var _ = Describe("TLS Crypto Setup", func() { Expect(aeadChanged).To(BeClosed()) }) + Context("determining the packet type", func() { + Context("for the client", func() { + var csClient *cryptoSetupTLS + + BeforeEach(func() { + csInt, err := NewCryptoSetupTLSClient( + nil, + 1, + "quic.clemente.io", + testdata.GetTLSConfig(), + &TransportParameters{}, + paramsChan, + aeadChanged, + protocol.VersionTLS, + []protocol.VersionNumber{protocol.VersionTLS}, + protocol.VersionTLS, + ) + Expect(err).ToNot(HaveOccurred()) + csClient = csInt.(*cryptoSetupTLS) + csClient.tls = mockhandshake.NewMockmintTLS(mockCtrl) + }) + + It("sends a Client Initial first", func() { + Expect(csClient.GetNextPacketType()).To(Equal(protocol.PacketTypeClientInitial)) + }) + + It("sends a Client Cleartext after the server sent a Server Hello", func() { + csClient.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ClientStateWaitEE"}) + err := csClient.determineNextPacketType() + Expect(err).ToNot(HaveOccurred()) + }) + }) + + Context("for the server", func() { + BeforeEach(func() { + cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) + }) + + It("sends a Stateless Retry packet", func() { + cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateStart"}) + err := cs.determineNextPacketType() + Expect(err).ToNot(HaveOccurred()) + Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeServerStatelessRetry)) + }) + + It("sends a Server Cleartext packet", func() { + cs.tls.(*mockhandshake.MockmintTLS).EXPECT().State().Return(mint.ConnectionState{HandshakeState: "ServerStateWaitFinished"}) + err := cs.determineNextPacketType() + Expect(err).ToNot(HaveOccurred()) + Expect(cs.GetNextPacketType()).To(Equal(protocol.PacketTypeServerCleartext)) + }) + }) + }) + Context("escalating crypto", func() { doHandshake := func() { cs.tls = mockhandshake.NewMockmintTLS(mockCtrl) diff --git a/internal/handshake/interface.go b/internal/handshake/interface.go index 90d3e583..c34c8f1f 100644 --- a/internal/handshake/interface.go +++ b/internal/handshake/interface.go @@ -15,8 +15,9 @@ type CryptoSetup interface { Open(dst, src []byte, packetNumber protocol.PacketNumber, associatedData []byte) ([]byte, protocol.EncryptionLevel, error) HandleCryptoStream() error // TODO: clean up this interface - DiversificationNonce() []byte // only needed for cryptoSetupServer - SetDiversificationNonce([]byte) // only needed for cryptoSetupClient + DiversificationNonce() []byte // only needed for cryptoSetupServer + SetDiversificationNonce([]byte) // only needed for cryptoSetupClient + GetNextPacketType() protocol.PacketType // only needed for cryptoSetupServer GetSealer() (protocol.EncryptionLevel, Sealer) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error) diff --git a/internal/handshake/mint_utils.go b/internal/handshake/mint_utils.go index 3792cfc4..a8bd2953 100644 --- a/internal/handshake/mint_utils.go +++ b/internal/handshake/mint_utils.go @@ -1,6 +1,7 @@ package handshake import ( + "bytes" gocrypto "crypto" "crypto/tls" "crypto/x509" @@ -50,6 +51,7 @@ type mintTLS interface { ComputeExporter(label string, context []byte, keyLength int) ([]byte, error) // additional methods Handshake() mint.Alert + State() mint.ConnectionState } var _ crypto.TLSExporter = (mintTLS)(nil) @@ -72,13 +74,18 @@ func (mc *mintController) Handshake() mint.Alert { return mc.conn.Handshake() } +func (mc *mintController) State() mint.ConnectionState { + return mc.conn.State() +} + // mint expects a net.Conn, but we're doing the handshake on a stream // so we wrap a stream such that implements a net.Conn type fakeConn struct { stream io.ReadWriter pers protocol.Perspective - blockRead bool + blockRead bool + writeBuffer bytes.Buffer } var _ net.Conn = &fakeConn{} @@ -92,11 +99,23 @@ func (c *fakeConn) Read(b []byte) (int, error) { } func (c *fakeConn) Write(p []byte) (int, error) { - return c.stream.Write(p) + if c.pers == protocol.PerspectiveClient { + return c.stream.Write(p) + } + // Buffer all writes by the server. + // Mint transitions to the next state *after* writing, so we need to let all the writes happen, only then we can determine the packet type to use to send out this data. + return c.writeBuffer.Write(p) } -func (c *fakeConn) UnblockRead() { +func (c *fakeConn) Continue() error { c.blockRead = false + if c.pers == protocol.PerspectiveClient { + return nil + } + // write all contents of the write buffer to the stream. + _, err := c.stream.Write(c.writeBuffer.Bytes()) + c.writeBuffer.Reset() + return err } func (c *fakeConn) Close() error { return nil } diff --git a/internal/handshake/mint_utils_test.go b/internal/handshake/mint_utils_test.go index 92525b83..d1ba7a95 100644 --- a/internal/handshake/mint_utils_test.go +++ b/internal/handshake/mint_utils_test.go @@ -3,6 +3,7 @@ package handshake import ( "bytes" + "github.com/lucas-clemente/quic-go/internal/protocol" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -35,10 +36,30 @@ var _ = Describe("Fake Conn", func() { b := make([]byte, 3) _, err := c.Read(b) Expect(err).ToNot(HaveOccurred()) - c.UnblockRead() + err = c.Continue() + Expect(err).ToNot(HaveOccurred()) _, err = c.Read(b) Expect(err).ToNot(HaveOccurred()) Expect(b).To(Equal([]byte("bar"))) }) }) + + Context("Writing", func() { + It("writes directly when acting as a client", func() { + c.pers = protocol.PerspectiveClient + _, err := c.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(stream.Bytes()).To(Equal([]byte("foobar"))) + }) + + It("only writes after flushing when acting as a server", func() { + c.pers = protocol.PerspectiveServer + _, err := c.Write([]byte("foobar")) + Expect(err).ToNot(HaveOccurred()) + Expect(stream.Bytes()).To(BeEmpty()) + err = c.Continue() + Expect(err).ToNot(HaveOccurred()) + Expect(stream.Bytes()).To(Equal([]byte("foobar"))) + }) + }) }) diff --git a/internal/mocks/handshake/mint_tls.go b/internal/mocks/handshake/mint_tls.go index 08697926..b03c3d4f 100644 --- a/internal/mocks/handshake/mint_tls.go +++ b/internal/mocks/handshake/mint_tls.go @@ -69,3 +69,15 @@ func (_m *MockmintTLS) Handshake() mint.Alert { func (_mr *MockmintTLSMockRecorder) Handshake() *gomock.Call { return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "Handshake", reflect.TypeOf((*MockmintTLS)(nil).Handshake)) } + +// State mocks base method +func (_m *MockmintTLS) State() mint.ConnectionState { + ret := _m.ctrl.Call(_m, "State") + ret0, _ := ret[0].(mint.ConnectionState) + return ret0 +} + +// State indicates an expected call of State +func (_mr *MockmintTLSMockRecorder) State() *gomock.Call { + return _mr.mock.ctrl.RecordCallWithMethodType(_mr.mock, "State", reflect.TypeOf((*MockmintTLS)(nil).State)) +} diff --git a/internal/protocol/protocol.go b/internal/protocol/protocol.go index 4459d24a..8acceff1 100644 --- a/internal/protocol/protocol.go +++ b/internal/protocol/protocol.go @@ -21,6 +21,24 @@ const ( PacketNumberLen6 PacketNumberLen = 6 ) +// The PacketType is the Long Header Type (only used for the IETF draft header format) +type PacketType uint8 + +const ( + // PacketTypeVersionNegotiation is the packet type of a Version Negotiation packet + PacketTypeVersionNegotiation PacketType = 1 + // PacketTypeClientInitial is the packet type of a Client Initial packet + PacketTypeClientInitial PacketType = 2 + // PacketTypeServerStatelessRetry is the packet type of a Server Stateless Retry packet + PacketTypeServerStatelessRetry PacketType = 3 + // PacketTypeServerCleartext is the packet type of a Server Cleartext packet + PacketTypeServerCleartext PacketType = 4 + // PacketTypeClientCleartext is the packet type of a Client Cleartext packet + PacketTypeClientCleartext PacketType = 5 + // PacketType0RTT is the packet type of a 0-RTT packet + PacketType0RTT PacketType = 6 +) + // A ConnectionID in QUIC type ConnectionID uint64 diff --git a/internal/wire/header.go b/internal/wire/header.go index c77b4cb6..05c9ba4b 100644 --- a/internal/wire/header.go +++ b/internal/wire/header.go @@ -24,7 +24,7 @@ type Header struct { DiversificationNonce []byte // only needed for the IETF Header - Type uint8 + Type protocol.PacketType IsLongHeader bool KeyPhase int diff --git a/internal/wire/ietf_header.go b/internal/wire/ietf_header.go index 239b58d4..6b419aa5 100644 --- a/internal/wire/ietf_header.go +++ b/internal/wire/ietf_header.go @@ -35,7 +35,7 @@ func parseLongHeader(b *bytes.Reader, packetSentBy protocol.Perspective, typeByt return nil, err } h := &Header{ - Type: typeByte & 0x7f, + Type: protocol.PacketType(typeByte & 0x7f), IsLongHeader: true, ConnectionID: protocol.ConnectionID(connID), PacketNumber: protocol.PacketNumber(pn), diff --git a/internal/wire/ietf_header_test.go b/internal/wire/ietf_header_test.go index 33a09557..b15b5287 100644 --- a/internal/wire/ietf_header_test.go +++ b/internal/wire/ietf_header_test.go @@ -32,7 +32,7 @@ var _ = Describe("IETF draft Header", func() { b := bytes.NewReader(data) h, err := parseHeader(b, protocol.PerspectiveClient) Expect(err).ToNot(HaveOccurred()) - Expect(h.Type).To(BeEquivalentTo(3)) + Expect(h.Type).To(Equal(protocol.PacketType(3))) Expect(h.IsLongHeader).To(BeTrue()) Expect(h.OmitConnectionID).To(BeFalse()) Expect(h.ConnectionID).To(Equal(protocol.ConnectionID(0xdeadbeefcafe1337))) @@ -62,6 +62,7 @@ var _ = Describe("IETF draft Header", func() { b := bytes.NewReader(data) h, err := parseHeader(b, protocol.PerspectiveServer) Expect(err).ToNot(HaveOccurred()) + Expect(h.Type).To(Equal(protocol.PacketTypeVersionNegotiation)) Expect(h.SupportedVersions).To(Equal([]protocol.VersionNumber{ 0x22334455, 0x33445566, diff --git a/packet_packer.go b/packet_packer.go index 98abe16c..1a637158 100644 --- a/packet_packer.go +++ b/packet_packer.go @@ -291,8 +291,11 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header header.VersionFlag = true header.Version = p.version } - } else if encLevel != protocol.EncryptionForwardSecure { - header.Version = p.version + } else { + header.Type = p.cryptoSetup.GetNextPacketType() + if encLevel != protocol.EncryptionForwardSecure { + header.Version = p.version + } } return header } diff --git a/packet_packer_test.go b/packet_packer_test.go index 610073dd..960f92c6 100644 --- a/packet_packer_test.go +++ b/packet_packer_test.go @@ -28,6 +28,7 @@ type mockCryptoSetup struct { divNonce []byte encLevelSeal protocol.EncryptionLevel encLevelSealCrypto protocol.EncryptionLevel + nextPacketType protocol.PacketType } var _ handshake.CryptoSetup = &mockCryptoSetup{} @@ -49,6 +50,7 @@ func (m *mockCryptoSetup) GetSealerWithEncryptionLevel(protocol.EncryptionLevel) } func (m *mockCryptoSetup) DiversificationNonce() []byte { return m.divNonce } func (m *mockCryptoSetup) SetDiversificationNonce(divNonce []byte) { m.divNonce = divNonce } +func (m *mockCryptoSetup) GetNextPacketType() protocol.PacketType { return m.nextPacketType } var _ = Describe("Packet packer", func() { var ( @@ -189,6 +191,13 @@ var _ = Describe("Packet packer", func() { Expect(h.Version).To(Equal(versionIETFHeader)) }) + It("sets the packet type based on the state of the handshake", func() { + packer.cryptoSetup.(*mockCryptoSetup).nextPacketType = 5 + h := packer.getHeader(protocol.EncryptionSecure) + Expect(h.IsLongHeader).To(BeTrue()) + Expect(h.Type).To(Equal(protocol.PacketType(5))) + }) + It("uses the Short Header format for forward-secure packets", func() { h := packer.getHeader(protocol.EncryptionForwardSecure) Expect(h.IsLongHeader).To(BeFalse())