diff --git a/crypto/proof_rsa.go b/crypto/proof_rsa.go index a580d298c..de83b1a70 100644 --- a/crypto/proof_rsa.go +++ b/crypto/proof_rsa.go @@ -44,10 +44,15 @@ func LoadKeyData(certFileName string, keyFileName string) (*KeyData, error) { // SignServerProof signs CHLO and server config for use in the server proof func (kd *KeyData) SignServerProof(chlo []byte, serverConfigData []byte) ([]byte, error) { hash := sha256.New() - hash.Write([]byte("QUIC CHLO and server config signature\x00")) - chloHash := sha256.Sum256(chlo) - hash.Write([]byte{32, 0, 0, 0}) - hash.Write(chloHash[:]) + if len(chlo) > 0 { + // Version >= 31 + hash.Write([]byte("QUIC CHLO and server config signature\x00")) + chloHash := sha256.Sum256(chlo) + hash.Write([]byte{32, 0, 0, 0}) + hash.Write(chloHash[:]) + } else { + hash.Write([]byte("QUIC server config signature\x00")) + } hash.Write(serverConfigData) return rsa.SignPSS(rand.Reader, kd.key, crypto.SHA256, hash.Sum(nil), &rsa.PSSOptions{SaltLength: 32}) } diff --git a/example/main.go b/example/main.go index 609796888..b84b8d349 100644 --- a/example/main.go +++ b/example/main.go @@ -15,6 +15,11 @@ import ( "github.com/lucas-clemente/quic-go/utils" ) +var supportedVersions = map[protocol.VersionNumber]bool{ + 30: true, + 32: true, +} + func main() { path := os.Getenv("GOPATH") + "/src/github.com/lucas-clemente/quic-go/example/" keyData, err := crypto.LoadKeyData(path+"cert.der", path+"key.der") @@ -55,7 +60,7 @@ func main() { fmt.Printf("Got packet # %d\n", publicHeader.PacketNumber) // Send Version Negotiation Packet if the client is speaking a different protocol version - if publicHeader.VersionFlag && publicHeader.VersionNumber != 32 { + if publicHeader.VersionFlag && !supportedVersions[publicHeader.VersionNumber] { fmt.Println("Sending VersionNegotiationPacket") fullReply := &bytes.Buffer{} responsePublicHeader := quic.PublicHeader{ConnectionID: publicHeader.ConnectionID, PacketNumber: 1, VersionFlag: true} @@ -63,6 +68,7 @@ func main() { if err != nil { panic(err) } + // TODO: Send all versions utils.WriteUint32(fullReply, protocol.VersionNumberToTag(protocol.VersionNumber(32))) _, err = conn.WriteToUDP(fullReply.Bytes(), remoteAddr) if err != nil { @@ -73,7 +79,7 @@ func main() { session, ok := sessions[publicHeader.ConnectionID] if !ok { - session = quic.NewSession(conn, publicHeader.ConnectionID, serverConfig, handleStream) + session = quic.NewSession(conn, publicHeader.VersionNumber, publicHeader.ConnectionID, serverConfig, handleStream) sessions[publicHeader.ConnectionID] = session } err = session.HandlePacket(remoteAddr, data[0:n-r.Len()], publicHeader, r) diff --git a/session.go b/session.go index 82bae67ec..9a0e9e19f 100644 --- a/session.go +++ b/session.go @@ -16,8 +16,9 @@ type StreamCallback func(*StreamFrame) []Frame // A Session is a QUIC session type Session struct { - ConnectionID protocol.ConnectionID - ServerConfig *ServerConfig + VersionNumber protocol.VersionNumber + ConnectionID protocol.ConnectionID + ServerConfig *ServerConfig Connection *net.UDPConn CurrentRemoteAddr *net.UDPAddr @@ -32,9 +33,10 @@ type Session struct { } // NewSession makes a new session -func NewSession(conn *net.UDPConn, connectionID protocol.ConnectionID, sCfg *ServerConfig, streamCallback StreamCallback) *Session { +func NewSession(conn *net.UDPConn, v protocol.VersionNumber, connectionID protocol.ConnectionID, sCfg *ServerConfig, streamCallback StreamCallback) *Session { return &Session{ Connection: conn, + VersionNumber: v, ConnectionID: connectionID, ServerConfig: sCfg, aead: &crypto.NullAEAD{}, @@ -172,7 +174,11 @@ func (s *Session) HandleCryptoHandshake(frame *StreamFrame) error { return nil } - proof, err := s.ServerConfig.Sign(frame.Data) + var chloOrNil []byte + if s.VersionNumber > protocol.VersionNumber(30) { + chloOrNil = frame.Data + } + proof, err := s.ServerConfig.Sign(chloOrNil) if err != nil { return err }