From e34256beed7302eeb42e11c9707b696c4da4bb9a Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Sun, 12 Aug 2018 15:26:40 +0700 Subject: [PATCH] update mint --- vendor/github.com/bifurcation/mint/README.md | 6 + .../bifurcation/mint/client-state-machine.go | 59 ++-- vendor/github.com/bifurcation/mint/common.go | 12 + vendor/github.com/bifurcation/mint/conn.go | 279 ++++++++++-------- vendor/github.com/bifurcation/mint/dtls.go | 200 ++++++++++++- .../bifurcation/mint/frame-reader.go | 4 +- .../bifurcation/mint/handshake-layer.go | 120 ++++++-- .../bifurcation/mint/negotiation.go | 9 +- .../bifurcation/mint/record-layer.go | 191 ++++++++---- .../bifurcation/mint/server-state-machine.go | 143 ++++++--- .../bifurcation/mint/state-machine.go | 20 +- vendor/github.com/bifurcation/mint/timer.go | 122 ++++++++ vendor/vendor.json | 6 +- 13 files changed, 883 insertions(+), 288 deletions(-) create mode 100644 vendor/github.com/bifurcation/mint/timer.go diff --git a/vendor/github.com/bifurcation/mint/README.md b/vendor/github.com/bifurcation/mint/README.md index 0ac41e0f..9fa05ddd 100644 --- a/vendor/github.com/bifurcation/mint/README.md +++ b/vendor/github.com/bifurcation/mint/README.md @@ -18,6 +18,12 @@ library](https://golang.org/pkg/crypto/tls/), especially where TLS 1.3 aligns with earlier TLS versions. However, unnecessary parts will be ruthlessly cut off. +## DTLS Support + +Mint has partial support for DTLS, but that support is not yet complete +and may still contain serious defects. + + ## Quickstart Installation is the same as for any other Go package: diff --git a/vendor/github.com/bifurcation/mint/client-state-machine.go b/vendor/github.com/bifurcation/mint/client-state-machine.go index ffca45ef..07e7f53f 100644 --- a/vendor/github.com/bifurcation/mint/client-state-machine.go +++ b/vendor/github.com/bifurcation/mint/client-state-machine.go @@ -58,7 +58,7 @@ type clientStateStart struct { cookie []byte firstClientHello *HandshakeMessage helloRetryRequest *HandshakeMessage - hsCtx HandshakeContext + hsCtx *HandshakeContext } var _ HandshakeState = &clientStateStart{} @@ -172,8 +172,10 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ } ch.CipherSuites = compatibleSuites + // TODO(ekr@rtfm.com): Check that the ticket can be used for early + // data. // Signal early data if we're going to do it - if len(state.Opts.EarlyData) > 0 { + if state.Config.AllowEarlyData && state.helloRetryRequest == nil { state.Params.ClientSendingEarlyData = true ed = &EarlyDataExtension{} err = ch.Extensions.Add(ed) @@ -255,9 +257,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ earlyTrafficSecret := deriveSecret(params, earlySecret, labelEarlyTrafficSecret, chHash) logf(logTypeCrypto, "early traffic secret: [%d] %x", len(earlyTrafficSecret), earlyTrafficSecret) clientEarlyTrafficKeys = makeTrafficKeys(params, earlyTrafficSecret) - } else if len(state.Opts.EarlyData) > 0 { - logf(logTypeHandshake, "[ClientStateWaitSH] Early data without PSK") - return nil, nil, AlertInternalError } else { clientHello, err = state.hsCtx.hOut.HandshakeMessageFromBody(ch) if err != nil { @@ -291,7 +290,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ if state.Params.ClientSendingEarlyData { toSend = append(toSend, []HandshakeAction{ RekeyOut{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, - SendEarlyData{}, }...) } @@ -302,7 +300,7 @@ type clientStateWaitSH struct { Config *Config Opts ConnectionOptions Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext OfferedDH map[NamedGroup][]byte OfferedPSK PreSharedKey PSK []byte @@ -412,6 +410,11 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, body: h.Sum(nil), } + state.hsCtx.receivedEndOfFlight() + + // TODO(ekr@rtfm.com): Need to rekey with cleartext if we are on 0-RTT + // mode. In DTLS, we also need to bump the sequence number. + // This is a pre-existing defect in Mint. Issue #175. logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateStart]") return clientStateStart{ Config: state.Config, @@ -420,7 +423,7 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, cookie: serverCookie.Cookie, firstClientHello: firstClientHello, helloRetryRequest: hm, - }, nil, AlertNoAlert + }, []HandshakeAction{ResetOut{1}}, AlertNoAlert } // This is SH. @@ -515,7 +518,6 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, logf(logTypeCrypto, "master secret: [%d] %x", len(masterSecret), masterSecret) serverHandshakeKeys := makeTrafficKeys(params, serverHandshakeTrafficSecret) - logf(logTypeHandshake, "[ClientStateWaitSH] -> [ClientStateWaitEE]") nextState := clientStateWaitEE{ Config: state.Config, @@ -530,13 +532,20 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, toSend := []HandshakeAction{ RekeyIn{epoch: EpochHandshakeData, KeySet: serverHandshakeKeys}, } + // We're definitely not going to have to send anything with + // early data. + if !state.Params.ClientSendingEarlyData { + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(params, clientHandshakeTrafficSecret)}) + } + return nextState, toSend, AlertNoAlert } type clientStateWaitEE struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash masterSecret []byte @@ -596,6 +605,14 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, state.handshakeHash.Write(hm.Marshal()) + toSend := []HandshakeAction{} + + if state.Params.ClientSendingEarlyData && !state.Params.UsingEarlyData { + // We didn't get 0-RTT, so rekey to handshake. + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) + } + if state.Params.UsingPSK { logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") nextState := clientStateWaitFinished{ @@ -608,7 +625,7 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, } - return nextState, nil, AlertNoAlert + return nextState, toSend, AlertNoAlert } logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitCertCR]") @@ -622,13 +639,13 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, clientHandshakeTrafficSecret: state.clientHandshakeTrafficSecret, serverHandshakeTrafficSecret: state.serverHandshakeTrafficSecret, } - return nextState, nil, AlertNoAlert + return nextState, toSend, AlertNoAlert } type clientStateWaitCertCR struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash masterSecret []byte @@ -706,7 +723,7 @@ func (state clientStateWaitCertCR) Next(hr handshakeMessageReader) (HandshakeSta type clientStateWaitCert struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -760,7 +777,7 @@ func (state clientStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState type clientStateWaitCV struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -861,7 +878,7 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, type clientStateWaitFinished struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams handshakeHash hash.Hash @@ -933,6 +950,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS toSend := []HandshakeAction{} if state.Params.UsingEarlyData { + logf(logTypeHandshake, "Sending end of early data") // Note: We only send EOED if the server is actually going to use the early // data. Otherwise, it will never see it, and the transcripts will // mismatch. @@ -942,10 +960,11 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS state.handshakeHash.Write(eoedm.Marshal()) logf(logTypeCrypto, "input to handshake hash [%d]: %x", len(eoedm.Marshal()), eoedm.Marshal()) - } - clientHandshakeKeys := makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret) - toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}) + // And then rekey to handshake + toSend = append(toSend, RekeyOut{epoch: EpochHandshakeData, + KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) + } if state.Params.UsingClientAuth { // Extract constraints from certicateRequest @@ -1045,6 +1064,8 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS RekeyOut{epoch: EpochApplicationData, KeySet: clientTrafficKeys}, }...) + state.hsCtx.receivedEndOfFlight() + logf(logTypeHandshake, "[ClientStateWaitFinished] -> [StateConnected]") nextState := stateConnected{ Params: state.Params, diff --git a/vendor/github.com/bifurcation/mint/common.go b/vendor/github.com/bifurcation/mint/common.go index 0fdba602..05af3e95 100644 --- a/vendor/github.com/bifurcation/mint/common.go +++ b/vendor/github.com/bifurcation/mint/common.go @@ -25,6 +25,7 @@ const ( RecordTypeAlert RecordType = 21 RecordTypeHandshake RecordType = 22 RecordTypeApplicationData RecordType = 23 + RecordTypeAck RecordType = 25 ) // enum {...} HandshakeType; @@ -166,6 +167,8 @@ const ( type State uint8 const ( + StateInit = 0 + // states valid for the client StateClientStart State = iota StateClientWaitSH @@ -179,6 +182,7 @@ const ( StateServerStart State = iota StateServerRecvdCH StateServerNegotiated + StateServerReadPastEarlyData StateServerWaitEOED StateServerWaitFlight2 StateServerWaitCert @@ -211,6 +215,8 @@ func (s State) String() string { return "Server RECVD_CH" case StateServerNegotiated: return "Server NEGOTIATED" + case StateServerReadPastEarlyData: + return "Server READ_PAST_EARLY_DATA" case StateServerWaitEOED: return "Server WAIT_EOED" case StateServerWaitFlight2: @@ -252,3 +258,9 @@ func (e Epoch) label() string { } return "Application data (updated)" } + +func assert(b bool) { + if !b { + panic("Assertion failed") + } +} diff --git a/vendor/github.com/bifurcation/mint/conn.go b/vendor/github.com/bifurcation/mint/conn.go index 0ce05b2a..12a99171 100644 --- a/vendor/github.com/bifurcation/mint/conn.go +++ b/vendor/github.com/bifurcation/mint/conn.go @@ -13,8 +13,6 @@ import ( "time" ) -var WouldBlock = fmt.Errorf("Would have blocked") - type Certificate struct { Chain []*x509.Certificate PrivateKey crypto.Signer @@ -253,6 +251,8 @@ type ConnectionState struct { PeerCertificates []*x509.Certificate // certificate chain presented by remote peer VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates NextProto string // Selected ALPN proto + UsingPSK bool // Are we using PSK. + UsingEarlyData bool // Did we negotiate 0-RTT. } // Conn implements the net.Conn interface, as with "crypto/tls" @@ -263,8 +263,6 @@ type Conn struct { conn net.Conn isClient bool - EarlyData []byte - state stateConnected hState HandshakeState handshakeMutex sync.Mutex @@ -273,22 +271,27 @@ type Conn struct { readBuffer []byte in, out *RecordLayer - hsCtx HandshakeContext + hsCtx *HandshakeContext } func NewConn(conn net.Conn, config *Config, isClient bool) *Conn { - c := &Conn{conn: conn, config: config, isClient: isClient} + c := &Conn{conn: conn, config: config, isClient: isClient, hsCtx: &HandshakeContext{}} if !config.UseDTLS { - c.in = NewRecordLayerTLS(c.conn) - c.out = NewRecordLayerTLS(c.conn) - c.hsCtx.hIn = NewHandshakeLayerTLS(c.in) - c.hsCtx.hOut = NewHandshakeLayerTLS(c.out) + c.in = NewRecordLayerTLS(c.conn, directionRead) + c.out = NewRecordLayerTLS(c.conn, directionWrite) + c.hsCtx.hIn = NewHandshakeLayerTLS(c.hsCtx, c.in) + c.hsCtx.hOut = NewHandshakeLayerTLS(c.hsCtx, c.out) } else { - c.in = NewRecordLayerDTLS(c.conn) - c.out = NewRecordLayerDTLS(c.conn) - c.hsCtx.hIn = NewHandshakeLayerDTLS(c.in) - c.hsCtx.hOut = NewHandshakeLayerDTLS(c.out) + c.in = NewRecordLayerDTLS(c.conn, directionRead) + c.out = NewRecordLayerDTLS(c.conn, directionWrite) + c.hsCtx.hIn = NewHandshakeLayerDTLS(c.hsCtx, c.in) + c.hsCtx.hOut = NewHandshakeLayerDTLS(c.hsCtx, c.out) + c.hsCtx.timeoutMS = initialTimeout + c.hsCtx.timers = newTimerSet() + c.hsCtx.waitingNextFlight = true } + c.in.label = c.label() + c.out.label = c.label() c.hsCtx.hIn.nonblocking = c.config.NonBlocking return c } @@ -374,20 +377,54 @@ func (c *Conn) consumeRecord() error { return io.EOF } + case RecordTypeAck: + if !c.hsCtx.hIn.datagram { + logf(logTypeHandshake, "Received ACK in TLS mode") + return AlertUnexpectedMessage + } + return c.hsCtx.processAck(pt.fragment) + case RecordTypeApplicationData: c.readBuffer = append(c.readBuffer, pt.fragment...) logf(logTypeIO, "extended buffer: [%d] %x", len(c.readBuffer), c.readBuffer) + } return err } +func readPartial(in *[]byte, buffer []byte) int { + logf(logTypeIO, "conn.Read input buffer now has len %d", len((*in))) + read := copy(buffer, *in) + *in = (*in)[read:] + + logf(logTypeVerbose, "Returning %v", string(buffer)) + return read +} + // Read application data up to the size of buffer. Handshake and alert records // are consumed by the Conn object directly. func (c *Conn) Read(buffer []byte) (int, error) { if _, connected := c.hState.(stateConnected); !connected { - return 0, errors.New("Read called before the handshake completed") + // Clients can't call Read prior to handshake completion. + if c.isClient { + return 0, errors.New("Read called before the handshake completed") + } + + // Neither can servers that don't allow early data. + if !c.config.AllowEarlyData { + return 0, errors.New("Read called before the handshake completed") + } + + // If there's no early data, then return WouldBlock + if len(c.hsCtx.earlyData) == 0 { + return 0, AlertWouldBlock + } + + return readPartial(&c.hsCtx.earlyData, buffer), nil } + + // The handshake is now connected. logf(logTypeHandshake, "conn.Read with buffer = %d", len(buffer)) if alert := c.Handshake(); alert != AlertNoAlert { return 0, alert @@ -397,6 +434,13 @@ func (c *Conn) Read(buffer []byte) (int, error) { return 0, nil } + // Run our timers. + if c.config.UseDTLS { + if err := c.hsCtx.timers.check(time.Now()); err != nil { + return 0, AlertInternalError + } + } + // Lock the input channel c.in.Lock() defer c.in.Unlock() @@ -406,30 +450,14 @@ func (c *Conn) Read(buffer []byte) (int, error) { // err can be nil if consumeRecord processed a non app-data // record. if err != nil { - if c.config.NonBlocking || err != WouldBlock { + if c.config.NonBlocking || err != AlertWouldBlock { logf(logTypeIO, "conn.Read returns err=%v", err) return 0, err } } } - var read int - n := len(buffer) - logf(logTypeIO, "conn.Read input buffer now has len %d", len(c.readBuffer)) - if len(c.readBuffer) <= n { - buffer = buffer[:len(c.readBuffer)] - copy(buffer, c.readBuffer) - read = len(c.readBuffer) - c.readBuffer = c.readBuffer[:0] - } else { - logf(logTypeIO, "read buffer larger than input buffer (%d > %d)", len(c.readBuffer), n) - copy(buffer[:n], c.readBuffer[:n]) - c.readBuffer = c.readBuffer[n:] - read = n - } - - logf(logTypeVerbose, "Returning %v", string(buffer)) - return read, nil + return readPartial(&c.readBuffer, buffer), nil } // Write application data @@ -438,6 +466,10 @@ func (c *Conn) Write(buffer []byte) (int, error) { c.out.Lock() defer c.out.Unlock() + if !c.Writable() { + return 0, errors.New("Write called before the handshake completed (and early data not in use)") + } + // Send full-size fragments var start int sent := 0 @@ -549,13 +581,23 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { } case SendQueuedHandshake: - err := c.hsCtx.hOut.SendQueuedMessages() + _, err := c.hsCtx.hOut.SendQueuedMessages() if err != nil { logf(logTypeHandshake, "%s Error writing handshake message: %v", label, err) return AlertInternalError } + if c.config.UseDTLS { + c.hsCtx.timers.start(retransmitTimerLabel, + c.hsCtx.handshakeRetransmit, + c.hsCtx.timeoutMS) + } case RekeyIn: logf(logTypeHandshake, "%s Rekeying in to %s: %+v", label, action.epoch.label(), action.KeySet) + // Check that we don't have an input data in the handshake frame parser. + if len(c.hsCtx.hIn.frame.remainder) > 0 { + logf(logTypeHandshake, "%s Rekey with data still in handshake buffers", label) + return AlertDecodeError + } err := c.in.Rekey(action.epoch, action.KeySet.cipher, action.KeySet.key, action.KeySet.iv) if err != nil { logf(logTypeHandshake, "%s Unable to rekey inbound: %v", label, err) @@ -570,61 +612,9 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { return AlertInternalError } - case SendEarlyData: - logf(logTypeHandshake, "%s Sending early data...", label) - _, err := c.Write(c.EarlyData) - if err != nil { - logf(logTypeHandshake, "%s Error writing early data: %v", label, err) - return AlertInternalError - } - - case ReadPastEarlyData: - logf(logTypeHandshake, "%s Reading past early data...", label) - // Scan past all records that fail to decrypt - _, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok := err.(DecryptError) - - for ok { - _, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err == nil { - break - } - _, ok = err.(DecryptError) - } - - case ReadEarlyData: - logf(logTypeHandshake, "%s Reading early data...", label) - t, err := c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (1): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type(1): %v", label, t) - - for t == RecordTypeApplicationData { - // Read a record into the buffer. Note that this is safe - // in blocking mode because we read the record in in - // PeekRecordType. - pt, err := c.in.ReadRecord() - if err != nil { - logf(logTypeHandshake, "%s Error reading early data record: %v", label, err) - return AlertInternalError - } - - logf(logTypeHandshake, "%s Read early data: %x", label, pt.fragment) - c.EarlyData = append(c.EarlyData, pt.fragment...) - - t, err = c.in.PeekRecordType(!c.config.NonBlocking) - if err != nil { - logf(logTypeHandshake, "%s Error reading record type (2): %v", label, err) - return AlertInternalError - } - logf(logTypeHandshake, "%s Got record type (2): %v", label, t) - } - logf(logTypeHandshake, "%s Done reading early data", label) + case ResetOut: + logf(logTypeHandshake, "%s Rekeying out to %s seq=%v", label, EpochClear, action.seq) + c.out.ResetClear(action.seq) case StorePSK: logf(logTypeHandshake, "%s Storing new session ticket with identity [%x]", label, action.PSK.Identity) @@ -637,7 +627,8 @@ func (c *Conn) takeAction(actionGeneric HandshakeAction) Alert { } default: - logf(logTypeHandshake, "%s Unknown actionuction type", label) + logf(logTypeHandshake, "%s Unknown action type", label) + assert(false) return AlertInternalError } @@ -657,7 +648,6 @@ func (c *Conn) HandshakeSetup() Alert { opts := ConnectionOptions{ ServerName: c.config.ServerName, NextProtos: c.config.NextProtos, - EarlyData: c.EarlyData, } if c.isClient { @@ -706,18 +696,21 @@ type handshakeMessageReaderImpl struct { var _ handshakeMessageReader = &handshakeMessageReaderImpl{} func (r *handshakeMessageReaderImpl) ReadMessage() (*HandshakeMessage, Alert) { - hm, err := r.hsCtx.hIn.ReadMessage() - if err == WouldBlock { - return nil, AlertWouldBlock + var hm *HandshakeMessage + var err error + for { + hm, err = r.hsCtx.hIn.ReadMessage() + if err == AlertWouldBlock { + return nil, AlertWouldBlock + } + if err != nil { + logf(logTypeHandshake, "Error reading message: %v", err) + return nil, AlertCloseNotify + } + if hm != nil { + break + } } - if err != nil { - logf(logTypeHandshake, "[client] Error reading message: %v", err) - return nil, AlertCloseNotify - } - - // Once you have read a message, you no longer need the outgoing queue - // for DTLS. - r.hsCtx.hOut.ClearQueuedMessages() return hm, AlertNoAlert } @@ -753,14 +746,21 @@ func (c *Conn) Handshake() Alert { state := c.hState _, connected := state.(stateConnected) - hmr := &handshakeMessageReaderImpl{hsCtx: &c.hsCtx} + hmr := &handshakeMessageReaderImpl{hsCtx: c.hsCtx} for !connected { var alert Alert var actions []HandshakeAction + // Advance the state machine state, actions, alert = state.Next(hmr) - if alert == WouldBlock { + if alert == AlertWouldBlock { logf(logTypeHandshake, "%s Would block reading message: %s", label, alert) + // If we blocked, then run our timers to see if any have expired. + if c.hsCtx.hIn.datagram { + if err := c.hsCtx.timers.check(time.Now()); err != nil { + return AlertInternalError + } + } return AlertWouldBlock } if alert == AlertCloseNotify { @@ -788,6 +788,34 @@ func (c *Conn) Handshake() Alert { if connected { c.state = state.(stateConnected) c.handshakeComplete = true + + if !c.isClient { + // Send NewSessionTicket if configured to + if c.config.SendSessionTickets { + actions, alert := c.state.NewSessionTicket( + c.config.TicketLen, + c.config.TicketLifetime, + c.config.EarlyDataLifetime) + + for _, action := range actions { + alert = c.takeAction(action) + if alert != AlertNoAlert { + logf(logTypeHandshake, "Error during handshake actions: %v", alert) + c.sendAlert(alert) + return alert + } + } + } + + // If there is early data, move it into the main buffer + if c.hsCtx.earlyData != nil { + c.readBuffer = c.hsCtx.earlyData + c.hsCtx.earlyData = nil + } + + } else { + assert(c.hsCtx.earlyData == nil) + } } if c.config.NonBlocking { @@ -798,23 +826,6 @@ func (c *Conn) Handshake() Alert { } } - // Send NewSessionTicket if acting as server - if !c.isClient && c.config.SendSessionTickets { - actions, alert := c.state.NewSessionTicket( - c.config.TicketLen, - c.config.TicketLifetime, - c.config.EarlyDataLifetime) - - for _, action := range actions { - alert = c.takeAction(action) - if alert != AlertNoAlert { - logf(logTypeHandshake, "Error during handshake actions: %v", alert) - c.sendAlert(alert) - return alert - } - } - } - return AlertNoAlert } @@ -848,6 +859,9 @@ func (c *Conn) SendKeyUpdate(requestUpdate bool) error { } func (c *Conn) GetHsState() State { + if c.hState == nil { + return StateInit + } return c.hState.State() } @@ -878,7 +892,30 @@ func (c *Conn) ConnectionState() ConnectionState { state.NextProto = c.state.Params.NextProto state.VerifiedChains = c.state.verifiedChains state.PeerCertificates = c.state.peerCertificates + state.UsingPSK = c.state.Params.UsingPSK + state.UsingEarlyData = c.state.Params.UsingEarlyData } return state } + +func (c *Conn) Writable() bool { + // If we're connected, we're writable. + if _, connected := c.hState.(stateConnected); connected { + return true + } + + // If we're a client in 0-RTT, then we're writable. + if c.isClient && c.out.cipher.epoch == EpochEarlyData { + return true + } + + return false +} + +func (c *Conn) label() string { + if c.isClient { + return "client" + } + return "server" +} diff --git a/vendor/github.com/bifurcation/mint/dtls.go b/vendor/github.com/bifurcation/mint/dtls.go index df4f1aa1..aa914e3e 100644 --- a/vendor/github.com/bifurcation/mint/dtls.go +++ b/vendor/github.com/bifurcation/mint/dtls.go @@ -2,14 +2,33 @@ package mint import ( "fmt" + "github.com/bifurcation/mint/syntax" + "time" ) -// This file is a placeholder. DTLS-specific stuff (timer management, -// ACKs, retransmits, etc. will eventually go here. const ( - initialMtu = 1200 + initialMtu = 1200 + initialTimeout = 100 ) +// labels for timers +const ( + retransmitTimerLabel = "handshake retransmit" + ackTimerLabel = "ack timer" +) + +type SentHandshakeFragment struct { + seq uint32 + offset int + fragLength int + record uint64 + acked bool +} + +type DtlsAck struct { + RecordNumbers []uint64 `tls:"head=2"` +} + func wireVersion(h *HandshakeLayer) uint16 { if h.datagram { return dtls12WireVersion @@ -26,3 +45,178 @@ func dtlsConvertVersion(version uint16) uint16 { } panic(fmt.Sprintf("Internal error, unexpected version=%d", version)) } + +// TODO(ekr@rtfm.com): Move these to state-machine.go +func (h *HandshakeContext) handshakeRetransmit() error { + if _, err := h.hOut.SendQueuedMessages(); err != nil { + return err + } + + h.timers.start(retransmitTimerLabel, + h.handshakeRetransmit, + h.timeoutMS) + + // TODO(ekr@rtfm.com): Back off timer + return nil +} + +func (h *HandshakeContext) sendAck() error { + toack := h.hIn.recvdRecords + + count := (initialMtu - 2) / 8 // TODO(ekr@rtfm.com): Current MTU + if len(toack) > count { + toack = toack[:count] + } + logf(logTypeHandshake, "Sending ACK: [%x]", toack) + + ack := &DtlsAck{toack} + body, err := syntax.Marshal(&ack) + if err != nil { + return err + } + err = h.hOut.conn.WriteRecord(&TLSPlaintext{ + contentType: RecordTypeAck, + fragment: body, + }) + if err != nil { + return err + } + return nil +} + +func (h *HandshakeContext) processAck(data []byte) error { + // Cancel the retransmit timer because we will be resending + // and possibly re-arming later. + h.timers.cancel(retransmitTimerLabel) + + ack := &DtlsAck{} + read, err := syntax.Unmarshal(data, &ack) + if err != nil { + return err + } + if len(data) != read { + return fmt.Errorf("Invalid encoding: Extra data not consumed") + } + logf(logTypeHandshake, "ACK: [%x]", ack.RecordNumbers) + + for _, r := range ack.RecordNumbers { + for _, m := range h.sentFragments { + if r == m.record { + logf(logTypeHandshake, "Marking %v %v(%v) as acked", + m.seq, m.offset, m.fragLength) + m.acked = true + } + } + } + + count, err := h.hOut.SendQueuedMessages() + if err != nil { + return err + } + + if count == 0 { + logf(logTypeHandshake, "All messages ACKed") + h.hOut.ClearQueuedMessages() + return nil + } + + // Reset the timer + h.timers.start(retransmitTimerLabel, + h.handshakeRetransmit, + h.timeoutMS) + + return nil +} + +func (c *Conn) GetDTLSTimeout() (bool, time.Duration) { + return c.hsCtx.timers.remaining() +} + +func (h *HandshakeContext) receivedHandshakeMessage() { + logf(logTypeHandshake, "%p Received handshake, waiting for start of flight = %v", h, h.waitingNextFlight) + // This just enables tests. + if h.hIn == nil { + return + } + + if !h.hIn.datagram { + return + } + + if h.waitingNextFlight { + logf(logTypeHandshake, "Received the start of the flight") + + // Clear the outgoing DTLS queue and terminate the retransmit timer + h.hOut.ClearQueuedMessages() + h.timers.cancel(retransmitTimerLabel) + + // OK, we're not waiting any more. + h.waitingNextFlight = false + } + + // Now pre-emptively arm the ACK timer if it's not armed already. + // We'll automatically dis-arm it at the end of the handshake. + if h.timers.getTimer(ackTimerLabel) == nil { + h.timers.start(ackTimerLabel, h.sendAck, h.timeoutMS/4) + } +} + +func (h *HandshakeContext) receivedEndOfFlight() { + logf(logTypeHandshake, "%p Received the end of the flight", h) + if !h.hIn.datagram { + return + } + + // Empty incoming queue + h.hIn.queued = nil + + // Note that we are waiting for the next flight. + h.waitingNextFlight = true + + // Clear the ACK queue. + h.hIn.recvdRecords = nil + + // Disarm the ACK timer + h.timers.cancel(ackTimerLabel) +} + +func (h *HandshakeContext) receivedFinalFlight() { + logf(logTypeHandshake, "%p Received final flight", h) + if !h.hIn.datagram { + return + } + + // Disarm the ACK timer + h.timers.cancel(ackTimerLabel) + + // But send an ACK immediately. + h.sendAck() +} + +func (h *HandshakeContext) fragmentAcked(seq uint32, offset int, fraglen int) bool { + logf(logTypeHandshake, "Looking to see if fragment %v %v(%v) was acked", seq, offset, fraglen) + for _, f := range h.sentFragments { + if !f.acked { + continue + } + + if f.seq != seq { + continue + } + + if f.offset > offset { + continue + } + + // At this point, we know that the stored fragment starts + // at or before what we want to send, so check where the end + // is. + if f.offset+f.fragLength < offset+fraglen { + continue + } + + return true + } + + return false +} diff --git a/vendor/github.com/bifurcation/mint/frame-reader.go b/vendor/github.com/bifurcation/mint/frame-reader.go index 54f40ce2..4ccfc23f 100644 --- a/vendor/github.com/bifurcation/mint/frame-reader.go +++ b/vendor/github.com/bifurcation/mint/frame-reader.go @@ -67,7 +67,7 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) { f.writeOffset += copied if f.writeOffset < len(f.working) { logf(logTypeVerbose, "Read would have blocked 1") - return nil, nil, WouldBlock + return nil, nil, AlertWouldBlock } // Reset the write offset, because we are now full. f.writeOffset = 0 @@ -94,5 +94,5 @@ func (f *frameReader) process() (hdr []byte, body []byte, err error) { } logf(logTypeVerbose, "Read would have blocked 2") - return nil, nil, WouldBlock + return nil, nil, AlertWouldBlock } diff --git a/vendor/github.com/bifurcation/mint/handshake-layer.go b/vendor/github.com/bifurcation/mint/handshake-layer.go index 888c5f36..de17b30b 100644 --- a/vendor/github.com/bifurcation/mint/handshake-layer.go +++ b/vendor/github.com/bifurcation/mint/handshake-layer.go @@ -35,7 +35,6 @@ type HandshakeMessage struct { datagram bool offset uint32 // Used for DTLS length uint32 - records []uint64 // Used for DTLS cipher *cipherState } @@ -119,6 +118,7 @@ func (h *HandshakeLayer) HandshakeMessageFromBody(body HandshakeMessageBody) (*H } type HandshakeLayer struct { + ctx *HandshakeContext // The handshake we are attached to nonblocking bool // Should we operate in nonblocking mode conn *RecordLayer // Used for reading/writing records frame *frameReader // The buffered frame reader @@ -126,6 +126,7 @@ type HandshakeLayer struct { msgSeq uint32 // The DTLS message sequence number queued []*HandshakeMessage // In/out queue sent []*HandshakeMessage // Sent messages for DTLS + recvdRecords []uint64 // Records we have received. maxFragmentLen int } @@ -152,8 +153,9 @@ func (d handshakeLayerFrameDetails) frameLen(hdr []byte) (int, error) { return int(val), nil } -func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer { +func NewHandshakeLayerTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { h := HandshakeLayer{} + h.ctx = c h.conn = r h.datagram = false h.frame = newFrameReader(&handshakeLayerFrameDetails{false}) @@ -161,8 +163,9 @@ func NewHandshakeLayerTLS(r *RecordLayer) *HandshakeLayer { return &h } -func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer { +func NewHandshakeLayerDTLS(c *HandshakeContext, r *RecordLayer) *HandshakeLayer { h := HandshakeLayer{} + h.ctx = c h.conn = r h.datagram = true h.frame = newFrameReader(&handshakeLayerFrameDetails{true}) @@ -172,16 +175,25 @@ func NewHandshakeLayerDTLS(r *RecordLayer) *HandshakeLayer { func (h *HandshakeLayer) readRecord() error { logf(logTypeVerbose, "Trying to read record") - pt, err := h.conn.ReadRecord() + pt, err := h.conn.readRecordAnyEpoch() if err != nil { return err } - if pt.contentType != RecordTypeHandshake && - pt.contentType != RecordTypeAlert { + switch pt.contentType { + case RecordTypeHandshake, RecordTypeAlert, RecordTypeAck: + default: return fmt.Errorf("tls.handshakelayer: Unexpected record type %d", pt.contentType) } + if pt.contentType == RecordTypeAck { + if !h.datagram { + return fmt.Errorf("tls.handshakelayer: can't have ACK with TLS") + } + logf(logTypeIO, "read ACK") + return h.ctx.processAck(pt.fragment) + } + if pt.contentType == RecordTypeAlert { logf(logTypeIO, "read alert %v", pt.fragment[1]) if len(pt.fragment) < 2 { @@ -191,6 +203,19 @@ func (h *HandshakeLayer) readRecord() error { return Alert(pt.fragment[1]) } + assert(h.ctx.hIn.conn != nil) + if pt.epoch != h.ctx.hIn.conn.cipher.epoch { + // This is out of order but we're dropping it. + // TODO(ekr@rtfm.com): If server, need to retransmit Finished. + if pt.epoch == EpochClear || pt.epoch == EpochHandshakeData { + return nil + } + + // Anything else shouldn't happen. + return AlertIllegalParameter + } + + h.recvdRecords = append(h.recvdRecords, pt.seq) h.frame.addChunk(pt.fragment) return nil @@ -227,9 +252,13 @@ func (h *HandshakeLayer) noteMessageDelivered(seq uint32) { func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMessage, error) { if hm.seq < h.msgSeq { - return nil, WouldBlock + return nil, nil } + // TODO(ekr@rtfm.com): Send an ACK immediately if we got something + // out of order. + h.ctx.receivedHandshakeMessage() + if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { // TODO(ekr@rtfm.com): Check the length? // This is complete. @@ -259,12 +288,12 @@ func (h *HandshakeLayer) newFragmentReceived(hm *HandshakeMessage) (*HandshakeMe func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { if len(h.queued) == 0 { - return nil, WouldBlock + return nil, nil } hm := h.queued[0] if hm.seq != h.msgSeq { - return nil, WouldBlock + return nil, nil } if hm.seq == h.msgSeq && hm.offset == 0 && hm.length == uint32(len(hm.body)) { @@ -307,7 +336,7 @@ func (h *HandshakeLayer) checkMessageAvailable() (*HandshakeMessage, error) { } - return nil, WouldBlock + return nil, nil } func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { @@ -315,19 +344,19 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { var err error hm, err := h.checkMessageAvailable() - if err == nil { - return hm, err - } - if err != WouldBlock { + if err != nil { return nil, err } + if hm != nil { + return hm, nil + } for { logf(logTypeVerbose, "ReadMessage() buffered=%v", len(h.frame.remainder)) if h.frame.needed() > 0 { logf(logTypeVerbose, "Trying to read a new record") err = h.readRecord() - if err != nil && (h.nonblocking || err != WouldBlock) { + if err != nil && (h.nonblocking || err != AlertWouldBlock) { return nil, err } } @@ -336,7 +365,7 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { if err == nil { break } - if err != nil && (h.nonblocking || err != WouldBlock) { + if err != nil && (h.nonblocking || err != AlertWouldBlock) { return nil, err } } @@ -370,12 +399,13 @@ func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error { return nil } -func (h *HandshakeLayer) SendQueuedMessages() error { +func (h *HandshakeLayer) SendQueuedMessages() (int, error) { logf(logTypeHandshake, "Sending outgoing messages") - err := h.WriteMessages(h.queued) - h.ClearQueuedMessages() // This isn't going to work for DTLS, but we'll - // get there. - return err + count, err := h.WriteMessages(h.queued) + if !h.datagram { + h.ClearQueuedMessages() + } + return count, err } func (h *HandshakeLayer) ClearQueuedMessages() { @@ -383,7 +413,7 @@ func (h *HandshakeLayer) ClearQueuedMessages() { h.queued = nil } -func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (int, error) { +func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int) (bool, int, error) { var buf []byte // Figure out if we're going to want the full header or just @@ -408,17 +438,35 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int } body := hm.body[start : start+bodylen] + // Now see if this chunk has been ACKed. This doesn't produce ideal + // retransmission but is simple. + if h.ctx.fragmentAcked(hm.seq, start, bodylen) { + logf(logTypeHandshake, "Fragment %v %v(%v) already acked. Skipping", hm.seq, start, bodylen) + return false, start + bodylen, nil + } + // Encode the data. if hdrlen > 0 { hm2 := *hm hm2.offset = uint32(start) hm2.body = body buf = hm2.Marshal() + hm = &hm2 } else { buf = body } - return start + bodylen, h.conn.writeRecordWithPadding( + if h.datagram { + // Remember that we sent this. + h.ctx.sentFragments = append(h.ctx.sentFragments, &SentHandshakeFragment{ + hm.seq, + start, + len(body), + h.conn.cipher.combineSeq(true), + false, + }) + } + return true, start + bodylen, h.conn.writeRecordWithPadding( &TLSPlaintext{ contentType: RecordTypeHandshake, fragment: buf, @@ -426,38 +474,46 @@ func (h *HandshakeLayer) writeFragment(hm *HandshakeMessage, start int, room int hm.cipher, 0) } -func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) error { +func (h *HandshakeLayer) WriteMessage(hm *HandshakeMessage) (int, error) { start := int(0) if len(hm.body) > maxHandshakeMessageLen { - return fmt.Errorf("Tried to write a handshake message that's too long") + return 0, fmt.Errorf("Tried to write a handshake message that's too long") } + written := 0 + wrote := false + // Always make one pass through to allow EOED (which is empty). for { var err error - start, err = h.writeFragment(hm, start, h.maxFragmentLen) + wrote, start, err = h.writeFragment(hm, start, h.maxFragmentLen) if err != nil { - return err + return 0, err + } + if wrote { + written++ } if start >= len(hm.body) { break } } - return nil + return written, nil } -func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) error { +func (h *HandshakeLayer) WriteMessages(hms []*HandshakeMessage) (int, error) { + written := 0 for _, hm := range hms { logf(logTypeHandshake, "WriteMessage [%d] %x", hm.msgType, hm.body) - err := h.WriteMessage(hm) + wrote, err := h.WriteMessage(hm) if err != nil { - return err + return 0, err } + written += wrote } - return nil + return written, nil } func encodeUint(v uint64, size int, out []byte) []byte { diff --git a/vendor/github.com/bifurcation/mint/negotiation.go b/vendor/github.com/bifurcation/mint/negotiation.go index 4697bbc8..2c80b8d7 100644 --- a/vendor/github.com/bifurcation/mint/negotiation.go +++ b/vendor/github.com/bifurcation/mint/negotiation.go @@ -168,10 +168,11 @@ func CertificateSelection(serverName *string, signatureSchemes []SignatureScheme return nil, 0, fmt.Errorf("No certificates compatible with signature schemes") } -func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) bool { - usingEarlyData := gotEarlyData && usingPSK && allowEarlyData - logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v", usingPSK, gotEarlyData, allowEarlyData, usingEarlyData) - return usingEarlyData +func EarlyDataNegotiation(usingPSK, gotEarlyData, allowEarlyData bool) (using bool, rejected bool) { + using = gotEarlyData && usingPSK && allowEarlyData + rejected = gotEarlyData && !using + logf(logTypeNegotiation, "Early data negotiation (%v, %v, %v) => %v, %v", usingPSK, gotEarlyData, allowEarlyData, using, rejected) + return } func CipherSuiteNegotiation(psk *PreSharedKey, offered, supported []CipherSuite) (CipherSuite, error) { diff --git a/vendor/github.com/bifurcation/mint/record-layer.go b/vendor/github.com/bifurcation/mint/record-layer.go index 761a868d..5cf8ae2c 100644 --- a/vendor/github.com/bifurcation/mint/record-layer.go +++ b/vendor/github.com/bifurcation/mint/record-layer.go @@ -1,7 +1,6 @@ package mint import ( - "bytes" "crypto/cipher" "fmt" "io" @@ -21,6 +20,13 @@ func (err DecryptError) Error() string { return string(err) } +type direction uint8 + +const ( + directionWrite = direction(1) + directionRead = direction(2) +) + // struct { // ContentType type; // ProtocolVersion record_version [0301 for CH, 0303 for others] @@ -31,20 +37,23 @@ type TLSPlaintext struct { // Omitted: record_version (static) // Omitted: length (computed from fragment) contentType RecordType + epoch Epoch + seq uint64 fragment []byte } type cipherState struct { epoch Epoch // DTLS epoch ivLength int // Length of the seq and nonce fields - seq []byte // Zero-padded sequence number + seq uint64 // Zero-padded sequence number iv []byte // Buffer for the IV cipher cipher.AEAD // AEAD cipher } type RecordLayer struct { sync.Mutex - + label string + direction direction version uint16 // The current version number conn io.ReadWriter // The underlying connection frame *frameReader // The buffered frame reader @@ -52,7 +61,9 @@ type RecordLayer struct { cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" cachedError error // Error on the last record read - cipher *cipherState + cipher *cipherState + readCiphers map[Epoch]*cipherState + datagram bool } @@ -76,7 +87,7 @@ func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { } func newCipherStateNull() *cipherState { - return &cipherState{EpochClear, 0, bytes.Repeat([]byte{0}, sequenceNumberLen), nil, nil} + return &cipherState{EpochClear, 0, 0, nil, nil} } func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) { @@ -85,11 +96,13 @@ func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) return nil, err } - return &cipherState{epoch, len(iv), bytes.Repeat([]byte{0}, sequenceNumberLen), iv, cipher}, nil + return &cipherState{epoch, len(iv), 0, iv, cipher}, nil } -func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer { +func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer { r := RecordLayer{} + r.label = "" + r.direction = dir r.conn = conn r.frame = newFrameReader(recordLayerFrameDetails{false}) r.cipher = newCipherStateNull() @@ -97,11 +110,15 @@ func NewRecordLayerTLS(conn io.ReadWriter) *RecordLayer { return &r } -func NewRecordLayerDTLS(conn io.ReadWriter) *RecordLayer { +func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer { r := RecordLayer{} + r.label = "" + r.direction = dir r.conn = conn r.frame = newFrameReader(recordLayerFrameDetails{true}) r.cipher = newCipherStateNull() + r.readCiphers = make(map[Epoch]*cipherState, 0) + r.readCiphers[0] = r.cipher r.datagram = true return &r } @@ -110,53 +127,67 @@ func (r *RecordLayer) SetVersion(v uint16) { r.version = v } +func (r *RecordLayer) ResetClear(seq uint64) { + r.cipher = newCipherStateNull() + r.cipher.seq = seq +} + func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error { cipher, err := newCipherStateAead(epoch, factory, key, iv) if err != nil { return err } r.cipher = cipher + if r.datagram && r.direction == directionRead { + r.readCiphers[epoch] = cipher + } return nil } -func (c *cipherState) formatSeq(datagram bool) []byte { - seq := append([]byte{}, c.seq...) +// TODO(ekr@rtfm.com): This is never used, which is a bug. +func (r *RecordLayer) DiscardReadKey(epoch Epoch) { + if !r.datagram { + return + } + + _, ok := r.readCiphers[epoch] + assert(ok) + delete(r.readCiphers, epoch) +} + +func (c *cipherState) combineSeq(datagram bool) uint64 { + seq := c.seq if datagram { - seq[0] = byte(c.epoch >> 8) - seq[1] = byte(c.epoch & 0xff) + seq |= uint64(c.epoch) << 48 } return seq } -func (c *cipherState) computeNonce(seq []byte) []byte { +func (c *cipherState) computeNonce(seq uint64) []byte { nonce := make([]byte, len(c.iv)) copy(nonce, c.iv) - offset := len(c.iv) - len(seq) - for i, b := range seq { - nonce[i+offset] ^= b + s := seq + + offset := len(c.iv) + for i := 0; i < 8; i++ { + nonce[(offset-i)-1] ^= byte(s & 0xff) + s >>= 8 } + logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce) return nonce } func (c *cipherState) incrementSequenceNumber() { - var i int - for i = len(c.seq) - 1; i >= 0; i-- { - c.seq[i]++ - if c.seq[i] != 0 { - break - } - } - - if i < 0 { + if c.seq >= (1<<48 - 1) { // Not allowed to let sequence number wrap. // Instead, must renegotiate before it does. - // Not likely enough to bother. - // TODO(ekr@rtfm.com): Check for DTLS here - // because the limit is sooner. + // Not likely enough to bother. This is the + // DTLS limit. panic("TLS: sequence number wraparound") } + c.seq++ } func (c *cipherState) overhead() int { @@ -166,8 +197,9 @@ func (c *cipherState) overhead() int { return c.cipher.Overhead() } -func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, padLen int) *TLSPlaintext { - logf(logTypeIO, "Encrypt seq=[%x]", seq) +func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext { + assert(r.direction == directionWrite) + logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq) // Expand the fragment to hold contentType, padding, and overhead originalLen := len(pt.fragment) plaintextLen := originalLen + 1 + padLen @@ -191,8 +223,9 @@ func (r *RecordLayer) encrypt(cipher *cipherState, seq []byte, pt *TLSPlaintext, return out } -func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, error) { - logf(logTypeIO, "Decrypt seq=[%x]", seq) +func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) { + assert(r.direction == directionRead) + logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq) if len(pt.fragment) < r.cipher.overhead() { msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead()) return nil, 0, DecryptError(msg) @@ -207,7 +240,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, // Decrypt _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil) if err != nil { - logf(logTypeIO, "AEAD decryption failure [%x]", pt) + logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt) return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") } @@ -222,6 +255,7 @@ func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq []byte) (*TLSPlaintext, int, // Truncate the message to remove contentType, padding, overhead out.fragment = out.fragment[:newLen] + out.seq = seq return out, padLen, nil } @@ -230,11 +264,11 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { var err error for { - pt, err = r.nextRecord() + pt, err = r.nextRecord(false) if err == nil { break } - if !block || err != WouldBlock { + if !block || err != AlertWouldBlock { return 0, err } } @@ -242,7 +276,7 @@ func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { } func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { - pt, err := r.nextRecord() + pt, err := r.nextRecord(false) // Consume the cached record if there was one r.cachedRecord = nil @@ -251,10 +285,20 @@ func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { return pt, err } -func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { +func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) { + pt, err := r.nextRecord(true) + + // Consume the cached record if there was one + r.cachedRecord = nil + r.cachedError = nil + + return pt, err +} + +func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) { cipher := r.cipher if r.cachedRecord != nil { - logf(logTypeIO, "Returning cached record") + logf(logTypeIO, "%s Returning cached record", r.label) return r.cachedRecord, r.cachedError } @@ -262,9 +306,10 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { // // 1. We get a frame // 2. We try to read off the socket and get nothing, in which case - // return WouldBlock + // returnAlertWouldBlock // 3. We get an error. - err := WouldBlock + var err error + err = AlertWouldBlock var header, body []byte for err != nil { @@ -272,24 +317,24 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) n, err := r.conn.Read(buf) if err != nil { - logf(logTypeIO, "Error reading, %v", err) + logf(logTypeIO, "%s Error reading, %v", r.label, err) return nil, err } if n == 0 { - return nil, WouldBlock + return nil, AlertWouldBlock } - logf(logTypeIO, "Read %v bytes", n) + logf(logTypeIO, "%s Read %v bytes", r.label, n) buf = buf[:n] r.frame.addChunk(buf) } header, body, err = r.frame.process() - // Loop around on WouldBlock to see if some + // Loop around onAlertWouldBlock to see if some // data is now available. - if err != nil && err != WouldBlock { + if err != nil && err != AlertWouldBlock { return nil, err } } @@ -299,7 +344,7 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { switch RecordType(header[0]) { default: return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) - case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData: + case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck: pt.contentType = RecordType(header[0]) } @@ -318,28 +363,48 @@ func (r *RecordLayer) nextRecord() (*TLSPlaintext, error) { pt.fragment = make([]byte, size) copy(pt.fragment, body) + // TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data. + // Attempt to decrypt fragment - if cipher.cipher != nil { - seq := cipher.seq - if r.datagram { - seq = header[3:11] - } - // TODO(ekr@rtfm.com): Handle the wrong epoch. + seq := cipher.seq + if r.datagram { // TODO(ekr@rtfm.com): Handle duplicates. - logf(logTypeIO, "RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), seq, pt.contentType, pt.fragment) + seq, _ = decodeUint(header[3:11], 8) + epoch := Epoch(seq >> 48) + + // Look up the cipher suite from the epoch + c, ok := r.readCiphers[epoch] + if !ok { + logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch) + return nil, AlertWouldBlock + } + + if epoch != cipher.epoch { + logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch, + cipher.epoch, allowOldEpoch) + if !allowOldEpoch { + return nil, AlertWouldBlock + } + cipher = c + } + } + + if cipher.cipher != nil { + logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment) pt, _, err = r.decrypt(pt, seq) if err != nil { - logf(logTypeIO, "Decryption failed") + logf(logTypeIO, "%s Decryption failed", r.label) return nil, err } } + pt.epoch = cipher.epoch // Check that plaintext length is not too long if len(pt.fragment) > maxFragmentLen { return nil, fmt.Errorf("tls.record: Plaintext size too big") } - logf(logTypeIO, "RecordLayer.ReadRecord [%d] [%x]", pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment) r.cachedRecord = pt cipher.incrementSequenceNumber() @@ -355,10 +420,9 @@ func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error } func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error { - seq := cipher.formatSeq(r.datagram) - + seq := cipher.combineSeq(r.datagram) if cipher.cipher != nil { - logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) pt = r.encrypt(cipher, seq, pt, padLen) } else if padLen > 0 { return fmt.Errorf("tls.record: Padding can only be done on encrypted records") @@ -376,16 +440,17 @@ func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherSta byte(r.version >> 8), byte(r.version & 0xff), byte(length >> 8), byte(length)} } else { + header = make([]byte, 13) version := dtlsConvertVersion(r.version) - header = []byte{byte(pt.contentType), + copy(header, []byte{byte(pt.contentType), byte(version >> 8), byte(version & 0xff), - seq[0], seq[1], seq[2], seq[3], - seq[4], seq[5], seq[6], seq[7], - byte(length >> 8), byte(length)} + }) + encodeUint(seq, 8, header[3:]) + encodeUint(uint64(length), 2, header[11:]) } record := append(header, pt.fragment...) - logf(logTypeIO, "RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) + logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) cipher.incrementSequenceNumber() _, err := r.conn.Write(record) diff --git a/vendor/github.com/bifurcation/mint/server-state-machine.go b/vendor/github.com/bifurcation/mint/server-state-machine.go index 0b851f40..f91b22e4 100644 --- a/vendor/github.com/bifurcation/mint/server-state-machine.go +++ b/vendor/github.com/bifurcation/mint/server-state-machine.go @@ -24,14 +24,17 @@ import ( // | [Send CertificateRequest] // Can send | [Send Certificate + CertificateVerify] // app data --> | Send Finished -// after +--------+--------+ -// here No 0-RTT | | 0-RTT -// | v -// | WAIT_EOED <---+ -// | Recv | | | Recv -// | EndOfEarlyData | | | early data -// | | +-----+ -// +> WAIT_FLIGHT2 <-+ +// after here | +// +-----------+--------+ +// | | | +// Rejected 0-RTT | No | | 0-RTT +// | 0-RTT | | +// | | v +// +---->READ_PAST | WAIT_EOED <---+ +// Decrypt | | | Decrypt | Recv | | | Recv +// error | | | OK + HS | EOED | | | early data +// +-----+ | V | +-----+ +// +---> WAIT_FLIGHT2 <-+ // | // +--------+--------+ // No auth | | Client auth @@ -50,16 +53,17 @@ import ( // // NB: Not using state RECVD_CH // -// State Instructions -// START {} -// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] -// WAIT_EOED RekeyIn; -// WAIT_FLIGHT2 {} -// WAIT_CERT_CR {} -// WAIT_CERT {} -// WAIT_CV {} -// WAIT_FINISHED RekeyIn; RekeyOut; -// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) +// State Instructions +// START {} +// NEGOTIATED Send(SH); [RekeyIn;] RekeyOut; Send(EE); [Send(CertReq);] [Send(Cert); Send(CV)] +// WAIT_EOED RekeyIn; +// READ_PAST {} +// WAIT_FLIGHT2 {} +// WAIT_CERT_CR {} +// WAIT_CERT {} +// WAIT_CV {} +// WAIT_FINISHED RekeyIn; RekeyOut; +// CONNECTED StoreTicket || (RekeyIn; [RekeyOut]) // A cookie can be sent to the client in a HRR. type cookie struct { @@ -74,7 +78,7 @@ type cookie struct { type serverStateStart struct { Config *Config conn *Conn - hsCtx HandshakeContext + hsCtx *HandshakeContext } var _ HandshakeState = &serverStateStart{} @@ -235,10 +239,6 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ logf(logTypeHandshake, "[ServerStateStart] Error in PSK negotiation [%v]", err) return nil, nil, AlertInternalError } - if clientSentCookie && initialCipherSuite.Suite != params.Suite { - logf(logTypeHandshake, "[ServerStateStart] Would have selected a different CipherSuite after receiving the client's Cookie") - return nil, nil, AlertInternalError - } } // Figure out if we actually should do DH / PSK @@ -361,7 +361,7 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ // Figure out if we're going to do early data var clientEarlyTrafficSecret []byte connParams.ClientSendingEarlyData = foundExts[ExtensionTypeEarlyData] - connParams.UsingEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) + connParams.UsingEarlyData, connParams.RejectedEarlyData = EarlyDataNegotiation(connParams.UsingPSK, foundExts[ExtensionTypeEarlyData], state.Config.AllowEarlyData) if connParams.UsingEarlyData { h := params.Hash.New() h.Write(clientHello.Marshal()) @@ -379,6 +379,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertNoApplicationProtocol } + state.hsCtx.receivedEndOfFlight() + logf(logTypeHandshake, "[ServerStateStart] -> [ServerStateNegotiated]") state.hsCtx.SetVersion(tls12Version) // Everything after this should be 1.2. return serverStateNegotiated{ @@ -445,7 +447,7 @@ func (state *serverStateStart) generateHRR(cs CipherSuite, legacySessionId []byt type serverStateNegotiated struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext dhGroup NamedGroup dhPublic []byte dhSecret []byte @@ -731,7 +733,6 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat } toSend = append(toSend, []HandshakeAction{ RekeyIn{epoch: EpochEarlyData, KeySet: clientEarlyTrafficKeys}, - ReadEarlyData{}, }...) return nextState, toSend, AlertNoAlert } @@ -739,9 +740,9 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat logf(logTypeHandshake, "[ServerStateNegotiated] -> [ServerStateWaitFlight2]") toSend = append(toSend, []HandshakeAction{ RekeyIn{epoch: EpochHandshakeData, KeySet: clientHandshakeKeys}, - ReadPastEarlyData{}, }...) - waitFlight2 := serverStateWaitFlight2{ + var nextState HandshakeState + nextState = serverStateWaitFlight2{ Config: state.Config, Params: state.Params, hsCtx: state.hsCtx, @@ -753,13 +754,19 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat serverTrafficSecret: serverTrafficSecret, exporterSecret: exporterSecret, } - return waitFlight2, toSend, AlertNoAlert + if state.Params.RejectedEarlyData { + nextState = serverStateReadPastEarlyData{ + hsCtx: state.hsCtx, + next: &nextState, + } + } + return nextState, toSend, AlertNoAlert } type serverStateWaitEOED struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -776,6 +783,38 @@ func (state serverStateWaitEOED) State() State { } func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + for { + logf(logTypeHandshake, "Server reading early data...") + assert(state.hsCtx.hIn.conn.cipher.epoch == EpochEarlyData) + t, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) + if err == AlertWouldBlock { + return nil, nil, AlertWouldBlock + } + + if err != nil { + logf(logTypeHandshake, "Server Error reading record type (1): %v", err) + return nil, nil, AlertBadRecordMAC + } + + logf(logTypeHandshake, "Server got record type(1): %v", t) + + if t != RecordTypeApplicationData { + break + } + + // Read a record into the buffer. Note that this is safe + // in blocking mode because we read the record in + // PeekRecordType. + pt, err := state.hsCtx.hIn.conn.ReadRecord() + if err != nil { + logf(logTypeHandshake, "Server error reading early data record: %v", err) + return nil, nil, AlertInternalError + } + + logf(logTypeHandshake, "Server read early data: %x", pt.fragment) + state.hsCtx.earlyData = append(state.hsCtx.earlyData, pt.fragment...) + } + hm, alert := hr.ReadMessage() if alert != AlertNoAlert { return nil, nil, alert @@ -813,10 +852,44 @@ func (state serverStateWaitEOED) Next(hr handshakeMessageReader) (HandshakeState return waitFlight2, toSend, AlertNoAlert } +var _ HandshakeState = &serverStateReadPastEarlyData{} + +type serverStateReadPastEarlyData struct { + hsCtx *HandshakeContext + next *HandshakeState +} + +func (state serverStateReadPastEarlyData) Next(hr handshakeMessageReader) (HandshakeState, []HandshakeAction, Alert) { + for { + logf(logTypeHandshake, "Server reading past early data...") + // Scan past all records that fail to decrypt + _, err := state.hsCtx.hIn.conn.PeekRecordType(!state.hsCtx.hIn.nonblocking) + if err == nil { + break + } + + if err == AlertWouldBlock { + return nil, nil, AlertWouldBlock + } + + // Continue on DecryptError + _, ok := err.(DecryptError) + if !ok { + return nil, nil, AlertInternalError // Really need something else. + } + } + + return *state.next, nil, AlertNoAlert +} + +func (state serverStateReadPastEarlyData) State() State { + return StateServerReadPastEarlyData +} + type serverStateWaitFlight2 struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -868,7 +941,7 @@ func (state serverStateWaitFlight2) Next(_ handshakeMessageReader) (HandshakeSta type serverStateWaitCert struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte clientHandshakeTrafficSecret []byte @@ -940,7 +1013,7 @@ func (state serverStateWaitCert) Next(hr handshakeMessageReader) (HandshakeState type serverStateWaitCV struct { Config *Config Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte @@ -1023,7 +1096,7 @@ func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, type serverStateWaitFinished struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext cryptoParams CipherSuiteParams masterSecret []byte @@ -1082,6 +1155,8 @@ func (state serverStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS // Compute client traffic keys clientTrafficKeys := makeTrafficKeys(state.cryptoParams, state.clientTrafficSecret) + state.hsCtx.receivedFinalFlight() + logf(logTypeHandshake, "[ServerStateWaitFinished] -> [StateConnected]") nextState := stateConnected{ Params: state.Params, diff --git a/vendor/github.com/bifurcation/mint/state-machine.go b/vendor/github.com/bifurcation/mint/state-machine.go index 7639c5f6..558b76cc 100644 --- a/vendor/github.com/bifurcation/mint/state-machine.go +++ b/vendor/github.com/bifurcation/mint/state-machine.go @@ -17,10 +17,6 @@ type SendQueuedHandshake struct{} type SendEarlyData struct{} -type ReadEarlyData struct{} - -type ReadPastEarlyData struct{} - type RekeyIn struct { epoch Epoch KeySet keySet @@ -31,6 +27,10 @@ type RekeyOut struct { KeySet keySet } +type ResetOut struct { + seq uint64 +} + type StorePSK struct { PSK PreSharedKey } @@ -50,7 +50,6 @@ type AppExtensionHandler interface { type ConnectionOptions struct { ServerName string NextProtos []string - EarlyData []byte } // ConnectionParameters objects represent the parameters negotiated for a @@ -60,6 +59,7 @@ type ConnectionParameters struct { UsingDH bool ClientSendingEarlyData bool UsingEarlyData bool + RejectedEarlyData bool UsingClientAuth bool CipherSuite CipherSuite @@ -69,7 +69,13 @@ type ConnectionParameters struct { // Working state for the handshake. type HandshakeContext struct { - hIn, hOut *HandshakeLayer + timeoutMS uint32 + timers *timerSet + recvdRecords []uint64 + sentFragments []*SentHandshakeFragment + hIn, hOut *HandshakeLayer + waitingNextFlight bool + earlyData []byte } func (hc *HandshakeContext) SetVersion(version uint16) { @@ -84,7 +90,7 @@ func (hc *HandshakeContext) SetVersion(version uint16) { // stateConnected is symmetric between client and server type stateConnected struct { Params ConnectionParameters - hsCtx HandshakeContext + hsCtx *HandshakeContext isClient bool cryptoParams CipherSuiteParams resumptionSecret []byte diff --git a/vendor/github.com/bifurcation/mint/timer.go b/vendor/github.com/bifurcation/mint/timer.go new file mode 100644 index 00000000..0b7f7aff --- /dev/null +++ b/vendor/github.com/bifurcation/mint/timer.go @@ -0,0 +1,122 @@ +package mint + +import ( + "time" +) + +// This is a simple timer implementation. Timers are stored in a sorted +// list. +// TODO(ekr@rtfm.com): Add a way to uncouple these from the system +// clock. +type timerCb func() error + +type timer struct { + label string + cb timerCb + deadline time.Time + duration uint32 +} + +type timerSet struct { + ts []*timer +} + +func newTimerSet() *timerSet { + return &timerSet{} +} + +func (ts *timerSet) start(label string, cb timerCb, delayMs uint32) *timer { + now := time.Now() + t := timer{ + label, + cb, + now.Add(time.Millisecond * time.Duration(delayMs)), + delayMs, + } + logf(logTypeHandshake, "Timer %s set [%v -> %v]", t.label, now, t.deadline) + + var i int + ntimers := len(ts.ts) + for i = 0; i < ntimers; i++ { + if t.deadline.Before(ts.ts[i].deadline) { + break + } + } + + tmp := make([]*timer, 0, ntimers+1) + tmp = append(tmp, ts.ts[:i]...) + tmp = append(tmp, &t) + tmp = append(tmp, ts.ts[i:]...) + ts.ts = tmp + + return &t +} + +// TODO(ekr@rtfm.com): optimize this now that the list is sorted. +// We should be able to do just one list manipulation, as long +// as we're careful about how we handle inserts during callbacks. +func (ts *timerSet) check(now time.Time) error { + for i, t := range ts.ts { + if now.After(t.deadline) { + ts.ts = append(ts.ts[:i], ts.ts[:i+1]...) + if t.cb != nil { + logf(logTypeHandshake, "Timer %s expired [%v > %v]", t.label, now, t.deadline) + cb := t.cb + t.cb = nil + err := cb() + if err != nil { + return err + } + } + } else { + break + } + } + return nil +} + +// Returns the next time any of the timers would fire. +func (ts *timerSet) remaining() (bool, time.Duration) { + for _, t := range ts.ts { + if t.cb != nil { + return true, time.Until(t.deadline) + } + } + + return false, time.Duration(0) +} + +func (ts *timerSet) cancel(label string) { + for _, t := range ts.ts { + if t.label == label { + t.cancel() + } + } +} + +func (ts *timerSet) getTimer(label string) *timer { + for _, t := range ts.ts { + if t.label == label && t.cb != nil { + return t + } + } + return nil +} + +func (ts *timerSet) getAllTimers() []string { + var ret []string + + for _, t := range ts.ts { + if t.cb != nil { + ret = append(ret, t.label) + } + } + + return ret +} + +func (t *timer) cancel() { + logf(logTypeHandshake, "Timer %s cancelled", t.label) + t.cb = nil + t.label = "" +} diff --git a/vendor/vendor.json b/vendor/vendor.json index f1f184fd..2329cb55 100644 --- a/vendor/vendor.json +++ b/vendor/vendor.json @@ -3,10 +3,10 @@ "ignore": "test", "package": [ { - "checksumSHA1": "uPcaX0wCGnS167g9KfKOrDWydWQ=", + "checksumSHA1": "Mb7/IC/rqrfJzLKVV5EIvZ+R8ls=", "path": "github.com/bifurcation/mint", - "revision": "30a67d8540b4f721cfea9f9ae1cd7f22d227a054", - "revisionTime": "2018-02-24T18:21:15Z" + "revision": "93c51c6ce11597a26e246fc33a301d62d3439cd2", + "revisionTime": "2018-07-15T13:32:06Z" }, { "checksumSHA1": "PZNcjO1c9gV/LZzppwpVRl6+QAY=",