diff --git a/session.go b/session.go index ba23368d..bd58f489 100644 --- a/session.go +++ b/session.go @@ -33,6 +33,8 @@ type Session struct { Streams map[protocol.StreamID]*Stream streamCallback StreamCallback + + s1offset uint64 } // NewSession makes a new session @@ -51,8 +53,6 @@ func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protoc // HandlePacket handles a packet func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, publicHeader *PublicHeader, r *bytes.Reader) error { - // TODO: Only do this after authenticating - if s.lastObservedPacketNumber > 0 { // the first packet doesn't neccessarily need to have packetNumber 1 if publicHeader.PacketNumber < s.lastObservedPacketNumber || publicHeader.PacketNumber > s.lastObservedPacketNumber+1 { return errors.New("Out of order packet") @@ -63,6 +63,7 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub } s.lastObservedPacketNumber = publicHeader.PacketNumber + // TODO: Only do this after authenticating if addr != s.CurrentRemoteAddr { s.CurrentRemoteAddr = addr } @@ -87,89 +88,124 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub // read all frames in the packet for r.Len() > 0 { - typeByte, err := r.ReadByte() - if err != nil { - fmt.Println("No more frames in this packet.") - break - } + typeByte, _ := r.ReadByte() r.UnreadByte() frameCounter++ fmt.Printf("Reading frame %d\n", frameCounter) - if typeByte&0x80 == 0x80 { // STREAM - fmt.Println("Detected STREAM") - frame, err := frames.ParseStreamFrame(r) - if err != nil { - return err - } - fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID) - - if frame.StreamID == 0 { - return errors.New("Session: 0 is not a valid Stream ID") - } - - if frame.StreamID == 1 { - reply, err := s.cryptoSetup.HandleCryptoMessage(frame.Data) - if err != nil { - return err - } - if reply != nil { - s.SendFrames([]frames.Frame{&frames.StreamFrame{StreamID: 1, Data: reply}}) - } - } else { - stream, ok := s.Streams[frame.StreamID] - if !ok { - stream = NewStream(frame.StreamID) - s.Streams[frame.StreamID] = stream - } - err := stream.AddStreamFrame(frame) - if err != nil { - return err - } - - replyFrames := s.streamCallback(stream) - if replyFrames != nil { - s.SendFrames(replyFrames) - } - } - continue - } else if typeByte&0xC0 == 0x40 { // ACK - fmt.Println("Detected ACK") - frame, err := frames.ParseAckFrame(r) - if err != nil { - return err - } - - fmt.Printf("%#v\n", frame) - - continue - } else if typeByte&0xE0 == 0x20 { // CONGESTION_FEEDBACK - return errors.New("Detected CONGESTION_FEEDBACK") - } else if typeByte&0x06 == 0x06 { // STOP_WAITING - fmt.Println("Detected STOP_WAITING") - _, err := frames.ParseStopWaitingFrame(r, publicHeader.PacketNumberLen) - if err != nil { - return err - } - // ToDo: react to receiving this frame - } else if typeByte&0x02 == 0x02 { // CONNECTION_CLOSE - fmt.Println("Detected CONNECTION_CLOSE") - frame, err := frames.ParseConnectionCloseFrame(r) - if err != nil { - return err - } - fmt.Printf("%#v\n", frame) - } else if typeByte == 0 { - // PAD - return nil + err = nil + if typeByte&0x80 == 0x80 { + err = s.handleStreamFrame(r) + } else if typeByte == 0x40 { + err = s.handleAckFrame(r) + } else if typeByte&0xE0 == 0x20 { + err = errors.New("unimplemented: CONGESTION_FEEDBACK") } else { - return errors.New("Session: invalid Frame Type Field") + switch typeByte { + case 0x0: // PAD + return nil + case 0x01: + err = errors.New("unimplemented: RST_STREAM") + case 0x02: + err = s.handleConnectionCloseFrame(r) + case 0x03: + err = errors.New("unimplemented: GOAWAY") + case 0x04: + err = errors.New("unimplemented: WINDOW_UPDATE") + case 0x05: + err = errors.New("unimplemented: BLOCKED") + case 0x06: + err = s.handleStopWaitingFrame(r, publicHeader) + case 0x07: + // PING, do nothing + r.ReadByte() + default: + err = fmt.Errorf("unknown frame type: %x", typeByte) + } + if err != nil { + return err + } } } return nil } +func (s *Session) handleStreamFrame(r *bytes.Reader) error { + fmt.Println("Detected STREAM") + frame, err := frames.ParseStreamFrame(r) + if err != nil { + return err + } + fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID) + + if frame.StreamID == 0 { + return errors.New("Session: 0 is not a valid Stream ID") + } + + if frame.StreamID == 1 { + reply, err := s.cryptoSetup.HandleCryptoMessage(frame.Data) + if err != nil { + return err + } + if reply != nil { + if len(reply) > 1000 { + s.SendFrames([]frames.Frame{&frames.StreamFrame{StreamID: 1, Offset: s.s1offset, Data: reply[:1000]}}) + s.s1offset += 1000 + s.SendFrames([]frames.Frame{&frames.StreamFrame{StreamID: 1, Offset: s.s1offset, Data: reply[1000:]}}) + s.s1offset += uint64(len(reply[1000:])) + } else { + s.SendFrames([]frames.Frame{&frames.StreamFrame{StreamID: 1, Offset: s.s1offset, Data: reply}}) + s.s1offset += uint64(len(reply)) + } + } + } else { + stream, ok := s.Streams[frame.StreamID] + if !ok { + stream = NewStream(frame.StreamID) + s.Streams[frame.StreamID] = stream + } + err := stream.AddStreamFrame(frame) + if err != nil { + return err + } + + replyFrames := s.streamCallback(stream) + if replyFrames != nil { + s.SendFrames(replyFrames) + } + } + return nil +} + +func (s *Session) handleAckFrame(r *bytes.Reader) error { + fmt.Println("Detected ACK") + _, err := frames.ParseAckFrame(r) + if err != nil { + return err + } + return nil +} + +func (s *Session) handleConnectionCloseFrame(r *bytes.Reader) error { + fmt.Println("Detected CONNECTION_CLOSE") + frame, err := frames.ParseConnectionCloseFrame(r) + if err != nil { + return err + } + fmt.Printf("%#v\n", frame) + return nil +} + +func (s *Session) handleStopWaitingFrame(r *bytes.Reader, publicHeader *PublicHeader) error { + fmt.Println("Detected STOP_WAITING") + _, err := frames.ParseStopWaitingFrame(r, publicHeader.PacketNumberLen) + if err != nil { + return err + } + return nil +} + // SendFrames sends a number of frames to the client func (s *Session) SendFrames(frames []frames.Frame) error { var framesData bytes.Buffer