forked from quic-go/quic-go
buffer writes to the crypto stream
mint performs a Write for every state change. This results in a lot of small packets getting sent when using an unbuffered connection. By buffering, we make sure that packets are filled up properly.
This commit is contained in:
@@ -24,6 +24,7 @@ type cryptoSetupTLS struct {
|
||||
aead crypto.AEAD
|
||||
|
||||
tls mintTLS
|
||||
conn *cryptoStreamConn
|
||||
handshakeEvent chan<- struct{}
|
||||
}
|
||||
|
||||
@@ -41,9 +42,11 @@ func NewCryptoSetupTLSServer(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := mint.Server(newCryptoStreamConn(cryptoStream), config)
|
||||
conn := newCryptoStreamConn(cryptoStream)
|
||||
tls := mint.Server(conn, config)
|
||||
return &cryptoSetupTLS{
|
||||
tls: conn,
|
||||
tls: tls,
|
||||
conn: conn,
|
||||
nullAEAD: nullAEAD,
|
||||
perspective: protocol.PerspectiveServer,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
@@ -63,9 +66,11 @@ func NewCryptoSetupTLSClient(
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conn := mint.Client(newCryptoStreamConn(cryptoStream), config)
|
||||
conn := newCryptoStreamConn(cryptoStream)
|
||||
tls := mint.Client(conn, config)
|
||||
return &cryptoSetupTLS{
|
||||
tls: conn,
|
||||
tls: tls,
|
||||
conn: conn,
|
||||
perspective: protocol.PerspectiveClient,
|
||||
nullAEAD: nullAEAD,
|
||||
keyDerivation: crypto.DeriveAESKeys,
|
||||
@@ -79,6 +84,9 @@ func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||
}
|
||||
state := h.tls.ConnectionState().HandshakeState
|
||||
if err := h.conn.Flush(); err != nil {
|
||||
return err
|
||||
}
|
||||
if state == mint.StateClientConnected || state == mint.StateServerConnected {
|
||||
break
|
||||
}
|
||||
|
||||
@@ -1,23 +1,43 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type cryptoStreamConn struct {
|
||||
io.ReadWriter
|
||||
buffer *bytes.Buffer
|
||||
stream io.ReadWriter
|
||||
}
|
||||
|
||||
var _ net.Conn = &cryptoStreamConn{}
|
||||
|
||||
func newCryptoStreamConn(stream io.ReadWriter) net.Conn {
|
||||
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
|
||||
return &cryptoStreamConn{
|
||||
ReadWriter: stream,
|
||||
stream: stream,
|
||||
buffer: &bytes.Buffer{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
|
||||
return c.stream.Read(b)
|
||||
}
|
||||
|
||||
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
|
||||
return c.buffer.Write(p)
|
||||
}
|
||||
|
||||
func (c *cryptoStreamConn) Flush() error {
|
||||
if c.buffer.Len() == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := c.stream.Write(c.buffer.Bytes())
|
||||
c.buffer.Reset()
|
||||
return err
|
||||
}
|
||||
|
||||
// Close is not implemented
|
||||
func (c *cryptoStreamConn) Close() error {
|
||||
return nil
|
||||
|
||||
41
internal/handshake/crypto_stream_conn_test.go
Normal file
41
internal/handshake/crypto_stream_conn_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/onsi/ginkgo"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Crypto Stream Conn", func() {
|
||||
var (
|
||||
stream *bytes.Buffer
|
||||
csc *cryptoStreamConn
|
||||
)
|
||||
|
||||
BeforeEach(func() {
|
||||
stream = &bytes.Buffer{}
|
||||
csc = newCryptoStreamConn(stream)
|
||||
})
|
||||
|
||||
It("buffers writes", func() {
|
||||
_, err := csc.Write([]byte("foo"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.Len()).To(BeZero())
|
||||
_, err = csc.Write([]byte("bar"))
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(stream.Len()).To(BeZero())
|
||||
|
||||
Expect(csc.Flush()).To(Succeed())
|
||||
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||
})
|
||||
|
||||
It("reads from the stream", func() {
|
||||
stream.Write([]byte("foobar"))
|
||||
b := make([]byte, 6)
|
||||
n, err := csc.Read(b)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(n).To(Equal(6))
|
||||
Expect(b).To(Equal([]byte("foobar")))
|
||||
})
|
||||
})
|
||||
Reference in New Issue
Block a user