From 579e6a1eee2006efe660a3ff17dac1d477300651 Mon Sep 17 00:00:00 2001 From: lew Date: Fri, 13 Feb 2026 15:12:22 +0000 Subject: [PATCH] feat(identity): added --add-recipient and --remove-recipient flags for multi-recipient keys --- README.md | 20 ++++ cmd/doctor.go | 5 +- cmd/doctor_test.go | 2 +- cmd/identity.go | 139 ++++++++++++++++++++++++- cmd/list.go | 11 +- cmd/mv.go | 13 ++- cmd/ndjson.go | 12 +-- cmd/restore.go | 13 +-- cmd/secret.go | 150 +++++++++++++++++++++++++-- cmd/secret_test.go | 251 +++++++++++++++++++++++++++++++++++++++++++-- cmd/set.go | 8 +- main_test.go | 2 +- 12 files changed, 575 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index a1d6a66..7c5e3b3 100644 --- a/README.md +++ b/README.md @@ -777,6 +777,26 @@ pda identity --new

+By default, secrets are encrypted only for your own identity. To encrypt for additional recipients (e.g. a teammate or another device), use `--add-recipient` with their age public key. All existing secrets are automatically re-encrypted for every recipient. +```bash +# Add a recipient. All secrets are re-encrypted for both keys. +pda identity --add-recipient age1ql3z7hjy54pw3hyww5ayyfg7zqgvc7w3j2elw8zmrj2kg5sfn9aqmcac8p +# ok re-encrypted api-key +# ok added recipient age1ql3z... +# ok re-encrypted 1 secret(s) + +# Remove a recipient. Secrets are re-encrypted without their key. +pda identity --remove-recipient age1ql3z7hjy54pw3hyww5ayyfg7zqgvc7w3j2elw8zmrj2kg5sfn9aqmcac8p + +# Additional recipients are shown in the default identity display. +pda identity +# ok pubkey age1abc... +# ok identity ~/.local/share/pda/identity.txt +# ok recipient age1ql3z... +``` + +

