diff --git a/frame.go b/frame.go index 4579b21ab..3ba31e204 100644 --- a/frame.go +++ b/frame.go @@ -21,13 +21,9 @@ type StreamFrame struct { } // ParseStreamFrame reads a stream frame. The type byte must not have been read yet. -func ParseStreamFrame(r *bytes.Reader) (*StreamFrame, error) { +func ParseStreamFrame(r *bytes.Reader, typeByte byte) (*StreamFrame, error) { frame := &StreamFrame{} - typeByte, err := r.ReadByte() - if err != nil { - return nil, err - } frame.FinBit = typeByte&0x40 > 0 dataLenPresent := typeByte&0x20 > 0 offsetLen := typeByte & 0x1C >> 2 diff --git a/frame_test.go b/frame_test.go index 61ccd7f20..9a50abf2f 100644 --- a/frame_test.go +++ b/frame_test.go @@ -11,8 +11,8 @@ var _ = Describe("Frame", func() { Context("stream frames", func() { Context("when parsing", func() { It("accepts sample frame", func() { - b := bytes.NewReader([]byte{0xa0, 0x1, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'}) - frame, err := ParseStreamFrame(b) + b := bytes.NewReader([]byte{0x1, 0x06, 0x00, 'f', 'o', 'o', 'b', 'a', 'r'}) + frame, err := ParseStreamFrame(b, 0xa0) Expect(err).ToNot(HaveOccurred()) Expect(frame.FinBit).To(BeFalse()) Expect(frame.StreamID).To(Equal(uint32(1))) @@ -21,8 +21,8 @@ var _ = Describe("Frame", func() { }) It("accepts frame without datalength", func() { - b := bytes.NewReader([]byte{0x80, 0x1, 'f', 'o', 'o', 'b', 'a', 'r'}) - frame, err := ParseStreamFrame(b) + b := bytes.NewReader([]byte{0x1, 'f', 'o', 'o', 'b', 'a', 'r'}) + frame, err := ParseStreamFrame(b, 0x80) Expect(err).ToNot(HaveOccurred()) Expect(frame.FinBit).To(BeFalse()) Expect(frame.StreamID).To(Equal(uint32(1))) diff --git a/session.go b/session.go index 99a7f4690..11878a640 100644 --- a/session.go +++ b/session.go @@ -54,19 +54,84 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub } s.Entropy.Add(publicHeader.PacketNumber, privateFlag&0x01 > 0) - // TODO: Switch frame type here + frameCounter := 0 - frame, err := ParseStreamFrame(r) - if err != nil { + // read all frames in the packet + for r.Len() > 0 { + typeByte, err := r.ReadByte() + if err != nil { + fmt.Println("No more frames in this packet.") + break + } + + frameCounter++ + fmt.Printf("Reading frame %d\n", frameCounter) + fmt.Printf("\ttype byte: %b\n", typeByte) + + if (typeByte&0x80)>>7 == 1 { // STREAM + fmt.Println("Detected STREAM") + frame, err := ParseStreamFrame(r, typeByte) + if err != nil { + return err + } + fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID) + + if frame.StreamID == 0 { + return errors.New("Session: 0 is not a valid Stream ID") + } + + // TODO: Switch stream here + if frame.StreamID == 1 { + s.HandleCryptoHandshake(frame) + } else { + fmt.Printf("%#v\n", frame) + panic("streamid not 1") + } + + } else if (typeByte&0xC0)>>6 == 1 { // ACK + fmt.Println("Detected ACK") + continue // not yet implemented + } else if (typeByte&0xE0)>>5 == 1 { // CONGESTION_FEEDBACK + fmt.Println("Detected CONGESTION_FEEDBACK") + continue // not yet implemented + } else { + fmt.Println("Detected invalid frame type. Not looking for any further frames in this packet.") + // at least one of the first three bits of the Type field has be 1 + // ToDo: sometimes there are packets that have this kind of "frame". Find out what's going wrong there. Ignore for the moment + // return errors.New("Session: invalid Frame Type Field") + break + } + } + return nil +} + +// SendFrames sends a number of frames to the client +func (s *Session) SendFrames(frames []Frame) error { + var framesData bytes.Buffer + framesData.WriteByte(0) // TODO: entropy + for _, f := range frames { + if err := f.Write(&framesData); err != nil { + return err + } + } + + s.lastSentPacketNumber++ + + var fullReply bytes.Buffer + responsePublicHeader := PublicHeader{ConnectionID: s.ConnectionID, PacketNumber: s.lastSentPacketNumber} + fmt.Printf("Sending packet # %d\n", responsePublicHeader.PacketNumber) + if err := responsePublicHeader.WritePublicHeader(&fullReply); err != nil { return err } - fmt.Printf("Got %d bytes for stream %d\n", len(frame.Data), frame.StreamID) - // TODO: Switch stream here - if frame.StreamID != 1 { - panic("streamid not 1") - } + s.aead.Seal(s.lastSentPacketNumber, &fullReply, fullReply.Bytes(), framesData.Bytes()) + _, err := s.Connection.WriteToUDP(fullReply.Bytes(), s.CurrentRemoteAddr) + return err +} + +// HandleCryptoHandshake handles the crypto handshake +func (s *Session) HandleCryptoHandshake(frame *StreamFrame) error { messageTag, cryptoData, err := handshake.ParseHandshakeMessage(frame.Data) if err != nil { panic(err) @@ -105,7 +170,7 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub handshake.TagPROF: proof, }) - s.SendFrames([]Frame{ + return s.SendFrames([]Frame{ &AckFrame{ Entropy: s.Entropy.Get(), LargestObserved: 1, @@ -115,31 +180,4 @@ func (s *Session) HandlePacket(addr *net.UDPAddr, publicHeaderBinary []byte, pub Data: serverReply.Bytes(), }, }) - - return nil -} - -// SendFrames sends a number of frames to the client -func (s *Session) SendFrames(frames []Frame) error { - var framesData bytes.Buffer - framesData.WriteByte(0) // TODO: entropy - for _, f := range frames { - if err := f.Write(&framesData); err != nil { - return err - } - } - - s.lastSentPacketNumber++ - - var fullReply bytes.Buffer - responsePublicHeader := PublicHeader{ConnectionID: s.ConnectionID, PacketNumber: s.lastSentPacketNumber} - fmt.Printf("Sending packet # %d\n", responsePublicHeader.PacketNumber) - if err := responsePublicHeader.WritePublicHeader(&fullReply); err != nil { - return err - } - - s.aead.Seal(s.lastSentPacketNumber, &fullReply, fullReply.Bytes(), framesData.Bytes()) - - _, err := s.Connection.WriteToUDP(fullReply.Bytes(), s.CurrentRemoteAddr) - return err }