From 9f439a062902ca9c2fc5d80bbf373f0b3c5a6820 Mon Sep 17 00:00:00 2001
From: Naoki Kosaka <n.k@mail.yukimochi.net>
Date: Thu, 2 Jan 2020 01:26:43 +0900
Subject: [PATCH] Improve state test.

---
 State/state.go      |  71 +++++++--------
 State/state_test.go | 217 +++++++++++++++++++++++++++-----------------
 cli/cli_test.go     |   2 -
 cli/config_test.go  |   4 +
 cli/domain.go       |   2 +-
 main.go             |   2 +-
 6 files changed, 170 insertions(+), 128 deletions(-)

diff --git a/State/state.go b/State/state.go
index 942c8b9..e708bce 100644
--- a/State/state.go
+++ b/State/state.go
@@ -19,20 +19,10 @@ const (
 	CreateAsAnnounce
 )
 
-// NewState : Create new RelayState instance with redis client
-func NewState(redisClient *redis.Client, notify bool) RelayState {
-	var config RelayState
-	config.RedisClient = redisClient
-	config.notify = notify
-
-	config.Load()
-	return config
-}
-
 // RelayState : Store subscriptions and relay configrations
 type RelayState struct {
 	RedisClient *redis.Client
-	notify      bool
+	notifiable  bool
 
 	RelayConfig    relayConfig    `json:"relayConfig,omitempty"`
 	LimitedDomains []string       `json:"limitedDomains,omitempty"`
@@ -40,17 +30,31 @@ type RelayState struct {
 	Subscriptions  []Subscription `json:"subscriptions,omitempty"`
 }
 
-func (config *RelayState) ListenNotify() {
-	go func() {
-		_, err := config.RedisClient.Subscribe("relay_refresh").Receive()
-		if err != nil {
-			panic(err)
-		}
-		ch := config.RedisClient.Subscribe("relay_refresh").Channel()
+// NewState : Create new RelayState instance with redis client
+func NewState(redisClient *redis.Client, notifiable bool) RelayState {
+	var config RelayState
+	config.RedisClient = redisClient
+	config.notifiable = notifiable
 
+	config.Load()
+	return config
+}
+
+func (config *RelayState) ListenNotify(c chan<- bool) {
+	_, err := config.RedisClient.Subscribe("relay_refresh").Receive()
+	if err != nil {
+		panic(err)
+	}
+	ch := config.RedisClient.Subscribe("relay_refresh").Channel()
+
+	cNotify := c != nil
+	go func() {
 		for range ch {
 			fmt.Println("Config refreshed from state changed notify.")
 			config.Load()
+			if cNotify {
+				c <- true
+			}
 		}
 	}()
 }
@@ -102,11 +106,8 @@ func (config *RelayState) SetConfig(key Config, value bool) {
 	case CreateAsAnnounce:
 		config.RedisClient.HSet("relay:config", "create_as_announce", strValue).Result()
 	}
-	if config.notify {
-		config.RedisClient.Publish("relay_refresh", "Config refreshing request.")
-	} else {
-		config.Load()
-	}
+
+	config.refresh()
 }
 
 // AddSubscription : Add new instance for subscription list
@@ -117,11 +118,7 @@ func (config *RelayState) AddSubscription(domain Subscription) {
 		"actor_id":    domain.ActorID,
 	})
 
-	if config.notify {
-		config.RedisClient.Publish("relay_refresh", "Config refreshing request.")
-	} else {
-		config.Load()
-	}
+	config.refresh()
 }
 
 // DelSubscription : Delete instance from subscription list
@@ -129,11 +126,7 @@ func (config *RelayState) DelSubscription(domain string) {
 	config.RedisClient.Del("relay:subscription:" + domain).Result()
 	config.RedisClient.Del("relay:pending:" + domain).Result()
 
-	if config.notify {
-		config.RedisClient.Publish("relay_refresh", "Config refreshing request.")
-	} else {
-		config.Load()
-	}
+	config.refresh()
 }
 
 // SelectSubscription : Select instance from string
@@ -154,11 +147,7 @@ func (config *RelayState) SetBlockedDomain(domain string, value bool) {
 		config.RedisClient.HDel("relay:config:blockedDomain", domain).Result()
 	}
 
