diff --git a/cmd/clientsim/main.go b/cmd/clientsim/main.go index aefe282..c546b9b 100644 --- a/cmd/clientsim/main.go +++ b/cmd/clientsim/main.go @@ -20,16 +20,16 @@ limitations under the License. package main import ( + "context" "fmt" "log" "os" "sync" "time" + "git.cacert.org/cacert-gosigner/pkg/messages" "github.com/justincpresley/go-cobs" "github.com/shamaton/msgpackgen/msgpack" - - "git.cacert.org/cacert-gosigner/pkg/messages" ) const cobsDelimiter = 0x00 @@ -37,108 +37,157 @@ const cobsDelimiter = 0x00 var cobsConfig = cobs.Config{SpecialByte: cobsDelimiter, Delimiter: true, EndingSave: true} func main() { + errorLog := log.New(os.Stderr, "", log.LstdFlags) + + sim := &clientSimulator{ + errorLog: errorLog, + } + + err := sim.Run() + if err != nil { + errorLog.Printf("simulator returned an error: %v", err) + } +} + +type clientSimulator struct { + errorLog *log.Logger + commands chan messages.Command + responses chan []byte +} + +func (c *clientSimulator) writeTestCommands(ctx context.Context) error { + messages.RegisterGeneratedResolver() + + const healthInterval = 10 * time.Second + + timer := time.NewTimer(healthInterval) + + for { + select { + case <-ctx.Done(): + _ = timer.Stop() + + return nil + case <-timer.C: + c.commands <- messages.Command{ + Code: messages.CmdHealth, + TimeStamp: time.Now().UTC(), + } + + timer.Reset(healthInterval) + } + } +} + +func (c *clientSimulator) handleInput(ctx context.Context) error { const ( bufferSize = 1024 * 1024 readInterval = 50 * time.Millisecond ) - errors := make(chan error) + buf := make([]byte, bufferSize) - errorLog := log.New(os.Stderr, "", log.LstdFlags) + for { + select { + case <-ctx.Done(): + return nil + default: + count, err := os.Stdin.Read(buf) + if err != nil { + return err + } - wg := sync.WaitGroup{} - wg.Add(1) + if count == 0 { + time.Sleep(readInterval) - done := make(chan struct{}) - frame := make(chan []byte) + continue + } - go func(done chan struct{}) { - buf := make([]byte, bufferSize) + data := buf[:count] - for { - select { - case <-done: - wg.Done() + err = cobs.Verify(data, cobsConfig) + if err != nil { + return err + } - return + c.responses <- cobs.Decode(data, cobsConfig) + } + } +} - default: - count, err := os.Stdin.Read(buf) - if err != nil { - errors <- err +func (c *clientSimulator) handleCommands(ctx context.Context) error { + for { + select { + case command := <-c.commands: + commandBytes, err := msgpack.Marshal(command) + if err != nil { + return fmt.Errorf("could not marshal command bytes: %w", err) + } - wg.Done() + _, err = os.Stdout.Write(cobs.Encode(commandBytes, cobsConfig)) + if err != nil { + return fmt.Errorf("write failed: %w", err) + } - return - } + responseBytes := <-c.responses - if count == 0 { - time.Sleep(readInterval) + var response messages.Response - continue - } + err = msgpack.Unmarshal(responseBytes, &response) + if err != nil { + return fmt.Errorf("could not unmarshal msgpack data: %w", err) + } - data := buf[:count] + c.errorLog.Printf("received response: %+v", response) + case <-ctx.Done(): + return nil + } + } +} - err = cobs.Verify(data, cobsConfig) - if err != nil { - errors <- err +func (c *clientSimulator) Run() error { + ctx, cancel := context.WithCancel(context.Background()) - wg.Done() + c.commands = make(chan messages.Command) + c.responses = make(chan []byte) - return - } + wg := sync.WaitGroup{} + wg.Add(2) - frame <- cobs.Decode(data, cobsConfig) - } - } - }(done) + var inputError, commandError error - err := writeTestCommands(frame, errorLog) - if err != nil { - errorLog.Printf("could not write test commands") - } + go func(inputErr error) { + inputError = c.handleInput(ctx) - err = <-errors - if err != nil { - errorLog.Printf("error: %v", err) - } + cancel() - wg.Wait() -} + wg.Done() + }(inputError) -func writeTestCommands(responses chan []byte, errorLog *log.Logger) error { - messages.RegisterGeneratedResolver() + go func(commandErr error) { + commandErr = c.handleCommands(ctx) - commands := []messages.Command{ - { - Code: messages.CmdHealth, - TimeStamp: time.Now().UTC(), - }, - } + cancel() - for _, command := range commands { - commandBytes, err := msgpack.Marshal(command) - if err != nil { - return fmt.Errorf("could not marshal command bytes: %w", err) - } + wg.Done() + }(commandError) - _, err = os.Stdout.Write(cobs.Encode(commandBytes, cobsConfig)) - if err != nil { - return fmt.Errorf("write failed: %w", err) - } + var result error - responseBytes := <-responses + if err := c.writeTestCommands(ctx); err != nil { + c.errorLog.Printf("test commands failed: %v", err) + } - var response messages.Response + cancel() + wg.Wait() - err = msgpack.Unmarshal(responseBytes, &response) - if err != nil { - return fmt.Errorf("could not unmarshal msgpack data: %w", err) - } + if inputError != nil { + c.errorLog.Printf("reading input failed: %v", inputError) + } - errorLog.Printf("received response: %+v", response) + if commandError != nil { + c.errorLog.Printf("sending commands failed: %v", commandError) } - return nil + return result }