diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index ae3317ad..8f4acf8c 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -55,7 +55,7 @@ func NewCryptoSetupTLSServer( nullAEAD: crypto.NewNullAEAD(protocol.PerspectiveServer, version), keyDerivation: crypto.DeriveAESKeys, aeadChanged: aeadChanged, - extensionHandler: newExtensionHandlerServer(params, supportedVersions), + extensionHandler: newExtensionHandlerServer(params, supportedVersions, version), }, params, nil } diff --git a/internal/handshake/tl_extension_handler_server.go b/internal/handshake/tl_extension_handler_server.go index 3687baaa..aeea7eb5 100644 --- a/internal/handshake/tl_extension_handler_server.go +++ b/internal/handshake/tl_extension_handler_server.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" + "github.com/lucas-clemente/quic-go/qerr" + "github.com/bifurcation/mint" "github.com/bifurcation/mint/syntax" "github.com/lucas-clemente/quic-go/internal/protocol" @@ -13,14 +15,16 @@ import ( type extensionHandlerServer struct { params *paramsNegotiator + version protocol.VersionNumber supportedVersions []protocol.VersionNumber } var _ mint.AppExtensionHandler = &extensionHandlerServer{} -func newExtensionHandlerServer(params *paramsNegotiator, supportedVersions []protocol.VersionNumber) *extensionHandlerServer { +func newExtensionHandlerServer(params *paramsNegotiator, supportedVersions []protocol.VersionNumber, version protocol.VersionNumber) *extensionHandlerServer { return &extensionHandlerServer{ params: params, + version: version, supportedVersions: supportedVersions, } } @@ -66,7 +70,19 @@ func (h *extensionHandlerServer) Receive(hType mint.HandshakeType, el *mint.Exte if _, err := syntax.Unmarshal(ext.data, chtp); err != nil { return err } - // TODO: check versions + initialVersion := protocol.VersionNumber(chtp.InitialVersion) + negotiatedVersion := protocol.VersionNumber(chtp.NegotiatedVersion) + // check that the negotiated version is the version we're currently using + if negotiatedVersion != h.version { + return qerr.Error(qerr.VersionNegotiationMismatch, "Inconsistent negotiated version") + } + // perform the stateless version negotiation validation: + // make sure that we would have sent a Version Negotiation Packet if the client offered the initial version + // this is the case when the initial version is not contained in the supported versions + if initialVersion != negotiatedVersion && protocol.IsSupportedVersion(h.supportedVersions, initialVersion) { + return qerr.Error(qerr.VersionNegotiationMismatch, "Client should have used the initial version") + } + for _, p := range chtp.Parameters { if p.Parameter == statelessResetTokenParameterID { // TODO: return the correct error type diff --git a/internal/handshake/tls_extension_handler_server_test.go b/internal/handshake/tls_extension_handler_server_test.go index c8bc463e..c5857ded 100644 --- a/internal/handshake/tls_extension_handler_server_test.go +++ b/internal/handshake/tls_extension_handler_server_test.go @@ -16,7 +16,7 @@ var _ = Describe("TLS Extension Handler, for the server", func() { BeforeEach(func() { pn := ¶msNegotiator{} - handler = newExtensionHandlerServer(pn, nil) + handler = newExtensionHandlerServer(pn, nil, protocol.VersionWhatever) el = make(mint.ExtensionList, 0) }) @@ -50,14 +50,19 @@ var _ = Describe("TLS Extension Handler, for the server", func() { var fakeBody *tlsExtensionBody var parameters map[transportParameterID][]byte - paramaterMapToExtensionBody := func(paramMap map[transportParameterID][]byte) *tlsExtensionBody { + parameterMapToList := func(paramMap map[transportParameterID][]byte) []transportParameter { var params []transportParameter for id, val := range paramMap { params = append(params, transportParameter{id, val}) } - body, err := syntax.Marshal(clientHelloTransportParameters{Parameters: params}) + return params + } + + addClientHelloWithParameters := func(paramMap map[transportParameterID][]byte) { + body, err := syntax.Marshal(clientHelloTransportParameters{Parameters: parameterMapToList(paramMap)}) + Expect(err).ToNot(HaveOccurred()) + err = el.Add(&tlsExtensionBody{data: body}) Expect(err).ToNot(HaveOccurred()) - return &tlsExtensionBody{data: body} } BeforeEach(func() { @@ -71,9 +76,8 @@ var _ = Describe("TLS Extension Handler, for the server", func() { }) It("accepts the TransportParameters on the EncryptedExtensions message", func() { - err := el.Add(paramaterMapToExtensionBody(parameters)) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeClientHello, &el) + addClientHelloWithParameters(parameters) + err := handler.Receive(mint.HandshakeTypeClientHello, &el) Expect(err).ToNot(HaveOccurred()) Expect(handler.params.GetSendStreamFlowControlWindow()).To(BeEquivalentTo(0x11223344)) }) @@ -96,18 +100,71 @@ var _ = Describe("TLS Extension Handler, for the server", func() { }) It("rejects messages other than the ClientHello that contain TransportParameters", func() { - err := el.Add(paramaterMapToExtensionBody(parameters)) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeCertificateRequest, &el) + addClientHelloWithParameters(parameters) + err := handler.Receive(mint.HandshakeTypeCertificateRequest, &el) Expect(err).To(MatchError(fmt.Sprintf("Unexpected QUIC extension in handshake message %d", mint.HandshakeTypeCertificateRequest))) }) It("rejects messages that contain a stateless reset token", func() { parameters[statelessResetTokenParameterID] = []byte("reset") - err := el.Add(paramaterMapToExtensionBody(parameters)) - Expect(err).ToNot(HaveOccurred()) - err = handler.Receive(mint.HandshakeTypeClientHello, &el) + addClientHelloWithParameters(parameters) + err := handler.Receive(mint.HandshakeTypeClientHello, &el) Expect(err).To(MatchError("client sent a stateless reset token")) }) + + Context("Version Negotiation", func() { + It("accepts a ClientHello, when no version negotiation was performed", func() { + handler.version = 42 + body, err := syntax.Marshal(clientHelloTransportParameters{ + NegotiatedVersion: 42, + InitialVersion: 42, + Parameters: parameterMapToList(parameters), + }) + Expect(err).ToNot(HaveOccurred()) + err = el.Add(&tlsExtensionBody{data: body}) + Expect(err).ToNot(HaveOccurred()) + err = handler.Receive(mint.HandshakeTypeClientHello, &el) + Expect(err).ToNot(HaveOccurred()) + }) + + It("accepts a valid version negotiation", func() { + handler.version = 42 + handler.supportedVersions = []protocol.VersionNumber{13, 37, 42} + body, err := syntax.Marshal(clientHelloTransportParameters{ + NegotiatedVersion: 42, + InitialVersion: 22, // this must be an unsupported version + Parameters: parameterMapToList(parameters), + }) + Expect(err).ToNot(HaveOccurred()) + err = el.Add(&tlsExtensionBody{data: body}) + Expect(err).ToNot(HaveOccurred()) + err = handler.Receive(mint.HandshakeTypeClientHello, &el) + Expect(err).ToNot(HaveOccurred()) + }) + + It("errors when the NegotiatedVersion field doesn't match the current version", func() { + handler.version = 42 + body, err := syntax.Marshal(clientHelloTransportParameters{NegotiatedVersion: 43}) + Expect(err).ToNot(HaveOccurred()) + err = el.Add(&tlsExtensionBody{data: body}) + Expect(err).ToNot(HaveOccurred()) + err = handler.Receive(mint.HandshakeTypeClientHello, &el) + Expect(err).To(MatchError("VersionNegotiationMismatch: Inconsistent negotiated version")) + }) + + It("errros when a version negotiation was performed, although we already support the inital version", func() { + handler.supportedVersions = []protocol.VersionNumber{11, 12, 13} + handler.version = 13 + body, err := syntax.Marshal(clientHelloTransportParameters{ + NegotiatedVersion: 13, + InitialVersion: 11, // this is an supported version + }) + Expect(err).ToNot(HaveOccurred()) + err = el.Add(&tlsExtensionBody{data: body}) + Expect(err).ToNot(HaveOccurred()) + err = handler.Receive(mint.HandshakeTypeClientHello, &el) + Expect(err).To(MatchError("VersionNegotiationMismatch: Client should have used the initial version")) + }) + }) }) })