diff --git a/client.go b/client.go index ada245c56..ed1253b12 100644 --- a/client.go +++ b/client.go @@ -257,8 +257,8 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e c.hostname, c.version, c.connectionID, - c.config.TLSConfig, c.cryptoChangeCallback, + c.config, negotiatedVersions, ) if err != nil { diff --git a/server.go b/server.go index ca5192ec3..aa4ee2c77 100644 --- a/server.go +++ b/server.go @@ -34,7 +34,7 @@ type server struct { sessionsMutex sync.RWMutex deleteClosedSessionsAfter time.Duration - newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) + newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, config *Config) (packetHandler, error) } var _ Listener = &server{} @@ -197,7 +197,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet hdr.ConnectionID, s.scfg, s.cryptoChangeCallback, - s.config.Versions, + s.config, ) if err != nil { return err diff --git a/server_test.go b/server_test.go index 4c829aaae..0ef386735 100644 --- a/server_test.go +++ b/server_test.go @@ -56,7 +56,7 @@ func (s *mockSession) RemoteAddr() net.Addr { var _ Session = &mockSession{} -func newMockSession(_ connection, _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, _ cryptoChangeCallback, _ []protocol.VersionNumber) (packetHandler, error) { +func newMockSession(_ connection, _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, _ cryptoChangeCallback, _ *Config) (packetHandler, error) { return &mockSession{ connectionID: connectionID, stopRunLoop: make(chan struct{}), diff --git a/session.go b/session.go index 3e4d3353d..cc39c6ff2 100644 --- a/session.go +++ b/session.go @@ -1,7 +1,6 @@ package quic import ( - "crypto/tls" "errors" "fmt" "net" @@ -49,6 +48,7 @@ type session struct { connectionID protocol.ConnectionID perspective protocol.Perspective version protocol.VersionNumber + config *Config cryptoChangeCallback cryptoChangeCallback @@ -106,12 +106,20 @@ type session struct { var _ Session = &session{} // newSession makes a new session -func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) { +func newSession( + conn connection, + v protocol.VersionNumber, + connectionID protocol.ConnectionID, + sCfg *handshake.ServerConfig, + cryptoChangeCallback cryptoChangeCallback, + config *Config, +) (packetHandler, error) { s := &session{ conn: conn, connectionID: connectionID, perspective: protocol.PerspectiveServer, version: v, + config: config, cryptoChangeCallback: cryptoChangeCallback, connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v), @@ -129,7 +137,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol aeadChanged := make(chan protocol.EncryptionLevel, 2) s.aeadChanged = aeadChanged var err error - s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, supportedVersions, aeadChanged) + s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, config.Versions, aeadChanged) if err != nil { return nil, err } @@ -140,12 +148,21 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol return s, err } -func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) { +func newClientSession( + conn connection, + hostname string, + v protocol.VersionNumber, + connectionID protocol.ConnectionID, + cryptoChangeCallback cryptoChangeCallback, + config *Config, + negotiatedVersions []protocol.VersionNumber, +) (*session, error) { s := &session{ conn: conn, connectionID: connectionID, perspective: protocol.PerspectiveClient, version: v, + config: config, cryptoChangeCallback: cryptoChangeCallback, connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), @@ -158,7 +175,7 @@ func newClientSession(conn connection, hostname string, v protocol.VersionNumber s.aeadChanged = aeadChanged cryptoStream, _ := s.OpenStream() var err error - s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, s.connectionParameters, aeadChanged, negotiatedVersions) + s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, config.TLSConfig, s.connectionParameters, aeadChanged, negotiatedVersions) if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index c00802c8a..113aab085 100644 --- a/session_test.go +++ b/session_test.go @@ -150,7 +150,7 @@ var _ = Describe("Session", func() { 0, scfg, func(Session, bool) {}, - nil, + populateServerConfig(&Config{}), ) Expect(err).NotTo(HaveOccurred()) sess = pSess.(*session) @@ -167,8 +167,8 @@ var _ = Describe("Session", func() { "hostname", protocol.Version35, 0, - nil, func(Session, bool) {}, + populateClientConfig(&Config{}), nil, ) Expect(err).ToNot(HaveOccurred()) @@ -188,7 +188,7 @@ var _ = Describe("Session", func() { 0, scfg, func(Session, bool) {}, - nil, + populateServerConfig(&Config{}), ) Expect(err).ToNot(HaveOccurred()) Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200})) @@ -204,7 +204,7 @@ var _ = Describe("Session", func() { 0, scfg, func(Session, bool) {}, - nil, + populateServerConfig(&Config{}), ) Expect(err).ToNot(HaveOccurred()) Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337")))