forked from quic-go/quic-go
expose the tls.ConnectionState
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"unsafe"
|
||||
|
||||
"github.com/lucas-clemente/quic-go/internal/protocol"
|
||||
"github.com/lucas-clemente/quic-go/internal/qerr"
|
||||
@@ -541,13 +542,14 @@ func (h *cryptoSetup) GetOpener(level protocol.EncryptionLevel) (Opener, error)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) ConnectionState() ConnectionState {
|
||||
connState := h.conn.ConnectionState()
|
||||
return ConnectionState{
|
||||
HandshakeComplete: connState.HandshakeComplete,
|
||||
ServerName: connState.ServerName,
|
||||
PeerCertificates: connState.PeerCertificates,
|
||||
}
|
||||
func (h *cryptoSetup) ConnectionState() tls.ConnectionState {
|
||||
cs := h.conn.ConnectionState()
|
||||
// h.conn is a qtls.Conn, which returns a qtls.ConnectionState.
|
||||
// qtls.ConnectionState is identical to the tls.ConnectionState.
|
||||
// It contains an unexported field which is used ExportKeyingMaterial().
|
||||
// The only way to return a tls.ConnectionState is to use unsafe.
|
||||
// In unsafe.go we check that the two objects are actually identical.
|
||||
return *(*tls.ConnectionState)(unsafe.Pointer(&cs))
|
||||
}
|
||||
|
||||
func (h *cryptoSetup) tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"io"
|
||||
|
||||
@@ -35,7 +36,7 @@ type CryptoSetup interface {
|
||||
ChangeConnectionID(protocol.ConnectionID) error
|
||||
|
||||
HandleMessage([]byte, protocol.EncryptionLevel) bool
|
||||
ConnectionState() ConnectionState
|
||||
ConnectionState() tls.ConnectionState
|
||||
|
||||
GetSealer() (protocol.EncryptionLevel, Sealer)
|
||||
GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (Sealer, error)
|
||||
|
||||
33
internal/handshake/unsafe.go
Normal file
33
internal/handshake/unsafe.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package handshake
|
||||
|
||||
// This package uses unsafe to convert between qtls.ConnectionState and tls.ConnectionState.
|
||||
// We check in init() that this conversion actually is safe.
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"reflect"
|
||||
|
||||
"github.com/marten-seemann/qtls"
|
||||
)
|
||||
|
||||
func init() {
|
||||
if !structsEqual(&tls.ConnectionState{}, &qtls.ConnectionState{}) {
|
||||
panic("qtls.ConnectionState not compatible with tls.ConnectionState")
|
||||
}
|
||||
}
|
||||
|
||||
func structsEqual(a, b interface{}) bool {
|
||||
sa := reflect.ValueOf(a).Elem()
|
||||
sb := reflect.ValueOf(b).Elem()
|
||||
if sa.NumField() != sb.NumField() {
|
||||
return false
|
||||
}
|
||||
for i := 0; i < sa.NumField(); i++ {
|
||||
fa := sa.Type().Field(i)
|
||||
fb := sb.Type().Field(i)
|
||||
if !reflect.DeepEqual(fa.Index, fb.Index) || fa.Name != fb.Name || fa.Anonymous != fb.Anonymous || fa.Offset != fb.Offset || !reflect.DeepEqual(fa.Type, fb.Type) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
60
internal/handshake/unsafe_test.go
Normal file
60
internal/handshake/unsafe_test.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
type target struct {
|
||||
Name string
|
||||
Version string
|
||||
|
||||
callback func(label string, length int) error
|
||||
}
|
||||
|
||||
type renamedField struct {
|
||||
NewName string
|
||||
Version string
|
||||
|
||||
callback func(label string, length int) error
|
||||
}
|
||||
|
||||
type renamedPrivateField struct {
|
||||
Name string
|
||||
Version string
|
||||
|
||||
cb func(label string, length int) error
|
||||
}
|
||||
|
||||
type additionalField struct {
|
||||
Name string
|
||||
Version string
|
||||
|
||||
callback func(label string, length int) error
|
||||
secret []byte
|
||||
}
|
||||
|
||||
type interchangedFields struct {
|
||||
Version string
|
||||
Name string
|
||||
|
||||
callback func(label string, length int) error
|
||||
}
|
||||
|
||||
type renamedCallbackFunctionParams struct { // should be equivalent
|
||||
Name string
|
||||
Version string
|
||||
|
||||
callback func(newLabel string, length int) error
|
||||
}
|
||||
|
||||
var _ = Describe("Unsafe checks", func() {
|
||||
It("detects if an unsafe conversion is safe", func() {
|
||||
Expect(structsEqual(&target{}, &target{})).To(BeTrue())
|
||||
Expect(structsEqual(&target{}, &renamedField{})).To(BeFalse())
|
||||
Expect(structsEqual(&target{}, &renamedPrivateField{})).To(BeFalse())
|
||||
Expect(structsEqual(&target{}, &additionalField{})).To(BeFalse())
|
||||
Expect(structsEqual(&target{}, &interchangedFields{})).To(BeFalse())
|
||||
Expect(structsEqual(&target{}, &renamedCallbackFunctionParams{})).To(BeTrue())
|
||||
})
|
||||
})
|
||||
@@ -5,6 +5,7 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
tls "crypto/tls"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -64,10 +65,10 @@ func (mr *MockCryptoSetupMockRecorder) Close() *gomock.Call {
|
||||
}
|
||||
|
||||
// ConnectionState mocks base method
|
||||
func (m *MockCryptoSetup) ConnectionState() handshake.ConnectionState {
|
||||
func (m *MockCryptoSetup) ConnectionState() tls.ConnectionState {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ConnectionState")
|
||||
ret0, _ := ret[0].(handshake.ConnectionState)
|
||||
ret0, _ := ret[0].(tls.ConnectionState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user