forked from quic-go/quic-go
split SNI and ECH extensions in the ClientHello (#5107)
* create a new type for crypto stream used for Initial data This currently the exact same implementation as the other streams, thus no functional change is expected. * handshake: implement a function to find the SNI and the ECH extension * move the SNI parsing logic to the quic package * implement splitting logic * generalize cutting logic * introduce QUIC_GO_DISABLE_CLIENTHELLO_SCRAMBLING * improve testing
This commit is contained in:
181
crypto_stream.go
181
crypto_stream.go
@@ -1,14 +1,23 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/quic-go/quic-go/internal/protocol"
|
||||
"github.com/quic-go/quic-go/internal/qerr"
|
||||
"github.com/quic-go/quic-go/internal/wire"
|
||||
)
|
||||
|
||||
type cryptoStream struct {
|
||||
const disableClientHelloScramblingEnv = "QUIC_GO_DISABLE_CLIENTHELLO_SCRAMBLING"
|
||||
|
||||
// The baseCryptoStream is used by the cryptoStream and the initialCryptoStream.
|
||||
// This allows us to implement different logic for PopCryptoFrame for the two streams.
|
||||
type baseCryptoStream struct {
|
||||
queue frameSorter
|
||||
|
||||
highestOffset protocol.ByteCount
|
||||
@@ -19,10 +28,10 @@ type cryptoStream struct {
|
||||
}
|
||||
|
||||
func newCryptoStream() *cryptoStream {
|
||||
return &cryptoStream{queue: *newFrameSorter()}
|
||||
return &cryptoStream{baseCryptoStream{queue: *newFrameSorter()}}
|
||||
}
|
||||
|
||||
func (s *cryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||
func (s *baseCryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
|
||||
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
|
||||
return &qerr.TransportError{
|
||||
@@ -47,12 +56,12 @@ func (s *cryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
|
||||
}
|
||||
|
||||
// GetCryptoData retrieves data that was received in CRYPTO frames
|
||||
func (s *cryptoStream) GetCryptoData() []byte {
|
||||
func (s *baseCryptoStream) GetCryptoData() []byte {
|
||||
_, data, _ := s.queue.Pop()
|
||||
return data
|
||||
}
|
||||
|
||||
func (s *cryptoStream) Finish() error {
|
||||
func (s *baseCryptoStream) Finish() error {
|
||||
if s.queue.HasMoreData() {
|
||||
return &qerr.TransportError{
|
||||
ErrorCode: qerr.ProtocolViolation,
|
||||
@@ -64,19 +73,19 @@ func (s *cryptoStream) Finish() error {
|
||||
}
|
||||
|
||||
// Writes writes data that should be sent out in CRYPTO frames
|
||||
func (s *cryptoStream) Write(p []byte) (int, error) {
|
||||
func (s *baseCryptoStream) Write(p []byte) (int, error) {
|
||||
s.writeBuf = append(s.writeBuf, p...)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *cryptoStream) HasData() bool {
|
||||
func (s *baseCryptoStream) HasData() bool {
|
||||
return len(s.writeBuf) > 0
|
||||
}
|
||||
|
||||
func (s *cryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||
func (s *baseCryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||
n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
|
||||
if n == 0 {
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
f.Data = s.writeBuf[:n]
|
||||
@@ -84,3 +93,157 @@ func (s *cryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFra
|
||||
s.writeOffset += n
|
||||
return f
|
||||
}
|
||||
|
||||
type cryptoStream struct {
|
||||
baseCryptoStream
|
||||
}
|
||||
|
||||
type clientHelloCut struct {
|
||||
start protocol.ByteCount
|
||||
end protocol.ByteCount
|
||||
}
|
||||
|
||||
type initialCryptoStream struct {
|
||||
baseCryptoStream
|
||||
|
||||
scramble bool
|
||||
end protocol.ByteCount
|
||||
cuts [2]clientHelloCut
|
||||
}
|
||||
|
||||
func newInitialCryptoStream(isClient bool) *initialCryptoStream {
|
||||
var scramble bool
|
||||
if isClient {
|
||||
disabled, err := strconv.ParseBool(os.Getenv(disableClientHelloScramblingEnv))
|
||||
scramble = err != nil || !disabled
|
||||
}
|
||||
s := &initialCryptoStream{
|
||||
baseCryptoStream: baseCryptoStream{queue: *newFrameSorter()},
|
||||
scramble: scramble,
|
||||
}
|
||||
for i := range len(s.cuts) {
|
||||
s.cuts[i].start = protocol.InvalidByteCount
|
||||
s.cuts[i].end = protocol.InvalidByteCount
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *initialCryptoStream) HasData() bool {
|
||||
// The ClientHello might be written in multiple parts.
|
||||
// In order to correctly split the ClientHello, we need the entire ClientHello has been queued.
|
||||
if s.scramble && s.writeOffset == 0 && s.cuts[0].start == protocol.InvalidByteCount {
|
||||
return false
|
||||
}
|
||||
return s.baseCryptoStream.HasData()
|
||||
}
|
||||
|
||||
func (s *initialCryptoStream) Write(p []byte) (int, error) {
|
||||
s.writeBuf = append(s.writeBuf, p...)
|
||||
if !s.scramble {
|
||||
return len(p), nil
|
||||
}
|
||||
if s.cuts[0].start == protocol.InvalidByteCount {
|
||||
sniPos, sniLen, echPos, err := findSNIAndECH(s.writeBuf)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return len(p), nil
|
||||
}
|
||||
if err != nil {
|
||||
return len(p), err
|
||||
}
|
||||
if sniPos == -1 && echPos == -1 {
|
||||
// Neither SNI nor ECH found.
|
||||
// There's nothing to scramble.
|
||||
s.scramble = false
|
||||
return len(p), nil
|
||||
}
|
||||
s.end = protocol.ByteCount(len(s.writeBuf))
|
||||
s.cuts[0].start = protocol.ByteCount(sniPos + sniLen/2) // right in the middle
|
||||
s.cuts[0].end = protocol.ByteCount(sniPos + sniLen)
|
||||
if echPos > 0 {
|
||||
// ECH extension found, cut the ECH extension type value (a uint16) in half
|
||||
start := protocol.ByteCount(echPos + 1)
|
||||
s.cuts[1].start = start
|
||||
// cut somewhere (16 bytes), most likely in the ECH extension value
|
||||
s.cuts[1].end = min(start+16, s.end)
|
||||
}
|
||||
slices.SortFunc(s.cuts[:], func(a, b clientHelloCut) int {
|
||||
if a.start == protocol.InvalidByteCount {
|
||||
return 1
|
||||
}
|
||||
if a.start > b.start {
|
||||
return 1
|
||||
}
|
||||
return -1
|
||||
})
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *initialCryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
|
||||
if !s.scramble {
|
||||
return s.baseCryptoStream.PopCryptoFrame(maxLen)
|
||||
}
|
||||
|
||||
// send out the skipped parts
|
||||
if s.writeOffset == s.end {
|
||||
var foundCuts bool
|
||||
var f *wire.CryptoFrame
|
||||
for i, c := range s.cuts {
|
||||
if c.start == protocol.InvalidByteCount {
|
||||
continue
|
||||
}
|
||||
foundCuts = true
|
||||
if f != nil {
|
||||
break
|
||||
}
|
||||
f = &wire.CryptoFrame{Offset: c.start}
|
||||
n := min(f.MaxDataLen(maxLen), c.end-c.start)
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
f.Data = s.writeBuf[c.start : c.start+n]
|
||||
s.cuts[i].start += n
|
||||
if s.cuts[i].start == c.end {
|
||||
s.cuts[i].start = protocol.InvalidByteCount
|
||||
s.cuts[i].end = protocol.InvalidByteCount
|
||||
foundCuts = false
|
||||
}
|
||||
}
|
||||
if !foundCuts {
|
||||
// no more cuts found, we're done sending out everything up until s.end
|
||||
s.writeBuf = s.writeBuf[s.end:]
|
||||
s.end = protocol.InvalidByteCount
|
||||
s.scramble = false
|
||||
}
|
||||
return f
|
||||
}
|
||||
|
||||
nextCut := clientHelloCut{start: protocol.InvalidByteCount, end: protocol.InvalidByteCount}
|
||||
for _, c := range s.cuts {
|
||||
if c.start == protocol.InvalidByteCount {
|
||||
continue
|
||||
}
|
||||
if c.start > s.writeOffset {
|
||||
nextCut = c
|
||||
break
|
||||
}
|
||||
}
|
||||
f := &wire.CryptoFrame{Offset: s.writeOffset}
|
||||
maxOffset := nextCut.start
|
||||
if maxOffset == protocol.InvalidByteCount {
|
||||
maxOffset = s.end
|
||||
}
|
||||
n := min(f.MaxDataLen(maxLen), maxOffset-s.writeOffset)
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
f.Data = s.writeBuf[s.writeOffset : s.writeOffset+n]
|
||||
// Don't reslice the writeBuf yet.
|
||||
// This is done once all parts have been sent out.
|
||||
s.writeOffset += n
|
||||
if s.writeOffset == nextCut.start {
|
||||
s.writeOffset = nextCut.end
|
||||
}
|
||||
|
||||
return f
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user