forked from quic-go/quic-go
buffer writes to the crypto stream
mint performs a Write for every state change. This results in a lot of small packets getting sent when using an unbuffered connection. By buffering, we make sure that packets are filled up properly.
This commit is contained in:
@@ -24,6 +24,7 @@ type cryptoSetupTLS struct {
|
|||||||
aead crypto.AEAD
|
aead crypto.AEAD
|
||||||
|
|
||||||
tls mintTLS
|
tls mintTLS
|
||||||
|
conn *cryptoStreamConn
|
||||||
handshakeEvent chan<- struct{}
|
handshakeEvent chan<- struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,9 +42,11 @@ func NewCryptoSetupTLSServer(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
conn := mint.Server(newCryptoStreamConn(cryptoStream), config)
|
conn := newCryptoStreamConn(cryptoStream)
|
||||||
|
tls := mint.Server(conn, config)
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
tls: conn,
|
tls: tls,
|
||||||
|
conn: conn,
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
perspective: protocol.PerspectiveServer,
|
perspective: protocol.PerspectiveServer,
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
@@ -63,9 +66,11 @@ func NewCryptoSetupTLSClient(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
conn := mint.Client(newCryptoStreamConn(cryptoStream), config)
|
conn := newCryptoStreamConn(cryptoStream)
|
||||||
|
tls := mint.Client(conn, config)
|
||||||
return &cryptoSetupTLS{
|
return &cryptoSetupTLS{
|
||||||
tls: conn,
|
tls: tls,
|
||||||
|
conn: conn,
|
||||||
perspective: protocol.PerspectiveClient,
|
perspective: protocol.PerspectiveClient,
|
||||||
nullAEAD: nullAEAD,
|
nullAEAD: nullAEAD,
|
||||||
keyDerivation: crypto.DeriveAESKeys,
|
keyDerivation: crypto.DeriveAESKeys,
|
||||||
@@ -79,6 +84,9 @@ func (h *cryptoSetupTLS) HandleCryptoStream() error {
|
|||||||
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
return fmt.Errorf("TLS handshake error: %s (Alert %d)", alert.String(), alert)
|
||||||
}
|
}
|
||||||
state := h.tls.ConnectionState().HandshakeState
|
state := h.tls.ConnectionState().HandshakeState
|
||||||
|
if err := h.conn.Flush(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if state == mint.StateClientConnected || state == mint.StateServerConnected {
|
if state == mint.StateClientConnected || state == mint.StateServerConnected {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +1,43 @@
|
|||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type cryptoStreamConn struct {
|
type cryptoStreamConn struct {
|
||||||
io.ReadWriter
|
buffer *bytes.Buffer
|
||||||
|
stream io.ReadWriter
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ net.Conn = &cryptoStreamConn{}
|
var _ net.Conn = &cryptoStreamConn{}
|
||||||
|
|
||||||
func newCryptoStreamConn(stream io.ReadWriter) net.Conn {
|
func newCryptoStreamConn(stream io.ReadWriter) *cryptoStreamConn {
|
||||||
return &cryptoStreamConn{
|
return &cryptoStreamConn{
|
||||||
ReadWriter: stream,
|
stream: stream,
|
||||||
|
buffer: &bytes.Buffer{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *cryptoStreamConn) Read(b []byte) (int, error) {
|
||||||
|
return c.stream.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cryptoStreamConn) Write(p []byte) (int, error) {
|
||||||
|
return c.buffer.Write(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *cryptoStreamConn) Flush() error {
|
||||||
|
if c.buffer.Len() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err := c.stream.Write(c.buffer.Bytes())
|
||||||
|
c.buffer.Reset()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// Close is not implemented
|
// Close is not implemented
|
||||||
func (c *cryptoStreamConn) Close() error {
|
func (c *cryptoStreamConn) Close() error {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
41
internal/handshake/crypto_stream_conn_test.go
Normal file
41
internal/handshake/crypto_stream_conn_test.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package handshake
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
. "github.com/onsi/ginkgo"
|
||||||
|
. "github.com/onsi/gomega"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ = Describe("Crypto Stream Conn", func() {
|
||||||
|
var (
|
||||||
|
stream *bytes.Buffer
|
||||||
|
csc *cryptoStreamConn
|
||||||
|
)
|
||||||
|
|
||||||
|
BeforeEach(func() {
|
||||||
|
stream = &bytes.Buffer{}
|
||||||
|
csc = newCryptoStreamConn(stream)
|
||||||
|
})
|
||||||
|
|
||||||
|
It("buffers writes", func() {
|
||||||
|
_, err := csc.Write([]byte("foo"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(stream.Len()).To(BeZero())
|
||||||
|
_, err = csc.Write([]byte("bar"))
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(stream.Len()).To(BeZero())
|
||||||
|
|
||||||
|
Expect(csc.Flush()).To(Succeed())
|
||||||
|
Expect(stream.Bytes()).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
|
||||||
|
It("reads from the stream", func() {
|
||||||
|
stream.Write([]byte("foobar"))
|
||||||
|
b := make([]byte, 6)
|
||||||
|
n, err := csc.Read(b)
|
||||||
|
Expect(err).ToNot(HaveOccurred())
|
||||||
|
Expect(n).To(Equal(6))
|
||||||
|
Expect(b).To(Equal([]byte("foobar")))
|
||||||
|
})
|
||||||
|
})
|
||||||
Reference in New Issue
Block a user