forked from quic-go/quic-go
replace CachingReader with io.TeeReader
This commit is contained in:
@@ -3,6 +3,7 @@ package handshake
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -70,19 +71,18 @@ func NewCryptoSetup(
|
|||||||
// HandleCryptoStream reads and writes messages on the crypto stream
|
// HandleCryptoStream reads and writes messages on the crypto stream
|
||||||
func (h *CryptoSetup) HandleCryptoStream() error {
|
func (h *CryptoSetup) HandleCryptoStream() error {
|
||||||
for {
|
for {
|
||||||
cachingReader := utils.NewCachingReader(h.cryptoStream)
|
var chloData bytes.Buffer
|
||||||
messageTag, cryptoData, err := ParseHandshakeMessage(cachingReader)
|
messageTag, cryptoData, err := ParseHandshakeMessage(io.TeeReader(h.cryptoStream, &chloData))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return qerr.HandshakeFailed
|
return qerr.HandshakeFailed
|
||||||
}
|
}
|
||||||
if messageTag != TagCHLO {
|
if messageTag != TagCHLO {
|
||||||
return qerr.InvalidCryptoMessageType
|
return qerr.InvalidCryptoMessageType
|
||||||
}
|
}
|
||||||
chloData := cachingReader.Get()
|
|
||||||
|
|
||||||
utils.Debugf("Got CHLO:\n%s", printHandshakeMessage(cryptoData))
|
utils.Debugf("Got CHLO:\n%s", printHandshakeMessage(cryptoData))
|
||||||
|
|
||||||
done, err := h.handleMessage(chloData, cryptoData)
|
done, err := h.handleMessage(chloData.Bytes(), cryptoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
import "bytes"
|
|
||||||
|
|
||||||
// CachingReader wraps a reader and saves all data it reads
|
|
||||||
type CachingReader struct {
|
|
||||||
buf bytes.Buffer
|
|
||||||
r ReadStream
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCachingReader returns a new CachingReader
|
|
||||||
func NewCachingReader(r ReadStream) *CachingReader {
|
|
||||||
return &CachingReader{r: r}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read implements io.Reader
|
|
||||||
func (r *CachingReader) Read(p []byte) (int, error) {
|
|
||||||
n, err := r.r.Read(p)
|
|
||||||
r.buf.Write(p[:n])
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReadByte implements io.ByteReader
|
|
||||||
func (r *CachingReader) ReadByte() (byte, error) {
|
|
||||||
b, err := r.r.ReadByte()
|
|
||||||
if err == nil {
|
|
||||||
r.buf.WriteByte(b)
|
|
||||||
}
|
|
||||||
return b, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the data cached
|
|
||||||
func (r *CachingReader) Get() []byte {
|
|
||||||
return r.buf.Bytes()
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
|
|
||||||
. "github.com/onsi/ginkgo"
|
|
||||||
. "github.com/onsi/gomega"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ = Describe("Caching reader", func() {
|
|
||||||
It("caches Read()", func() {
|
|
||||||
r := bytes.NewReader([]byte("foobar"))
|
|
||||||
cr := NewCachingReader(r)
|
|
||||||
p := make([]byte, 3)
|
|
||||||
n, err := cr.Read(p)
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(n).To(Equal(3))
|
|
||||||
Expect(p).To(Equal([]byte("foo")))
|
|
||||||
Expect(cr.Get()).To(Equal([]byte("foo")))
|
|
||||||
})
|
|
||||||
|
|
||||||
It("caches ReadByte()", func() {
|
|
||||||
r := bytes.NewReader([]byte("foobar"))
|
|
||||||
cr := NewCachingReader(r)
|
|
||||||
b, err := cr.ReadByte()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(b).To(Equal(byte('f')))
|
|
||||||
b, err = cr.ReadByte()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(b).To(Equal(byte('o')))
|
|
||||||
b, err = cr.ReadByte()
|
|
||||||
Expect(err).ToNot(HaveOccurred())
|
|
||||||
Expect(b).To(Equal(byte('o')))
|
|
||||||
Expect(cr.Get()).To(Equal([]byte("foo")))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
Reference in New Issue
Block a user