package main

import (
	"flag"
	"os/exec"
	"path/filepath"
	"fmt"
	"os"
	"strings"
)

var removeArchives *bool

func processPath(curPath string) error {
	info, err := os.Stat(curPath)
	if err != nil {
		return fmt.Errorf("Unable to stat %s: %s", curPath, err.Error())
	}
	if info.IsDir() {
		entries, err := os.ReadDir(curPath)
		if err != nil {
			return fmt.Errorf("Unable to call readdir on %s: %s", curPath, err.Error())
		}
		for i := range(entries) {
			err := processPath(filepath.Join(curPath, entries[i].Name()))
			if err != nil {
				return err
			}
		}
	} else {
		extractedBaseName, err := extractIfPossible(curPath)
		if err != nil {
			return err
		}
		if extractedBaseName != "" {
			err = processPath(filepath.Join(filepath.Dir(curPath), extractedBaseName))
			if err != nil {
				return err
			}
			if *removeArchives {
				err = os.Remove(curPath)
				if err != nil {
					return fmt.Errorf("Unable to remove archive file %s: %s", curPath, err.Error())
				}
				fmt.Printf("== rm %s\n", curPath)
			}
		}
	}
	return nil
}

func extractIfPossible(curPath string) (string, error) {
	if strings.HasSuffix(curPath, ".tgz") {
		return extractTgz(curPath, ".tgz")
	} else if strings.HasSuffix(curPath, ".tar.gz") {
		return extractTgz(curPath, ".tar.gz")
	} else if strings.HasSuffix(curPath, ".7z") {
		return extract7z(curPath)
	} else if strings.HasSuffix(curPath, ".zip") {
		return extractZip(curPath)
	}
	return "", nil
}

func extractTgz(curPath string, suffix string) (string, error) {
	fmt.Printf("== decompressing tar file %s\n", curPath)
	cmd := exec.Command("tar", "xvf", filepath.Base(curPath))
	cmd.Dir = filepath.Dir(curPath)
	//cmd.Stdout = os.Stdout
	err := cmd.Run()
	if err != nil {
		return "", fmt.Errorf("error running tar xvf on %s: %s", curPath, err.Error())
	}
	return filepath.Base(curPath[:len(curPath) - len(suffix)]), nil
}

func extract7z(curPath string) (string, error) {
	fmt.Printf("== decompressing 7z file %s\n", curPath)
	extractedBaseName := curPath[:len(curPath) - len(".7z")]
	err := os.Mkdir(extractedBaseName, 0755)
	if err != nil {
		return "", fmt.Errorf("unable to mkdir '%s': %s", extractedBaseName, err.Error())
	}
	absCurPath, err := filepath.Abs(curPath)
	if err != nil {
		return "", fmt.Errorf("unable to get absolute path for '%s': %s", curPath, err.Error())
	}
	absExtractedBaseName, err := filepath.Abs(extractedBaseName)
	if err != nil {
		return "", fmt.Errorf("unable to get absolute path for '%s': %s", extractedBaseName, err.Error())
	}
	cmd := exec.Command("7z", "x", absCurPath)
	cmd.Dir = absExtractedBaseName
	//cmd.Stdout = os.Stdout
	err = cmd.Run()
	if err != nil {
		return "", fmt.Errorf("error running 7z on %s: %s", curPath, err.Error())
	}
	return extractedBaseName, nil
}

func extractZip(curPath string) (string, error) {
	fmt.Printf("== decompressing zip file %s\n", curPath)
	extractedBaseName := curPath[:len(curPath) - len(".zip")]
	err := os.Mkdir(extractedBaseName, 0755)
	if err != nil {
		return "", fmt.Errorf("unable to mkdir '%s': %s", extractedBaseName, err.Error())
	}
	absCurPath, err := filepath.Abs(curPath)
	if err != nil {
		return "", fmt.Errorf("unable to get absolute path for '%s': %s", curPath, err.Error())
	}
	absExtractedBaseName, err := filepath.Abs(extractedBaseName)
	if err != nil {
		return "", fmt.Errorf("unable to get absolute path for '%s': %s", extractedBaseName, err.Error())
	}
	cmd := exec.Command("unzip", absCurPath)
	cmd.Dir = absExtractedBaseName
	//cmd.Stdout = os.Stdout
	err = cmd.Run()
	if err != nil {
		return "", fmt.Errorf("error running unzip on %s: %s", curPath, err.Error())
	}
	return extractedBaseName, nil
}

func main() {
	removeArchives = flag.Bool("r", false, "Remove archive files after decompressing them")
	flag.Usage = func() {
		fmt.Printf("recursive_decompress: recursively decompresses a path.\n")
		fmt.Printf("usage: recursive_decompress [flags] [paths...]\n")
		fmt.Printf("\n")
		flag.PrintDefaults()
	}
	flag.Parse()
	curPath := "."
	if flag.Arg(0) != "" {
		curPath = flag.Arg(0)
	}
	err := processPath(curPath)
	if err != nil {
		fmt.Printf("error: %s\n", err.Error())
		os.Exit(1)
	}
	os.Exit(0)
}