pda/cmd/vcs.go

202 lines
4.5 KiB
Go

package cmd
import (
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
)
func ensureVCSInitialized() (string, error) {
repoDir, err := (&Store{}).path()
if err != nil {
return "", err
}
if _, err := os.Stat(filepath.Join(repoDir, ".git")); err != nil {
if os.IsNotExist(err) {
return "", withHint(fmt.Errorf("vcs not initialised"), "run 'pda init' first")
}
return "", err
}
return repoDir, nil
}
func writeGitignore(repoDir string) error {
path := filepath.Join(repoDir, ".gitignore")
if _, err := os.Stat(path); os.IsNotExist(err) {
content := strings.Join([]string{
"# generated by pda",
"*",
"!.gitignore",
"!*.ndjson",
"",
}, "\n")
if err := os.WriteFile(path, []byte(content), 0o640); err != nil {
return err
}
if err := runGit(repoDir, "add", ".gitignore"); err != nil {
return err
}
return runGit(repoDir, "commit", "-m", "generated gitignore")
}
okf("existing .gitignore found")
return nil
}
func runGit(dir string, args ...string) error {
cmd := exec.Command("git", args...)
cmd.Dir = dir
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
}
type gitRemoteInfo struct {
Ref string
HasUpstream bool
Remote string
Branch string
}
func repoRemoteInfo(dir string) (gitRemoteInfo, error) {
hasUpstream, err := repoHasUpstream(dir)
if err != nil {
return gitRemoteInfo{}, err
}
if hasUpstream {
return gitRemoteInfo{Ref: "@{u}", HasUpstream: true}, nil
}
hasOrigin, err := repoHasRemote(dir, "origin")
if err != nil {
return gitRemoteInfo{}, err
}
if !hasOrigin {
return gitRemoteInfo{}, nil
}
branch, err := currentBranch(dir)
if err != nil {
return gitRemoteInfo{}, err
}
if branch == "" {
branch = "main"
}
return gitRemoteInfo{
Ref: fmt.Sprintf("origin/%s", branch),
Remote: "origin",
Branch: branch,
}, nil
}
func repoAheadBehind(dir, ref string) (int, int, error) {
cmd := exec.Command("git", "rev-list", "--left-right", "--count", "HEAD..."+ref)
cmd.Dir = dir
out, err := cmd.Output()
if err != nil {
return 0, 0, err
}
fields := strings.Fields(string(out))
if len(fields) != 2 {
return 0, 0, fmt.Errorf("unexpected rev-list output: %q", strings.TrimSpace(string(out)))
}
ahead, err := strconv.Atoi(fields[0])
if err != nil {
return 0, 0, fmt.Errorf("parse ahead count: %w", err)
}
behind, err := strconv.Atoi(fields[1])
if err != nil {
return 0, 0, fmt.Errorf("parse behind count: %w", err)
}
return ahead, behind, nil
}
func repoHasStagedChanges(dir string) (bool, error) {
cmd := exec.Command("git", "diff", "--cached", "--quiet")
cmd.Dir = dir
err := cmd.Run()
if err == nil {
return false, nil
}
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
return true, nil
}
return false, err
}
func pullRemote(dir string, info gitRemoteInfo) error {
if info.HasUpstream {
return runGit(dir, "pull", "--rebase")
}
return runGit(dir, "pull", "--rebase", info.Remote, info.Branch)
}
func pushRemote(dir string, info gitRemoteInfo) error {
if info.HasUpstream {
return runGit(dir, "push")
}
return runGit(dir, "push", "-u", info.Remote, info.Branch)
}
func repoHasUpstream(dir string) (bool, error) {
cmd := exec.Command("git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}")
cmd.Dir = dir
cmd.Stdout = io.Discard
cmd.Stderr = io.Discard
err := cmd.Run()
if err == nil {
return true, nil
}
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() != 0 {
return false, nil
}
return false, err
}
func repoHasRemote(dir, name string) (bool, error) {
cmd := exec.Command("git", "remote", "get-url", name)
cmd.Dir = dir
cmd.Stdout = io.Discard
cmd.Stderr = io.Discard
err := cmd.Run()
if err == nil {
return true, nil
}
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() != 0 {
return false, nil
}
return false, err
}
func currentBranch(dir string) (string, error) {
cmd := exec.Command("git", "rev-parse", "--abbrev-ref", "HEAD")
cmd.Dir = dir
out, err := cmd.Output()
if err != nil {
return "", err
}
branch := strings.TrimSpace(string(out))
if branch == "HEAD" {
return "", nil
}
return branch, nil
}
func wipeAllStores(store *Store) error {
dbs, err := store.AllStores()
if err != nil {
return err
}
for _, db := range dbs {
p, err := store.storePath(db)
if err != nil {
return err
}
if err := os.Remove(p); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("cannot remove store '%s': %w", db, err)
}
}
return nil
}