feat(set/mv/restore): adds --interactive, replacing --force

This commit is contained in:
Lewis Wynne 2025-12-18 18:32:31 +00:00
parent f0be9c42d3
commit 6953ba7583
3 changed files with 111 additions and 11 deletions

View file

@ -24,6 +24,7 @@ package cmd
import ( import (
"fmt" "fmt"
"strings"
"github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -52,6 +53,12 @@ func cp(cmd *cobra.Command, args []string) error {
func mv(cmd *cobra.Command, args []string) error { func mv(cmd *cobra.Command, args []string) error {
store := &Store{} store := &Store{}
interactive, err := cmd.Flags().GetBool("interactive")
if err != nil {
return err
}
promptOverwrite := interactive || config.Key.AlwaysPromptOverwrite
fromSpec, err := store.parseKey(args[0], true) fromSpec, err := store.parseKey(args[0], true)
if err != nil { if err != nil {
return err return err
@ -67,6 +74,37 @@ func mv(cmd *cobra.Command, args []string) error {
fromRef := fromSpec.Full() fromRef := fromSpec.Full()
toRef := toSpec.Full() toRef := toSpec.Full()
var destExists bool
if promptOverwrite {
existsErr := store.Transaction(TransactionArgs{
key: toRef,
readonly: true,
transact: func(tx *badger.Txn, k []byte) error {
if _, err := tx.Get(k); err == nil {
destExists = true
return nil
} else if err == badger.ErrKeyNotFound {
return nil
}
return err
},
})
if existsErr != nil {
return fmt.Errorf("cannot move '%s': %v", fromSpec.Key, existsErr)
}
}
if promptOverwrite && destExists {
var confirm string
fmt.Printf("overwrite '%s'? (y/n)\n", toSpec.Display())
if _, err := fmt.Scanln(&confirm); err != nil {
return fmt.Errorf("cannot move '%s': %v", fromSpec.Key, err)
}
if strings.ToLower(confirm) != "y" {
return nil
}
}
readErr := store.Transaction(TransactionArgs{ readErr := store.Transaction(TransactionArgs{
key: fromRef, key: fromRef,
readonly: true, readonly: true,
@ -92,13 +130,6 @@ func mv(cmd *cobra.Command, args []string) error {
readonly: false, readonly: false,
sync: false, sync: false,
transact: func(tx *badger.Txn, k []byte) error { transact: func(tx *badger.Txn, k []byte) error {
if !force && config.Key.AlwaysPromptOverwrite {
if _, err := tx.Get(k); err == nil {
return fmt.Errorf("cannot move '%s': '%s' already exists > run with --force to overwrite", fromSpec.Key, toSpec.Key)
} else if err != badger.ErrKeyNotFound {
return fmt.Errorf("cannot move '%s': %v", fromSpec.Key, err)
}
}
entry := badger.NewEntry(k, srcVal).WithMeta(srcMeta) entry := badger.NewEntry(k, srcVal).WithMeta(srcMeta)
if srcExpires > 0 { if srcExpires > 0 {
entry.ExpiresAt = srcExpires entry.ExpiresAt = srcExpires
@ -126,13 +157,12 @@ func mv(cmd *cobra.Command, args []string) error {
var ( var (
copy bool = false copy bool = false
force bool = false
) )
func init() { func init() {
mvCmd.Flags().BoolVar(&copy, "copy", false, "Copy instead of move (keeps source)") mvCmd.Flags().BoolVar(&copy, "copy", false, "Copy instead of move (keeps source)")
mvCmd.Flags().BoolVarP(&force, "force", "f", false, "Overwrite destination if it exists") mvCmd.Flags().BoolP("interactive", "i", false, "Prompt before overwriting destination")
rootCmd.AddCommand(mvCmd) rootCmd.AddCommand(mvCmd)
cpCmd.Flags().BoolVarP(&force, "force", "f", false, "Overwrite destination if it exists") cpCmd.Flags().BoolP("interactive", "i", false, "Prompt before overwriting destination")
rootCmd.AddCommand(cpCmd) rootCmd.AddCommand(cpCmd)
} }

View file

@ -88,6 +88,12 @@ func restore(cmd *cobra.Command, args []string) error {
wb := db.NewWriteBatch() wb := db.NewWriteBatch()
defer wb.Cancel() defer wb.Cancel()
interactive, err := cmd.Flags().GetBool("interactive")
if err != nil {
return fmt.Errorf("cannot restore '%s': %v", displayTarget, err)
}
promptOverwrite := interactive || config.Key.AlwaysPromptOverwrite
entryNo := 0 entryNo := 0
var restored int var restored int
var matched bool var matched bool
@ -108,6 +114,23 @@ func restore(cmd *cobra.Command, args []string) error {
continue continue
} }
if promptOverwrite {
exists, err := keyExistsInDB(db, entry.Key)
if err != nil {
return fmt.Errorf("cannot restore '%s': entry %d: %v", displayTarget, entryNo, err)
}
if exists {
fmt.Printf("overwrite '%s'? (y/n)\n", entry.Key)
var confirm string
if _, err := fmt.Scanln(&confirm); err != nil {
return fmt.Errorf("cannot restore '%s': entry %d: %v", displayTarget, entryNo, err)
}
if strings.ToLower(confirm) != "y" {
continue
}
}
}
value, err := decodeEntryValue(entry) value, err := decodeEntryValue(entry)
if err != nil { if err != nil {
return fmt.Errorf("cannot restore '%s': entry %d: %w", displayTarget, entryNo, err) return fmt.Errorf("cannot restore '%s': entry %d: %w", displayTarget, entryNo, err)
@ -179,5 +202,22 @@ func init() {
restoreCmd.Flags().StringP("file", "f", "", "Path to an NDJSON dump (defaults to stdin)") restoreCmd.Flags().StringP("file", "f", "", "Path to an NDJSON dump (defaults to stdin)")
restoreCmd.Flags().StringSliceP("glob", "g", nil, "Restore keys matching glob pattern (repeatable)") restoreCmd.Flags().StringSliceP("glob", "g", nil, "Restore keys matching glob pattern (repeatable)")
restoreCmd.Flags().String("glob-sep", "", fmt.Sprintf("Characters treated as separators for globbing (default %q)", defaultGlobSeparatorsDisplay())) restoreCmd.Flags().String("glob-sep", "", fmt.Sprintf("Characters treated as separators for globbing (default %q)", defaultGlobSeparatorsDisplay()))
restoreCmd.Flags().BoolP("interactive", "i", false, "Prompt before overwriting existing keys")
rootCmd.AddCommand(restoreCmd) rootCmd.AddCommand(restoreCmd)
} }
func keyExistsInDB(db *badger.DB, key string) (bool, error) {
var exists bool
err := db.View(func(tx *badger.Txn) error {
_, err := tx.Get([]byte(key))
if err == nil {
exists = true
return nil
}
if err == badger.ErrKeyNotFound {
return nil
}
return err
})
return exists, err
}

View file

@ -25,6 +25,7 @@ package cmd
import ( import (
"fmt" "fmt"
"io" "io"
"strings"
"github.com/dgraph-io/badger/v4" "github.com/dgraph-io/badger/v4"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -53,6 +54,17 @@ For example:
func set(cmd *cobra.Command, args []string) error { func set(cmd *cobra.Command, args []string) error {
store := &Store{} store := &Store{}
interactive, err := cmd.Flags().GetBool("interactive")
if err != nil {
return err
}
promptOverwrite := interactive || config.Key.AlwaysPromptOverwrite
spec, err := store.parseKey(args[0], true)
if err != nil {
return fmt.Errorf("cannot set '%s': %v", args[0], err)
}
var value []byte var value []byte
if len(args) == 2 { if len(args) == 2 {
value = []byte(args[1]) value = []byte(args[1])
@ -73,6 +85,23 @@ func set(cmd *cobra.Command, args []string) error {
return fmt.Errorf("cannot set '%s': %v", args[0], err) return fmt.Errorf("cannot set '%s': %v", args[0], err)
} }
if promptOverwrite {
exists, err := keyExists(store, spec.Full())
if err != nil {
return fmt.Errorf("cannot set '%s': %v", args[0], err)
}
if exists {
fmt.Printf("overwrite '%s'? (y/n)\n", spec.Display())
var confirm string
if _, err := fmt.Scanln(&confirm); err != nil {
return fmt.Errorf("cannot set '%s': %v", args[0], err)
}
if strings.ToLower(confirm) != "y" {
return nil
}
}
}
trans := TransactionArgs{ trans := TransactionArgs{
key: args[0], key: args[0],
readonly: false, readonly: false,
@ -96,4 +125,5 @@ func init() {
rootCmd.AddCommand(setCmd) rootCmd.AddCommand(setCmd)
setCmd.Flags().Bool("secret", false, "Mark the stored value as a secret") setCmd.Flags().Bool("secret", false, "Mark the stored value as a secret")
setCmd.Flags().DurationP("ttl", "t", 0, "Expire the key after the provided duration (e.g. 24h, 30m)") setCmd.Flags().DurationP("ttl", "t", 0, "Expire the key after the provided duration (e.g. 24h, 30m)")
setCmd.Flags().BoolP("interactive", "i", false, "Prompt before overwriting an existing key")
} }