Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/mp4ff-decrypt/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ options:
-init string
Path to init file with encryption info (scheme, kid, pssh)
-key string
Required: key (32 hex or 24 base64 chars)
Required: key (32 hex or 24 base64 chars) or kid:key pair. Can be repeated
-version
Get mp4ff version
*/
Expand Down
89 changes: 78 additions & 11 deletions cmd/mp4ff-decrypt/main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package main

import (
"encoding/hex"
"errors"
"flag"
"fmt"
"io"
"os"
"strings"

"github.com/Eyevinn/mp4ff/internal"
"github.com/Eyevinn/mp4ff/mp4"
Expand All @@ -23,10 +25,75 @@ Usage of %s:

type options struct {
initFilePath string
keyStr string
keyStrs stringSliceFlag
version bool
}

type stringSliceFlag []string

func (s *stringSliceFlag) String() string {
return strings.Join(*s, ",")
}

func (s *stringSliceFlag) Set(value string) error {
*s = append(*s, value)
return nil
}

func parseKeys(keyStrs []string) (key []byte, keysByKID map[string][]byte, strictKIDMode bool, err error) {
if len(keyStrs) == 0 {
return nil, nil, false, fmt.Errorf("no key specified")
}

hasKIDPair := false
hasLegacyKey := false
for _, keyStr := range keyStrs {
if strings.Contains(keyStr, ":") {
hasKIDPair = true
} else {
hasLegacyKey = true
}
}

if hasKIDPair && hasLegacyKey {
return nil, nil, false, fmt.Errorf("cannot mix legacy key and kid:key key format")
}

if !hasKIDPair {
if len(keyStrs) != 1 {
return nil, nil, false, fmt.Errorf("multiple legacy keys are not supported")
}
key, err = mp4.UnpackKey(keyStrs[0])
if err != nil {
return nil, nil, false, fmt.Errorf("unpacking key: %w", err)
}
return key, nil, false, nil
}

keysByKID = make(map[string][]byte, len(keyStrs))
for _, keyStr := range keyStrs {
parts := strings.SplitN(keyStr, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return nil, nil, false, fmt.Errorf("bad kid:key format %q", keyStr)
}
kid, err := mp4.UnpackKey(parts[0])
if err != nil {
return nil, nil, false, fmt.Errorf("unpacking kid: %w", err)
}
kidHex := hex.EncodeToString(kid)
if _, exists := keysByKID[kidHex]; exists {
return nil, nil, false, fmt.Errorf("duplicate kid %s", kidHex)
}
k, err := mp4.UnpackKey(parts[1])
if err != nil {
return nil, nil, false, fmt.Errorf("unpacking key for kid %s: %w", kidHex, err)
}
keysByKID[kidHex] = k
}

return nil, keysByKID, true, nil
}

func parseOptions(fs *flag.FlagSet, args []string) (*options, error) {
fs.Usage = func() {
fmt.Fprintf(os.Stderr, usg, appName, appName)
Expand All @@ -36,7 +103,7 @@ func parseOptions(fs *flag.FlagSet, args []string) (*options, error) {

opts := options{}
fs.StringVar(&opts.initFilePath, "init", "", "Path to init file with encryption info (scheme, kid, pssh)")
fs.StringVar(&opts.keyStr, "key", "", "Required: key (32 hex or 24 base64 chars)")
fs.Var(&opts.keyStrs, "key", "Required: key (32 hex or 24 base64 chars) or kid:key pair. Can be repeated")
fs.BoolVar(&opts.version, "version", false, "Get mp4ff version")
err := fs.Parse(args[1:])
return &opts, err
Expand Down Expand Up @@ -73,14 +140,10 @@ func run(args []string) error {
var inFilePath = fs.Arg(0)
var outFilePath = fs.Arg(1)

if opts.keyStr == "" {
fs.Usage()
return fmt.Errorf("no key specified")
}

key, err := mp4.UnpackKey(opts.keyStr)
key, keysByKID, strictKIDMode, err := parseKeys(opts.keyStrs)
if err != nil {
return fmt.Errorf("unpacking key: %w", err)
fs.Usage()
return err
}

ifh, err := os.Open(inFilePath)
Expand All @@ -102,14 +165,18 @@ func run(args []string) error {
defer inith.Close()
}

err = decryptFile(ifh, inith, ofh, key)
err = decryptFileWithKeyMap(ifh, inith, ofh, key, keysByKID, strictKIDMode)
if err != nil {
return fmt.Errorf("decryptFile: %w", err)
}
return nil
}

func decryptFile(r, initR io.Reader, w io.Writer, key []byte) error {
return decryptFileWithKeyMap(r, initR, w, key, nil, false)
}

func decryptFileWithKeyMap(r, initR io.Reader, w io.Writer, key []byte, keysByKID map[string][]byte, strictKIDMode bool) error {
inMp4, err := mp4.DecodeFile(r)
if err != nil {
return err
Expand Down Expand Up @@ -145,7 +212,7 @@ func decryptFile(r, initR io.Reader, w io.Writer, key []byte) error {
}

for _, seg := range inMp4.Segments {
err = mp4.DecryptSegment(seg, decryptInfo, key)
err = mp4.DecryptSegmentWithKeys(seg, decryptInfo, key, keysByKID, strictKIDMode)
if err != nil {
return fmt.Errorf("decryptSegment: %w", err)
}
Expand Down
103 changes: 103 additions & 0 deletions cmd/mp4ff-decrypt/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package main

import (
"bytes"
"encoding/hex"
"fmt"
"os"
"path"
"strings"
"testing"

"github.com/Eyevinn/mp4ff/mp4"
Expand Down Expand Up @@ -123,6 +125,107 @@ func TestDecodeFiles(t *testing.T) {
}
}

func TestParseKeys(t *testing.T) {
legacyKey := "00112233445566778899aabbccddeeff"
kidWithDash := "855ca997-b201-5736-f3d6-a59c9eff84d9"
kidNoDash := "855ca997b2015736f3d6a59c9eff84d9"

t.Run("legacy key", func(t *testing.T) {
key, keysByKID, strictMode, err := parseKeys([]string{legacyKey})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if strictMode {
t.Fatal("expected non-strict mode")
}
if len(key) != 16 {
t.Fatalf("unexpected key length: %d", len(key))
}
if len(keysByKID) != 0 {
t.Fatalf("expected no kid map, got %d", len(keysByKID))
}
})

t.Run("duplicate kid fails", func(t *testing.T) {
_, _, _, err := parseKeys([]string{kidWithDash + ":" + legacyKey, kidNoDash + ":" + legacyKey})
if err == nil {
t.Fatal("expected duplicate kid error")
}
if !strings.Contains(err.Error(), "duplicate kid") {
t.Fatalf("unexpected error: %v", err)
}
})

t.Run("mixed mode fails", func(t *testing.T) {
_, _, _, err := parseKeys([]string{legacyKey, kidNoDash + ":" + legacyKey})
if err == nil {
t.Fatal("expected strict mixed mode error")
}
if !strings.Contains(err.Error(), "cannot mix") {
t.Fatalf("unexpected error: %v", err)
}
})
}

func TestStrictKIDKeySelection(t *testing.T) {
inFile := "../../mp4/testdata/cbcs_audio.mp4"
expectedOutFile := "../../mp4/testdata/cbcs_audiodec.mp4"
rawKey := "5ffd93861fa776e96cccd934898fc1c8"
tmpDir := t.TempDir()
outFile := path.Join(tmpDir, "outfile.mp4")

input, err := mp4.ReadMP4File(inFile)
if err != nil {
t.Fatal(err)
}
decInfo, err := mp4.DecryptInit(input.Init)
if err != nil {
t.Fatal(err)
}
if len(decInfo.TrackInfos) == 0 || decInfo.TrackInfos[0].Sinf == nil {
t.Fatal("missing encrypted track info")
}
kid := decInfo.TrackInfos[0].Sinf.Schi.Tenc.DefaultKID
kidHex := hex.EncodeToString(kid)

t.Run("matching kid decrypts", func(t *testing.T) {
args := []string{"mp4ff-decrypt", "-key", kidHex + ":" + rawKey, inFile, outFile}
err := run(args)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

expectedOut, err := os.ReadFile(expectedOutFile)
if err != nil {
t.Fatal(err)
}
out, err := os.ReadFile(outFile)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(expectedOut, out) {
t.Fatal("output file does not match expected")
}
})

t.Run("missing kid fails", func(t *testing.T) {
missingKID := kidHex
if missingKID[0] == '0' {
missingKID = "1" + missingKID[1:]
} else {
missingKID = "0" + missingKID[1:]
}
args := []string{"mp4ff-decrypt", "-key", missingKID + ":" + rawKey, inFile, outFile}
err := run(args)
if err == nil {
t.Fatal("expected missing kid error")
}
if !strings.Contains(err.Error(), "requested key was not found") {
t.Fatalf("unexpected error: %v", err)
}
})
}

func BenchmarkDecodeCenc(b *testing.B) {
inFile := "../../mp4/testdata/prog_8s_enc_dashinit.mp4"
hexKey := "63cb5f7184dd4b689a5c5ff11ee6a328"
Expand Down
52 changes: 50 additions & 2 deletions mp4/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"encoding/hex"
"fmt"

"github.com/Eyevinn/mp4ff/avc"
Expand Down Expand Up @@ -623,6 +624,13 @@ func DecryptInit(init *InitSegment) (DecryptInfo, error) {

// DecryptSegment decrypts a media segment in place
func DecryptSegment(seg *MediaSegment, di DecryptInfo, key []byte) error {
return DecryptSegmentWithKeys(seg, di, key, nil, false)
}

// DecryptSegmentWithKeys decrypts a media segment in place using either a legacy key
// or keys selected by KID. KID values are expected as 32-char lowercase hex without dashes.
// If strictKIDMode is true, encrypted tracks must have a matching key in keysByKID.
func DecryptSegmentWithKeys(seg *MediaSegment, di DecryptInfo, key []byte, keysByKID map[string][]byte, strictKIDMode bool) error {
for _, frag := range seg.Fragments {
for _, traf := range frag.Moof.Trafs {
hasSenc, _ := traf.ContainsSencBox()
Expand All @@ -635,7 +643,7 @@ func DecryptSegment(seg *MediaSegment, di DecryptInfo, key []byte) error {
}
}
for _, frag := range seg.Fragments {
err := DecryptFragment(frag, di, key)
err := DecryptFragmentWithKeys(frag, di, key, keysByKID, strictKIDMode)
if err != nil {
return err
}
Expand All @@ -649,6 +657,42 @@ func DecryptSegment(seg *MediaSegment, di DecryptInfo, key []byte) error {

// DecryptFragment decrypts a fragment in place
func DecryptFragment(frag *Fragment, di DecryptInfo, key []byte) error {
return DecryptFragmentWithKeys(frag, di, key, nil, false)
}

func getTrackKIDHex(ti DecryptTrackInfo) (string, error) {
if ti.Sinf == nil || ti.Sinf.Schi == nil || ti.Sinf.Schi.Tenc == nil {
return "", fmt.Errorf("missing tenc for trackID=%d", ti.TrackID)
}
kid := ti.Sinf.Schi.Tenc.DefaultKID
if len(kid) != 16 {
return "", fmt.Errorf("bad kid length %d for trackID=%d", len(kid), ti.TrackID)
}
return hex.EncodeToString(kid), nil
}

func getTrackKey(ti DecryptTrackInfo, key []byte, keysByKID map[string][]byte, strictKIDMode bool) ([]byte, error) {
if len(keysByKID) == 0 {
return key, nil
}
kidHex, err := getTrackKIDHex(ti)
if err != nil {
return nil, err
}
mappedKey, ok := keysByKID[kidHex]
if !ok {
if strictKIDMode {
return nil, fmt.Errorf("requested key was not found for kid=%s", kidHex)
}
return key, nil
}
return mappedKey, nil
}

// DecryptFragmentWithKeys decrypts a fragment in place using either a legacy key
// or keys selected by KID. KID values are expected as 32-char lowercase hex without dashes.
// If strictKIDMode is true, encrypted tracks must have a matching key in keysByKID.
func DecryptFragmentWithKeys(frag *Fragment, di DecryptInfo, key []byte, keysByKID map[string][]byte, strictKIDMode bool) error {
moof := frag.Moof
var nrBytesRemoved uint64 = 0
for _, traf := range moof.Trafs {
Expand All @@ -671,6 +715,10 @@ func DecryptFragment(frag *Fragment, di DecryptInfo, key []byte) error {
}

tenc := ti.Sinf.Schi.Tenc
trackKey, err := getTrackKey(ti, key, keysByKID, strictKIDMode)
if err != nil {
return err
}
samples, err := frag.GetFullSamples(ti.Trex)
if err != nil {
return err
Expand All @@ -682,7 +730,7 @@ func DecryptFragment(frag *Fragment, di DecryptInfo, key []byte) error {
senc = traf.UUIDSenc.Senc
}

err = decryptSamplesInPlace(schemeType, samples, key, tenc, senc)
err = decryptSamplesInPlace(schemeType, samples, trackKey, tenc, senc)
if err != nil {
return err
}
Expand Down
Loading