forked from quic-go/quic-go
call the ConnStateCallback in the server for encryption level changes
ref #439
This commit is contained in:
@@ -99,14 +99,14 @@ var _ = Describe("Benchmarks", func() {
|
||||
connID := protocol.ConnectionID(mrand.Uint32())
|
||||
|
||||
c1 := newLinkedConnection(nil)
|
||||
session1I, err := newSession(c1, version, connID, nil, func(id protocol.ConnectionID) {})
|
||||
session1I, err := newSession(c1, version, connID, nil, func(id protocol.ConnectionID) {}, func(Session, bool) {})
|
||||
if err != nil {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
session1 := session1I.(*session)
|
||||
|
||||
c2 := newLinkedConnection(session1)
|
||||
session2I, err := newSession(c2, version, connID, nil, func(id protocol.ConnectionID) {})
|
||||
session2I, err := newSession(c2, version, connID, nil, func(id protocol.ConnectionID) {}, func(Session, bool) {})
|
||||
if err != nil {
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ func (c *client) handlePacket(remoteAddr net.Addr, packet []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) cryptoChangeCallback(isForwardSecure bool) {
|
||||
func (c *client) cryptoChangeCallback(_ Session, isForwardSecure bool) {
|
||||
var state ConnState
|
||||
if isForwardSecure {
|
||||
state = ConnStateForwardSecure
|
||||
|
||||
15
server.go
15
server.go
@@ -36,7 +36,7 @@ type server struct {
|
||||
sessionsMutex sync.RWMutex
|
||||
deleteClosedSessionsAfter time.Duration
|
||||
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback) (packetHandler, error)
|
||||
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error)
|
||||
}
|
||||
|
||||
var _ Listener = &server{}
|
||||
@@ -185,6 +185,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
||||
hdr.ConnectionID,
|
||||
s.scfg,
|
||||
s.closeCallback,
|
||||
s.cryptoChangeCallback,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -210,6 +211,18 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *server) cryptoChangeCallback(session Session, isForwardSecure bool) {
|
||||
var state ConnState
|
||||
if isForwardSecure {
|
||||
state = ConnStateForwardSecure
|
||||
} else {
|
||||
state = ConnStateSecure
|
||||
}
|
||||
if s.config.ConnState != nil {
|
||||
go s.config.ConnState(session, state)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *server) closeCallback(id protocol.ConnectionID) {
|
||||
s.sessionsMutex.Lock()
|
||||
s.sessions[id] = nil
|
||||
|
||||
@@ -47,7 +47,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
|
||||
|
||||
var _ Session = &mockSession{}
|
||||
|
||||
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback) (packetHandler, error) {
|
||||
func newMockSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
|
||||
return &mockSession{
|
||||
connectionID: connectionID,
|
||||
}, nil
|
||||
@@ -57,15 +57,18 @@ var _ = Describe("Server", func() {
|
||||
var (
|
||||
conn *mockPacketConn
|
||||
config *Config
|
||||
connStateSession Session
|
||||
connStateStatus ConnState
|
||||
connStateCalled bool
|
||||
udpAddr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
connStateCalled = false
|
||||
conn = &mockPacketConn{}
|
||||
config = &Config{
|
||||
ConnState: func(_ Session, cs ConnState) {
|
||||
ConnState: func(s Session, cs ConnState) {
|
||||
connStateSession = s
|
||||
connStateStatus = cs
|
||||
connStateCalled = true
|
||||
},
|
||||
@@ -112,12 +115,30 @@ var _ = Describe("Server", func() {
|
||||
err := serv.handlePacket(nil, nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(serv.sessions).To(HaveLen(1))
|
||||
Expect(serv.sessions[connID].(*mockSession).connectionID).To(Equal(connID))
|
||||
Expect(serv.sessions[connID].(*mockSession).packetCount).To(Equal(1))
|
||||
sess := serv.sessions[connID].(*mockSession)
|
||||
Expect(sess.connectionID).To(Equal(connID))
|
||||
Expect(sess.packetCount).To(Equal(1))
|
||||
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
|
||||
Expect(connStateSession).To(Equal(sess))
|
||||
Expect(connStateStatus).To(Equal(ConnStateVersionNegotiated))
|
||||
})
|
||||
|
||||
It("calls the ConnState callback when the connection is secure", func() {
|
||||
sess := &mockSession{}
|
||||
serv.cryptoChangeCallback(sess, false)
|
||||
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
|
||||
Expect(connStateSession).To(Equal(sess))
|
||||
Expect(connStateStatus).To(Equal(ConnStateSecure))
|
||||
})
|
||||
|
||||
It("calls the ConnState callback when the connection is forward-secure", func() {
|
||||
sess := &mockSession{}
|
||||
serv.cryptoChangeCallback(sess, true)
|
||||
Eventually(func() bool { return connStateCalled }).Should(BeTrue())
|
||||
Expect(connStateStatus).To(Equal(ConnStateForwardSecure))
|
||||
Expect(connStateSession).To(Equal(sess))
|
||||
})
|
||||
|
||||
It("assigns packets to existing sessions", func() {
|
||||
err := serv.handlePacket(nil, nil, firstPacket)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
@@ -37,7 +37,7 @@ var (
|
||||
|
||||
// cryptoChangeCallback is called every time the encryption level changes
|
||||
// Once the callback has been called with isForwardSecure = true, it is guarantueed to not be called with isForwardSecure = false after that
|
||||
type cryptoChangeCallback func(isForwardSecure bool)
|
||||
type cryptoChangeCallback func(session Session, isForwardSecure bool)
|
||||
|
||||
// closeCallback is called when a session is closed
|
||||
type closeCallback func(id protocol.ConnectionID)
|
||||
@@ -99,7 +99,7 @@ type session struct {
|
||||
var _ Session = &session{}
|
||||
|
||||
// newSession makes a new session
|
||||
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback) (packetHandler, error) {
|
||||
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, closeCallback closeCallback, cryptoChangeCallback cryptoChangeCallback) (packetHandler, error) {
|
||||
s := &session{
|
||||
conn: conn,
|
||||
connectionID: connectionID,
|
||||
@@ -107,7 +107,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
|
||||
version: v,
|
||||
|
||||
closeCallback: closeCallback,
|
||||
cryptoChangeCallback: func(bool) {},
|
||||
cryptoChangeCallback: cryptoChangeCallback,
|
||||
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
|
||||
}
|
||||
|
||||
@@ -237,7 +237,7 @@ runLoop:
|
||||
putPacketBuffer(p.publicHeader.Raw)
|
||||
case <-s.aeadChanged:
|
||||
s.tryDecryptingQueuedPackets()
|
||||
s.cryptoChangeCallback(s.cryptoSetup.HandshakeComplete())
|
||||
s.cryptoChangeCallback(s, s.cryptoSetup.HandshakeComplete())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -138,6 +138,7 @@ var _ = Describe("Session", func() {
|
||||
0,
|
||||
scfg,
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(Session, bool) {},
|
||||
)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
sess = pSess.(*session)
|
||||
@@ -153,7 +154,7 @@ var _ = Describe("Session", func() {
|
||||
0,
|
||||
nil,
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(isForwardSecure bool) {},
|
||||
func(Session, bool) {},
|
||||
nil,
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
@@ -169,6 +170,7 @@ var _ = Describe("Session", func() {
|
||||
0,
|
||||
scfg,
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(Session, bool) {},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200}))
|
||||
@@ -184,6 +186,7 @@ var _ = Describe("Session", func() {
|
||||
0,
|
||||
scfg,
|
||||
func(protocol.ConnectionID) { closeCallbackCalled = true },
|
||||
func(Session, bool) {},
|
||||
)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337")))
|
||||
@@ -1100,9 +1103,11 @@ var _ = Describe("Session", func() {
|
||||
It("calls the cryptoChangeCallback when the AEAD changes", func(done Done) {
|
||||
var callbackCalled bool
|
||||
var callbackCalledWith bool
|
||||
cb := func(p bool) {
|
||||
var callbackSession Session
|
||||
cb := func(s Session, p bool) {
|
||||
callbackCalled = true
|
||||
callbackCalledWith = p
|
||||
callbackSession = s
|
||||
}
|
||||
sess.cryptoChangeCallback = cb
|
||||
sess.cryptoSetup = &mockCryptoSetup{handshakeComplete: false}
|
||||
@@ -1110,12 +1115,15 @@ var _ = Describe("Session", func() {
|
||||
go sess.run()
|
||||
Eventually(func() bool { return callbackCalled }).Should(BeTrue())
|
||||
Expect(callbackCalledWith).To(BeFalse())
|
||||
Expect(callbackSession).To(Equal(sess))
|
||||
close(done)
|
||||
})
|
||||
|
||||
It("calls the cryptoChangeCallback when the AEAD changes to forward secure encryption", func(done Done) {
|
||||
var callbackCalledWith bool
|
||||
cb := func(p bool) {
|
||||
var callbackSession Session
|
||||
cb := func(s Session, p bool) {
|
||||
callbackSession = s
|
||||
callbackCalledWith = p
|
||||
}
|
||||
sess.cryptoChangeCallback = cb
|
||||
@@ -1123,6 +1131,7 @@ var _ = Describe("Session", func() {
|
||||
sess.aeadChanged <- struct{}{}
|
||||
go sess.run()
|
||||
Eventually(func() bool { return callbackCalledWith }).Should(BeTrue())
|
||||
Expect(callbackSession).To(Equal(sess))
|
||||
close(done)
|
||||
})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user