mirror of
https://github.com/kvazar-network/keva-stratum.git
synced 2025-01-28 15:54:19 +00:00
272 lines
5.6 KiB
Go
272 lines
5.6 KiB
Go
package policy
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os/exec"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"../../pool"
|
|
"../../storage"
|
|
"../../util"
|
|
)
|
|
|
|
type Stats struct {
|
|
sync.Mutex
|
|
ValidShares uint32
|
|
InvalidShares uint32
|
|
Malformed uint32
|
|
ConnLimit int32
|
|
FailsCount uint32
|
|
LastBeat int64
|
|
Banned uint32
|
|
BannedAt int64
|
|
}
|
|
|
|
type PolicyServer struct {
|
|
sync.RWMutex
|
|
config *pool.Policy
|
|
stats StatsMap
|
|
banChannel chan string
|
|
startedAt int64
|
|
grace int64
|
|
timeout int64
|
|
blacklist []string
|
|
whitelist []string
|
|
storage *storage.RedisClient
|
|
}
|
|
|
|
func Start(cfg *pool.Config, storage *storage.RedisClient) *PolicyServer {
|
|
s := &PolicyServer{config: &cfg.Policy, startedAt: util.MakeTimestamp()}
|
|
grace, _ := time.ParseDuration(cfg.Policy.Limits.Grace)
|
|
s.grace = int64(grace / time.Millisecond)
|
|
s.banChannel = make(chan string, 64)
|
|
s.stats = NewStatsMap()
|
|
s.storage = storage
|
|
s.refreshState()
|
|
|
|
timeout, _ := time.ParseDuration(s.config.ResetInterval)
|
|
s.timeout = int64(timeout / time.Millisecond)
|
|
|
|
resetIntv, _ := time.ParseDuration(s.config.ResetInterval)
|
|
resetTimer := time.NewTimer(resetIntv)
|
|
log.Printf("Set policy stats reset every %v", resetIntv)
|
|
|
|
refreshIntv, _ := time.ParseDuration(s.config.RefreshInterval)
|
|
refreshTimer := time.NewTimer(refreshIntv)
|
|
log.Printf("Set policy state refresh every %v", refreshIntv)
|
|
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-resetTimer.C:
|
|
s.resetStats()
|
|
resetTimer.Reset(resetIntv)
|
|
case <-refreshTimer.C:
|
|
s.refreshState()
|
|
refreshTimer.Reset(refreshIntv)
|
|
}
|
|
}
|
|
}()
|
|
|
|
for i := 0; i < s.config.Workers; i++ {
|
|
s.startPolicyWorker()
|
|
}
|
|
log.Printf("Running with %v policy workers", s.config.Workers)
|
|
return s
|
|
}
|
|
|
|
func (s *PolicyServer) startPolicyWorker() {
|
|
go func() {
|
|
for {
|
|
select {
|
|
case ip := <-s.banChannel:
|
|
s.doBan(ip)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *PolicyServer) resetStats() {
|
|
now := util.MakeTimestamp()
|
|
banningTimeout := s.config.Banning.Timeout * 1000
|
|
total := 0
|
|
|
|
for m := range s.stats.IterBuffered() {
|
|
lastBeat := atomic.LoadInt64(&m.Val.LastBeat)
|
|
bannedAt := atomic.LoadInt64(&m.Val.BannedAt)
|
|
|
|
if now-bannedAt >= banningTimeout {
|
|
atomic.StoreInt64(&m.Val.BannedAt, 0)
|
|
if atomic.CompareAndSwapUint32(&m.Val.Banned, 1, 0) {
|
|
log.Printf("Ban dropped for %v", m.Key)
|
|
}
|
|
}
|
|
if now-lastBeat >= s.timeout {
|
|
s.stats.Remove(m.Key)
|
|
total++
|
|
}
|
|
}
|
|
log.Printf("Flushed stats for %v IP addresses", total)
|
|
}
|
|
|
|
func (s *PolicyServer) refreshState() {
|
|
s.Lock()
|
|
defer s.Unlock()
|
|
|
|
s.blacklist = s.storage.GetBlacklist()
|
|
s.whitelist = s.storage.GetWhitelist()
|
|
log.Println("Policy state refresh complete")
|
|
}
|
|
|
|
func (s *PolicyServer) NewStats() *Stats {
|
|
x := &Stats{
|
|
ConnLimit: s.config.Limits.Limit,
|
|
Malformed: s.config.Banning.MalformedLimit,
|
|
}
|
|
x.heartbeat()
|
|
return x
|
|
}
|
|
|
|
func (s *PolicyServer) Get(ip string) *Stats {
|
|
if x, ok := s.stats.Get(ip); ok {
|
|
x.heartbeat()
|
|
return x
|
|
}
|
|
x := s.NewStats()
|
|
s.stats.Set(ip, x)
|
|
return x
|
|
}
|
|
|
|
func (s *PolicyServer) ApplyLimitPolicy(ip string) bool {
|
|
if !s.config.Limits.Enabled {
|
|
return true
|
|
}
|
|
now := util.MakeTimestamp()
|
|
if now-s.startedAt > s.grace {
|
|
return s.Get(ip).decrLimit() > 0
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (s *PolicyServer) ApplyLoginPolicy(addy, ip string) bool {
|
|
if s.InBlackList(addy) {
|
|
x := s.Get(ip)
|
|
s.forceBan(x, ip)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (s *PolicyServer) ApplyMalformedPolicy(ip string) {
|
|
x := s.Get(ip)
|
|
n := x.incrMalformed()
|
|
if n >= s.config.Banning.MalformedLimit {
|
|
s.forceBan(x, ip)
|
|
}
|
|
}
|
|
|
|
func (s *PolicyServer) ApplySharePolicy(ip string, validShare bool) bool {
|
|
x := s.Get(ip)
|
|
if validShare && s.config.Limits.Enabled {
|
|
s.Get(ip).incrLimit(s.config.Limits.LimitJump)
|
|
}
|
|
x.Lock()
|
|
|
|
if validShare {
|
|
x.ValidShares++
|
|
if s.config.Limits.Enabled {
|
|
x.incrLimit(s.config.Limits.LimitJump)
|
|
}
|
|
} else {
|
|
x.InvalidShares++
|
|
}
|
|
|
|
totalShares := x.ValidShares + x.InvalidShares
|
|
if totalShares < s.config.Banning.CheckThreshold {
|
|
x.Unlock()
|
|
return true
|
|
}
|
|
validShares := float32(x.ValidShares)
|
|
invalidShares := float32(x.InvalidShares)
|
|
x.resetShares()
|
|
x.Unlock()
|
|
|
|
if invalidShares == 0 {
|
|
return true
|
|
}
|
|
|
|
// Can be +Inf or value, previous check prevents NaN
|
|
ratio := invalidShares / validShares
|
|
|
|
if ratio >= s.config.Banning.InvalidPercent/100.0 {
|
|
s.forceBan(x, ip)
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (x *Stats) resetShares() {
|
|
x.ValidShares = 0
|
|
x.InvalidShares = 0
|
|
}
|
|
|
|
func (s *PolicyServer) forceBan(x *Stats, ip string) {
|
|
if !s.config.Banning.Enabled || s.InWhiteList(ip) {
|
|
return
|
|
}
|
|
|
|
if atomic.CompareAndSwapUint32(&x.Banned, 0, 1) {
|
|
if len(s.config.Banning.IPSet) > 0 {
|
|
s.banChannel <- ip
|
|
}
|
|
}
|
|
}
|
|
|
|
func (x *Stats) incrLimit(n int32) {
|
|
atomic.AddInt32(&x.ConnLimit, n)
|
|
}
|
|
|
|
func (x *Stats) incrMalformed() uint32 {
|
|
return atomic.AddUint32(&x.Malformed, 1)
|
|
}
|
|
|
|
func (x *Stats) decrLimit() int32 {
|
|
return atomic.AddInt32(&x.ConnLimit, -1)
|
|
}
|
|
|
|
func (s *PolicyServer) InBlackList(addy string) bool {
|
|
s.RLock()
|
|
defer s.RUnlock()
|
|
return util.StringInSlice(addy, s.blacklist)
|
|
}
|
|
|
|
func (s *PolicyServer) InWhiteList(ip string) bool {
|
|
s.RLock()
|
|
defer s.RUnlock()
|
|
return util.StringInSlice(ip, s.whitelist)
|
|
}
|
|
|
|
func (s *PolicyServer) doBan(ip string) {
|
|
set, timeout := s.config.Banning.IPSet, s.config.Banning.Timeout
|
|
cmd := fmt.Sprintf("sudo ipset add %s %s timeout %v -!", set, ip, timeout)
|
|
args := strings.Fields(cmd)
|
|
head := args[0]
|
|
args = args[1:]
|
|
|
|
log.Printf("Banned %v with timeout %v", ip, timeout)
|
|
|
|
_, err := exec.Command(head, args...).Output()
|
|
if err != nil {
|
|
log.Printf("CMD Error: %s", err)
|
|
}
|
|
}
|
|
|
|
func (x *Stats) heartbeat() {
|
|
now := util.MakeTimestamp()
|
|
atomic.StoreInt64(&x.LastBeat, now)
|
|
}
|