From 556bf18dbf3bf7cd8ee6110207961545de806587 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 23 Aug 2020 17:03:15 +0700 Subject: [PATCH] inject a random source into the token protector --- internal/handshake/token_generator.go | 5 ++-- internal/handshake/token_generator_test.go | 3 ++- internal/handshake/token_protector.go | 11 +++++--- internal/handshake/token_protector_test.go | 30 +++++++++++++++++++++- server.go | 3 ++- session_test.go | 2 +- 6 files changed, 44 insertions(+), 10 deletions(-) diff --git a/internal/handshake/token_generator.go b/internal/handshake/token_generator.go index af8c3f1e..2df5fcd8 100644 --- a/internal/handshake/token_generator.go +++ b/internal/handshake/token_generator.go @@ -3,6 +3,7 @@ package handshake import ( "encoding/asn1" "fmt" + "io" "net" "time" @@ -39,8 +40,8 @@ type TokenGenerator struct { } // NewTokenGenerator initializes a new TookenGenerator -func NewTokenGenerator() (*TokenGenerator, error) { - tokenProtector, err := newTokenProtector() +func NewTokenGenerator(rand io.Reader) (*TokenGenerator, error) { + tokenProtector, err := newTokenProtector(rand) if err != nil { return nil, err } diff --git a/internal/handshake/token_generator_test.go b/internal/handshake/token_generator_test.go index 91cba354..3aef6a3d 100644 --- a/internal/handshake/token_generator_test.go +++ b/internal/handshake/token_generator_test.go @@ -1,6 +1,7 @@ package handshake import ( + "crypto/rand" "encoding/asn1" "net" "time" @@ -16,7 +17,7 @@ var _ = Describe("Token Generator", func() { BeforeEach(func() { var err error - tokenGen, err = NewTokenGenerator() + tokenGen, err = NewTokenGenerator(rand.Reader) Expect(err).ToNot(HaveOccurred()) }) diff --git a/internal/handshake/token_protector.go b/internal/handshake/token_protector.go index 33d18e60..650f230b 100644 --- a/internal/handshake/token_protector.go +++ b/internal/handshake/token_protector.go @@ -3,7 +3,6 @@ package handshake import ( "crypto/aes" "crypto/cipher" - "crypto/rand" "crypto/sha256" "fmt" "io" @@ -26,22 +25,26 @@ const ( // tokenProtector is used to create and verify a token type tokenProtectorImpl struct { + rand io.Reader secret []byte } // newTokenProtector creates a source for source address tokens -func newTokenProtector() (tokenProtector, error) { +func newTokenProtector(rand io.Reader) (tokenProtector, error) { secret := make([]byte, tokenSecretSize) if _, err := rand.Read(secret); err != nil { return nil, err } - return &tokenProtectorImpl{secret: secret}, nil + return &tokenProtectorImpl{ + rand: rand, + secret: secret, + }, nil } // NewToken encodes data into a new token. func (s *tokenProtectorImpl) NewToken(data []byte) ([]byte, error) { nonce := make([]byte, tokenNonceSize) - if _, err := rand.Read(nonce); err != nil { + if _, err := s.rand.Read(nonce); err != nil { return nil, err } aead, aeadNonce, err := s.createAEAD(nonce) diff --git a/internal/handshake/token_protector_test.go b/internal/handshake/token_protector_test.go index 53bae6ed..7171e865 100644 --- a/internal/handshake/token_protector_test.go +++ b/internal/handshake/token_protector_test.go @@ -1,19 +1,47 @@ package handshake import ( + "crypto/rand" + . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) +type zeroReader struct{} + +func (r *zeroReader) Read(b []byte) (int, error) { + for i := range b { + b[i] = 0 + } + return len(b), nil +} + var _ = Describe("Token Protector", func() { var tp tokenProtector BeforeEach(func() { var err error - tp, err = newTokenProtector() + tp, err = newTokenProtector(rand.Reader) Expect(err).ToNot(HaveOccurred()) }) + It("uses the random source", func() { + tp1, err := newTokenProtector(&zeroReader{}) + Expect(err).ToNot(HaveOccurred()) + tp2, err := newTokenProtector(&zeroReader{}) + Expect(err).ToNot(HaveOccurred()) + t1, err := tp1.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + t2, err := tp2.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(t1).To(Equal(t2)) + tp3, err := newTokenProtector(rand.Reader) + Expect(err).ToNot(HaveOccurred()) + t3, err := tp3.NewToken([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(t3).ToNot(Equal(t1)) + }) + It("encodes and decodes tokens", func() { token, err := tp.NewToken([]byte("foobar")) Expect(err).ToNot(HaveOccurred()) diff --git a/server.go b/server.go index 3403e22c..4c2c5525 100644 --- a/server.go +++ b/server.go @@ -3,6 +3,7 @@ package quic import ( "bytes" "context" + "crypto/rand" "crypto/tls" "errors" "fmt" @@ -185,7 +186,7 @@ func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config, acceptEarl if err != nil { return nil, err } - tokenGenerator, err := handshake.NewTokenGenerator() + tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) if err != nil { return nil, err } diff --git a/session_test.go b/session_test.go index 2f7cf408..e7901cce 100644 --- a/session_test.go +++ b/session_test.go @@ -86,7 +86,7 @@ var _ = Describe("Session", func() { mconn = NewMockSendConn(mockCtrl) mconn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes() mconn.EXPECT().LocalAddr().Return(localAddr).AnyTimes() - tokenGenerator, err := handshake.NewTokenGenerator() + tokenGenerator, err := handshake.NewTokenGenerator(rand.Reader) Expect(err).ToNot(HaveOccurred()) tracer = mocks.NewMockConnectionTracer(mockCtrl) tracer.EXPECT().SentTransportParameters(gomock.Any())