Use a context for closing channels. (#10)

Reviewed-on: #10
Co-authored-by: Anna Rose Wiggins <annabunches@gmail.com>
Co-committed-by: Anna Rose Wiggins <annabunches@gmail.com>
This commit is contained in:
Anna Rose Wiggins 2025-07-28 17:44:40 +00:00 committed by Anna Rose Wiggins
parent 4c04a9215d
commit 7b520af24a
2 changed files with 27 additions and 28 deletions

View file

@ -1,6 +1,7 @@
package main
import (
"context"
"flag"
"fmt"
"os"
@ -71,7 +72,7 @@ func main() {
// Initialize physical devices
pDevices := initPhysicalDevices(config)
rules, eventChannel, doneChannel, wg := loadRules(config, pDevices, getVirtualDevices(vBuffersByName))
rules, eventChannel, cancel, wg := loadRules(config, pDevices, getVirtualDevices(vBuffersByName))
// initialize the mode variable
mode := config.GetModes()[0]
@ -115,12 +116,12 @@ func main() {
case ChannelEventReload:
// stop existing channels
fmt.Println("Reloading rules.")
doneChannel <- true
cancel()
fmt.Println("Waiting for existing listeners to exit. Provide input from each of your devices.")
wg.Wait()
fmt.Println("Listeners exited. Parsing config.")
config := readConfig(configDir) // reload the config
rules, eventChannel, doneChannel, wg = loadRules(config, pDevices, getVirtualDevices(vBuffersByName))
rules, eventChannel, cancel, wg = loadRules(config, pDevices, getVirtualDevices(vBuffersByName))
fmt.Println("Config re-loaded. Only rule changes applied. Device and Mode changes require restart.")
}
}
@ -129,11 +130,11 @@ func main() {
func loadRules(
config *config.ConfigParser,
pDevices map[string]*evdev.InputDevice,
vDevices map[string]*evdev.InputDevice) ([]mappingrules.MappingRule, <-chan ChannelEvent, chan bool, *sync.WaitGroup) {
vDevices map[string]*evdev.InputDevice) ([]mappingrules.MappingRule, <-chan ChannelEvent, func(), *sync.WaitGroup) {
var wg sync.WaitGroup
eventChannel := make(chan ChannelEvent, 1000)
doneChannel := make(chan bool)
ctx, cancel := context.WithCancel(context.Background())
// Initialize rules
rules := config.BuildRules(pDevices, vDevices)
@ -142,20 +143,20 @@ func loadRules(
// start listening for events on devices and timers
for _, device := range pDevices {
wg.Add(1)
go eventWatcher(device, eventChannel, doneChannel, &wg)
go eventWatcher(device, eventChannel, ctx, &wg)
}
timerCount := 0
for _, rule := range rules {
if timedRule, ok := rule.(mappingrules.TimedEventEmitter); ok {
wg.Add(1)
go timerWatcher(timedRule, eventChannel, doneChannel, &wg)
go timerWatcher(timedRule, eventChannel, ctx, &wg)
timerCount++
}
}
logger.Logf("registered %d timers", timerCount)
go consoleWatcher(eventChannel, &wg)
go consoleWatcher(eventChannel)
return rules, eventChannel, doneChannel, &wg
return rules, eventChannel, cancel, &wg
}

View file

@ -2,6 +2,7 @@ package main
import (
"bufio"
"context"
"os"
"sync"
"time"
@ -19,26 +20,25 @@ const (
func eventWatcher(
device *evdev.InputDevice,
channel chan<- ChannelEvent,
done chan bool,
ctx context.Context,
wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case cancel := <-done:
if cancel {
done <- true
return
}
default:
}
event, err := device.ReadOne()
if err != nil {
logger.LogError(err, "Error while reading event. Disconnecting device.")
return
}
select {
case <-ctx.Done():
return
default:
// Proceed
}
channel <- ChannelEvent{Device: device, Event: event, Type: ChannelEventInput}
if event.Type == evdev.EV_SYN {
@ -50,19 +50,17 @@ func eventWatcher(
func timerWatcher(
rule mappingrules.TimedEventEmitter,
channel chan<- ChannelEvent,
done chan bool,
ctx context.Context,
wg *sync.WaitGroup) {
defer wg.Done()
for {
select {
case cancel := <-done:
if cancel {
done <- true
return
}
case <-ctx.Done():
return
default:
// Proceed
}
event := rule.TimerEvent()
@ -77,9 +75,9 @@ func timerWatcher(
}
}
// consoleWatcher reads input from stdin, and on receiving anything
func consoleWatcher(channel chan<- ChannelEvent, wg *sync.WaitGroup) {
defer wg.Done()
// consoleWatcher reads input from stdin, and on receiving anything,
// closes the current threading context
func consoleWatcher(channel chan<- ChannelEvent) {
stdin := bufio.NewReader(os.Stdin)
for {
_, err := stdin.ReadString('\n')