From 6953ba7583b6e568f1c550cb9d139e227d8d08e1 Mon Sep 17 00:00:00 2001 From: lew Date: Thu, 18 Dec 2025 18:32:31 +0000 Subject: [PATCH] feat(set/mv/restore): adds --interactive, replacing --force --- cmd/mv.go | 52 +++++++++++++++++++++++++++++++++++++++----------- cmd/restore.go | 40 ++++++++++++++++++++++++++++++++++++++ cmd/set.go | 30 +++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+), 11 deletions(-) diff --git a/cmd/mv.go b/cmd/mv.go index 5d9c825..82d9682 100644 --- a/cmd/mv.go +++ b/cmd/mv.go @@ -24,6 +24,7 @@ package cmd import ( "fmt" + "strings" "github.com/dgraph-io/badger/v4" "github.com/spf13/cobra" @@ -52,6 +53,12 @@ func cp(cmd *cobra.Command, args []string) error { func mv(cmd *cobra.Command, args []string) error { store := &Store{} + interactive, err := cmd.Flags().GetBool("interactive") + if err != nil { + return err + } + promptOverwrite := interactive || config.Key.AlwaysPromptOverwrite + fromSpec, err := store.parseKey(args[0], true) if err != nil { return err @@ -67,6 +74,37 @@ func mv(cmd *cobra.Command, args []string) error { fromRef := fromSpec.Full() toRef := toSpec.Full() + var destExists bool + if promptOverwrite { + existsErr := store.Transaction(TransactionArgs{ + key: toRef, + readonly: true, + transact: func(tx *badger.Txn, k []byte) error { + if _, err := tx.Get(k); err == nil { + destExists = true + return nil + } else if err == badger.ErrKeyNotFound { + return nil + } + return err + }, + }) + if existsErr != nil { + return fmt.Errorf("cannot move '%s': %v", fromSpec.Key, existsErr) + } + } + + if promptOverwrite && destExists { + var confirm string + fmt.Printf("overwrite '%s'? (y/n)\n", toSpec.Display()) + if _, err := fmt.Scanln(&confirm); err != nil { + return fmt.Errorf("cannot move '%s': %v", fromSpec.Key, err) + } + if strings.ToLower(confirm) != "y" { + return nil + } + } + readErr := store.Transaction(TransactionArgs{ key: fromRef, readonly: true, @@ -92,13 +130,6 @@ func mv(cmd *cobra.Command, args []string) error { readonly: false, sync: false, transact: func(tx *badger.Txn, k []byte) error { - if !force && config.Key.AlwaysPromptOverwrite { - if _, err := tx.Get(k); err == nil { - return fmt.Errorf("cannot move '%s': '%s' already exists > run with --force to overwrite", fromSpec.Key, toSpec.Key) - } else if err != badger.ErrKeyNotFound { - return fmt.Errorf("cannot move '%s': %v", fromSpec.Key, err) - } - } entry := badger.NewEntry(k, srcVal).WithMeta(srcMeta) if srcExpires > 0 { entry.ExpiresAt = srcExpires @@ -125,14 +156,13 @@ func mv(cmd *cobra.Command, args []string) error { } var ( - copy bool = false - force bool = false + copy bool = false ) func init() { mvCmd.Flags().BoolVar(©, "copy", false, "Copy instead of move (keeps source)") - mvCmd.Flags().BoolVarP(&force, "force", "f", false, "Overwrite destination if it exists") + mvCmd.Flags().BoolP("interactive", "i", false, "Prompt before overwriting destination") rootCmd.AddCommand(mvCmd) - cpCmd.Flags().BoolVarP(&force, "force", "f", false, "Overwrite destination if it exists") + cpCmd.Flags().BoolP("interactive", "i", false, "Prompt before overwriting destination") rootCmd.AddCommand(cpCmd) } diff --git a/cmd/restore.go b/cmd/restore.go index 9b000f5..e07d9b7 100644 --- a/cmd/restore.go +++ b/cmd/restore.go @@ -88,6 +88,12 @@ func restore(cmd *cobra.Command, args []string) error { wb := db.NewWriteBatch() defer wb.Cancel() + interactive, err := cmd.Flags().GetBool("interactive") + if err != nil { + return fmt.Errorf("cannot restore '%s': %v", displayTarget, err) + } + promptOverwrite := interactive || config.Key.AlwaysPromptOverwrite + entryNo := 0 var restored int var matched bool @@ -108,6 +114,23 @@ func restore(cmd *cobra.Command, args []string) error { continue } + if promptOverwrite { + exists, err := keyExistsInDB(db, entry.Key) + if err != nil { + return fmt.Errorf("cannot restore '%s': entry %d: %v", displayTarget, entryNo, err) + } + if exists { + fmt.Printf("overwrite '%s'? (y/n)\n", entry.Key) + var confirm string + if _, err := fmt.Scanln(&confirm); err != nil { + return fmt.Errorf("cannot restore '%s': entry %d: %v", displayTarget, entryNo, err) + } + if strings.ToLower(confirm) != "y" { + continue + } + } + } + value, err := decodeEntryValue(entry) if err != nil { return fmt.Errorf("cannot restore '%s': entry %d: %w", displayTarget, entryNo, err) @@ -179,5 +202,22 @@ func init() { restoreCmd.Flags().StringP("file", "f", "", "Path to an NDJSON dump (defaults to stdin)") restoreCmd.Flags().StringSliceP("glob", "g", nil, "Restore keys matching glob pattern (repeatable)") restoreCmd.Flags().String("glob-sep", "", fmt.Sprintf("Characters treated as separators for globbing (default %q)", defaultGlobSeparatorsDisplay())) + restoreCmd.Flags().BoolP("interactive", "i", false, "Prompt before overwriting existing keys") rootCmd.AddCommand(restoreCmd) } + +func keyExistsInDB(db *badger.DB, key string) (bool, error) { + var exists bool + err := db.View(func(tx *badger.Txn) error { + _, err := tx.Get([]byte(key)) + if err == nil { + exists = true + return nil + } + if err == badger.ErrKeyNotFound { + return nil + } + return err + }) + return exists, err +} diff --git a/cmd/set.go b/cmd/set.go index e60c637..9419022 100644 --- a/cmd/set.go +++ b/cmd/set.go @@ -25,6 +25,7 @@ package cmd import ( "fmt" "io" + "strings" "github.com/dgraph-io/badger/v4" "github.com/spf13/cobra" @@ -53,6 +54,17 @@ For example: func set(cmd *cobra.Command, args []string) error { store := &Store{} + interactive, err := cmd.Flags().GetBool("interactive") + if err != nil { + return err + } + promptOverwrite := interactive || config.Key.AlwaysPromptOverwrite + + spec, err := store.parseKey(args[0], true) + if err != nil { + return fmt.Errorf("cannot set '%s': %v", args[0], err) + } + var value []byte if len(args) == 2 { value = []byte(args[1]) @@ -73,6 +85,23 @@ func set(cmd *cobra.Command, args []string) error { return fmt.Errorf("cannot set '%s': %v", args[0], err) } + if promptOverwrite { + exists, err := keyExists(store, spec.Full()) + if err != nil { + return fmt.Errorf("cannot set '%s': %v", args[0], err) + } + if exists { + fmt.Printf("overwrite '%s'? (y/n)\n", spec.Display()) + var confirm string + if _, err := fmt.Scanln(&confirm); err != nil { + return fmt.Errorf("cannot set '%s': %v", args[0], err) + } + if strings.ToLower(confirm) != "y" { + return nil + } + } + } + trans := TransactionArgs{ key: args[0], readonly: false, @@ -96,4 +125,5 @@ func init() { rootCmd.AddCommand(setCmd) setCmd.Flags().Bool("secret", false, "Mark the stored value as a secret") setCmd.Flags().DurationP("ttl", "t", 0, "Expire the key after the provided duration (e.g. 24h, 30m)") + setCmd.Flags().BoolP("interactive", "i", false, "Prompt before overwriting an existing key") }