package main

import (
	"fmt"
	"log"
	"math"
	"math/big"
	"os"
	"strconv"
	"strings"

	"git.annabunch.es/annabunches/adventofcode/2020/lib/util"
)

//
// Begin code borrowed from https://golang.hotexamples.com/examples/math.big/Int/GCD/golang-int-gcd-method-examples.html
//
func crt(a, n []*big.Int) (*big.Int, error) {
	one := new(big.Int).SetInt64(1)
	p := new(big.Int).Set(n[0])
	for _, n1 := range n[1:] {
		p.Mul(p, n1)
	}
	var x, q, s, z big.Int
	for i, n1 := range n {
		q.Div(p, n1)
		z.GCD(nil, &s, n1, &q)
		if z.Cmp(one) != 0 {
			return nil, fmt.Errorf("%d not coprime", n1)
		}
		x.Add(&x, s.Mul(a[i], s.Mul(&s, &q)))
	}
	return x.Mod(&x, p), nil
}

//
// End borrowed code
//

func parseInput(input []string) (int, []int) {
	earliest, err := strconv.Atoi(input[0])
	if err != nil {
		log.Panicf(err.Error())
	}

	busList := make([]int, 0)
	for _, value := range strings.Split(input[1], ",") {
		if value == "x" {
			busList = append(busList, -1)
			continue
		}
		x, err := strconv.Atoi(value)
		if err != nil {
			log.Panicf(err.Error())
		}
		busList = append(busList, x)
	}

	return earliest, busList
}

func findBus(busList []int, earliest int) (int, int) {
	bestBus := -1
	bestTime := -1

	for _, bus := range busList {
		if bus == -1 {
			continue
		}
		time := int(math.Ceil(float64(earliest)/float64(bus))) * bus
		if bestBus == -1 || time < bestTime {
			bestBus = bus
			bestTime = time
		}
	}

	return bestBus, bestTime
}

// This uses the Chinese Remainder Theorem to calculate the answer.
// I don't actually understand the underlying logic, I just found this
// while looking around for modulus-related algorithms.
func findTimestampWithCRT(busList []int) *big.Int {
	bigBusList := make([]*big.Int, 0)
	offsetList := make([]*big.Int, 0)

	for i, bus := range busList {
		if bus == -1 {
			continue
		}

		bigBus := big.NewInt(int64(bus))
		bigBusList = append(bigBusList, bigBus)

		offset := big.NewInt(int64(i))
		offset.Sub(bigBus, offset)
		offsetList = append(offsetList, offset)
	}

	fmt.Println(offsetList)
	fmt.Println(bigBusList)
	value, err := crt(offsetList, bigBusList)
	if err != nil {
		log.Panicf(err.Error())
	}
	return value
}

func main() {
	step := os.Args[1]
	values := util.InputParserStrings(os.Args[2])

	earliest, busList := parseInput(values)

	switch step {
	case "1":
		busId, departureTime := findBus(busList, earliest)
		fmt.Println(busId * (departureTime - earliest))
	case "2":
		fmt.Println(findTimestampWithCRT(busList))
	}
}