pass the quic.Config to the session

This commit is contained in:
Marten Seemann
2017-05-09 22:38:22 +08:00
parent 22a9a8221c
commit 650af86c70
5 changed files with 30 additions and 13 deletions

View File

@@ -257,8 +257,8 @@ func (c *client) createNewSession(negotiatedVersions []protocol.VersionNumber) e
c.hostname, c.hostname,
c.version, c.version,
c.connectionID, c.connectionID,
c.config.TLSConfig,
c.cryptoChangeCallback, c.cryptoChangeCallback,
c.config,
negotiatedVersions, negotiatedVersions,
) )
if err != nil { if err != nil {

View File

@@ -34,7 +34,7 @@ type server struct {
sessionsMutex sync.RWMutex sessionsMutex sync.RWMutex
deleteClosedSessionsAfter time.Duration deleteClosedSessionsAfter time.Duration
newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) newSession func(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, config *Config) (packetHandler, error)
} }
var _ Listener = &server{} var _ Listener = &server{}
@@ -197,7 +197,7 @@ func (s *server) handlePacket(pconn net.PacketConn, remoteAddr net.Addr, packet
hdr.ConnectionID, hdr.ConnectionID,
s.scfg, s.scfg,
s.cryptoChangeCallback, s.cryptoChangeCallback,
s.config.Versions, s.config,
) )
if err != nil { if err != nil {
return err return err

View File

@@ -56,7 +56,7 @@ func (s *mockSession) RemoteAddr() net.Addr {
var _ Session = &mockSession{} var _ Session = &mockSession{}
func newMockSession(_ connection, _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, _ cryptoChangeCallback, _ []protocol.VersionNumber) (packetHandler, error) { func newMockSession(_ connection, _ protocol.VersionNumber, connectionID protocol.ConnectionID, _ *handshake.ServerConfig, _ cryptoChangeCallback, _ *Config) (packetHandler, error) {
return &mockSession{ return &mockSession{
connectionID: connectionID, connectionID: connectionID,
stopRunLoop: make(chan struct{}), stopRunLoop: make(chan struct{}),

View File

@@ -1,7 +1,6 @@
package quic package quic
import ( import (
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -49,6 +48,7 @@ type session struct {
connectionID protocol.ConnectionID connectionID protocol.ConnectionID
perspective protocol.Perspective perspective protocol.Perspective
version protocol.VersionNumber version protocol.VersionNumber
config *Config
cryptoChangeCallback cryptoChangeCallback cryptoChangeCallback cryptoChangeCallback
@@ -106,12 +106,20 @@ type session struct {
var _ Session = &session{} var _ Session = &session{}
// newSession makes a new session // newSession makes a new session
func newSession(conn connection, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *handshake.ServerConfig, cryptoChangeCallback cryptoChangeCallback, supportedVersions []protocol.VersionNumber) (packetHandler, error) { func newSession(
conn connection,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
sCfg *handshake.ServerConfig,
cryptoChangeCallback cryptoChangeCallback,
config *Config,
) (packetHandler, error) {
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveServer, perspective: protocol.PerspectiveServer,
version: v, version: v,
config: config,
cryptoChangeCallback: cryptoChangeCallback, cryptoChangeCallback: cryptoChangeCallback,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v), connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveServer, v),
@@ -129,7 +137,7 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
aeadChanged := make(chan protocol.EncryptionLevel, 2) aeadChanged := make(chan protocol.EncryptionLevel, 2)
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
var err error var err error
s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, supportedVersions, aeadChanged) s.cryptoSetup, err = handshake.NewCryptoSetup(connectionID, sourceAddr, v, sCfg, cryptoStream, s.connectionParameters, config.Versions, aeadChanged)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -140,12 +148,21 @@ func newSession(conn connection, v protocol.VersionNumber, connectionID protocol
return s, err return s, err
} }
func newClientSession(conn connection, hostname string, v protocol.VersionNumber, connectionID protocol.ConnectionID, tlsConfig *tls.Config, cryptoChangeCallback cryptoChangeCallback, negotiatedVersions []protocol.VersionNumber) (*session, error) { func newClientSession(
conn connection,
hostname string,
v protocol.VersionNumber,
connectionID protocol.ConnectionID,
cryptoChangeCallback cryptoChangeCallback,
config *Config,
negotiatedVersions []protocol.VersionNumber,
) (*session, error) {
s := &session{ s := &session{
conn: conn, conn: conn,
connectionID: connectionID, connectionID: connectionID,
perspective: protocol.PerspectiveClient, perspective: protocol.PerspectiveClient,
version: v, version: v,
config: config,
cryptoChangeCallback: cryptoChangeCallback, cryptoChangeCallback: cryptoChangeCallback,
connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v), connectionParameters: handshake.NewConnectionParamatersManager(protocol.PerspectiveClient, v),
@@ -158,7 +175,7 @@ func newClientSession(conn connection, hostname string, v protocol.VersionNumber
s.aeadChanged = aeadChanged s.aeadChanged = aeadChanged
cryptoStream, _ := s.OpenStream() cryptoStream, _ := s.OpenStream()
var err error var err error
s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, tlsConfig, s.connectionParameters, aeadChanged, negotiatedVersions) s.cryptoSetup, err = handshake.NewCryptoSetupClient(hostname, connectionID, v, cryptoStream, config.TLSConfig, s.connectionParameters, aeadChanged, negotiatedVersions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -150,7 +150,7 @@ var _ = Describe("Session", func() {
0, 0,
scfg, scfg,
func(Session, bool) {}, func(Session, bool) {},
nil, populateServerConfig(&Config{}),
) )
Expect(err).NotTo(HaveOccurred()) Expect(err).NotTo(HaveOccurred())
sess = pSess.(*session) sess = pSess.(*session)
@@ -167,8 +167,8 @@ var _ = Describe("Session", func() {
"hostname", "hostname",
protocol.Version35, protocol.Version35,
0, 0,
nil,
func(Session, bool) {}, func(Session, bool) {},
populateClientConfig(&Config{}),
nil, nil,
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
@@ -188,7 +188,7 @@ var _ = Describe("Session", func() {
0, 0,
scfg, scfg,
func(Session, bool) {}, func(Session, bool) {},
nil, populateServerConfig(&Config{}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200})) Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte{192, 168, 100, 200}))
@@ -204,7 +204,7 @@ var _ = Describe("Session", func() {
0, 0,
scfg, scfg,
func(Session, bool) {}, func(Session, bool) {},
nil, populateServerConfig(&Config{}),
) )
Expect(err).ToNot(HaveOccurred()) Expect(err).ToNot(HaveOccurred())
Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337"))) Expect(*(*[]byte)(unsafe.Pointer(reflect.ValueOf(sess.(*session).cryptoSetup).Elem().FieldByName("sourceAddr").UnsafeAddr()))).To(Equal([]byte("192.168.100.200:1337")))