Skip to content

Commit

Permalink
refactor(device): do not expose fields
Browse files Browse the repository at this point in the history
  • Loading branch information
iyear committed Oct 23, 2023
1 parent 9db8941 commit ddf7085
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 20 deletions.
8 changes: 4 additions & 4 deletions cdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func (c *CDM) GetLicenseChallenge(pssh *PSSH, typ wvpb.LicenseType, privacyMode

req.EncryptedClientId = encClientID
} else {
req.ClientId = c.device.ClientID
req.ClientId = c.device.ClientID()
}

reqData, err := proto.Marshal(req)
Expand All @@ -127,7 +127,7 @@ func (c *CDM) GetLicenseChallenge(pssh *PSSH, typ wvpb.LicenseType, privacyMode
hashed := sha1.Sum(reqData)
pss, err := rsa.SignPSS(
rand.New(c.rand),
c.device.PrivateKey,
c.device.PrivateKey(),
crypto.SHA1,
hashed[:],
&rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash})
Expand Down Expand Up @@ -161,7 +161,7 @@ func (c *CDM) encryptClientID(cert *wvpb.DrmCertificate) (*wvpb.EncryptedClientI
}

// encryptedClientID
clientID, err := proto.Marshal(c.device.ClientID)
clientID, err := proto.Marshal(c.device.ClientID())
if err != nil {
return nil, fmt.Errorf("marshal client id: %w", err)
}
Expand Down Expand Up @@ -211,7 +211,7 @@ func (c *CDM) parseLicense(license, licenseRequest []byte) ([]*Key, error) {
return nil, fmt.Errorf("invalid license type: %v", signedMsg.GetType())
}

sessionKey, err := c.rsaOAEPDecrypt(c.device.PrivateKey, signedMsg.SessionKey)
sessionKey, err := c.rsaOAEPDecrypt(c.device.PrivateKey(), signedMsg.SessionKey)
if err != nil {
return nil, fmt.Errorf("decrypt session key: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cdm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ var l3cdm *device.Device

func init() {
for _, l3 := range device.L3 {
if l3.SystemID == 4464 {
if l3.DrmCertificate().GetSystemId() == 4464 {
l3cdm = l3
break
}
Expand Down
47 changes: 37 additions & 10 deletions device/device.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package device

import (
"bytes"
"crypto/rsa"
"crypto/x509"
"embed"
Expand All @@ -21,9 +22,25 @@ const (
)

type Device struct {
SystemID uint32
ClientID *wvpb.ClientIdentification
PrivateKey *rsa.PrivateKey
clientID *wvpb.ClientIdentification
cert *wvpb.DrmCertificate
privateKey *rsa.PrivateKey
}

func New(clientID, privateKey []byte) (*Device, error) {
return toDevice(clientID, privateKey)
}

func (d *Device) ClientID() *wvpb.ClientIdentification {
return d.clientID
}

func (d *Device) DrmCertificate() *wvpb.DrmCertificate {
return d.cert
}

func (d *Device) PrivateKey() *rsa.PrivateKey {
return d.privateKey
}

//go:embed l3
Expand Down Expand Up @@ -69,9 +86,8 @@ func readBuildIns() error {
if err != nil {
return fmt.Errorf("read private key: %w", err)
}
block, _ := pem.Decode(privateKeyData)

device, err := toDevice(clientIDData, block.Bytes)
device, err := toDevice(clientIDData, privateKeyData)
if err != nil {
return fmt.Errorf("to device: %w", err)
}
Expand Down Expand Up @@ -149,18 +165,29 @@ func toDevice(clientID, privateKey []byte) (*Device, error) {
}

return &Device{
SystemID: cert.GetSystemId(),
ClientID: c,
PrivateKey: key,
clientID: c,
cert: cert,
privateKey: key,
}, nil
}

// parsePrivateKey modified from https://go.dev/src/crypto/tls/tls.go#L339
func parsePrivateKey(data []byte) (*rsa.PrivateKey, error) {
if key, err := x509.ParsePKCS1PrivateKey(data); err == nil {
var b = make([]byte, len(data))
copy(b, data)

if bytes.HasPrefix(data, []byte("-----")) {
block, _ := pem.Decode(data)
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block containing private key")
}
b = block.Bytes
}

if key, err := x509.ParsePKCS1PrivateKey(b); err == nil {
return key, nil
}
if key, err := x509.ParsePKCS8PrivateKey(data); err == nil {
if key, err := x509.ParsePKCS8PrivateKey(b); err == nil {
switch k := key.(type) {
case *rsa.PrivateKey:
return k, nil
Expand Down
27 changes: 22 additions & 5 deletions device/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,33 @@ func TestL3Device(t *testing.T) {
t.Logf("L3 Devices: %v", len(L3))
}

//go:embed l3/4464/client_id
var clientID []byte

//go:embed l3/4464/private_key
var privateKey []byte

func TestNewDevice(t *testing.T) {
device, err := New(clientID, privateKey)
require.NoError(t, err)

assert.Equal(t, uint32(4464), device.DrmCertificate().GetSystemId())
assert.Equal(t, wvpb.ClientIdentification_DRM_DEVICE_CERTIFICATE, device.ClientID().GetType())
assert.Equal(t, 1434, len(device.ClientID().GetToken()))
assert.Equal(t, 8, len(device.ClientID().GetClientInfo()))
assert.Equal(t, 256, device.PrivateKey().Size())
}

//go:embed testdata/samsung.wvd
var wvd []byte

func TestFromWVD(t *testing.T) {
device, err := FromWVD(bytes.NewReader(wvd))
require.NoError(t, err)

assert.Equal(t, uint32(5536), device.SystemID)
assert.Equal(t, wvpb.ClientIdentification_DRM_DEVICE_CERTIFICATE, device.ClientID.GetType())
assert.Equal(t, 1434, len(device.ClientID.GetToken()))
assert.Equal(t, 8, len(device.ClientID.GetClientInfo()))
assert.Equal(t, 256, device.PrivateKey.Size())
assert.Equal(t, uint32(5536), device.DrmCertificate().GetSystemId())
assert.Equal(t, wvpb.ClientIdentification_DRM_DEVICE_CERTIFICATE, device.ClientID().GetType())
assert.Equal(t, 1434, len(device.ClientID().GetToken()))
assert.Equal(t, 8, len(device.ClientID().GetClientInfo()))
assert.Equal(t, 256, device.PrivateKey().Size())
}

0 comments on commit ddf7085

Please sign in to comment.