package main

import (
	"fmt"
	"log"
	"os"
	"regexp"
	"strconv"

	"git.annabunch.es/annabunches/adventofcode/2020/lib/fileutils"
)

type Node struct {
	Color    string
	Children map[*Node]int
}

var nodeMap map[string]*Node

var colorRe = regexp.MustCompile("^(.*?) bags contain")
var childRe = regexp.MustCompile("(?:contain|,) (\\d+) (.*?) bag")

func findOrCreateNode(color string) *Node {
	if node, ok := nodeMap[color]; ok {
		return node
	}

	node := &Node{
		Color: color,
	}
	nodeMap[color] = node
	return node
}

func parseRule(line string) {
	if line == "" {
		return
	}

	color := colorRe.FindStringSubmatch(line)[1]

	node := findOrCreateNode(color)

	children := make(map[*Node]int)
	childrenData := childRe.FindAllStringSubmatch(line, 128)

	for _, childData := range childrenData {
		child := findOrCreateNode(childData[2])
		count, err := strconv.Atoi(childData[1])
		if err != nil {
			log.Panicf(err.Error())
		}
		children[child] = count
	}

	node.Children = children
}

func findInDescendants(node *Node, color string) bool {
	for child, _ := range node.Children {
		if rFindInDescendants(child, color) {
			return true
		}
	}
	return false
}

func rFindInDescendants(node *Node, color string) bool {
	// found it
	if node.Color == color {
		return true
	}

	for child, _ := range node.Children {
		if rFindInDescendants(child, color) {
			return true
		}
	}
	return false
}

func countChildren(node *Node) int {
	total := 1 // count ourself
	for child, count := range node.Children {
		total += count * countChildren(child)
	}
	return total
}

func main() {
	step := os.Args[1]
	values := fileutils.InputParserStrings(os.Args[2])

	nodeMap = make(map[string]*Node)

	for _, line := range values {
		parseRule(line)
	}

	switch step {
	case "1":
		total := 0
		for _, node := range nodeMap {
			if findInDescendants(node, "shiny gold") {
				total++
			}
		}
		fmt.Println("Total:", total)

	case "2":
		fmt.Println("Total:", countChildren(nodeMap["shiny gold"])-1)
	}
}