forked from quic-go/quic-go
simplify frame type switch, introduce temporary stream 1 offset variable
This commit is contained in:
184
session.go
184
session.go
@@ -33,6 +33,8 @@ type Session struct {
|
||||
Streams map[protocol.StreamID]*Stream
|
||||
|
||||
streamCallback StreamCallback
|
||||
|
||||
s1offset uint64
|
||||
}
|
||||
|
||||
// NewSession makes a new session
|
||||
@@ -51,8 +53,6 @@ func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protoc
|
||||
|
||||
// HandlePacket handles a packet
|
||||
func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, publicHeader *PublicHeader, r *bytes.Reader) error {
|
||||
// TODO: Only do this after authenticating
|
||||
|
||||
if s.lastObservedPacketNumber > 0 { // the first packet doesn't neccessarily need to have packetNumber 1
|
||||
if publicHeader.PacketNumber < s.lastObservedPacketNumber || publicHeader.PacketNumber > s.lastObservedPacketNumber+1 {
|
||||
return errors.New("Out of order packet")
|
||||
@@ -63,6 +63,7 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub
|
||||
}
|
||||
s.lastObservedPacketNumber = publicHeader.PacketNumber
|
||||
|
||||
// TODO: Only do this after authenticating
|
||||
if addr != s.CurrentRemoteAddr {
|
||||
s.CurrentRemoteAddr = addr
|
||||
}
|
||||
@@ -87,89 +88,124 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub
|
||||
|
||||
// read all frames in the packet
|
||||
for r.Len() > 0 {
|
||||
typeByte, err := r.ReadByte()
|
||||
if err != nil {
|
||||
fmt.Println("No more frames in this packet.")
|
||||
break
|
||||
}
|
||||
typeByte, _ := r.ReadByte()
|
||||
r.UnreadByte()
|
||||
|
||||
frameCounter++
|
||||
fmt.Printf("Reading frame %d\n", frameCounter)
|
||||
|
||||
if typeByte&0x80 == 0x80 { // STREAM
|
||||
fmt.Println("Detected STREAM")
|
||||
frame, err := frames.ParseStreamFrame(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID)
|
||||
|
||||
if frame.StreamID == 0 {
|
||||
return errors.New("Session: 0 is not a valid Stream ID")
|
||||
}
|
||||
|
||||
if frame.StreamID == 1 {
|
||||
reply, err := s.cryptoSetup.HandleCryptoMessage(frame.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply != nil {
|
||||
s.SendFrames([]frames.Frame{&frames.StreamFrame{StreamID: 1, Data: reply}})
|
||||
}
|
||||
} else {
|
||||
stream, ok := s.Streams[frame.StreamID]
|
||||
if !ok {
|
||||
stream = NewStream(frame.StreamID)
|
||||
s.Streams[frame.StreamID] = stream
|
||||
}
|
||||
err := stream.AddStreamFrame(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
replyFrames := s.streamCallback(stream)
|
||||
if replyFrames != nil {
|
||||
s.SendFrames(replyFrames)
|
||||
}
|
||||
}
|
||||
continue
|
||||
} else if typeByte&0xC0 == 0x40 { // ACK
|
||||
fmt.Println("Detected ACK")
|
||||
frame, err := frames.ParseAckFrame(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("%#v\n", frame)
|
||||
|
||||
continue
|
||||
} else if typeByte&0xE0 == 0x20 { // CONGESTION_FEEDBACK
|
||||
return errors.New("Detected CONGESTION_FEEDBACK")
|
||||
} else if typeByte&0x06 == 0x06 { // STOP_WAITING
|
||||
fmt.Println("Detected STOP_WAITING")
|
||||
_, err := frames.ParseStopWaitingFrame(r, publicHeader.PacketNumberLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// ToDo: react to receiving this frame
|
||||
} else if typeByte&0x02 == 0x02 { // CONNECTION_CLOSE
|
||||
fmt.Println("Detected CONNECTION_CLOSE")
|
||||
frame, err := frames.ParseConnectionCloseFrame(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("%#v\n", frame)
|
||||
} else if typeByte == 0 {
|
||||
// PAD
|
||||
return nil
|
||||
err = nil
|
||||
if typeByte&0x80 == 0x80 {
|
||||
err = s.handleStreamFrame(r)
|
||||
} else if typeByte == 0x40 {
|
||||
err = s.handleAckFrame(r)
|
||||
} else if typeByte&0xE0 == 0x20 {
|
||||
err = errors.New("unimplemented: CONGESTION_FEEDBACK")
|
||||
} else {
|
||||
return errors.New("Session: invalid Frame Type Field")
|
||||
switch typeByte {
|
||||
case 0x0: // PAD
|
||||
return nil
|
||||
case 0x01:
|
||||
err = errors.New("unimplemented: RST_STREAM")
|
||||
case 0x02:
|
||||
err = s.handleConnectionCloseFrame(r)
|
||||
case 0x03:
|
||||
err = errors.New("unimplemented: GOAWAY")
|
||||
case 0x04:
|
||||
err = errors.New("unimplemented: WINDOW_UPDATE")
|
||||
case 0x05:
|
||||
err = errors.New("unimplemented: BLOCKED")
|
||||
case 0x06:
|
||||
err = s.handleStopWaitingFrame(r, publicHeader)
|
||||
case 0x07:
|
||||
// PING, do nothing
|
||||
r.ReadByte()
|
||||
default:
|
||||
err = fmt.Errorf("unknown frame type: %x", typeByte)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) handleStreamFrame(r *bytes.Reader) error {
|
||||
fmt.Println("Detected STREAM")
|
||||
frame, err := frames.ParseStreamFrame(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID)
|
||||
|
||||
if frame.StreamID == 0 {
|
||||
return errors.New("Session: 0 is not a valid Stream ID")
|
||||
}
|
||||
|
||||
if frame.StreamID == 1 {
|
||||
reply, err := s.cryptoSetup.HandleCryptoMessage(frame.Data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if reply != nil {
|
||||
if len(reply) > 1000 {
|
||||
s.SendFrames([]frames.Frame{&frames.StreamFrame{StreamID: 1, Offset: s.s1offset, Data: reply[:1000]}})
|
||||
s.s1offset += 1000
|
||||
s.SendFrames([]frames.Frame{&frames.StreamFrame{StreamID: 1, Offset: s.s1offset, Data: reply[1000:]}})
|
||||
s.s1offset += uint64(len(reply[1000:]))
|
||||
} else {
|
||||
s.SendFrames([]frames.Frame{&frames.StreamFrame{StreamID: 1, Offset: s.s1offset, Data: reply}})
|
||||
s.s1offset += uint64(len(reply))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
stream, ok := s.Streams[frame.StreamID]
|
||||
if !ok {
|
||||
stream = NewStream(frame.StreamID)
|
||||
s.Streams[frame.StreamID] = stream
|
||||
}
|
||||
err := stream.AddStreamFrame(frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
replyFrames := s.streamCallback(stream)
|
||||
if replyFrames != nil {
|
||||
s.SendFrames(replyFrames)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) handleAckFrame(r *bytes.Reader) error {
|
||||
fmt.Println("Detected ACK")
|
||||
_, err := frames.ParseAckFrame(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) handleConnectionCloseFrame(r *bytes.Reader) error {
|
||||
fmt.Println("Detected CONNECTION_CLOSE")
|
||||
frame, err := frames.ParseConnectionCloseFrame(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("%#v\n", frame)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Session) handleStopWaitingFrame(r *bytes.Reader, publicHeader *PublicHeader) error {
|
||||
fmt.Println("Detected STOP_WAITING")
|
||||
_, err := frames.ParseStopWaitingFrame(r, publicHeader.PacketNumberLen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendFrames sends a number of frames to the client
|
||||
func (s *Session) SendFrames(frames []frames.Frame) error {
|
||||
var framesData bytes.Buffer
|
||||
|
||||
Reference in New Issue
Block a user