From acbd14f9406182b2b876159dc232d1fa21e5b79f Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Wed, 10 May 2017 16:26:57 +0800 Subject: [PATCH] implement a HandshakeMessage struct MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This makes passing handshake messages around easier, since it’s now one struct instead of one message tag and one data map. --- handshake/crypto_setup_client.go | 36 +++++------- handshake/crypto_setup_client_test.go | 16 ++--- handshake/crypto_setup_server.go | 26 ++++++--- handshake/crypto_setup_server_test.go | 81 +++++++++++++++----------- handshake/handshake_message.go | 40 ++++++++----- handshake/handshake_message_test.go | 38 ++++++++++-- handshake/server_config.go | 20 ++++--- handshake/server_config_client.go | 6 +- handshake/server_config_client_test.go | 12 ++-- public_reset.go | 8 +-- public_reset_test.go | 10 ++-- 11 files changed, 176 insertions(+), 117 deletions(-) diff --git a/handshake/crypto_setup_client.go b/handshake/crypto_setup_client.go index ec1a942e..9eb4039e 100644 --- a/handshake/crypto_setup_client.go +++ b/handshake/crypto_setup_client.go @@ -104,35 +104,27 @@ func (h *cryptoSetupClient) HandleCryptoStream() error { } } - messageTag, cryptoData, err := ParseHandshakeMessage(h.cryptoStream) + message, err := ParseHandshakeMessage(h.cryptoStream) if err != nil { return qerr.HandshakeFailed } - if messageTag != TagSHLO && messageTag != TagREJ { + utils.Debugf("Got %s", message) + switch message.Tag { + case TagREJ: + err = h.handleREJMessage(message.Data) + case TagSHLO: + err = h.handleSHLOMessage(message.Data) + default: return qerr.InvalidCryptoMessageType } - - if messageTag == TagSHLO { - utils.Debugf("Got SHLO:\n%s", printHandshakeMessage(cryptoData)) - err = h.handleSHLOMessage(cryptoData) - if err != nil { - return err - } - } - - if messageTag == TagREJ { - err = h.handleREJMessage(cryptoData) - if err != nil { - return err - } + if err != nil { + return err } } } func (h *cryptoSetupClient) handleREJMessage(cryptoData map[Tag][]byte) error { - utils.Debugf("Got REJ:\n%s", printHandshakeMessage(cryptoData)) - var err error if stk, ok := cryptoData[TagSTK]; ok { @@ -382,9 +374,13 @@ func (h *cryptoSetupClient) sendCHLO() error { return err } h.addPadding(tags) + message := HandshakeMessage{ + Tag: TagCHLO, + Data: tags, + } - utils.Debugf("Sending CHLO:\n%s", printHandshakeMessage(tags)) - WriteHandshakeMessage(b, TagCHLO, tags) + utils.Debugf("Sending %s", message) + message.Write(b) _, err = h.cryptoStream.Write(b.Bytes()) if err != nil { diff --git a/handshake/crypto_setup_client_test.go b/handshake/crypto_setup_client_test.go index 33e0de86..4005418f 100644 --- a/handshake/crypto_setup_client_test.go +++ b/handshake/crypto_setup_client_test.go @@ -121,14 +121,14 @@ var _ = Describe("Client Crypto Setup", func() { }) It("rejects handshake messages with the wrong message tag", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, tagMap) + HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) }) It("errors on invalid handshake messages", func() { b := &bytes.Buffer{} - WriteHandshakeMessage(b, TagCHLO, tagMap) + HandshakeMessage{Tag: TagCHLO, Data: tagMap}.Write(b) stream.dataToRead.Write(b.Bytes()[:b.Len()-2]) // cut the handshake message err := cs.HandleCryptoStream() // note that if this was a complete handshake message, HandleCryptoStream would fail with a qerr.InvalidCryptoMessageType @@ -138,7 +138,7 @@ var _ = Describe("Client Crypto Setup", func() { It("passes the message on for parsing, and reads the source address token", func() { stk := []byte("foobar") tagMap[TagSTK] = stk - WriteHandshakeMessage(&stream.dataToRead, TagREJ, tagMap) + HandshakeMessage{Tag: TagREJ, Data: tagMap}.Write(&stream.dataToRead) // this will throw a qerr.HandshakeFailed due to an EOF in WriteHandshakeMessage // this is because the mockStream doesn't block if there's no data to read err := cs.HandleCryptoStream() @@ -301,7 +301,7 @@ var _ = Describe("Client Crypto Setup", func() { It("reads a server config", func() { b := &bytes.Buffer{} scfg := getDefaultServerConfigClient() - WriteHandshakeMessage(b, TagSCFG, scfg) + HandshakeMessage{Tag: TagSCFG, Data: scfg}.Write(b) tagMap[TagSCFG] = b.Bytes() err := cs.handleREJMessage(tagMap) Expect(err).ToNot(HaveOccurred()) @@ -313,7 +313,7 @@ var _ = Describe("Client Crypto Setup", func() { b := &bytes.Buffer{} scfg := getDefaultServerConfigClient() scfg[TagEXPY] = []byte{0x80, 0x54, 0x72, 0x4F, 0, 0, 0, 0} // 2012-03-28 - WriteHandshakeMessage(b, TagSCFG, scfg) + HandshakeMessage{Tag: TagSCFG, Data: scfg}.Write(b) tagMap[TagSCFG] = b.Bytes() // make sure we actually set TagEXPY correct serverConfig, err := parseServerConfig(b.Bytes()) @@ -326,7 +326,7 @@ var _ = Describe("Client Crypto Setup", func() { It("generates a client nonce after reading a server config", func() { b := &bytes.Buffer{} - WriteHandshakeMessage(b, TagSCFG, getDefaultServerConfigClient()) + HandshakeMessage{Tag: TagSCFG, Data: getDefaultServerConfigClient()}.Write(b) tagMap[TagSCFG] = b.Bytes() err := cs.handleREJMessage(tagMap) Expect(err).ToNot(HaveOccurred()) @@ -335,7 +335,7 @@ var _ = Describe("Client Crypto Setup", func() { It("only generates a client nonce once, when reading multiple server configs", func() { b := &bytes.Buffer{} - WriteHandshakeMessage(b, TagSCFG, getDefaultServerConfigClient()) + HandshakeMessage{Tag: TagSCFG, Data: getDefaultServerConfigClient()}.Write(b) tagMap[TagSCFG] = b.Bytes() err := cs.handleREJMessage(tagMap) Expect(err).ToNot(HaveOccurred()) @@ -348,7 +348,7 @@ var _ = Describe("Client Crypto Setup", func() { It("passes on errors from reading the server config", func() { b := &bytes.Buffer{} - WriteHandshakeMessage(b, TagSHLO, make(map[Tag][]byte)) + HandshakeMessage{Tag: TagSHLO, Data: make(map[Tag][]byte)}.Write(b) tagMap[TagSCFG] = b.Bytes() _, origErr := parseServerConfig(b.Bytes()) err := cs.handleREJMessage(tagMap) diff --git a/handshake/crypto_setup_server.go b/handshake/crypto_setup_server.go index 7f5ef075..19774bfb 100644 --- a/handshake/crypto_setup_server.go +++ b/handshake/crypto_setup_server.go @@ -85,17 +85,16 @@ func NewCryptoSetup( func (h *cryptoSetupServer) HandleCryptoStream() error { for { var chloData bytes.Buffer - messageTag, cryptoData, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData)) + message, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData)) if err != nil { return qerr.HandshakeFailed } - if messageTag != TagCHLO { + if message.Tag != TagCHLO { return qerr.InvalidCryptoMessageType } - utils.Debugf("Got CHLO:\n%s", printHandshakeMessage(cryptoData)) - - done, err := h.handleMessage(chloData.Bytes(), cryptoData) + utils.Debugf("Got %s", message) + done, err := h.handleMessage(chloData.Bytes(), message.Data) if err != nil { return err } @@ -305,9 +304,14 @@ func (h *cryptoSetupServer) handleInchoateCHLO(sni string, chlo []byte, cryptoDa replyMap[TagCERT] = certCompressed } + message := HandshakeMessage{ + Tag: TagREJ, + Data: replyMap, + } + var serverReply bytes.Buffer - WriteHandshakeMessage(&serverReply, TagREJ, replyMap) - utils.Debugf("Sending REJ:\n%s", printHandshakeMessage(replyMap)) + message.Write(&serverReply) + utils.Debugf("Sending %s", message) return serverReply.Bytes(), nil } @@ -413,9 +417,13 @@ func (h *cryptoSetupServer) handleCHLO(sni string, data []byte, cryptoData map[T replyMap[TagVER] = verTag.Bytes() // note that the SHLO *has* to fit into one packet + message := HandshakeMessage{ + Tag: TagSHLO, + Data: replyMap, + } var reply bytes.Buffer - WriteHandshakeMessage(&reply, TagSHLO, replyMap) - utils.Debugf("Sending SHLO:\n%s", printHandshakeMessage(replyMap)) + message.Write(&reply) + utils.Debugf("Sending %s", message) h.aeadChanged <- protocol.EncryptionForwardSecure diff --git a/handshake/crypto_setup_server_test.go b/handshake/crypto_setup_server_test.go index 98d33226..d1b1513b 100644 --- a/handshake/crypto_setup_server_test.go +++ b/handshake/crypto_setup_server_test.go @@ -231,9 +231,12 @@ var _ = Describe("Server Crypto Setup", func() { }) It("doesn't support Chrome's head-of-line blocking experiment", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{ - TagFHL2: []byte("foobar"), - }) + HandshakeMessage{ + Tag: TagCHLO, + Data: map[Tag][]byte{ + TagFHL2: []byte("foobar"), + }, + }.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(ErrHOLExperiment)) }) @@ -292,13 +295,16 @@ var _ = Describe("Server Crypto Setup", func() { }) It("handles long handshake", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{ - TagSNI: []byte("quic.clemente.io"), - TagSTK: validSTK, - TagPAD: bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), - TagVER: versionTag, - }) - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{ + Tag: TagCHLO, + Data: map[Tag][]byte{ + TagSNI: []byte("quic.clemente.io"), + TagSTK: validSTK, + TagPAD: bytes.Repeat([]byte{'a'}, protocol.ClientHelloMinimumSize), + TagVER: versionTag, + }, + }.Write(&stream.dataToRead) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("REJ")) @@ -311,14 +317,14 @@ var _ = Describe("Server Crypto Setup", func() { It("rejects client nonces that have the wrong length", func() { fullCHLO[TagNONC] = []byte("too short client nonce") - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "invalid client nonce length"))) }) It("rejects client nonces that have the wrong OBIT value", func() { fullCHLO[TagNONC] = make([]byte, 32) // the OBIT value is nonce[4:12] and here just initialized to 0 - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "OBIT not matching"))) }) @@ -326,13 +332,13 @@ var _ = Describe("Server Crypto Setup", func() { It("errors if it can't calculate a shared key", func() { testErr := errors.New("test error") kex.sharedKeyError = testErr - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(testErr)) }) It("handles 0-RTT handshake", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).NotTo(HaveOccurred()) Expect(stream.dataWritten.Bytes()).To(HavePrefix("SHLO")) @@ -382,17 +388,20 @@ var _ = Describe("Server Crypto Setup", func() { }) It("rejects CHLOs without the version tag", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{ - TagSCID: scfg.ID, - TagSNI: []byte("quic.clemente.io"), - }) + HandshakeMessage{ + Tag: TagCHLO, + Data: map[Tag][]byte{ + TagSCID: scfg.ID, + TagSNI: []byte("quic.clemente.io"), + }, + }.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "client hello missing version tag"))) }) It("rejects CHLOs with a version tag that has the wrong length", func() { fullCHLO[TagVER] = []byte{0x13, 0x37} // should be 4 bytes - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.InvalidCryptoMessageParameter, "incorrect version tag"))) }) @@ -405,7 +414,7 @@ var _ = Describe("Server Crypto Setup", func() { b := make([]byte, 4) binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(lowestSupportedVersion)) fullCHLO[TagVER] = b - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.VersionNegotiationMismatch, "Downgrade attack detected"))) }) @@ -418,53 +427,59 @@ var _ = Describe("Server Crypto Setup", func() { b := make([]byte, 4) binary.LittleEndian.PutUint32(b, protocol.VersionNumberToTag(unsupportedVersion)) fullCHLO[TagVER] = b - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).ToNot(HaveOccurred()) }) It("errors if the AEAD tag is missing", func() { delete(fullCHLO, TagAEAD) - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the AEAD tag has the wrong value", func() { fullCHLO[TagAEAD] = []byte("wrong") - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the KEXS tag is missing", func() { delete(fullCHLO, TagKEXS) - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) It("errors if the KEXS tag has the wrong value", func() { fullCHLO[TagKEXS] = []byte("wrong") - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, fullCHLO) + HandshakeMessage{Tag: TagCHLO, Data: fullCHLO}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.Error(qerr.CryptoNoSupport, "Unsupported AEAD or KEXS"))) }) }) It("errors without SNI", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{ - TagSTK: validSTK, - }) + HandshakeMessage{ + Tag: TagCHLO, + Data: map[Tag][]byte{ + TagSTK: validSTK, + }, + }.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) }) It("errors with empty SNI", func() { - WriteHandshakeMessage(&stream.dataToRead, TagCHLO, map[Tag][]byte{ - TagSTK: validSTK, - TagSNI: nil, - }) + HandshakeMessage{ + Tag: TagCHLO, + Data: map[Tag][]byte{ + TagSTK: validSTK, + TagSNI: nil, + }, + }.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError("CryptoMessageParameterNotFound: SNI required")) }) @@ -475,7 +490,7 @@ var _ = Describe("Server Crypto Setup", func() { }) It("errors with non-CHLO message", func() { - WriteHandshakeMessage(&stream.dataToRead, TagPAD, nil) + HandshakeMessage{Tag: TagPAD, Data: nil}.Write(&stream.dataToRead) err := cs.HandleCryptoStream() Expect(err).To(MatchError(qerr.InvalidCryptoMessageType)) }) diff --git a/handshake/handshake_message.go b/handshake/handshake_message.go index 32f02651..1636aa84 100644 --- a/handshake/handshake_message.go +++ b/handshake/handshake_message.go @@ -12,27 +12,35 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) +// A HandshakeMessage is a handshake message +type HandshakeMessage struct { + Tag Tag + Data map[Tag][]byte +} + +var _ fmt.Stringer = &HandshakeMessage{} + // ParseHandshakeMessage reads a crypto message -func ParseHandshakeMessage(r io.Reader) (Tag, map[Tag][]byte, error) { +func ParseHandshakeMessage(r io.Reader) (HandshakeMessage, error) { slice4 := make([]byte, 4) if _, err := io.ReadFull(r, slice4); err != nil { - return 0, nil, err + return HandshakeMessage{}, err } messageTag := Tag(binary.LittleEndian.Uint32(slice4)) if _, err := io.ReadFull(r, slice4); err != nil { - return 0, nil, err + return HandshakeMessage{}, err } nPairs := binary.LittleEndian.Uint32(slice4) if nPairs > protocol.CryptoMaxParams { - return 0, nil, qerr.CryptoTooManyEntries + return HandshakeMessage{}, qerr.CryptoTooManyEntries } index := make([]byte, nPairs*8) if _, err := io.ReadFull(r, index); err != nil { - return 0, nil, err + return HandshakeMessage{}, err } resultMap := map[Tag][]byte{} @@ -44,24 +52,27 @@ func ParseHandshakeMessage(r io.Reader) (Tag, map[Tag][]byte, error) { dataLen := dataEnd - dataStart if dataLen > protocol.CryptoParameterMaxLength { - return 0, nil, qerr.Error(qerr.CryptoInvalidValueLength, "value too long") + return HandshakeMessage{}, qerr.Error(qerr.CryptoInvalidValueLength, "value too long") } data := make([]byte, dataLen) if _, err := io.ReadFull(r, data); err != nil { - return 0, nil, err + return HandshakeMessage{}, err } resultMap[tag] = data dataStart = dataEnd } - return messageTag, resultMap, nil + return HandshakeMessage{ + Tag: messageTag, + Data: resultMap}, nil } -// WriteHandshakeMessage writes a crypto message -func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte) { - utils.WriteUint32(b, uint32(messageTag)) +// Write writes a crypto message +func (h HandshakeMessage) Write(b *bytes.Buffer) { + data := h.Data + utils.WriteUint32(b, uint32(h.Tag)) utils.WriteUint16(b, uint16(len(data))) utils.WriteUint16(b, 0) @@ -93,10 +104,10 @@ func WriteHandshakeMessage(b *bytes.Buffer, messageTag Tag, data map[Tag][]byte) copy(b.Bytes()[indexStart:], indexData) } -func printHandshakeMessage(data map[Tag][]byte) string { - var res string +func (h HandshakeMessage) String() string { var pad string - for k, v := range data { + res := tagToString(h.Tag) + ":\n" + for k, v := range h.Data { if k == TagPAD { pad = fmt.Sprintf("\t%s: (%d bytes)\n", tagToString(k), len(v)) } else { @@ -107,7 +118,6 @@ func printHandshakeMessage(data map[Tag][]byte) string { if len(pad) > 0 { res += pad } - return res } diff --git a/handshake/handshake_message_test.go b/handshake/handshake_message_test.go index 81eac317..fc96b120 100644 --- a/handshake/handshake_message_test.go +++ b/handshake/handshake_message_test.go @@ -11,15 +11,15 @@ import ( var _ = Describe("Handshake Message", func() { Context("when parsing", func() { It("parses sample CHLO message", func() { - tag, msg, err := ParseHandshakeMessage(bytes.NewReader(sampleCHLO)) + msg, err := ParseHandshakeMessage(bytes.NewReader(sampleCHLO)) Expect(err).ToNot(HaveOccurred()) - Expect(tag).To(Equal(TagCHLO)) - Expect(msg).To(Equal(sampleCHLOMap)) + Expect(msg.Tag).To(Equal(TagCHLO)) + Expect(msg.Data).To(Equal(sampleCHLOMap)) }) It("rejects large numbers of pairs", func() { r := bytes.NewReader([]byte("CHLO\xff\xff\xff\xff")) - _, _, err := ParseHandshakeMessage(r) + _, err := ParseHandshakeMessage(r) Expect(err).To(MatchError(qerr.CryptoTooManyEntries)) }) @@ -30,7 +30,7 @@ var _ = Describe("Handshake Message", func() { 0, 0, 0, 0, 0xff, 0xff, 0xff, 0xff, }) - _, _, err := ParseHandshakeMessage(r) + _, err := ParseHandshakeMessage(r) Expect(err).To(MatchError(qerr.Error(qerr.CryptoInvalidValueLength, "value too long"))) }) }) @@ -38,8 +38,34 @@ var _ = Describe("Handshake Message", func() { Context("when writing", func() { It("writes sample message", func() { b := &bytes.Buffer{} - WriteHandshakeMessage(b, TagCHLO, sampleCHLOMap) + HandshakeMessage{Tag: TagCHLO, Data: sampleCHLOMap}.Write(b) Expect(b.Bytes()).To(Equal(sampleCHLO)) }) }) + + Context("string representation", func() { + It("has a string representation", func() { + str := HandshakeMessage{ + Tag: TagSHLO, + Data: map[Tag][]byte{ + TagAEAD: []byte("foobar"), + TagEXPY: []byte("raboof"), + }, + }.String() + Expect(str[:4]).To(Equal("SHLO")) + Expect(str).To(ContainSubstring("AEAD: \"foobar\"")) + Expect(str).To(ContainSubstring("EXPY: \"raboof\"")) + }) + + It("lists padding separately", func() { + str := HandshakeMessage{ + Tag: TagSHLO, + Data: map[Tag][]byte{ + TagPAD: bytes.Repeat([]byte{0}, 1337), + }, + }.String() + Expect(str).To(ContainSubstring("PAD")) + Expect(str).To(ContainSubstring("1337 bytes")) + }) + }) }) diff --git a/handshake/server_config.go b/handshake/server_config.go index cd15b205..24195608 100644 --- a/handshake/server_config.go +++ b/handshake/server_config.go @@ -51,14 +51,18 @@ func NewServerConfig(kex crypto.KeyExchange, certChain crypto.CertChain) (*Serve // Get the server config binary representation func (s *ServerConfig) Get() []byte { var serverConfig bytes.Buffer - WriteHandshakeMessage(&serverConfig, TagSCFG, map[Tag][]byte{ - TagSCID: s.ID, - TagKEXS: []byte("C255"), - TagAEAD: []byte("AESG"), - TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...), - TagOBIT: s.obit, - TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, - }) + msg := HandshakeMessage{ + Tag: TagSCFG, + Data: map[Tag][]byte{ + TagSCID: s.ID, + TagKEXS: []byte("C255"), + TagAEAD: []byte("AESG"), + TagPUBS: append([]byte{0x20, 0x00, 0x00}, s.kex.PublicKey()...), + TagOBIT: s.obit, + TagEXPY: {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, + } + msg.Write(&serverConfig) return serverConfig.Bytes() } diff --git a/handshake/server_config_client.go b/handshake/server_config_client.go index 1da65513..1c99d0be 100644 --- a/handshake/server_config_client.go +++ b/handshake/server_config_client.go @@ -28,16 +28,16 @@ var ( // parseServerConfig parses a server config func parseServerConfig(data []byte) (*serverConfigClient, error) { - tag, tagMap, err := ParseHandshakeMessage(bytes.NewReader(data)) + message, err := ParseHandshakeMessage(bytes.NewReader(data)) if err != nil { return nil, err } - if tag != TagSCFG { + if message.Tag != TagSCFG { return nil, errMessageNotServerConfig } scfg := &serverConfigClient{raw: data} - err = scfg.parseValues(tagMap) + err = scfg.parseValues(message.Data) if err != nil { return nil, err } diff --git a/handshake/server_config_client_test.go b/handshake/server_config_client_test.go index 07bdb627..34d314ab 100644 --- a/handshake/server_config_client_test.go +++ b/handshake/server_config_client_test.go @@ -31,7 +31,7 @@ var _ = Describe("Server Config", func() { It("returns the parsed server config", func() { tagMap[TagSCID] = []byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xa, 0xb, 0xc, 0xd, 0xe, 0xf} b := &bytes.Buffer{} - WriteHandshakeMessage(b, TagSCFG, tagMap) + HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b) scfg, err := parseServerConfig(b.Bytes()) Expect(err).ToNot(HaveOccurred()) Expect(scfg.ID).To(Equal(tagMap[TagSCID])) @@ -39,7 +39,7 @@ var _ = Describe("Server Config", func() { It("saves the raw server config", func() { b := &bytes.Buffer{} - WriteHandshakeMessage(b, TagSCFG, tagMap) + HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(b) scfg, err := parseServerConfig(b.Bytes()) Expect(err).ToNot(HaveOccurred()) Expect(scfg.raw).To(Equal(b.Bytes())) @@ -56,28 +56,28 @@ var _ = Describe("Server Config", func() { Context("parsing the server config", func() { It("rejects a handshake message with the wrong message tag", func() { var serverConfig bytes.Buffer - WriteHandshakeMessage(&serverConfig, TagCHLO, make(map[Tag][]byte)) + HandshakeMessage{Tag: TagCHLO, Data: make(map[Tag][]byte)}.Write(&serverConfig) _, err := parseServerConfig(serverConfig.Bytes()) Expect(err).To(MatchError(errMessageNotServerConfig)) }) It("errors on invalid handshake messages", func() { var serverConfig bytes.Buffer - WriteHandshakeMessage(&serverConfig, TagSCFG, make(map[Tag][]byte)) + HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig) _, err := parseServerConfig(serverConfig.Bytes()[:serverConfig.Len()-2]) Expect(err).To(MatchError("unexpected EOF")) }) It("passes on errors encountered when reading the TagMap", func() { var serverConfig bytes.Buffer - WriteHandshakeMessage(&serverConfig, TagSCFG, make(map[Tag][]byte)) + HandshakeMessage{Tag: TagSCFG, Data: make(map[Tag][]byte)}.Write(&serverConfig) _, err := parseServerConfig(serverConfig.Bytes()) Expect(err).To(MatchError("CryptoMessageParameterNotFound: SCID")) }) It("reads an example Handshake Message", func() { var serverConfig bytes.Buffer - WriteHandshakeMessage(&serverConfig, TagSCFG, tagMap) + HandshakeMessage{Tag: TagSCFG, Data: tagMap}.Write(&serverConfig) scfg, err := parseServerConfig(serverConfig.Bytes()) Expect(err).ToNot(HaveOccurred()) Expect(scfg.ID).To(Equal(tagMap[TagSCID])) diff --git a/public_reset.go b/public_reset.go index b1f60d41..2cbd9888 100644 --- a/public_reset.go +++ b/public_reset.go @@ -32,15 +32,15 @@ func writePublicReset(connectionID protocol.ConnectionID, rejectedPacketNumber p func parsePublicReset(r *bytes.Reader) (*publicReset, error) { pr := publicReset{} - tag, tagMap, err := handshake.ParseHandshakeMessage(r) + msg, err := handshake.ParseHandshakeMessage(r) if err != nil { return nil, err } - if tag != handshake.TagPRST { + if msg.Tag != handshake.TagPRST { return nil, errors.New("wrong public reset tag") } - rseq, ok := tagMap[handshake.TagRSEQ] + rseq, ok := msg.Data[handshake.TagRSEQ] if !ok { return nil, errors.New("RSEQ missing") } @@ -49,7 +49,7 @@ func parsePublicReset(r *bytes.Reader) (*publicReset, error) { } pr.rejectedPacketNumber = protocol.PacketNumber(binary.LittleEndian.Uint64(rseq)) - rnon, ok := tagMap[handshake.TagRNON] + rnon, ok := msg.Data[handshake.TagRNON] if !ok { return nil, errors.New("RNON missing") } diff --git a/public_reset_test.go b/public_reset_test.go index 0df859b5..f1344ee8 100644 --- a/public_reset_test.go +++ b/public_reset_test.go @@ -49,7 +49,7 @@ var _ = Describe("public reset", func() { }) It("rejects packets with the wrong tag", func() { - handshake.WriteHandshakeMessage(b, handshake.TagREJ, nil) + handshake.HandshakeMessage{Tag: handshake.TagREJ, Data: nil}.Write(b) _, err := parsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("wrong public reset tag")) }) @@ -58,7 +58,7 @@ var _ = Describe("public reset", func() { data := map[handshake.Tag][]byte{ handshake.TagRSEQ: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, } - handshake.WriteHandshakeMessage(b, handshake.TagPRST, data) + handshake.HandshakeMessage{Tag: handshake.TagPRST, Data: data}.Write(b) _, err := parsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("RNON missing")) }) @@ -68,7 +68,7 @@ var _ = Describe("public reset", func() { handshake.TagRSEQ: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13}, } - handshake.WriteHandshakeMessage(b, handshake.TagPRST, data) + handshake.HandshakeMessage{Tag: handshake.TagPRST, Data: data}.Write(b) _, err := parsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("invalid RNON tag")) }) @@ -77,7 +77,7 @@ var _ = Describe("public reset", func() { data := map[handshake.Tag][]byte{ handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, } - handshake.WriteHandshakeMessage(b, handshake.TagPRST, data) + handshake.HandshakeMessage{Tag: handshake.TagPRST, Data: data}.Write(b) _, err := parsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("RSEQ missing")) }) @@ -87,7 +87,7 @@ var _ = Describe("public reset", func() { handshake.TagRSEQ: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13}, handshake.TagRNON: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe, 0x13, 0x37}, } - handshake.WriteHandshakeMessage(b, handshake.TagPRST, data) + handshake.HandshakeMessage{Tag: handshake.TagPRST, Data: data}.Write(b) _, err := parsePublicReset(bytes.NewReader(b.Bytes())) Expect(err).To(MatchError("invalid RSEQ tag")) })