package handshake import ( "encoding/asn1" "fmt" "net" "time" "github.com/lucas-clemente/quic-go/internal/crypto" ) const ( stkPrefixIP byte = iota stkPrefixString ) // An STK is a Source Address token. // It is issued by the server and sent to the client. For the client, it is an opaque blob. // The client can send the STK in subsequent handshakes to prove ownership of its IP address. type STK struct { // The remote address this token was issued for. // If the server is run on a net.UDPConn, this is the string representation of the IP address (net.IP.String()) // Otherwise, this is the string representation of the net.Addr (net.Addr.String()) RemoteAddr string // The time that the STK was issued (resolution 1 second) SentTime time.Time } // token is the struct that is used for ASN1 serialization and deserialization type token struct { Data []byte Timestamp int64 } // An STKGenerator generates STKs type STKGenerator struct { stkSource crypto.StkSource } // NewSTKGenerator initializes a new STKGenerator func NewSTKGenerator() (*STKGenerator, error) { stkSource, err := crypto.NewStkSource() if err != nil { return nil, err } return &STKGenerator{ stkSource: stkSource, }, nil } // NewToken generates a new STK token for a given source address func (g *STKGenerator) NewToken(raddr net.Addr) ([]byte, error) { data, err := asn1.Marshal(token{ Data: encodeRemoteAddr(raddr), Timestamp: time.Now().Unix(), }) if err != nil { return nil, err } return g.stkSource.NewToken(data) } // DecodeToken decodes an STK token func (g *STKGenerator) DecodeToken(encrypted []byte) (*STK, error) { // if the client didn't send any STK, DecodeToken will be called with a nil-slice if len(encrypted) == 0 { return nil, nil } data, err := g.stkSource.DecodeToken(encrypted) if err != nil { return nil, err } t := &token{} rest, err := asn1.Unmarshal(data, t) if err != nil { return nil, err } if len(rest) != 0 { return nil, fmt.Errorf("rest when unpacking token: %d", len(rest)) } return &STK{ RemoteAddr: decodeRemoteAddr(t.Data), SentTime: time.Unix(t.Timestamp, 0), }, nil } // encodeRemoteAddr encodes a remote address such that it can be saved in the STK func encodeRemoteAddr(remoteAddr net.Addr) []byte { if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { return append([]byte{stkPrefixIP}, udpAddr.IP...) } return append([]byte{stkPrefixString}, []byte(remoteAddr.String())...) } // decodeRemoteAddr decodes the remote address saved in the STK func decodeRemoteAddr(data []byte) string { // data will never be empty for an STK that we generated. Check it to be on the safe side if len(data) == 0 { return "" } if data[0] == stkPrefixIP { return net.IP(data[1:]).String() } return string(data[1:]) }