diff --git a/crypto/AEAD.go b/crypto/AEAD.go index df7d7b4a..35772764 100644 --- a/crypto/AEAD.go +++ b/crypto/AEAD.go @@ -1,14 +1,9 @@ package crypto -import ( - "bytes" - "io" - - "github.com/lucas-clemente/quic-go/protocol" -) +import "github.com/lucas-clemente/quic-go/protocol" // An AEAD implements QUIC's authenticated encryption and associated data type AEAD interface { - Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext io.Reader) (*bytes.Reader, error) - Seal(packetNumber protocol.PacketNumber, b *bytes.Buffer, associatedData []byte, plaintext []byte) + Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext []byte) ([]byte, error) + Seal(packetNumber protocol.PacketNumber, associatedData []byte, plaintext []byte) []byte } diff --git a/crypto/chacha20poly1305_aead.go b/crypto/chacha20poly1305_aead.go index 3103e4d8..4098b5f5 100644 --- a/crypto/chacha20poly1305_aead.go +++ b/crypto/chacha20poly1305_aead.go @@ -1,12 +1,9 @@ package crypto import ( - "bytes" "crypto/cipher" "encoding/binary" "errors" - "io" - "io/ioutil" "github.com/lucas-clemente/quic-go/crypto/chacha20poly1305trunc12" "github.com/lucas-clemente/quic-go/protocol" @@ -40,21 +37,16 @@ func NewAEADChacha20Poly1305(otherKey []byte, myKey []byte, otherIV []byte, myIV }, nil } -func (aead *aeadChacha20Poly1305) Open(packetNumber protocol.PacketNumber, associatedData []byte, r io.Reader) (*bytes.Reader, error) { - ciphertext, err := ioutil.ReadAll(r) - if err != nil { - return nil, err - } +func (aead *aeadChacha20Poly1305) Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext []byte) ([]byte, error) { plaintext, err := aead.decrypter.Open(make([]byte, len(ciphertext)), makeNonce(aead.otherIV, packetNumber), ciphertext, associatedData) if err != nil { return nil, err } - return bytes.NewReader(plaintext), nil + return plaintext, nil } -func (aead *aeadChacha20Poly1305) Seal(packetNumber protocol.PacketNumber, b *bytes.Buffer, associatedData []byte, plaintext []byte) { - ciphertext := aead.encrypter.Seal(make([]byte, len(plaintext)+12), makeNonce(aead.myIV, packetNumber), plaintext, associatedData) - b.Write(ciphertext) +func (aead *aeadChacha20Poly1305) Seal(packetNumber protocol.PacketNumber, associatedData []byte, plaintext []byte) []byte { + return aead.encrypter.Seal(make([]byte, len(plaintext)+12), makeNonce(aead.myIV, packetNumber), plaintext, associatedData) } func makeNonce(iv []byte, packetNumber protocol.PacketNumber) []byte { diff --git a/crypto/chacha20poly1305_aead_test.go b/crypto/chacha20poly1305_aead_test.go index 739d62ce..8798d9b1 100644 --- a/crypto/chacha20poly1305_aead_test.go +++ b/crypto/chacha20poly1305_aead_test.go @@ -1,9 +1,7 @@ package crypto import ( - "bytes" "crypto/rand" - "io/ioutil" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -31,28 +29,21 @@ var _ = Describe("Chacha20poly1305", func() { }) It("seals and opens", func() { - b := &bytes.Buffer{} - alice.Seal(42, b, []byte("aad"), []byte("foobar")) - r, err := bob.Open(42, []byte("aad"), b) - Expect(err).ToNot(HaveOccurred()) - text, err := ioutil.ReadAll(r) + b := alice.Seal(42, []byte("aad"), []byte("foobar")) + text, err := bob.Open(42, []byte("aad"), b) Expect(err).ToNot(HaveOccurred()) Expect(text).To(Equal([]byte("foobar"))) }) It("seals and opens reverse", func() { - b := &bytes.Buffer{} - bob.Seal(42, b, []byte("aad"), []byte("foobar")) - r, err := alice.Open(42, []byte("aad"), b) - Expect(err).ToNot(HaveOccurred()) - text, err := ioutil.ReadAll(r) + b := bob.Seal(42, []byte("aad"), []byte("foobar")) + text, err := alice.Open(42, []byte("aad"), b) Expect(err).ToNot(HaveOccurred()) Expect(text).To(Equal([]byte("foobar"))) }) It("fails with wrong aad", func() { - b := &bytes.Buffer{} - alice.Seal(42, b, []byte("aad"), []byte("foobar")) + b := alice.Seal(42, []byte("aad"), []byte("foobar")) _, err := bob.Open(42, []byte("aad2"), b) Expect(err).To(HaveOccurred()) }) diff --git a/crypto/null_aead.go b/crypto/null_aead.go index 42608e59..34babfda 100644 --- a/crypto/null_aead.go +++ b/crypto/null_aead.go @@ -1,14 +1,10 @@ package crypto import ( - "bytes" "encoding/binary" "errors" - "io" - "io/ioutil" "github.com/lucas-clemente/quic-go/protocol" - "github.com/lucas-clemente/quic-go/utils" ) // NullAEAD handles not-yet encrypted packets @@ -17,11 +13,7 @@ type NullAEAD struct{} var _ AEAD = &NullAEAD{} // Open and verify the ciphertext -func (*NullAEAD) Open(packetNumber protocol.PacketNumber, associatedData []byte, r io.Reader) (*bytes.Reader, error) { - ciphertext, err := ioutil.ReadAll(r) - if err != nil { - return nil, err - } +func (*NullAEAD) Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext []byte) ([]byte, error) { if len(ciphertext) < 12 { return nil, errors.New("NullAEAD: ciphertext cannot be less than 12 bytes long") } @@ -37,17 +29,20 @@ func (*NullAEAD) Open(packetNumber protocol.PacketNumber, associatedData []byte, if uint32(testHigh&0xffffffff) != high || testLow != low { return nil, errors.New("NullAEAD: failed to authenticate received data") } - return bytes.NewReader(ciphertext[12:]), nil + return ciphertext[12:], nil } // Seal writes hash and ciphertext to the buffer -func (*NullAEAD) Seal(packetNumber protocol.PacketNumber, b *bytes.Buffer, associatedData []byte, plaintext []byte) { +func (*NullAEAD) Seal(packetNumber protocol.PacketNumber, associatedData []byte, plaintext []byte) []byte { + res := make([]byte, 12+len(plaintext)) + hash := New128a() hash.Write(associatedData) hash.Write(plaintext) high, low := hash.Sum128() - utils.WriteUint64(b, low) - utils.WriteUint32(b, uint32(high)) - b.Write(plaintext) + binary.LittleEndian.PutUint64(res, low) + binary.LittleEndian.PutUint32(res[8:], uint32(high)) + copy(res[12:], plaintext) + return res } diff --git a/crypto/null_aead_test.go b/crypto/null_aead_test.go index 39ec22d7..54affcc5 100644 --- a/crypto/null_aead_test.go +++ b/crypto/null_aead_test.go @@ -1,9 +1,6 @@ package crypto import ( - "bytes" - "io/ioutil" - . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" ) @@ -15,9 +12,7 @@ var _ = Describe("Crypto/NullAEAD", func() { hash := []byte{0x98, 0x9b, 0x33, 0x3f, 0xe8, 0xde, 0x32, 0x5c, 0xa6, 0x7f, 0x9c, 0xf7} cipherText := append(hash, plainText...) aead := &NullAEAD{} - r, err := aead.Open(0, aad, bytes.NewReader(cipherText)) - Expect(err).ToNot(HaveOccurred()) - res, err := ioutil.ReadAll(r) + res, err := aead.Open(0, aad, cipherText) Expect(err).ToNot(HaveOccurred()) Expect(res).To(Equal([]byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood."))) }) @@ -28,16 +23,14 @@ var _ = Describe("Crypto/NullAEAD", func() { hash := []byte{0x98, 0x9b, 0x33, 0x3f, 0xe8, 0xde, 0x32, 0x5c, 0xa6, 0x7f, 0x9c, 0xf7} cipherText := append(hash, plainText...) aead := &NullAEAD{} - _, err := aead.Open(0, aad, bytes.NewReader(cipherText)) + _, err := aead.Open(0, aad, cipherText) Expect(err).To(HaveOccurred()) }) It("seals", func() { aad := []byte("All human beings are born free and equal in dignity and rights.") plainText := []byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.") - b := &bytes.Buffer{} aead := &NullAEAD{} - aead.Seal(0, b, aad, plainText) - Expect(b.Bytes()).To(Equal(append([]byte{0x98, 0x9b, 0x33, 0x3f, 0xe8, 0xde, 0x32, 0x5c, 0xa6, 0x7f, 0x9c, 0xf7}, []byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")...))) + Expect(aead.Seal(0, aad, plainText)).To(Equal(append([]byte{0x98, 0x9b, 0x33, 0x3f, 0xe8, 0xde, 0x32, 0x5c, 0xa6, 0x7f, 0x9c, 0xf7}, []byte("They are endowed with reason and conscience and should act towards one another in a spirit of brotherhood.")...))) }) }) diff --git a/handshake/crypto_setup.go b/handshake/crypto_setup.go index 42a91558..d1ffad47 100644 --- a/handshake/crypto_setup.go +++ b/handshake/crypto_setup.go @@ -4,8 +4,6 @@ import ( "bytes" "crypto/rand" "fmt" - "io" - "io/ioutil" "sync" "github.com/lucas-clemente/quic-go/crypto" @@ -98,17 +96,12 @@ func (h *CryptoSetup) HandleCryptoStream() { } // Open a message -func (h *CryptoSetup) Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext io.Reader) (*bytes.Reader, error) { +func (h *CryptoSetup) Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext []byte) ([]byte, error) { h.mutex.RLock() defer h.mutex.RUnlock() - data, err := ioutil.ReadAll(ciphertext) - if err != nil { - return nil, err - } - if h.forwardSecureAEAD != nil { - res, err := h.forwardSecureAEAD.Open(packetNumber, associatedData, bytes.NewReader(data)) + res, err := h.forwardSecureAEAD.Open(packetNumber, associatedData, ciphertext) if err == nil { h.receivedForwardSecurePacket = true return res, nil @@ -118,22 +111,22 @@ func (h *CryptoSetup) Open(packetNumber protocol.PacketNumber, associatedData [] } } if h.secureAEAD != nil { - return h.secureAEAD.Open(packetNumber, associatedData, bytes.NewReader(data)) + return h.secureAEAD.Open(packetNumber, associatedData, ciphertext) } - return (&crypto.NullAEAD{}).Open(packetNumber, associatedData, bytes.NewReader(data)) + return (&crypto.NullAEAD{}).Open(packetNumber, associatedData, ciphertext) } // Seal a messageTag -func (h *CryptoSetup) Seal(packetNumber protocol.PacketNumber, b *bytes.Buffer, associatedData []byte, plaintext []byte) { +func (h *CryptoSetup) Seal(packetNumber protocol.PacketNumber, associatedData []byte, plaintext []byte) []byte { h.mutex.RLock() defer h.mutex.RUnlock() if h.receivedForwardSecurePacket { - h.forwardSecureAEAD.Seal(packetNumber, b, associatedData, plaintext) + return h.forwardSecureAEAD.Seal(packetNumber, associatedData, plaintext) } else if h.secureAEAD != nil { - h.secureAEAD.Seal(packetNumber, b, associatedData, plaintext) + return h.secureAEAD.Seal(packetNumber, associatedData, plaintext) } else { - (&crypto.NullAEAD{}).Seal(packetNumber, b, associatedData, plaintext) + return (&crypto.NullAEAD{}).Seal(packetNumber, associatedData, plaintext) } } diff --git a/handshake/crypto_setup_test.go b/handshake/crypto_setup_test.go index 7fc47dc5..ae75c481 100644 --- a/handshake/crypto_setup_test.go +++ b/handshake/crypto_setup_test.go @@ -3,8 +3,6 @@ package handshake import ( "bytes" "errors" - "io" - "io/ioutil" "github.com/lucas-clemente/quic-go/crypto" "github.com/lucas-clemente/quic-go/protocol" @@ -43,20 +41,18 @@ type mockAEAD struct { forwardSecure bool } -func (m *mockAEAD) Seal(packetNumber protocol.PacketNumber, b *bytes.Buffer, associatedData []byte, plaintext []byte) { +func (m *mockAEAD) Seal(packetNumber protocol.PacketNumber, associatedData []byte, plaintext []byte) []byte { if m.forwardSecure { - b.Write([]byte("forward secure encrypted")) - } else { - b.Write([]byte("encrypted")) + return []byte("forward secure encrypted") } + return []byte("encrypted") } -func (m *mockAEAD) Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext io.Reader) (*bytes.Reader, error) { - data, _ := ioutil.ReadAll(ciphertext) - if m.forwardSecure && string(data) == "forward secure encrypted" { - return bytes.NewReader([]byte("decrypted")), nil - } else if !m.forwardSecure && string(data) == "encrypted" { - return bytes.NewReader([]byte("decrypted")), nil +func (m *mockAEAD) Open(packetNumber protocol.PacketNumber, associatedData []byte, ciphertext []byte) ([]byte, error) { + if m.forwardSecure && string(ciphertext) == "forward secure encrypted" { + return []byte("decrypted"), nil + } else if !m.forwardSecure && string(ciphertext) == "encrypted" { + return []byte("decrypted"), nil } return nil, errors.New("authentication failed") } @@ -88,13 +84,11 @@ var _ = Describe("Crypto setup", func() { signer *mockSigner scfg *ServerConfig cs *CryptoSetup - buf *bytes.Buffer stream *mockStream ) BeforeEach(func() { stream = &mockStream{} - buf = &bytes.Buffer{} kex = &mockKEX{} signer = &mockSigner{} scfg = NewServerConfig(kex, signer) @@ -168,14 +162,11 @@ var _ = Describe("Crypto setup", func() { Context("null encryption", func() { It("is used initially", func() { - cs.Seal(0, buf, []byte{}, []byte("foobar")) - Expect(buf.Bytes()).To(Equal(foobarFNVSigned)) + Expect(cs.Seal(0, []byte{}, []byte("foobar"))).To(Equal(foobarFNVSigned)) }) It("is accepted initially", func() { - r, err := cs.Open(0, []byte{}, bytes.NewReader(foobarFNVSigned)) - Expect(err).ToNot(HaveOccurred()) - d, err := ioutil.ReadAll(r) + d, err := cs.Open(0, []byte{}, foobarFNVSigned) Expect(err).ToNot(HaveOccurred()) Expect(d).To(Equal([]byte("foobar"))) }) @@ -183,46 +174,44 @@ var _ = Describe("Crypto setup", func() { It("is not accepted after CHLO", func() { doCHLO() Expect(cs.secureAEAD).ToNot(BeNil()) - _, err := cs.Open(0, []byte{}, bytes.NewReader(foobarFNVSigned)) + _, err := cs.Open(0, []byte{}, foobarFNVSigned) Expect(err).To(MatchError("authentication failed")) }) It("is not used after CHLO", func() { doCHLO() - cs.Seal(0, buf, []byte{}, []byte("foobar")) - Expect(buf.Bytes()).ToNot(Equal(foobarFNVSigned)) + d := cs.Seal(0, []byte{}, []byte("foobar")) + Expect(d).ToNot(Equal(foobarFNVSigned)) }) }) Context("initial encryption", func() { It("is used after CHLO", func() { doCHLO() - cs.Seal(0, buf, []byte{}, []byte("foobar")) - Expect(buf.Bytes()).To(Equal([]byte("encrypted"))) + d := cs.Seal(0, []byte{}, []byte("foobar")) + Expect(d).To(Equal([]byte("encrypted"))) }) It("is accepted after CHLO", func() { doCHLO() - r, err := cs.Open(0, []byte{}, bytes.NewReader([]byte("encrypted"))) - Expect(err).ToNot(HaveOccurred()) - d, err := ioutil.ReadAll(r) + d, err := cs.Open(0, []byte{}, []byte("encrypted")) Expect(err).ToNot(HaveOccurred()) Expect(d).To(Equal([]byte("decrypted"))) }) It("is not used after receiving forward secure packet", func() { doCHLO() - _, err := cs.Open(0, []byte{}, bytes.NewReader([]byte("forward secure encrypted"))) + _, err := cs.Open(0, []byte{}, []byte("forward secure encrypted")) Expect(err).ToNot(HaveOccurred()) - cs.Seal(0, buf, []byte{}, []byte("foobar")) - Expect(buf.Bytes()).To(Equal([]byte("forward secure encrypted"))) + d := cs.Seal(0, []byte{}, []byte("foobar")) + Expect(d).To(Equal([]byte("forward secure encrypted"))) }) It("is not accepted after receiving forward secure packet", func() { doCHLO() - _, err := cs.Open(0, []byte{}, bytes.NewReader([]byte("forward secure encrypted"))) + _, err := cs.Open(0, []byte{}, []byte("forward secure encrypted")) Expect(err).ToNot(HaveOccurred()) - _, err = cs.Open(0, []byte{}, bytes.NewReader([]byte("encrypted"))) + _, err = cs.Open(0, []byte{}, []byte("encrypted")) Expect(err).To(MatchError("authentication failed")) }) }) @@ -230,10 +219,10 @@ var _ = Describe("Crypto setup", func() { Context("forward secure encryption", func() { It("is used after receiving forward secure packet", func() { doCHLO() - _, err := cs.Open(0, []byte{}, bytes.NewReader([]byte("forward secure encrypted"))) + _, err := cs.Open(0, []byte{}, []byte("forward secure encrypted")) Expect(err).ToNot(HaveOccurred()) - cs.Seal(0, buf, []byte{}, []byte("foobar")) - Expect(buf.Bytes()).To(Equal([]byte("forward secure encrypted"))) + d := cs.Seal(0, []byte{}, []byte("foobar")) + Expect(d).To(Equal([]byte("forward secure encrypted"))) }) }) }) diff --git a/session.go b/session.go index 0906b40d..a1eb3a70 100644 --- a/session.go +++ b/session.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io/ioutil" "net" "github.com/lucas-clemente/quic-go/frames" @@ -75,10 +76,12 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub s.CurrentRemoteAddr = addr } - r, err := s.cryptoSetup.Open(publicHeader.PacketNumber, publicHeaderBinary, r) + ciphertext, _ := ioutil.ReadAll(r) + plaintext, err := s.cryptoSetup.Open(publicHeader.PacketNumber, publicHeaderBinary, ciphertext) if err != nil { return err } + r = bytes.NewReader(plaintext) privateFlag, err := r.ReadByte() if err != nil { @@ -256,7 +259,8 @@ func (s *Session) SendFrame(frame frames.Frame) error { s.EntropySent.Add(packetNumber, entropyBit) s.EntropyHistory[packetNumber] = s.EntropySent - s.cryptoSetup.Seal(s.lastSentPacketNumber, &fullReply, fullReply.Bytes(), framesData.Bytes()) + ciphertext := s.cryptoSetup.Seal(s.lastSentPacketNumber, fullReply.Bytes(), framesData.Bytes()) + fullReply.Write(ciphertext) fmt.Printf("-> Sending packet %d (%d bytes) to %v\n", responsePublicHeader.PacketNumber, len(fullReply.Bytes()), s.CurrentRemoteAddr) _, err = s.Connection.WriteToUDP(fullReply.Bytes(), s.CurrentRemoteAddr)