diff --git a/session.go b/session.go index b2ca2f3fe..874b0faa4 100644 --- a/session.go +++ b/session.go @@ -286,6 +286,13 @@ func (s *Session) idleTimeout() time.Duration { } func (s *Session) handlePacketImpl(p *receivedPacket) error { + if s.perspective == protocol.PerspectiveClient { + diversificationNonce := p.publicHeader.DiversificationNonce + if len(diversificationNonce) > 0 { + s.cryptoSetup.SetDiversificationNonce(diversificationNonce) + } + } + if p.rcvTime.IsZero() { // To simplify testing p.rcvTime = time.Now() diff --git a/session_test.go b/session_test.go index 6d9fdee4c..15eed151c 100644 --- a/session_test.go +++ b/session_test.go @@ -117,6 +117,7 @@ var _ ackhandler.ReceivedPacketHandler = &mockReceivedPacketHandler{} var _ = Describe("Session", func() { var ( session *Session + clientSession *Session streamCallbackCalled bool closeCallbackCalled bool conn *mockConnection @@ -148,6 +149,18 @@ var _ = Describe("Session", func() { cpm = &mockConnectionParametersManager{idleTime: 60 * time.Second} session.connectionParameters = cpm + + clientSession, err = newClientSession( + &net.UDPConn{}, + &net.UDPAddr{}, + protocol.Version35, + 0, + func(*Session, utils.Stream) { streamCallbackCalled = true }, + func(protocol.ConnectionID) { closeCallbackCalled = true }, + ) + Expect(err).ToNot(HaveOccurred()) + Expect(clientSession.streamsMap.openStreams).To(HaveLen(1)) + }) Context("when handling stream frames", func() { @@ -584,6 +597,7 @@ var _ = Describe("Session", func() { BeforeEach(func() { session.unpacker = &mockUnpacker{} + clientSession.unpacker = &mockUnpacker{} hdr = &PublicHeader{PacketNumberLen: protocol.PacketNumberLen6} }) @@ -624,6 +638,14 @@ var _ = Describe("Session", func() { Expect(err).ToNot(HaveOccurred()) }) + It("passes the diversification nonce to the cryptoSetup, if it is a client", func() { + hdr.PacketNumber = 5 + hdr.DiversificationNonce = []byte("foobar") + err := clientSession.handlePacketImpl(&receivedPacket{publicHeader: hdr}) + Expect(err).ToNot(HaveOccurred()) + Expect((*[]byte)(unsafe.Pointer(reflect.ValueOf(clientSession.cryptoSetup).Elem().FieldByName("diversificationNonce").UnsafeAddr()))).To(Equal(&hdr.DiversificationNonce)) + }) + Context("updating the remote address", func() { It("sets the remote address", func() { remoteIP := net.IPv4(192, 168, 0, 100)