+ ### Doctor `pda doctor` runs a set of health checks of your environment. diff --git a/cmd/doctor.go b/cmd/doctor.go index dcb1b26..8e6fd7f 100644 --- a/cmd/doctor.go +++ b/cmd/doctor.go @@ -134,8 +134,7 @@ func runDoctor(w io.Writer) bool { issues = append(issues, "Fix with 'pda config edit' or 'pda config init --new'") } if unexpectedFiles(cfgDir, map[string]bool{ - "config.toml": true, - "identity.txt": true, + "config.toml": true, }) { issues = append(issues, "Unexpected file(s) in directory") } @@ -353,7 +352,7 @@ func unexpectedDataFiles(dir string) bool { if e.IsDir() && name == ".git" { continue } - if !e.IsDir() && (name == ".gitignore" || filepath.Ext(name) == ".ndjson") { + if !e.IsDir() && (name == ".gitignore" || name == "identity.txt" || name == "recipients.txt" || filepath.Ext(name) == ".ndjson") { continue } return true diff --git a/cmd/doctor_test.go b/cmd/doctor_test.go index 9efd9e3..3bb5f8c 100644 --- a/cmd/doctor_test.go +++ b/cmd/doctor_test.go @@ -71,7 +71,7 @@ func TestDoctorIdentityPermissions(t *testing.T) { t.Setenv("PDA_DATA", dataDir) t.Setenv("PDA_CONFIG", configDir) - idPath := filepath.Join(configDir, "identity.txt") + idPath := filepath.Join(dataDir, "identity.txt") if err := os.WriteFile(idPath, []byte("placeholder\n"), 0o644); err != nil { t.Fatal(err) } diff --git a/cmd/identity.go b/cmd/identity.go index 89e81a9..27a71c8 100644 --- a/cmd/identity.go +++ b/cmd/identity.go @@ -3,6 +3,7 @@ package cmd import ( "fmt" + "filippo.io/age" "github.com/spf13/cobra" ) @@ -24,6 +25,14 @@ func identityRun(cmd *cobra.Command, args []string) error { if err != nil { return err } + addRecipient, err := cmd.Flags().GetString("add-recipient") + if err != nil { + return err + } + removeRecipient, err := cmd.Flags().GetString("remove-recipient") + if err != nil { + return err + } if createNew { existing, err := loadIdentity() @@ -45,6 +54,14 @@ func identityRun(cmd *cobra.Command, args []string) error { return nil } + if addRecipient != "" { + return identityAddRecipient(addRecipient) + } + + if removeRecipient != "" { + return identityRemoveRecipient(removeRecipient) + } + if showPath { path, err := identityPath() if err != nil { @@ -66,12 +83,132 @@ func identityRun(cmd *cobra.Command, args []string) error { path, _ := identityPath() okf("pubkey %s", id.Recipient()) okf("identity %s", path) + + extra, err := loadRecipients() + if err != nil { + return fmt.Errorf("cannot load recipients: %v", err) + } + for _, r := range extra { + okf("recipient %s", r) + } + return nil } +func identityAddRecipient(key string) error { + r, err := age.ParseX25519Recipient(key) + if err != nil { + return fmt.Errorf("cannot add recipient: %v", err) + } + + identity, err := loadIdentity() + if err != nil { + return fmt.Errorf("cannot add recipient: %v", err) + } + if identity == nil { + return withHint( + fmt.Errorf("cannot add recipient: no identity found"), + "create one first with 'pda identity --new'", + ) + } + + if r.String() == identity.Recipient().String() { + return fmt.Errorf("cannot add recipient: key is your own identity") + } + + existing, err := loadRecipients() + if err != nil { + return fmt.Errorf("cannot add recipient: %v", err) + } + for _, e := range existing { + if e.String() == r.String() { + return fmt.Errorf("cannot add recipient: key already present") + } + } + + existing = append(existing, r) + if err := saveRecipients(existing); err != nil { + return fmt.Errorf("cannot add recipient: %v", err) + } + + recipients, err := allRecipients(identity) + if err != nil { + return fmt.Errorf("cannot add recipient: %v", err) + } + + count, err := reencryptAllStores(identity, recipients) + if err != nil { + return fmt.Errorf("cannot add recipient: %v", err) + } + + okf("added recipient %s", r) + if count > 0 { + okf("re-encrypted %d secret(s)", count) + } + return autoSync("added recipient") +} + +func identityRemoveRecipient(key string) error { + r, err := age.ParseX25519Recipient(key) + if err != nil { + return fmt.Errorf("cannot remove recipient: %v", err) + } + + identity, err := loadIdentity() + if err != nil { + return fmt.Errorf("cannot remove recipient: %v", err) + } + if identity == nil { + return withHint( + fmt.Errorf("cannot remove recipient: no identity found"), + "create one first with 'pda identity --new'", + ) + } + + existing, err := loadRecipients() + if err != nil { + return fmt.Errorf("cannot remove recipient: %v", err) + } + + found := false + var updated []*age.X25519Recipient + for _, e := range existing { + if e.String() == r.String() { + found = true + continue + } + updated = append(updated, e) + } + if !found { + return fmt.Errorf("cannot remove recipient: key not found") + } + + if err := saveRecipients(updated); err != nil { + return fmt.Errorf("cannot remove recipient: %v", err) + } + + recipients, err := allRecipients(identity) + if err != nil { + return fmt.Errorf("cannot remove recipient: %v", err) + } + + count, err := reencryptAllStores(identity, recipients) + if err != nil { + return fmt.Errorf("cannot remove recipient: %v", err) + } + + okf("removed recipient %s", r) + if count > 0 { + okf("re-encrypted %d secret(s)", count) + } + return autoSync("removed recipient") +} + func init() { identityCmd.Flags().Bool("new", false, "generate a new identity (errors if one already exists)") identityCmd.Flags().Bool("path", false, "print only the identity file path") - identityCmd.MarkFlagsMutuallyExclusive("new", "path") + identityCmd.Flags().String("add-recipient", "", "add an age public key as an additional encryption recipient") + identityCmd.Flags().String("remove-recipient", "", "remove an age public key from the recipient list") + identityCmd.MarkFlagsMutuallyExclusive("new", "path", "add-recipient", "remove-recipient") rootCmd.AddCommand(identityCmd) } diff --git a/cmd/list.go b/cmd/list.go index ee604e3..336f18d 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -33,7 +33,6 @@ import ( "strings" "unicode/utf8" - "filippo.io/age" "github.com/jedib0t/go-pretty/v6/table" "github.com/jedib0t/go-pretty/v6/text" "github.com/spf13/cobra" @@ -218,9 +217,9 @@ func list(cmd *cobra.Command, args []string) error { } identity, _ := loadIdentity() - var recipient *age.X25519Recipient - if identity != nil { - recipient = identity.Recipient() + recipients, err := allRecipients(identity) + if err != nil { + return fmt.Errorf("cannot ls '%s': %v", targetDB, err) } var entries []Entry @@ -297,7 +296,7 @@ func list(cmd *cobra.Command, args []string) error { // NDJSON format: emit JSON lines directly (encrypted form for secrets) if listFormat.String() == "ndjson" { for _, e := range filtered { - je, err := encodeJsonEntry(e, recipient) + je, err := encodeJsonEntry(e, recipients) if err != nil { return fmt.Errorf("cannot ls '%s': %v", targetDB, err) } @@ -315,7 +314,7 @@ func list(cmd *cobra.Command, args []string) error { if listFormat.String() == "json" { var jsonEntries []jsonEntry for _, e := range filtered { - je, err := encodeJsonEntry(e, recipient) + je, err := encodeJsonEntry(e, recipients) if err != nil { return fmt.Errorf("cannot ls '%s': %v", targetDB, err) } diff --git a/cmd/mv.go b/cmd/mv.go index ecc5e92..2e1bf7e 100644 --- a/cmd/mv.go +++ b/cmd/mv.go @@ -26,7 +26,6 @@ import ( "fmt" "strings" - "filippo.io/age" "github.com/spf13/cobra" ) @@ -75,9 +74,9 @@ func mvImpl(cmd *cobra.Command, args []string, keepSource bool) error { promptOverwrite := !yes && (interactive || config.Key.AlwaysPromptOverwrite) identity, _ := loadIdentity() - var recipient *age.X25519Recipient - if identity != nil { - recipient = identity.Recipient() + recipients, err := allRecipients(identity) + if err != nil { + return err } fromSpec, err := store.parseKey(args[0], true) @@ -161,7 +160,7 @@ func mvImpl(cmd *cobra.Command, args []string, keepSource bool) error { dstEntries = append(dstEntries[:idx], dstEntries[idx+1:]...) } } - if err := writeStoreFile(dstPath, dstEntries, recipient); err != nil { + if err := writeStoreFile(dstPath, dstEntries, recipients); err != nil { return err } } else { @@ -171,12 +170,12 @@ func mvImpl(cmd *cobra.Command, args []string, keepSource bool) error { } else { dstEntries = append(dstEntries, newEntry) } - if err := writeStoreFile(dstPath, dstEntries, recipient); err != nil { + if err := writeStoreFile(dstPath, dstEntries, recipients); err != nil { return err } if !keepSource { srcEntries = append(srcEntries[:srcIdx], srcEntries[srcIdx+1:]...) - if err := writeStoreFile(srcPath, srcEntries, recipient); err != nil { + if err := writeStoreFile(srcPath, srcEntries, recipients); err != nil { return err } } diff --git a/cmd/ndjson.go b/cmd/ndjson.go index 3908232..2e7f855 100644 --- a/cmd/ndjson.go +++ b/cmd/ndjson.go @@ -98,8 +98,8 @@ func readStoreFile(path string, identity *age.X25519Identity) ([]Entry, error) { // writeStoreFile atomically writes entries to an NDJSON file, sorted by key. // Expired entries are excluded. Empty entry list writes an empty file. -// If recipient is nil, secret entries are written as-is (locked passthrough). -func writeStoreFile(path string, entries []Entry, recipient *age.X25519Recipient) error { +// If recipients is empty, secret entries are written as-is (locked passthrough). +func writeStoreFile(path string, entries []Entry, recipients []age.Recipient) error { // Sort by key for deterministic output slices.SortFunc(entries, func(a, b Entry) int { return strings.Compare(a.Key, b.Key) @@ -121,7 +121,7 @@ func writeStoreFile(path string, entries []Entry, recipient *age.X25519Recipient if e.ExpiresAt > 0 && e.ExpiresAt <= now { continue } - je, err := encodeJsonEntry(e, recipient) + je, err := encodeJsonEntry(e, recipients) if err != nil { return fmt.Errorf("key '%s': %w", e.Key, err) } @@ -182,7 +182,7 @@ func decodeJsonEntry(je jsonEntry, identity *age.X25519Identity) (Entry, error) return Entry{Key: je.Key, Value: value, ExpiresAt: expiresAt}, nil } -func encodeJsonEntry(e Entry, recipient *age.X25519Recipient) (jsonEntry, error) { +func encodeJsonEntry(e Entry, recipients []age.Recipient) (jsonEntry, error) { je := jsonEntry{Key: e.Key} if e.ExpiresAt > 0 { ts := int64(e.ExpiresAt) @@ -196,10 +196,10 @@ func encodeJsonEntry(e Entry, recipient *age.X25519Recipient) (jsonEntry, error) return je, nil } if e.Secret { - if recipient == nil { + if len(recipients) == 0 { return je, fmt.Errorf("no recipient available to encrypt") } - ciphertext, err := encrypt(e.Value, recipient) + ciphertext, err := encrypt(e.Value, recipients...) if err != nil { return je, fmt.Errorf("encrypt: %w", err) } diff --git a/cmd/restore.go b/cmd/restore.go index ba4a577..70948ba 100644 --- a/cmd/restore.go +++ b/cmd/restore.go @@ -31,6 +31,7 @@ import ( "strings" "filippo.io/age" + "github.com/gobwas/glob" "github.com/spf13/cobra" ) @@ -97,9 +98,9 @@ func restore(cmd *cobra.Command, args []string) error { } identity, _ := loadIdentity() - var recipient *age.X25519Recipient - if identity != nil { - recipient = identity.Recipient() + recipients, err := allRecipients(identity) + if err != nil { + return fmt.Errorf("cannot restore '%s': %v", displayTarget, err) } var promptReader io.Reader @@ -121,7 +122,7 @@ func restore(cmd *cobra.Command, args []string) error { promptOverwrite: promptOverwrite, drop: drop, identity: identity, - recipient: recipient, + recipients: recipients, promptReader: promptReader, } @@ -193,7 +194,7 @@ type restoreOpts struct { promptOverwrite bool drop bool identity *age.X25519Identity - recipient *age.X25519Recipient + recipients []age.Recipient promptReader io.Reader } @@ -310,7 +311,7 @@ func restoreEntries(decoder *json.Decoder, storePaths map[string]string, default for _, acc := range stores { if restored > 0 || opts.drop { - if err := writeStoreFile(acc.path, acc.entries, opts.recipient); err != nil { + if err := writeStoreFile(acc.path, acc.entries, opts.recipients); err != nil { return 0, err } } diff --git a/cmd/secret.go b/cmd/secret.go index b71f272..52eb848 100644 --- a/cmd/secret.go +++ b/cmd/secret.go @@ -1,24 +1,26 @@ package cmd import ( + "bufio" "bytes" "fmt" "io" "os" "path/filepath" + "strings" "filippo.io/age" gap "github.com/muesli/go-app-paths" ) // identityPath returns the path to the age identity file, -// respecting PDA_CONFIG the same way configPath() does. +// respecting PDA_DATA the same way Store.path() does. func identityPath() (string, error) { - if override := os.Getenv("PDA_CONFIG"); override != "" { + if override := os.Getenv("PDA_DATA"); override != "" { return filepath.Join(override, "identity.txt"), nil } scope := gap.NewScope(gap.User, "pda") - dir, err := scope.ConfigPath("") + dir, err := scope.DataPath("") if err != nil { return "", err } @@ -77,10 +79,100 @@ func ensureIdentity() (*age.X25519Identity, error) { return id, nil } -// encrypt encrypts plaintext for the given recipient using age. -func encrypt(plaintext []byte, recipient *age.X25519Recipient) ([]byte, error) { +// recipientsPath returns the path to the additional recipients file, +// respecting PDA_DATA the same way identityPath does. +func recipientsPath() (string, error) { + if override := os.Getenv("PDA_DATA"); override != "" { + return filepath.Join(override, "recipients.txt"), nil + } + scope := gap.NewScope(gap.User, "pda") + dir, err := scope.DataPath("") + if err != nil { + return "", err + } + return filepath.Join(dir, "recipients.txt"), nil +} + +// loadRecipients loads additional age recipients from disk. +// Returns (nil, nil) if the recipients file does not exist. +func loadRecipients() ([]*age.X25519Recipient, error) { + path, err := recipientsPath() + if err != nil { + return nil, err + } + f, err := os.Open(path) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, err + } + defer f.Close() + + var recipients []*age.X25519Recipient + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + r, err := age.ParseX25519Recipient(line) + if err != nil { + return nil, fmt.Errorf("parse recipient %q: %w", line, err) + } + recipients = append(recipients, r) + } + return recipients, scanner.Err() +} + +// saveRecipients writes the recipients file. If the list is empty, +// the file is deleted. +func saveRecipients(recipients []*age.X25519Recipient) error { + path, err := recipientsPath() + if err != nil { + return err + } + if len(recipients) == 0 { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil + } + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return err + } var buf bytes.Buffer - w, err := age.Encrypt(&buf, recipient) + for _, r := range recipients { + fmt.Fprintln(&buf, r.String()) + } + return os.WriteFile(path, buf.Bytes(), 0o600) +} + +// allRecipients combines the identity's own recipient with any additional +// recipients from the recipients file into a single []age.Recipient slice. +// Returns nil if identity is nil and no recipients file exists. +func allRecipients(identity *age.X25519Identity) ([]age.Recipient, error) { + extra, err := loadRecipients() + if err != nil { + return nil, err + } + if identity == nil && len(extra) == 0 { + return nil, nil + } + var recipients []age.Recipient + if identity != nil { + recipients = append(recipients, identity.Recipient()) + } + for _, r := range extra { + recipients = append(recipients, r) + } + return recipients, nil +} + +// encrypt encrypts plaintext for the given recipients using age. +func encrypt(plaintext []byte, recipients ...age.Recipient) ([]byte, error) { + var buf bytes.Buffer + w, err := age.Encrypt(&buf, recipients...) if err != nil { return nil, err } @@ -93,6 +185,52 @@ func encrypt(plaintext []byte, recipient *age.X25519Recipient) ([]byte, error) { return buf.Bytes(), nil } +// reencryptAllStores decrypts all secrets across all stores with the +// given identity, then re-encrypts them for the new recipient list. +// Returns the count of re-encrypted secrets. +func reencryptAllStores(identity *age.X25519Identity, recipients []age.Recipient) (int, error) { + store := &Store{} + storeNames, err := store.AllStores() + if err != nil { + return 0, err + } + + count := 0 + for _, name := range storeNames { + p, err := store.storePath(name) + if err != nil { + return 0, err + } + entries, err := readStoreFile(p, identity) + if err != nil { + return 0, err + } + hasSecrets := false + for _, e := range entries { + if e.Secret { + if e.Locked { + return 0, fmt.Errorf("cannot re-encrypt: secret '%s@%s' is locked (identity cannot decrypt it)", e.Key, name) + } + hasSecrets = true + } + } + if !hasSecrets { + continue + } + if err := writeStoreFile(p, entries, recipients); err != nil { + return 0, err + } + for _, e := range entries { + if e.Secret { + spec := KeySpec{Key: e.Key, DB: name} + okf("re-encrypted %s", spec.Display()) + count++ + } + } + } + return count, nil +} + // decrypt decrypts age ciphertext with the given identity. func decrypt(ciphertext []byte, identity *age.X25519Identity) ([]byte, error) { r, err := age.Decrypt(bytes.NewReader(ciphertext), identity) diff --git a/cmd/secret_test.go b/cmd/secret_test.go index 6db1bb1..fb209a1 100644 --- a/cmd/secret_test.go +++ b/cmd/secret_test.go @@ -46,7 +46,7 @@ func TestEncryptDecryptRoundtrip(t *testing.T) { } func TestLoadIdentityMissing(t *testing.T) { - t.Setenv("PDA_CONFIG", t.TempDir()) + t.Setenv("PDA_DATA", t.TempDir()) id, err := loadIdentity() if err != nil { t.Fatal(err) @@ -58,7 +58,7 @@ func TestLoadIdentityMissing(t *testing.T) { func TestEnsureIdentityCreatesFile(t *testing.T) { dir := t.TempDir() - t.Setenv("PDA_CONFIG", dir) + t.Setenv("PDA_DATA", dir) id, err := ensureIdentity() if err != nil { @@ -89,7 +89,7 @@ func TestEnsureIdentityCreatesFile(t *testing.T) { func TestEnsureIdentityIdempotent(t *testing.T) { dir := t.TempDir() - t.Setenv("PDA_CONFIG", dir) + t.Setenv("PDA_DATA", dir) id1, err := ensureIdentity() if err != nil { @@ -109,7 +109,7 @@ func TestSecretEntryRoundtrip(t *testing.T) { if err != nil { t.Fatal(err) } - recipient := id.Recipient() + recipients := []age.Recipient{id.Recipient()} dir := t.TempDir() path := filepath.Join(dir, "test.ndjson") @@ -118,7 +118,7 @@ func TestSecretEntryRoundtrip(t *testing.T) { {Key: "encrypted", Value: []byte("secret-value"), Secret: true}, } - if err := writeStoreFile(path, entries, recipient); err != nil { + if err := writeStoreFile(path, entries, recipients); err != nil { t.Fatal(err) } @@ -153,14 +153,14 @@ func TestSecretEntryLockedWithoutIdentity(t *testing.T) { if err != nil { t.Fatal(err) } - recipient := id.Recipient() + recipients := []age.Recipient{id.Recipient()} dir := t.TempDir() path := filepath.Join(dir, "test.ndjson") entries := []Entry{ {Key: "encrypted", Value: []byte("secret-value"), Secret: true}, } - if err := writeStoreFile(path, entries, recipient); err != nil { + if err := writeStoreFile(path, entries, recipients); err != nil { t.Fatal(err) } @@ -185,7 +185,7 @@ func TestLockedPassthrough(t *testing.T) { if err != nil { t.Fatal(err) } - recipient := id.Recipient() + recipients := []age.Recipient{id.Recipient()} dir := t.TempDir() path := filepath.Join(dir, "test.ndjson") @@ -193,7 +193,7 @@ func TestLockedPassthrough(t *testing.T) { entries := []Entry{ {Key: "encrypted", Value: []byte("secret-value"), Secret: true}, } - if err := writeStoreFile(path, entries, recipient); err != nil { + if err := writeStoreFile(path, entries, recipients); err != nil { t.Fatal(err) } @@ -224,9 +224,240 @@ func TestLockedPassthrough(t *testing.T) { } } +func TestMultiRecipientEncryptDecrypt(t *testing.T) { + id1, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + id2, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + + recipients := []age.Recipient{id1.Recipient(), id2.Recipient()} + plaintext := []byte("shared secret") + + ciphertext, err := encrypt(plaintext, recipients...) + if err != nil { + t.Fatalf("encrypt: %v", err) + } + + // Both identities should be able to decrypt + for i, id := range []*age.X25519Identity{id1, id2} { + got, err := decrypt(ciphertext, id) + if err != nil { + t.Fatalf("identity %d decrypt: %v", i, err) + } + if string(got) != string(plaintext) { + t.Errorf("identity %d: got %q, want %q", i, got, plaintext) + } + } +} + +func TestMultiRecipientStoreRoundtrip(t *testing.T) { + id1, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + id2, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + + recipients := []age.Recipient{id1.Recipient(), id2.Recipient()} + dir := t.TempDir() + path := filepath.Join(dir, "test.ndjson") + + entries := []Entry{ + {Key: "secret", Value: []byte("multi-recipient-value"), Secret: true}, + } + if err := writeStoreFile(path, entries, recipients); err != nil { + t.Fatal(err) + } + + // Both identities should decrypt the store + for i, id := range []*age.X25519Identity{id1, id2} { + got, err := readStoreFile(path, id) + if err != nil { + t.Fatalf("identity %d read: %v", i, err) + } + if len(got) != 1 { + t.Fatalf("identity %d: got %d entries, want 1", i, len(got)) + } + if string(got[0].Value) != "multi-recipient-value" { + t.Errorf("identity %d: value = %q, want %q", i, got[0].Value, "multi-recipient-value") + } + } +} + +func TestLoadRecipientsMissing(t *testing.T) { + t.Setenv("PDA_DATA", t.TempDir()) + recipients, err := loadRecipients() + if err != nil { + t.Fatal(err) + } + if recipients != nil { + t.Fatal("expected nil recipients for missing file") + } +} + +func TestSaveLoadRecipientsRoundtrip(t *testing.T) { + dir := t.TempDir() + t.Setenv("PDA_DATA", dir) + + id1, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + id2, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + + toSave := []*age.X25519Recipient{id1.Recipient(), id2.Recipient()} + if err := saveRecipients(toSave); err != nil { + t.Fatal(err) + } + + // Check file permissions + path := filepath.Join(dir, "recipients.txt") + info, err := os.Stat(path) + if err != nil { + t.Fatalf("recipients file not created: %v", err) + } + if perm := info.Mode().Perm(); perm != 0o600 { + t.Errorf("recipients file permissions = %o, want 0600", perm) + } + + loaded, err := loadRecipients() + if err != nil { + t.Fatal(err) + } + if len(loaded) != 2 { + t.Fatalf("got %d recipients, want 2", len(loaded)) + } + if loaded[0].String() != id1.Recipient().String() { + t.Errorf("recipient 0 = %s, want %s", loaded[0], id1.Recipient()) + } + if loaded[1].String() != id2.Recipient().String() { + t.Errorf("recipient 1 = %s, want %s", loaded[1], id2.Recipient()) + } +} + +func TestSaveRecipientsEmptyDeletesFile(t *testing.T) { + dir := t.TempDir() + t.Setenv("PDA_DATA", dir) + + // Create a recipients file first + id, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + if err := saveRecipients([]*age.X25519Recipient{id.Recipient()}); err != nil { + t.Fatal(err) + } + + // Save empty list should delete the file + if err := saveRecipients(nil); err != nil { + t.Fatal(err) + } + + path := filepath.Join(dir, "recipients.txt") + if _, err := os.Stat(path); !os.IsNotExist(err) { + t.Error("expected recipients file to be deleted") + } +} + +func TestAllRecipientsNoIdentityNoFile(t *testing.T) { + t.Setenv("PDA_DATA", t.TempDir()) + recipients, err := allRecipients(nil) + if err != nil { + t.Fatal(err) + } + if recipients != nil { + t.Fatal("expected nil recipients") + } +} + +func TestAllRecipientsCombines(t *testing.T) { + dir := t.TempDir() + t.Setenv("PDA_DATA", dir) + + id, err := ensureIdentity() + if err != nil { + t.Fatal(err) + } + + extra, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + if err := saveRecipients([]*age.X25519Recipient{extra.Recipient()}); err != nil { + t.Fatal(err) + } + + recipients, err := allRecipients(id) + if err != nil { + t.Fatal(err) + } + if len(recipients) != 2 { + t.Fatalf("got %d recipients, want 2", len(recipients)) + } +} + +func TestReencryptAllStores(t *testing.T) { + dir := t.TempDir() + t.Setenv("PDA_DATA", dir) + + id, err := ensureIdentity() + if err != nil { + t.Fatal(err) + } + + // Write a store with a secret + storePath := filepath.Join(dir, "test.ndjson") + entries := []Entry{ + {Key: "plain", Value: []byte("hello")}, + {Key: "secret", Value: []byte("secret-value"), Secret: true}, + } + if err := writeStoreFile(storePath, entries, []age.Recipient{id.Recipient()}); err != nil { + t.Fatal(err) + } + + // Generate a second identity and re-encrypt for both + id2, err := age.GenerateX25519Identity() + if err != nil { + t.Fatal(err) + } + newRecipients := []age.Recipient{id.Recipient(), id2.Recipient()} + + count, err := reencryptAllStores(id, newRecipients) + if err != nil { + t.Fatal(err) + } + if count != 1 { + t.Fatalf("re-encrypted %d secrets, want 1", count) + } + + // Both identities should be able to decrypt + for i, identity := range []*age.X25519Identity{id, id2} { + got, err := readStoreFile(storePath, identity) + if err != nil { + t.Fatalf("identity %d read: %v", i, err) + } + idx := findEntry(got, "secret") + if idx < 0 { + t.Fatalf("identity %d: secret key not found", i) + } + if string(got[idx].Value) != "secret-value" { + t.Errorf("identity %d: value = %q, want %q", i, got[idx].Value, "secret-value") + } + } +} + func generateTestIdentity(t *testing.T) (*age.X25519Identity, error) { t.Helper() dir := t.TempDir() - t.Setenv("PDA_CONFIG", dir) + t.Setenv("PDA_DATA", dir) return ensureIdentity() } diff --git a/cmd/set.go b/cmd/set.go index e39586f..7ba38e8 100644 --- a/cmd/set.go +++ b/cmd/set.go @@ -119,9 +119,9 @@ func set(cmd *cobra.Command, args []string) error { } else { identity, _ = loadIdentity() } - var recipient *age.X25519Recipient - if identity != nil { - recipient = identity.Recipient() + recipients, err := allRecipients(identity) + if err != nil { + return fmt.Errorf("cannot set '%s': %v", args[0], err) } p, err := store.storePath(spec.DB) @@ -172,7 +172,7 @@ func set(cmd *cobra.Command, args []string) error { entries = append(entries, entry) } - if err := writeStoreFile(p, entries, recipient); err != nil { + if err := writeStoreFile(p, entries, recipients); err != nil { return fmt.Errorf("cannot set '%s': %v", args[0], err) } diff --git a/main_test.go b/main_test.go index 0eee94b..60d648b 100644 --- a/main_test.go +++ b/main_test.go @@ -66,7 +66,7 @@ func TestMain(t *testing.T) { if err != nil { return err } - return os.WriteFile(filepath.Join(configDir, "identity.txt"), []byte(id.String()+"\n"), 0o600) + return os.WriteFile(filepath.Join(dataDir, "identity.txt"), []byte(id.String()+"\n"), 0o600) } ts.Run(t, *update)