diff --git a/internal/handshake/crypto_setup_tls.go b/internal/handshake/crypto_setup_tls.go index 3ca34821c..2e5dd0259 100644 --- a/internal/handshake/crypto_setup_tls.go +++ b/internal/handshake/crypto_setup_tls.go @@ -24,6 +24,7 @@ type cryptoSetupTLS struct { aead crypto.AEAD tls mintTLS + conn *cryptoStreamConn handshakeEvent chan<- struct{} } @@ -41,9 +42,11 @@ func NewCryptoSetupTLSServer( if err != nil { return nil, err } - conn := mint.Server(newCryptoStreamConn(cryptoStream), config) + conn := newCryptoStreamConn(cryptoStream) + tls := mint.Server(conn, config) return &cryptoSetupTLS{ - tls: conn, + tls: tls, + conn: conn, nullAEAD: nullAEAD, perspective: protocol.PerspectiveServer, keyDerivation: crypto.DeriveAESKeys, @@ -63,9 +66,11 @@ func NewCryptoSetupTLSClient( if err != nil { return nil, err } - conn := mint.Client(newCryptoStreamConn(cryptoStream), config) + conn := newCryptoStreamConn(cryptoStream) + tls := mint.Client(conn, config) return &cryptoSetupTLS{ - tls: conn, + tls: tls, + conn: conn, perspective: protocol.PerspectiveClient, nullAEAD: nullAEAD, keyDerivation: crypto.DeriveAESKeys, @@ -79,6 +84,9 @@ func (h *cryptoSetupTLS) HandleCryptoStream() error { return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert) } state := h.tls.ConnectionState().HandshakeState + if err := h.conn.Flush(); err != nil { + return err + } if state == mint.StateClientConnected || state == mint.StateServerConnected { break } diff --git a/internal/handshake/crypto_stream_conn.go b/internal/handshake/crypto_stream_conn.go index a787826da..a031f90c0 100644 --- a/internal/handshake/crypto_stream_conn.go +++ b/internal/handshake/crypto_stream_conn.go @@ -1,23 +1,43 @@ package handshake import ( + "bytes" "io" "net" "time" ) type cryptoStreamConn struct { - io.ReadWriter + buffer *bytes.Buffer + stream io.ReadWriter } var _ net.Conn = &cryptoStreamConn{} -func newCryptoStreamConn(stream io.ReadWriter) net.Conn { +func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn { return &cryptoStreamConn{ - ReadWriter: stream, + stream: stream, + buffer: &bytes.Buffer{}, } } +func (c *cryptoStreamConn) Read(b []byte) (int, error) { + return c.stream.Read(b) +} + +func (c *cryptoStreamConn) Write(p []byte) (int, error) { + return c.buffer.Write(p) +} + +func (c *cryptoStreamConn) Flush() error { + if c.buffer.Len() == 0 { + return nil + } + _, err := c.stream.Write(c.buffer.Bytes()) + c.buffer.Reset() + return err +} + // Close is not implemented func (c *cryptoStreamConn) Close() error { return nil diff --git a/internal/handshake/crypto_stream_conn_test.go b/internal/handshake/crypto_stream_conn_test.go new file mode 100644 index 000000000..64bb6cbd2 --- /dev/null +++ b/internal/handshake/crypto_stream_conn_test.go @@ -0,0 +1,41 @@ +package handshake + +import ( + "bytes" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Crypto Stream Conn", func() { + var ( + stream *bytes.Buffer + csc *cryptoStreamConn + ) + + BeforeEach(func() { + stream = &bytes.Buffer{} + csc = newCryptoStreamConn(stream) + }) + + It("buffers writes", func() { + _, err := csc.Write([]byte("foo")) + Expect(err).ToNot(HaveOccurred()) + Expect(stream.Len()).To(BeZero()) + _, err = csc.Write([]byte("bar")) + Expect(err).ToNot(HaveOccurred()) + Expect(stream.Len()).To(BeZero()) + + Expect(csc.Flush()).To(Succeed()) + Expect(stream.Bytes()).To(Equal([]byte("foobar"))) + }) + + It("reads from the stream", func() { + stream.Write([]byte("foobar")) + b := make([]byte, 6) + n, err := csc.Read(b) + Expect(err).ToNot(HaveOccurred()) + Expect(n).To(Equal(6)) + Expect(b).To(Equal([]byte("foobar"))) + }) +})