forked from quic-go/quic-go
expose the connection tracing ID on the stream context (#4414)
This is especially interesting for HTTP servers: They can now learn which connection a request was received on.
This commit is contained in:
@@ -272,8 +272,8 @@ var newConnection = func(
|
||||
s.queueControlFrame,
|
||||
connIDGenerator,
|
||||
)
|
||||
s.preSetup()
|
||||
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
||||
s.preSetup()
|
||||
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
||||
0,
|
||||
getMaxPacketSize(s.conn.RemoteAddr()),
|
||||
@@ -381,8 +381,8 @@ var newClientConnection = func(
|
||||
s.queueControlFrame,
|
||||
connIDGenerator,
|
||||
)
|
||||
s.preSetup()
|
||||
s.ctx, s.ctxCancel = context.WithCancelCause(context.WithValue(context.Background(), ConnectionTracingKey, tracingID))
|
||||
s.preSetup()
|
||||
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
|
||||
initialPacketNumber,
|
||||
getMaxPacketSize(s.conn.RemoteAddr()),
|
||||
@@ -471,6 +471,7 @@ func (s *connection) preSetup() {
|
||||
)
|
||||
s.earlyConnReadyChan = make(chan struct{})
|
||||
s.streamsMap = newStreamsMap(
|
||||
s.ctx,
|
||||
s,
|
||||
s.newFlowController,
|
||||
uint64(s.config.MaxIncomingStreams),
|
||||
|
||||
@@ -536,6 +536,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
|
||||
It("sets conn context", func() {
|
||||
type ctxKey int
|
||||
var tracingID quic.ConnectionTracingID
|
||||
server.ConnContext = func(ctx context.Context, c quic.Connection) context.Context {
|
||||
serv, ok := ctx.Value(http3.ServerContextKey).(*http3.Server)
|
||||
Expect(ok).To(BeTrue())
|
||||
@@ -543,6 +544,7 @@ var _ = Describe("HTTP tests", func() {
|
||||
|
||||
ctx = context.WithValue(ctx, ctxKey(0), "Hello")
|
||||
ctx = context.WithValue(ctx, ctxKey(1), c)
|
||||
tracingID = c.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
|
||||
return ctx
|
||||
}
|
||||
mux.HandleFunc("/conn-context", func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -558,6 +560,10 @@ var _ = Describe("HTTP tests", func() {
|
||||
serv, ok := r.Context().Value(http3.ServerContextKey).(*http3.Server)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(serv).To(Equal(server))
|
||||
|
||||
id, ok := r.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
|
||||
Expect(ok).To(BeTrue())
|
||||
Expect(id).To(Equal(tracingID))
|
||||
})
|
||||
|
||||
resp, err := client.Get(fmt.Sprintf("https://localhost:%d/conn-context", port))
|
||||
|
||||
@@ -60,6 +60,7 @@ var (
|
||||
)
|
||||
|
||||
func newSendStream(
|
||||
ctx context.Context,
|
||||
streamID protocol.StreamID,
|
||||
sender streamSender,
|
||||
flowController flowcontrol.StreamFlowController,
|
||||
@@ -71,7 +72,7 @@ func newSendStream(
|
||||
writeChan: make(chan struct{}, 1),
|
||||
writeOnce: make(chan struct{}, 1), // cap: 1, to protect against concurrent use of Write
|
||||
}
|
||||
s.ctx, s.ctxCancel = context.WithCancelCause(context.Background())
|
||||
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
|
||||
return s
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ var _ = Describe("Send Stream", func() {
|
||||
BeforeEach(func() {
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
|
||||
str = newSendStream(streamID, mockSender, mockFC)
|
||||
str = newSendStream(context.Background(), streamID, mockSender, mockFC)
|
||||
|
||||
timeout := scaleDuration(250 * time.Millisecond)
|
||||
strWithTimeout = gbytes.TimeoutWriter(str, timeout)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
@@ -85,7 +86,9 @@ type stream struct {
|
||||
var _ Stream = &stream{}
|
||||
|
||||
// newStream creates a new Stream
|
||||
func newStream(streamID protocol.StreamID,
|
||||
func newStream(
|
||||
ctx context.Context,
|
||||
streamID protocol.StreamID,
|
||||
sender streamSender,
|
||||
flowController flowcontrol.StreamFlowController,
|
||||
) *stream {
|
||||
@@ -99,7 +102,7 @@ func newStream(streamID protocol.StreamID,
|
||||
s.completedMutex.Unlock()
|
||||
},
|
||||
}
|
||||
s.sendStream = *newSendStream(streamID, senderForSendStream, flowController)
|
||||
s.sendStream = *newSendStream(ctx, streamID, senderForSendStream, flowController)
|
||||
senderForReceiveStream := &uniStreamSender{
|
||||
streamSender: sender,
|
||||
onStreamCompletedImpl: func() {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package quic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
@@ -41,7 +42,7 @@ var _ = Describe("Stream", func() {
|
||||
BeforeEach(func() {
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
mockFC = mocks.NewMockStreamFlowController(mockCtrl)
|
||||
str = newStream(streamID, mockSender, mockFC)
|
||||
str = newStream(context.Background(), streamID, mockSender, mockFC)
|
||||
|
||||
timeout := scaleDuration(250 * time.Millisecond)
|
||||
strWithTimeout = struct {
|
||||
|
||||
@@ -45,6 +45,7 @@ func (streamOpenErr) Timeout() bool { return false }
|
||||
var errTooManyOpenStreams = errors.New("too many open streams")
|
||||
|
||||
type streamsMap struct {
|
||||
ctx context.Context // not used for cancellations, but carries the values associated with the connection
|
||||
perspective protocol.Perspective
|
||||
|
||||
maxIncomingBidiStreams uint64
|
||||
@@ -64,6 +65,7 @@ type streamsMap struct {
|
||||
var _ streamManager = &streamsMap{}
|
||||
|
||||
func newStreamsMap(
|
||||
ctx context.Context,
|
||||
sender streamSender,
|
||||
newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
|
||||
maxIncomingBidiStreams uint64,
|
||||
@@ -71,6 +73,7 @@ func newStreamsMap(
|
||||
perspective protocol.Perspective,
|
||||
) streamManager {
|
||||
m := &streamsMap{
|
||||
ctx: ctx,
|
||||
perspective: perspective,
|
||||
newFlowController: newFlowController,
|
||||
maxIncomingBidiStreams: maxIncomingBidiStreams,
|
||||
@@ -86,7 +89,7 @@ func (m *streamsMap) initMaps() {
|
||||
protocol.StreamTypeBidi,
|
||||
func(num protocol.StreamNum) streamI {
|
||||
id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
|
||||
return newStream(id, m.sender, m.newFlowController(id))
|
||||
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
|
||||
},
|
||||
m.sender.queueControlFrame,
|
||||
)
|
||||
@@ -94,7 +97,7 @@ func (m *streamsMap) initMaps() {
|
||||
protocol.StreamTypeBidi,
|
||||
func(num protocol.StreamNum) streamI {
|
||||
id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
|
||||
return newStream(id, m.sender, m.newFlowController(id))
|
||||
return newStream(m.ctx, id, m.sender, m.newFlowController(id))
|
||||
},
|
||||
m.maxIncomingBidiStreams,
|
||||
m.sender.queueControlFrame,
|
||||
@@ -103,7 +106,7 @@ func (m *streamsMap) initMaps() {
|
||||
protocol.StreamTypeUni,
|
||||
func(num protocol.StreamNum) sendStreamI {
|
||||
id := num.StreamID(protocol.StreamTypeUni, m.perspective)
|
||||
return newSendStream(id, m.sender, m.newFlowController(id))
|
||||
return newSendStream(m.ctx, id, m.sender, m.newFlowController(id))
|
||||
},
|
||||
m.sender.queueControlFrame,
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ var _ = Describe("Streams Map", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
mockSender = NewMockStreamSender(mockCtrl)
|
||||
m = newStreamsMap(mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap)
|
||||
m = newStreamsMap(context.Background(), mockSender, newFlowController, MaxBidiStreamNum, MaxUniStreamNum, perspective).(*streamsMap)
|
||||
})
|
||||
|
||||
Context("opening", func() {
|
||||
|
||||
Reference in New Issue
Block a user