-	if config.notify {
-		config.RedisClient.Publish("relay_refresh", "Config refreshing request.")
-	} else {
-		config.Load()
-	}
+	config.refresh()
 }
 
 // SetLimitedDomain : Set/Unset instance for limited domain
@@ -169,7 +158,11 @@ func (config *RelayState) SetLimitedDomain(domain string, value bool) {
 		config.RedisClient.HDel("relay:config:limitedDomain", domain).Result()
 	}
 
-	if config.notify {
+	config.refresh()
+}
+
+func (config *RelayState) refresh() {
+	if config.notifiable {
 		config.RedisClient.Publish("relay_refresh", "Config refreshing request.")
 	} else {
 		config.Load()
diff --git a/State/state_test.go b/State/state_test.go
index e6f8114..1e0b87e 100644
--- a/State/state_test.go
+++ b/State/state_test.go
@@ -30,7 +30,7 @@ func TestMain(m *testing.M) {
 	redisClient.FlushAll().Result()
 }
 
-func TestInitialLoad(t *testing.T) {
+func TestLoadEmpty(t *testing.T) {
 	redisClient.FlushAll().Result()
 	testState := NewState(redisClient, false)
 
@@ -47,74 +47,58 @@ func TestInitialLoad(t *testing.T) {
 	redisClient.FlushAll().Result()
 }
 
-func TestAddLimited(t *testing.T) {
+func TestSetConfig(t *testing.T) {
+	ch := make(chan bool)
 	redisClient.FlushAll().Result()
-	testState := NewState(redisClient, false)
+	testState := NewState(redisClient, true)
+	testState.ListenNotify(ch)
 
-	testState.SetLimitedDomain("example.com", true)
-
-	valid := false
-	for _, domain := range testState.LimitedDomains {
-		if domain == "example.com" {
-			valid = true
-		}
+	testState.SetConfig(BlockService, true)
+	<-ch
+	if testState.RelayConfig.BlockService != true {
+		t.Fatalf("Failed enable config.")
 	}
-	if !valid {
-		t.Fatalf("Failed write config.")
+	testState.SetConfig(CreateAsAnnounce, true)
+	<-ch
+	if testState.RelayConfig.CreateAsAnnounce != true {
+		t.Fatalf("Failed enable config.")
+	}
+	testState.SetConfig(ManuallyAccept, true)
+	<-ch
+	if testState.RelayConfig.ManuallyAccept != true {
+		t.Fatalf("Failed enable config.")
 	}
 
-	testState.SetLimitedDomain("example.com", false)
-
-	for _, domain := range testState.LimitedDomains {
-		if domain == "example.com" {
-			valid = false
-		}
+	testState.SetConfig(BlockService, false)
+	<-ch
+	if testState.RelayConfig.BlockService != false {
+		t.Fatalf("Failed disable config.")
 	}
-	if !valid {
-		t.Fatalf("Failed write config.")
+	testState.SetConfig(CreateAsAnnounce, false)
+	<-ch
+	if testState.RelayConfig.CreateAsAnnounce != false {
+		t.Fatalf("Failed disable config.")
+	}
+	testState.SetConfig(ManuallyAccept, false)
+	<-ch
+	if testState.RelayConfig.ManuallyAccept != false {
+		t.Fatalf("Failed disable config.")
 	}
 
 	redisClient.FlushAll().Result()
 }
 
-func TestAddBlocked(t *testing.T) {
+func TestTreatSubscriptionNotify(t *testing.T) {
+	ch := make(chan bool)
 	redisClient.FlushAll().Result()
-	testState := NewState(redisClient, false)
-
-	testState.SetBlockedDomain("example.com", true)
-
-	valid := false
-	for _, domain := range testState.BlockedDomains {
-		if domain == "example.com" {
-			valid = true
-		}
-	}
-	if !valid {
-		t.Fatalf("Failed write config.")
-	}
-
-	testState.SetBlockedDomain("example.com", false)
-
-	for _, domain := range testState.BlockedDomains {
-		if domain == "example.com" {
-			valid = false
-		}
-	}
-	if !valid {
-		t.Fatalf("Failed write config.")
-	}
-
-	redisClient.FlushAll().Result()
-}
-
-func TestAddSubscription(t *testing.T) {
-	redisClient.FlushAll().Result()
-	testState := NewState(redisClient, false)
+	testState := NewState(redisClient, true)
+	testState.ListenNotify(ch)
 
 	testState.AddSubscription(Subscription{
 		Domain:   "example.com",
 		InboxURL: "https://example.com/inbox",
 	})
+	<-ch
 
 	valid := false
 	for _, domain := range testState.Subscriptions {
@@ -127,6 +111,7 @@ func TestAddSubscription(t *testing.T) {
 	}
 
 	testState.DelSubscription("example.com")
+	<-ch
 
 	for _, domain := range testState.Subscriptions {
 		if domain.Domain == "example.com" {
@@ -140,6 +125,101 @@ func TestAddSubscription(t *testing.T) {
 	redisClient.FlushAll().Result()
 }
 
+func TestSelectDomain(t *testing.T) {
+	ch := make(chan bool)
+	redisClient.FlushAll().Result()
+	testState := NewState(redisClient, true)
+	testState.ListenNotify(ch)
+
+	exampleSubscription := Subscription{
+		Domain:   "example.com",
+		InboxURL: "https://example.com/inbox",
+	}
+
+	testState.AddSubscription(exampleSubscription)
+	<-ch
+
+	subscription := testState.SelectSubscription("example.com")
+	if *subscription != exampleSubscription {
+		t.Fatalf("Failed select domain.")
+	}
+
+	subscription = testState.SelectSubscription("example.org")
+	if subscription != nil {
+		t.Fatalf("Failed select domain.")
+	}
+
+	redisClient.FlushAll().Result()
+}
+
+func TestBlockedDomain(t *testing.T) {
+	ch := make(chan bool)
+	redisClient.FlushAll().Result()
+	testState := NewState(redisClient, true)
+	testState.ListenNotify(ch)
+
+	testState.SetBlockedDomain("example.com", true)
+	<-ch
+
+	valid := false
+	for _, domain := range testState.BlockedDomains {
+		if domain == "example.com" {
+			valid = true
+		}
+	}
+	if !valid {
+		t.Fatalf("Failed write config.")
+	}
+
+	testState.SetBlockedDomain("example.com", false)
+	<-ch
+
+	for _, domain := range testState.BlockedDomains {
+		if domain == "example.com" {
+			valid = false
+		}
+	}
+	if !valid {
+		t.Fatalf("Failed write config.")
+	}
+
+	redisClient.FlushAll().Result()
+}
+
+func TestLimitedDomain(t *testing.T) {
+	ch := make(chan bool)
+	redisClient.FlushAll().Result()
+	testState := NewState(redisClient, true)
+	testState.ListenNotify(ch)
+
+	testState.SetLimitedDomain("example.com", true)
+	<-ch
+
+	valid := false
+	for _, domain := range testState.LimitedDomains {
+		if domain == "example.com" {
+			valid = true
+		}
+	}
+	if !valid {
+		t.Fatalf("Failed write config.")
+	}
+
+	testState.SetLimitedDomain("example.com", false)
+	<-ch
+
+	for _, domain := range testState.LimitedDomains {
+		if domain == "example.com" {
+			valid = false
+		}
+	}
+	if !valid {
+		t.Fatalf("Failed write config.")
+	}
+
+	redisClient.FlushAll().Result()
+}
+
 func TestLoadCompatiSubscription(t *testing.T) {
 	redisClient.FlushAll().Result()
 	testState := NewState(redisClient, false)
@@ -164,36 +244,3 @@ func TestLoadCompatiSubscription(t *testing.T) {
 
 	redisClient.FlushAll().Result()
 }
-
-func TestSetConfig(t *testing.T) {
-	redisClient.FlushAll().Result()
-	testState := NewState(redisClient, false)
-
-	testState.SetConfig(BlockService, true)
-	if testState.RelayConfig.BlockService != true {
-		t.Fatalf("Failed enable config.")
-	}
-	testState.SetConfig(CreateAsAnnounce, true)
-	if testState.RelayConfig.CreateAsAnnounce != true {
-		t.Fatalf("Failed enable config.")
-	}
-	testState.SetConfig(ManuallyAccept, true)
-	if testState.RelayConfig.ManuallyAccept != true {
-		t.Fatalf("Failed enable config.")
-	}
-
-	testState.SetConfig(BlockService, false)
-	if testState.RelayConfig.BlockService != false {
-		t.Fatalf("Failed disable config.")
-	}
-	testState.SetConfig(CreateAsAnnounce, false)
-	if testState.RelayConfig.CreateAsAnnounce != false {
-		t.Fatalf("Failed disable config.")
-	}
-	testState.SetConfig(ManuallyAccept, false)
-	if testState.RelayConfig.ManuallyAccept != false {
-		t.Fatalf("Failed disable config.")
-	}
-
-	redisClient.FlushAll().Result()
-}
diff --git a/cli/cli_test.go b/cli/cli_test.go
index 5993338..14cf9b7 100644
--- a/cli/cli_test.go
+++ b/cli/cli_test.go
@@ -5,14 +5,12 @@ import (
 	"testing"
 
 	"github.com/spf13/viper"
-	state "github.com/yukimochi/Activity-Relay/State"
 )
 
 func TestMain(m *testing.M) {
 	viper.Set("actor_pem", "../misc/testKey.pem")
 	viper.Set("relay_domain", "relay.yukimochi.example.org")
 	initConfig()
-	relayState = state.NewState(relayState.RedisClient, false)
 
 	relayState.RedisClient.FlushAll().Result()
 	code := m.Run()
diff --git a/cli/config_test.go b/cli/config_test.go
index 66210fa..13d587e 100644
--- a/cli/config_test.go
+++ b/cli/config_test.go
@@ -24,6 +24,7 @@ func TestServiceBlock(t *testing.T) {
 		t.Fatalf("Not Disabled Blocking feature for service-type actor")
 	}
 }
+
 func TestManuallyAccept(t *testing.T) {
 	app := buildNewCmd()
 
@@ -40,6 +41,7 @@ func TestManuallyAccept(t *testing.T) {
 		t.Fatalf("Not Disabled Manually accept follow-request feature")
 	}
 }
+
 func TestCreateAsAnnounce(t *testing.T) {
 	app := buildNewCmd()
 
@@ -56,6 +58,7 @@ func TestCreateAsAnnounce(t *testing.T) {
 		t.Fatalf("Enable announce activity instead of relay create activity")
 	}
 }
+
 func TestInvalidConfig(t *testing.T) {
 	app := buildNewCmd()
 	buffer := new(bytes.Buffer)
@@ -115,6 +118,7 @@ func TestExportConfig(t *testing.T) {
 		t.Fatalf("Invalid Responce.")
 	}
 }
+
 func TestImportConfig(t *testing.T) {
 	app := buildNewCmd()
 
diff --git a/cli/domain.go b/cli/domain.go
index 9f688f4..6e1d443 100644
--- a/cli/domain.go
+++ b/cli/domain.go
@@ -121,9 +121,9 @@ func unfollowDomains(cmd *cobra.Command, args []string) error {
 	for _, domain := range args {
 		if contains(subscriptions, domain) {
 			subscription := *relayState.SelectSubscription(domain)
-			cmd.Println("Unfollow [" + subscription.Domain + "]")
 			createUnfollowRequestResponse(subscription)
 			relayState.DelSubscription(subscription.Domain)
+			cmd.Println("Unfollow [" + subscription.Domain + "]")
 			break
 		} else {
 			cmd.Println("Invalid domain [" + domain + "] given")
diff --git a/main.go b/main.go
index 2c12dda..0578448 100644
--- a/main.go
+++ b/main.go
@@ -59,7 +59,7 @@ func initConfig() {
 	}
 	redisClient := redis.NewClient(redisOption)
 	relayState = state.NewState(redisClient, true)
-	relayState.ListenNotify()
+	relayState.ListenNotify(nil)
 	machineryConfig := &config.Config{
 		Broker:          viper.GetString("redis_url"),
 		DefaultQueue:    "relay",