Use a context for closing channels. (#10)

Reviewed-on: #10
Co-authored-by: Anna Rose Wiggins <annabunches@gmail.com>
Co-committed-by: Anna Rose Wiggins <annabunches@gmail.com>
This commit is contained in:
Anna Rose Wiggins 2025-07-28 17:44:40 +00:00 committed by Anna Rose Wiggins
parent 4c04a9215d
commit 7b520af24a
2 changed files with 27 additions and 28 deletions

View file

@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"os" "os"
@ -71,7 +72,7 @@ func main() {
// Initialize physical devices // Initialize physical devices
pDevices := initPhysicalDevices(config) pDevices := initPhysicalDevices(config)
rules, eventChannel, doneChannel, wg := loadRules(config, pDevices, getVirtualDevices(vBuffersByName)) rules, eventChannel, cancel, wg := loadRules(config, pDevices, getVirtualDevices(vBuffersByName))
// initialize the mode variable // initialize the mode variable
mode := config.GetModes()[0] mode := config.GetModes()[0]
@ -115,12 +116,12 @@ func main() {
case ChannelEventReload: case ChannelEventReload:
// stop existing channels // stop existing channels
fmt.Println("Reloading rules.") fmt.Println("Reloading rules.")
doneChannel <- true cancel()
fmt.Println("Waiting for existing listeners to exit. Provide input from each of your devices.") fmt.Println("Waiting for existing listeners to exit. Provide input from each of your devices.")
wg.Wait() wg.Wait()
fmt.Println("Listeners exited. Parsing config.") fmt.Println("Listeners exited. Parsing config.")
config := readConfig(configDir) // reload the config config := readConfig(configDir) // reload the config
rules, eventChannel, doneChannel, wg = loadRules(config, pDevices, getVirtualDevices(vBuffersByName)) rules, eventChannel, cancel, wg = loadRules(config, pDevices, getVirtualDevices(vBuffersByName))
fmt.Println("Config re-loaded. Only rule changes applied. Device and Mode changes require restart.") fmt.Println("Config re-loaded. Only rule changes applied. Device and Mode changes require restart.")
} }
} }
@ -129,11 +130,11 @@ func main() {
func loadRules( func loadRules(
config *config.ConfigParser, config *config.ConfigParser,
pDevices map[string]*evdev.InputDevice, pDevices map[string]*evdev.InputDevice,
vDevices map[string]*evdev.InputDevice) ([]mappingrules.MappingRule, <-chan ChannelEvent, chan bool, *sync.WaitGroup) { vDevices map[string]*evdev.InputDevice) ([]mappingrules.MappingRule, <-chan ChannelEvent, func(), *sync.WaitGroup) {
var wg sync.WaitGroup var wg sync.WaitGroup
eventChannel := make(chan ChannelEvent, 1000) eventChannel := make(chan ChannelEvent, 1000)
doneChannel := make(chan bool) ctx, cancel := context.WithCancel(context.Background())
// Initialize rules // Initialize rules
rules := config.BuildRules(pDevices, vDevices) rules := config.BuildRules(pDevices, vDevices)
@ -142,20 +143,20 @@ func loadRules(
// start listening for events on devices and timers // start listening for events on devices and timers
for _, device := range pDevices { for _, device := range pDevices {
wg.Add(1) wg.Add(1)
go eventWatcher(device, eventChannel, doneChannel, &wg) go eventWatcher(device, eventChannel, ctx, &wg)
} }
timerCount := 0 timerCount := 0
for _, rule := range rules { for _, rule := range rules {
if timedRule, ok := rule.(mappingrules.TimedEventEmitter); ok { if timedRule, ok := rule.(mappingrules.TimedEventEmitter); ok {
wg.Add(1) wg.Add(1)
go timerWatcher(timedRule, eventChannel, doneChannel, &wg) go timerWatcher(timedRule, eventChannel, ctx, &wg)
timerCount++ timerCount++
} }
} }
logger.Logf("registered %d timers", timerCount) logger.Logf("registered %d timers", timerCount)
go consoleWatcher(eventChannel, &wg) go consoleWatcher(eventChannel)
return rules, eventChannel, doneChannel, &wg return rules, eventChannel, cancel, &wg
} }

View file

@ -2,6 +2,7 @@ package main
import ( import (
"bufio" "bufio"
"context"
"os" "os"
"sync" "sync"
"time" "time"
@ -19,26 +20,25 @@ const (
func eventWatcher( func eventWatcher(
device *evdev.InputDevice, device *evdev.InputDevice,
channel chan<- ChannelEvent, channel chan<- ChannelEvent,
done chan bool, ctx context.Context,
wg *sync.WaitGroup) { wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
for { for {
select {
case cancel := <-done:
if cancel {
done <- true
return
}
default:
}
event, err := device.ReadOne() event, err := device.ReadOne()
if err != nil { if err != nil {
logger.LogError(err, "Error while reading event. Disconnecting device.") logger.LogError(err, "Error while reading event. Disconnecting device.")
return return
} }
select {
case <-ctx.Done():
return
default:
// Proceed
}
channel <- ChannelEvent{Device: device, Event: event, Type: ChannelEventInput} channel <- ChannelEvent{Device: device, Event: event, Type: ChannelEventInput}
if event.Type == evdev.EV_SYN { if event.Type == evdev.EV_SYN {
@ -50,19 +50,17 @@ func eventWatcher(
func timerWatcher( func timerWatcher(
rule mappingrules.TimedEventEmitter, rule mappingrules.TimedEventEmitter,
channel chan<- ChannelEvent, channel chan<- ChannelEvent,
done chan bool, ctx context.Context,
wg *sync.WaitGroup) { wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
for { for {
select { select {
case cancel := <-done: case <-ctx.Done():
if cancel { return
done <- true
return
}
default: default:
// Proceed
} }
event := rule.TimerEvent() event := rule.TimerEvent()
@ -77,9 +75,9 @@ func timerWatcher(
} }
} }
// consoleWatcher reads input from stdin, and on receiving anything // consoleWatcher reads input from stdin, and on receiving anything,
func consoleWatcher(channel chan<- ChannelEvent, wg *sync.WaitGroup) { // closes the current threading context
defer wg.Done() func consoleWatcher(channel chan<- ChannelEvent) {
stdin := bufio.NewReader(os.Stdin) stdin := bufio.NewReader(os.Stdin)
for { for {
_, err := stdin.ReadString('\n') _, err := stdin.ReadString('\n')