apcssh: add shell cmd timeout

It was possible for scanner.Scan() to block indefinitely if the UPS never returned the expected prompt regex pattern. This could occur with a UPS using a prompt format I'm not aware of, or if the UPS responds in a non-standard way.

This change ensures that Scan() is aborted after a fixed amount of blocking time and the shell cmd function accordingly returns an error.

Some error messages, comments, and var names are also updated for clarity.
This commit is contained in:
Greg T. Wallace 2024-06-19 19:56:17 -04:00
parent 841a459dca
commit 703c26bd27
2 changed files with 92 additions and 44 deletions

View file

@ -2,14 +2,25 @@ package apcssh
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"strings" "strings"
"time"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
// upsCmdResponse is a structure that holds all of a shell commands results // Abort shell connection if UPS doesn't send a recognizable response within
type upsCmdResponse struct { // the specified timeouts; Cmd timeout is very long as it is unlikely to be
// needed but still exists to avoid an indefinite hang in the unlikely event
// something does go wrong at that part of the app
const (
shellTimeoutLogin = 20 * time.Second
shellTimeoutCmd = 5 * time.Minute
)
// upsCmdResult is a structure that holds all of a shell commands results
type upsCmdResult struct {
command string command string
code string code string
codeText string codeText string
@ -17,75 +28,112 @@ type upsCmdResponse struct {
} }
// cmd creates an interactive shell and executes the specified command // cmd creates an interactive shell and executes the specified command
func (cli *Client) cmd(command string) (*upsCmdResponse, error) { func (cli *Client) cmd(command string) (*upsCmdResult, error) {
// connect // connect
sshClient, err := ssh.Dial("tcp", cli.hostname, cli.sshCfg) sshClient, err := ssh.Dial("tcp", cli.hostname, cli.sshCfg)
if err != nil { if err != nil {
return nil, fmt.Errorf("apcssh: failed to dial session (%w)", err) return nil, fmt.Errorf("failed to dial client (%w)", err)
} }
defer sshClient.Close() defer sshClient.Close()
session, err := sshClient.NewSession() session, err := sshClient.NewSession()
if err != nil { if err != nil {
return nil, fmt.Errorf("apcssh: failed to create session (%w)", err) return nil, fmt.Errorf("failed to create session (%w)", err)
} }
defer session.Close() defer session.Close()
// pipes to send shell command to; and to receive repsonse // pipes to send shell command to; and to receive repsonse
sshInput, err := session.StdinPipe() sshInput, err := session.StdinPipe()
if err != nil { if err != nil {
return nil, fmt.Errorf("apcssh: failed to make shell input pipe (%w)", err) return nil, fmt.Errorf("failed to make shell input pipe (%w)", err)
} }
sshOutput, err := session.StdoutPipe() sshOutput, err := session.StdoutPipe()
if err != nil { if err != nil {
return nil, fmt.Errorf("apcssh: failed to make shell output pipe (%w)", err) return nil, fmt.Errorf("failed to make shell output pipe (%w)", err)
} }
// make scanner to read shell continuously // make scanner to read shell output continuously
scanner := bufio.NewScanner(sshOutput) scanner := bufio.NewScanner(sshOutput)
scanner.Split(scanAPCShell) scanner.Split(scanAPCShell)
// start interactive shell // start interactive shell
if err := session.Shell(); err != nil { if err := session.Shell(); err != nil {
return nil, fmt.Errorf("apcssh: failed to start shell (%w)", err) return nil, fmt.Errorf("failed to start shell (%w)", err)
} }
// discard the initial shell response (login message(s) / initial shell prompt)
for { // use a timer to close the session early in case Scan() hangs (which can
if token := scanner.Scan(); token { // happen if the UPS provides output this app does not understand)
_ = scanner.Bytes() cancelAbort := make(chan struct{})
break defer close(cancelAbort)
go func() {
select {
case <-time.After(shellTimeoutLogin):
_ = session.Close()
case <-cancelAbort:
// aborted cancel (i.e., succesful Scan())
} }
}()
// check shell response after connect
scannedOk := scanner.Scan()
// if failed to scan (e.g., timer closed the session after timeout)
if !scannedOk {
return nil, errors.New("shell did not return parsable login response")
} }
// success; cancel abort timer
cancelAbort <- struct{}{}
// discard the initial shell response (login message(s) / initial shell prompt)
_ = scanner.Bytes()
// send command // send command
_, err = fmt.Fprint(sshInput, command+"\n") _, err = fmt.Fprint(sshInput, command+"\n")
if err != nil { if err != nil {
return nil, fmt.Errorf("apcssh: failed to send shell command (%w)", err) return nil, fmt.Errorf("failed to send shell command (%w)", err)
} }
res := &upsCmdResponse{} // use a timer to close the session early in case Scan() hangs (which can
for { // happen if the UPS provides output this app does not understand);
if tkn := scanner.Scan(); tkn { // since initial login message Scan() was okay, it is relatively unlikely this
result := string(scanner.Bytes()) // will hang
go func() {
select {
case <-time.After(shellTimeoutCmd):
_ = session.Close()
cmdIndx := strings.Index(result, "\n") case <-cancelAbort:
res.command = result[:cmdIndx-1] // aborted cancel (i.e., succesful Scan())
result = result[cmdIndx+1:]
codeIndx := strings.Index(result, ": ")
res.code = result[:codeIndx]
result = result[codeIndx+2:]
codeTxtIndx := strings.Index(result, "\n")
res.codeText = result[:codeTxtIndx-1]
// avoid out of bounds if no result text
if codeTxtIndx+1 <= len(result)-2 {
res.resultText = result[codeTxtIndx+1 : len(result)-2]
}
break
} }
}()
// check shell response to command
scannedOk = scanner.Scan()
// if failed to scan (e.g., timer closed the session after timeout)
if !scannedOk {
return nil, fmt.Errorf("shell did not return parsable response to cmd '%s'", command)
}
// success; cancel abort timer
cancelAbort <- struct{}{}
// parse the UPS response into result struct and return
upsRawResponse := string(scanner.Bytes())
result := &upsCmdResult{}
cmdIndx := strings.Index(upsRawResponse, "\n")
result.command = upsRawResponse[:cmdIndx-1]
upsRawResponse = upsRawResponse[cmdIndx+1:]
codeIndx := strings.Index(upsRawResponse, ": ")
result.code = upsRawResponse[:codeIndx]
upsRawResponse = upsRawResponse[codeIndx+2:]
codeTxtIndx := strings.Index(upsRawResponse, "\n")
result.codeText = upsRawResponse[:codeTxtIndx-1]
// avoid out of bounds if no result text
if codeTxtIndx+1 <= len(upsRawResponse)-2 {
result.resultText = upsRawResponse[codeTxtIndx+1 : len(upsRawResponse)-2]
} }
return res, nil return result, nil
} }

View file

@ -1,31 +1,31 @@
package apcssh package apcssh
import ( import (
"io"
"regexp" "regexp"
) )
// scanAPCShell is a SplitFunc to capture shell output after each interactive // scanAPCShell is a SplitFunc to capture shell output after each interactive
// shell command is run // shell command is run
func scanAPCShell(data []byte, atEOF bool) (advance int, token []byte, err error) { func scanAPCShell(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 { // EOF is not an expected response and should error (e.g., when the output pipe
// gets closed by timeout)
if atEOF {
return len(data), dropCR(data), io.ErrUnexpectedEOF
} else if len(data) == 0 {
// no data to process, request more data
return 0, nil, nil return 0, nil, nil
} }
// regex for shell prompt (e.g., `apc@apc>`, `apc>`, `some@dev>`, `other123>`, etc.) // regex for shell prompt (e.g., `apc@apc>`, `apc>`, `some@dev>`, `other123>`, etc.)
re := regexp.MustCompile(`(\r\n|\r|\n)([A-Za-z0-9.]+@?)?[A-Za-z0-9.]+>`) re := regexp.MustCompile(`(\r\n|\r|\n)([A-Za-z0-9.]+@?)?[A-Za-z0-9.]+>`)
// find match for prompt // find match for prompt
if index := re.FindStringIndex(string(data)); index != nil { if index := re.FindStringIndex(string(data)); index != nil {
// advance starts after the prompt; token is everything before the prompt // advance starts after the prompt; token is everything before the prompt
return index[1], dropCR(data[0:index[0]]), nil return index[1], dropCR(data[0:index[0]]), nil
} }
// If we're at EOF, we have a final, non-terminated line. Return it. // no match, request more data
if atEOF {
return len(data), dropCR(data), nil
}
// Request more data.
return 0, nil, nil return 0, nil, nil
} }