diff --git a/cmd/mp4ff-decrypt/doc.go b/cmd/mp4ff-decrypt/doc.go index 74f3e0f8..09a9637a 100644 --- a/cmd/mp4ff-decrypt/doc.go +++ b/cmd/mp4ff-decrypt/doc.go @@ -10,7 +10,7 @@ options: -init string Path to init file with encryption info (scheme, kid, pssh) -key string - Required: key (32 hex or 24 base64 chars) + Required: key (32 hex or 24 base64 chars) or kid:key pair. Can be repeated -version Get mp4ff version */ diff --git a/cmd/mp4ff-decrypt/main.go b/cmd/mp4ff-decrypt/main.go index 61b1bb62..62b4bd8b 100644 --- a/cmd/mp4ff-decrypt/main.go +++ b/cmd/mp4ff-decrypt/main.go @@ -1,11 +1,13 @@ package main import ( + "encoding/hex" "errors" "flag" "fmt" "io" "os" + "strings" "github.com/Eyevinn/mp4ff/internal" "github.com/Eyevinn/mp4ff/mp4" @@ -23,10 +25,75 @@ Usage of %s: type options struct { initFilePath string - keyStr string + keyStrs stringSliceFlag version bool } +type stringSliceFlag []string + +func (s *stringSliceFlag) String() string { + return strings.Join(*s, ",") +} + +func (s *stringSliceFlag) Set(value string) error { + *s = append(*s, value) + return nil +} + +func parseKeys(keyStrs []string) (key []byte, keysByKID map[string][]byte, strictKIDMode bool, err error) { + if len(keyStrs) == 0 { + return nil, nil, false, fmt.Errorf("no key specified") + } + + hasKIDPair := false + hasLegacyKey := false + for _, keyStr := range keyStrs { + if strings.Contains(keyStr, ":") { + hasKIDPair = true + } else { + hasLegacyKey = true + } + } + + if hasKIDPair && hasLegacyKey { + return nil, nil, false, fmt.Errorf("cannot mix legacy key and kid:key key format") + } + + if !hasKIDPair { + if len(keyStrs) != 1 { + return nil, nil, false, fmt.Errorf("multiple legacy keys are not supported") + } + key, err = mp4.UnpackKey(keyStrs[0]) + if err != nil { + return nil, nil, false, fmt.Errorf("unpacking key: %w", err) + } + return key, nil, false, nil + } + + keysByKID = make(map[string][]byte, len(keyStrs)) + for _, keyStr := range keyStrs { + parts := strings.SplitN(keyStr, ":", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return nil, nil, false, fmt.Errorf("bad kid:key format %q", keyStr) + } + kid, err := mp4.UnpackKey(parts[0]) + if err != nil { + return nil, nil, false, fmt.Errorf("unpacking kid: %w", err) + } + kidHex := hex.EncodeToString(kid) + if _, exists := keysByKID[kidHex]; exists { + return nil, nil, false, fmt.Errorf("duplicate kid %s", kidHex) + } + k, err := mp4.UnpackKey(parts[1]) + if err != nil { + return nil, nil, false, fmt.Errorf("unpacking key for kid %s: %w", kidHex, err) + } + keysByKID[kidHex] = k + } + + return nil, keysByKID, true, nil +} + func parseOptions(fs *flag.FlagSet, args []string) (*options, error) { fs.Usage = func() { fmt.Fprintf(os.Stderr, usg, appName, appName) @@ -36,7 +103,7 @@ func parseOptions(fs *flag.FlagSet, args []string) (*options, error) { opts := options{} fs.StringVar(&opts.initFilePath, "init", "", "Path to init file with encryption info (scheme, kid, pssh)") - fs.StringVar(&opts.keyStr, "key", "", "Required: key (32 hex or 24 base64 chars)") + fs.Var(&opts.keyStrs, "key", "Required: key (32 hex or 24 base64 chars) or kid:key pair. Can be repeated") fs.BoolVar(&opts.version, "version", false, "Get mp4ff version") err := fs.Parse(args[1:]) return &opts, err @@ -73,14 +140,10 @@ func run(args []string) error { var inFilePath = fs.Arg(0) var outFilePath = fs.Arg(1) - if opts.keyStr == "" { - fs.Usage() - return fmt.Errorf("no key specified") - } - - key, err := mp4.UnpackKey(opts.keyStr) + key, keysByKID, strictKIDMode, err := parseKeys(opts.keyStrs) if err != nil { - return fmt.Errorf("unpacking key: %w", err) + fs.Usage() + return err } ifh, err := os.Open(inFilePath) @@ -102,7 +165,7 @@ func run(args []string) error { defer inith.Close() } - err = decryptFile(ifh, inith, ofh, key) + err = decryptFileWithKeyMap(ifh, inith, ofh, key, keysByKID, strictKIDMode) if err != nil { return fmt.Errorf("decryptFile: %w", err) } @@ -110,6 +173,10 @@ func run(args []string) error { } func decryptFile(r, initR io.Reader, w io.Writer, key []byte) error { + return decryptFileWithKeyMap(r, initR, w, key, nil, false) +} + +func decryptFileWithKeyMap(r, initR io.Reader, w io.Writer, key []byte, keysByKID map[string][]byte, strictKIDMode bool) error { inMp4, err := mp4.DecodeFile(r) if err != nil { return err @@ -145,7 +212,7 @@ func decryptFile(r, initR io.Reader, w io.Writer, key []byte) error { } for _, seg := range inMp4.Segments { - err = mp4.DecryptSegment(seg, decryptInfo, key) + err = mp4.DecryptSegmentWithKeys(seg, decryptInfo, key, keysByKID, strictKIDMode) if err != nil { return fmt.Errorf("decryptSegment: %w", err) } diff --git a/cmd/mp4ff-decrypt/main_test.go b/cmd/mp4ff-decrypt/main_test.go index 90b8ef72..13698faa 100644 --- a/cmd/mp4ff-decrypt/main_test.go +++ b/cmd/mp4ff-decrypt/main_test.go @@ -2,9 +2,11 @@ package main import ( "bytes" + "encoding/hex" "fmt" "os" "path" + "strings" "testing" "github.com/Eyevinn/mp4ff/mp4" @@ -123,6 +125,107 @@ func TestDecodeFiles(t *testing.T) { } } +func TestParseKeys(t *testing.T) { + legacyKey := "00112233445566778899aabbccddeeff" + kidWithDash := "855ca997-b201-5736-f3d6-a59c9eff84d9" + kidNoDash := "855ca997b2015736f3d6a59c9eff84d9" + + t.Run("legacy key", func(t *testing.T) { + key, keysByKID, strictMode, err := parseKeys([]string{legacyKey}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if strictMode { + t.Fatal("expected non-strict mode") + } + if len(key) != 16 { + t.Fatalf("unexpected key length: %d", len(key)) + } + if len(keysByKID) != 0 { + t.Fatalf("expected no kid map, got %d", len(keysByKID)) + } + }) + + t.Run("duplicate kid fails", func(t *testing.T) { + _, _, _, err := parseKeys([]string{kidWithDash + ":" + legacyKey, kidNoDash + ":" + legacyKey}) + if err == nil { + t.Fatal("expected duplicate kid error") + } + if !strings.Contains(err.Error(), "duplicate kid") { + t.Fatalf("unexpected error: %v", err) + } + }) + + t.Run("mixed mode fails", func(t *testing.T) { + _, _, _, err := parseKeys([]string{legacyKey, kidNoDash + ":" + legacyKey}) + if err == nil { + t.Fatal("expected strict mixed mode error") + } + if !strings.Contains(err.Error(), "cannot mix") { + t.Fatalf("unexpected error: %v", err) + } + }) +} + +func TestStrictKIDKeySelection(t *testing.T) { + inFile := "../../mp4/testdata/cbcs_audio.mp4" + expectedOutFile := "../../mp4/testdata/cbcs_audiodec.mp4" + rawKey := "5ffd93861fa776e96cccd934898fc1c8" + tmpDir := t.TempDir() + outFile := path.Join(tmpDir, "outfile.mp4") + + input, err := mp4.ReadMP4File(inFile) + if err != nil { + t.Fatal(err) + } + decInfo, err := mp4.DecryptInit(input.Init) + if err != nil { + t.Fatal(err) + } + if len(decInfo.TrackInfos) == 0 || decInfo.TrackInfos[0].Sinf == nil { + t.Fatal("missing encrypted track info") + } + kid := decInfo.TrackInfos[0].Sinf.Schi.Tenc.DefaultKID + kidHex := hex.EncodeToString(kid) + + t.Run("matching kid decrypts", func(t *testing.T) { + args := []string{"mp4ff-decrypt", "-key", kidHex + ":" + rawKey, inFile, outFile} + err := run(args) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + expectedOut, err := os.ReadFile(expectedOutFile) + if err != nil { + t.Fatal(err) + } + out, err := os.ReadFile(outFile) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(expectedOut, out) { + t.Fatal("output file does not match expected") + } + }) + + t.Run("missing kid fails", func(t *testing.T) { + missingKID := kidHex + if missingKID[0] == '0' { + missingKID = "1" + missingKID[1:] + } else { + missingKID = "0" + missingKID[1:] + } + args := []string{"mp4ff-decrypt", "-key", missingKID + ":" + rawKey, inFile, outFile} + err := run(args) + if err == nil { + t.Fatal("expected missing kid error") + } + if !strings.Contains(err.Error(), "requested key was not found") { + t.Fatalf("unexpected error: %v", err) + } + }) +} + func BenchmarkDecodeCenc(b *testing.B) { inFile := "../../mp4/testdata/prog_8s_enc_dashinit.mp4" hexKey := "63cb5f7184dd4b689a5c5ff11ee6a328" diff --git a/mp4/crypto.go b/mp4/crypto.go index b6eb5ff1..67dfd53e 100644 --- a/mp4/crypto.go +++ b/mp4/crypto.go @@ -4,6 +4,7 @@ import ( "crypto/aes" "crypto/cipher" "encoding/binary" + "encoding/hex" "fmt" "github.com/Eyevinn/mp4ff/avc" @@ -623,6 +624,13 @@ func DecryptInit(init *InitSegment) (DecryptInfo, error) { // DecryptSegment decrypts a media segment in place func DecryptSegment(seg *MediaSegment, di DecryptInfo, key []byte) error { + return DecryptSegmentWithKeys(seg, di, key, nil, false) +} + +// DecryptSegmentWithKeys decrypts a media segment in place using either a legacy key +// or keys selected by KID. KID values are expected as 32-char lowercase hex without dashes. +// If strictKIDMode is true, encrypted tracks must have a matching key in keysByKID. +func DecryptSegmentWithKeys(seg *MediaSegment, di DecryptInfo, key []byte, keysByKID map[string][]byte, strictKIDMode bool) error { for _, frag := range seg.Fragments { for _, traf := range frag.Moof.Trafs { hasSenc, _ := traf.ContainsSencBox() @@ -635,7 +643,7 @@ func DecryptSegment(seg *MediaSegment, di DecryptInfo, key []byte) error { } } for _, frag := range seg.Fragments { - err := DecryptFragment(frag, di, key) + err := DecryptFragmentWithKeys(frag, di, key, keysByKID, strictKIDMode) if err != nil { return err } @@ -649,6 +657,42 @@ func DecryptSegment(seg *MediaSegment, di DecryptInfo, key []byte) error { // DecryptFragment decrypts a fragment in place func DecryptFragment(frag *Fragment, di DecryptInfo, key []byte) error { + return DecryptFragmentWithKeys(frag, di, key, nil, false) +} + +func getTrackKIDHex(ti DecryptTrackInfo) (string, error) { + if ti.Sinf == nil || ti.Sinf.Schi == nil || ti.Sinf.Schi.Tenc == nil { + return "", fmt.Errorf("missing tenc for trackID=%d", ti.TrackID) + } + kid := ti.Sinf.Schi.Tenc.DefaultKID + if len(kid) != 16 { + return "", fmt.Errorf("bad kid length %d for trackID=%d", len(kid), ti.TrackID) + } + return hex.EncodeToString(kid), nil +} + +func getTrackKey(ti DecryptTrackInfo, key []byte, keysByKID map[string][]byte, strictKIDMode bool) ([]byte, error) { + if len(keysByKID) == 0 { + return key, nil + } + kidHex, err := getTrackKIDHex(ti) + if err != nil { + return nil, err + } + mappedKey, ok := keysByKID[kidHex] + if !ok { + if strictKIDMode { + return nil, fmt.Errorf("requested key was not found for kid=%s", kidHex) + } + return key, nil + } + return mappedKey, nil +} + +// DecryptFragmentWithKeys decrypts a fragment in place using either a legacy key +// or keys selected by KID. KID values are expected as 32-char lowercase hex without dashes. +// If strictKIDMode is true, encrypted tracks must have a matching key in keysByKID. +func DecryptFragmentWithKeys(frag *Fragment, di DecryptInfo, key []byte, keysByKID map[string][]byte, strictKIDMode bool) error { moof := frag.Moof var nrBytesRemoved uint64 = 0 for _, traf := range moof.Trafs { @@ -671,6 +715,10 @@ func DecryptFragment(frag *Fragment, di DecryptInfo, key []byte) error { } tenc := ti.Sinf.Schi.Tenc + trackKey, err := getTrackKey(ti, key, keysByKID, strictKIDMode) + if err != nil { + return err + } samples, err := frag.GetFullSamples(ti.Trex) if err != nil { return err @@ -682,7 +730,7 @@ func DecryptFragment(frag *Fragment, di DecryptInfo, key []byte) error { senc = traf.UUIDSenc.Senc } - err = decryptSamplesInPlace(schemeType, samples, key, tenc, senc) + err = decryptSamplesInPlace(schemeType, samples, trackKey, tenc, senc) if err != nil { return err }