call the ConnStateCallback in the server for encryption level changes

ref #439
This commit is contained in:
Marten Seemann
2017-02-24 08:42:49 +07:00
parent d0777fe7fa
commit 820eb2e6ee
6 changed files with 63 additions and 20 deletions

View File

@@ -99,14 +99,14 @@ var _ = Describe("Benchmarks", func() {
connID := protocol.ConnectionID(mrand.Uint32())
c1 := newLinkedConnection(nil)
session1I, err := newSession(c1, version, connID, nil, func(id protocol.ConnectionID) {})
session1I, err := newSession(c1, version, connID, nil, func(id protocol.ConnectionID) {}, func(Session, bool) {})
if err != nil {
Expect(err).NotTo(HaveOccurred())
}
session1 := session1I.(*session)
c2 := newLinkedConnection(session1)
session2I, err := newSession(c2, version, connID, nil, func(id protocol.ConnectionID) {})
session2I, err := newSession(c2, version, connID, nil, func(id protocol.ConnectionID) {}, func(Session, bool) {})
if err != nil {
Expect(err).NotTo(HaveOccurred())
}

View File

@@ -220,7 +220,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
return nil
}
func (c *client) cryptoChangeCallback(isForwardSecure bool) {
func (c *client) cryptoChangeCallback(_ Session, isForwardSecure bool) {
var state ConnState
if isForwardSecure {
state = ConnStateForwardSecure

View File

@@ -36,7 +36,7 @@ type server struct {
sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback) (packetHandler, error)
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error)
}
var _ Listener = &server{}
@@ -185,6 +185,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
hdr.ConnectionID,
s.scfg,
s.closeCallback,
s.cryptoChangeCallback,
)
if err != nil {
return err
@@ -210,6 +211,18 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
return nil
}
func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) {
var state ConnState
if isForwardSecure {
state = ConnStateForwardSecure
} else {
state = ConnStateSecure
}
if s.config.ConnState != nil {
go s.config.ConnState(session, state)
}
}
func (s *server) closeCallback(id protocol.ConnectionID) {
s.sessionsMutex.Lock()
s.sessions[id] = nil

View File

@@ -47,7 +47,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
var _ Session = &mockSession{}
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback) (packetHandler, error) {
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
return &mockSession{
connectionID: connectionID,
}, nil
@@ -57,15 +57,18 @@ var _ = Describe("Server", func() {
var (
conn *mockPacketConn
config *Config
connStateSession Session
connStateStatus ConnState
connStateCalled bool
udpAddr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
)
BeforeEach(func() {
connStateCalled = false
conn = &mockPacketConn{}
config = &Config{
ConnState: func(_ Session, cs ConnState) {
ConnState: func(s Session, cs ConnState) {
connStateSession = s
connStateStatus = cs
connStateCalled = true
},
@@ -112,12 +115,30 @@ var _ = Describe("Server", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())
Expect(serv.sessions).To(HaveLen(1))
Expect(serv.sessions[connID].(*mockSession).connectionID).To(Equal(connID))
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
sess := serv.sessions[connID].(*mockSession)
Expect(sess.connectionID).To(Equal(connID))
Expect(sess.packetCount).To(Equal(1))
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateSession).To(Equal(sess))
Expect(connStateStatus).To(Equal(ConnStateVersionNegotiated))
})
It("calls the ConnState callback when the connection is secure", func() {
sess := &mockSession{}
serv.cryptoChangeCallback(sess, false)
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateSession).To(Equal(sess))
Expect(connStateStatus).To(Equal(ConnStateSecure))
})
It("calls the ConnState callback when the connection is forward-secure", func() {
sess := &mockSession{}
serv.cryptoChangeCallback(sess, true)
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
Expect(connStateStatus).To(Equal(ConnStateForwardSecure))
Expect(connStateSession).To(Equal(sess))
})
It("assigns packets to existing sessions", func() {
err := serv.handlePacket(nil, nil, firstPacket)
Expect(err).ToNot(HaveOccurred())

View File

@@ -37,7 +37,7 @@ var (
// cryptoChangeCallback is called every time the encryption level changes
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
type cryptoChangeCallback func(isForwardSecure bool)
type cryptoChangeCallback func(session Session, isForwardSecure bool)
// closeCallback is called when a session is closed
type closeCallback func(id protocol.ConnectionID)
@@ -99,7 +99,7 @@ type session struct {
var _ Session = &session{}
// newSession makes a new session
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback) (packetHandler, error) {
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
s := &session{
conn: conn,
connectionID: connectionID,
@@ -107,7 +107,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
version: v,
closeCallback: closeCallback,
cryptoChangeCallback: func(bool) {},
cryptoChangeCallback: cryptoChangeCallback,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
}
@@ -237,7 +237,7 @@ runLoop:
putPacketBuffer(p.publicHeader.Raw)
case <-s.aeadChanged:
s.tryDecryptingQueuedPackets()
s.cryptoChangeCallback(s.cryptoSetup.HandshakeComplete())
s.cryptoChangeCallback(s, s.cryptoSetup.HandshakeComplete())
}
if err != nil {

View File

@@ -138,6 +138,7 @@ var _ = Describe("Session", func() {
0,
scfg,
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(Session, bool) {},
)
Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session)
@@ -153,7 +154,7 @@ var _ = Describe("Session", func() {
0,
nil,
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(isForwardSecure bool) {},
func(Session, bool) {},
nil,
)
Expect(err).ToNot(HaveOccurred())
@@ -169,6 +170,7 @@ var _ = Describe("Session", func() {
0,
scfg,
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(Session, bool) {},
)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200}))
@@ -184,6 +186,7 @@ var _ = Describe("Session", func() {
0,
scfg,
func(protocol.ConnectionID) { closeCallbackCalled = true },
func(Session, bool) {},
)
Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337")))
@@ -1100,9 +1103,11 @@ var _ = Describe("Session", func() {
It("calls the cryptoChangeCallback when the AEAD changes", func(done Done) {
var callbackCalled bool
var callbackCalledWith bool
cb := func(p bool) {
var callbackSession Session
cb := func(s Session, p bool) {
callbackCalled = true
callbackCalledWith = p
callbackSession = s
}
sess.cryptoChangeCallback = cb
sess.cryptoSetup = &mockCryptoSetup{handshakeComplete: false}
@@ -1110,12 +1115,15 @@ var _ = Describe("Session", func() {
go sess.run()
Eventually(func() bool { return callbackCalled }).Should(BeTrue())
Expect(callbackCalledWith).To(BeFalse())
Expect(callbackSession).To(Equal(sess))
close(done)
})
It("calls the cryptoChangeCallback when the AEAD changes to forward secure encryption", func(done Done) {
var callbackCalledWith bool
cb := func(p bool) {
var callbackSession Session
cb := func(s Session, p bool) {
callbackSession = s
callbackCalledWith = p
}
sess.cryptoChangeCallback = cb
@@ -1123,6 +1131,7 @@ var _ = Describe("Session", func() {
sess.aeadChanged <- struct{}{}
go sess.run()
Eventually(func() bool { return callbackCalledWith }).Should(BeTrue())
Expect(callbackSession).To(Equal(sess))
close(done)
})