package main

import (
	"fmt"
	"log"
	"os"

	"git.annabunch.es/annabunches/adventofcode/2020/lib/util"
)

func getMax(values map[int]bool) int {
	max := 0
	for value, _ := range values {
		if value > max {
			max = value
		}
	}

	return max
}

type Node struct {
	Value    int
	Children []*Node
	Paths    int // the total number of paths from here to some target
}

// Build a tree of possible paths, then count them.
func findArrangements(values map[int]bool, target int) int {
	root := buildTree(values, target)
	return countPaths(root, target)
}

func buildTree(values map[int]bool, target int) *Node {
	values[0] = true
	nodeMap := make(map[int]*Node)

	for i := 0; i < target; i++ {
		if values[i] {
			node := createOrFetchNode(nodeMap, i)

			for j := 1; j < 4; j++ {
				if i+j > target {
					break
				}

				if values[i+j] {
					child := createOrFetchNode(nodeMap, i+j)
					node.Children = append(node.Children, child)
				}
			}
		}
	}

	return nodeMap[0]
}

func createOrFetchNode(nodeMap map[int]*Node, value int) *Node {
	if node, ok := nodeMap[value]; ok {
		return node
	}

	node := &Node{
		Value:    value,
		Children: make([]*Node, 0),
	}
	nodeMap[value] = node

	return node
}

func countPaths(node *Node, target int) int {
	if node.Value == target {
		return 1
	}

	paths := 0
	for _, child := range node.Children {
		if child.Paths != 0 {
			paths += child.Paths
		} else {
			paths += countPaths(child, target)
		}
	}
	node.Paths = paths
	return paths
}

func main() {
	step := os.Args[1]
	values := util.InputParserIntMap(os.Args[2])

	diffMap := make(map[int]int)

	device := getMax(values) + 3
	values[device] = true

	switch step {
	case "1":
		diff := 0
		for i := 0; i < device; i++ {
			// increment diff, make sure we haven't broken the chain
			// Note that we're actually logically checking the joltage at i+1
			// but that serves us well here
			diff++
			if diff > 3 {
				log.Panicf("Diff too big, bailing.")
			}

			// if we have a device at this joltage, register and reset count
			if values[i+1] {
				diffMap[diff]++
				diff = 0
			}
		}
		fmt.Println(diffMap[1] * diffMap[3])

	case "2":
		fmt.Println(findArrangements(values, device))
	}
}