Add context_test.go

master
Ivan Klymenchenko 3 years ago
parent 50c2641b6e
commit 0880b52738

@ -6,8 +6,9 @@ type Context struct {
InsideTmuxSession bool InsideTmuxSession bool
} }
func CreateContext() *Context { func CreateContext() Context {
_, tmux := os.LookupEnv("TMUX") _, tmux := os.LookupEnv("TMUX")
os.Environ()
insideTmuxSession := os.Getenv("TERM") == "screen" || tmux insideTmuxSession := os.Getenv("TERM") == "screen" || tmux
return &Context{insideTmuxSession} return Context{insideTmuxSession}
} }

@ -0,0 +1,51 @@
package main
import (
"os"
"reflect"
"testing"
)
var environmentTestTable = []struct {
environment map[string]string
context Context
}{
{
map[string]string{},
Context{InsideTmuxSession: false},
},
{
map[string]string{
"TMUX": "",
},
Context{InsideTmuxSession: true},
},
{
map[string]string{
"TERM": "screen",
},
Context{InsideTmuxSession: true},
},
{
map[string]string{
"TERM": "xterm",
"TMUX": "",
},
Context{InsideTmuxSession: true},
},
}
func TestCreateContext(t *testing.T) {
os.Clearenv()
for _, v := range environmentTestTable {
for key, value := range v.environment {
os.Setenv(key, value)
}
context := CreateContext()
if !reflect.DeepEqual(v.context, context) {
t.Errorf("expected context %v, got %v", v.context, context)
}
}
}

@ -90,10 +90,10 @@ func main() {
} else { } else {
fmt.Println("Starting new windows...") fmt.Println("Starting new windows...")
} }
err = smug.Start(*config, options, *context) err = smug.Start(*config, options, context)
if err != nil { if err != nil {
fmt.Println("Oops, an error occurred! Rolling back...") fmt.Println("Oops, an error occurred! Rolling back...")
smug.Stop(*config, options, *context) smug.Stop(*config, options, context)
} }
case CommandStop: case CommandStop:
if len(options.Windows) == 0 { if len(options.Windows) == 0 {
@ -101,7 +101,7 @@ func main() {
} else { } else {
fmt.Println("Killing windows...") fmt.Println("Killing windows...")
} }
err = smug.Stop(*config, options, *context) err = smug.Stop(*config, options, context)
} }
if err != nil { if err != nil {

@ -1,6 +1,7 @@
package main package main
import ( import (
"errors"
"reflect" "reflect"
"testing" "testing"
@ -91,6 +92,12 @@ var usageTestTable = []struct {
ErrHelp, ErrHelp,
1, 1,
}, },
{
[]string{"start", "--test"},
Options{},
errors.New("unknown flag: --test"),
0,
},
} }
func TestParseOptions(t *testing.T) { func TestParseOptions(t *testing.T) {
@ -116,7 +123,7 @@ func TestParseOptions(t *testing.T) {
t.Errorf("expected to get %d help calls, got %d", v.helpCalls, helpCalls) t.Errorf("expected to get %d help calls, got %d", v.helpCalls, helpCalls)
} }
if err != v.err { if v.err != nil && err.Error() != v.err.Error() {
t.Errorf("expected to get error %v, got %v", v.err, err) t.Errorf("expected to get error %v, got %v", v.err, err)
} }

Loading…
Cancel
Save