package main

import (
	"fmt"
	"log"
	"os"
	"regexp"
	"strconv"

	"git.annabunch.es/annabunches/adventofcode/2020/lib/util"
)

type Instruction struct {
	Op     string
	Value0 int64 // address for mem, zeromask for mask
	Value1 int64 // value for mem, onemask for mask
}

func makeMask(input string, maskType byte) int64 {
	var mask int64
	for i := 0; i < len(input); i++ {
		index := len(input) - 1 - i
		if input[index] == maskType {
			mask = mask | (int64(1) << i)
		}
	}

	return mask & 0xfffffffff
}

func applyFloatMask(address int64, mask int64) []int64 {
	results := make([]int64, 0)
	results = append(results, address)

	for i := 0; i < 36; i++ {
		if mask&(1<<i) != 0 {
			results = fanoutAddresses(results, i)
		}
	}
	return results
}

func fanoutAddresses(addresses []int64, bit int) []int64 {
	newAddresses := make([]int64, 0)
	for _, a := range addresses {
		newAddresses = append(newAddresses, a|1<<bit)
		newAddresses = append(newAddresses, a&(^(1 << bit)))
	}
	return newAddresses
}

var maskRe = regexp.MustCompile("^mask = ([01X]+)$")
var memRe = regexp.MustCompile("^mem\\[([0-9]+)\\] = ([0-9]+)")

func parseProgram(input []string, step string) []Instruction {
	program := make([]Instruction, 0)
	for _, line := range input {
		maskData := maskRe.FindAllStringSubmatch(line, 10)
		memData := memRe.FindAllStringSubmatch(line, 10)

		if len(maskData) != 0 {
			mask := maskData[0][1]
			var mask0, mask1 int64
			switch step {
			case "1":
				mask0 = makeMask(mask, '0')
				mask1 = makeMask(mask, '1')
			case "2":
				mask0 = makeMask(mask, '1')
				mask1 = makeMask(mask, 'X')
			}

			program = append(program, Instruction{
				Op:     "mask",
				Value0: mask0,
				Value1: mask1,
			})
		} else if len(memData) != 0 {
			address, err := strconv.Atoi(memData[0][1])
			if err != nil {
				log.Panicf(err.Error())
			}
			value, err := strconv.Atoi(memData[0][2])
			if err != nil {
				log.Panicf(err.Error())
			}
			program = append(program, Instruction{
				Op:     "mem",
				Value0: int64(address),
				Value1: int64(value),
			})
		} else {
			log.Panicf("Program parse error: %s", line)
		}
	}

	return program
}

// returns the program's memory dump
func executeProgram1(program []Instruction) map[int64]int64 {
	memory := make(map[int64]int64)
	mask := make([]int64, 2)
	for _, instruction := range program {
		switch instruction.Op {
		case "mask":
			mask[0] = instruction.Value0
			mask[1] = instruction.Value1
		case "mem":
			address := instruction.Value0
			value := instruction.Value1
			value = value & mask[0]
			value = value | mask[1]
			memory[address] = value & 0xfffffffff
		}
	}
	return memory
}

func executeProgram2(program []Instruction) map[int64]int64 {
	memory := make(map[int64]int64)
	mask := make([]int64, 2)
	for _, instruction := range program {
		switch instruction.Op {
		case "mask":
			mask[0] = instruction.Value0
			mask[1] = instruction.Value1
		case "mem":
			address := instruction.Value0
			value := instruction.Value1
			address = address | mask[0]
			addresses := applyFloatMask(address, mask[1])
			for _, a := range addresses {
				memory[a] = value & 0xfffffffff
			}
		}
	}
	return memory
}

func main() {
	step := os.Args[1]
	values := util.InputParserStrings(os.Args[2])

	var memory map[int64]int64
	program := parseProgram(values, step)

	switch step {
	case "1":
		memory = executeProgram1(program)
	case "2":
		memory = executeProgram2(program)
	}

	var total int64
	for _, value := range memory {
		total += value
	}
	fmt.Println(total)
}