Sammy Libre
8 years ago
2 changed files with 0 additions and 407 deletions
@ -1,271 +0,0 @@ |
|||||||
package policy |
|
||||||
|
|
||||||
import ( |
|
||||||
"fmt" |
|
||||||
"log" |
|
||||||
"os/exec" |
|
||||||
"strings" |
|
||||||
"sync" |
|
||||||
"sync/atomic" |
|
||||||
"time" |
|
||||||
|
|
||||||
"../../pool" |
|
||||||
"../../storage" |
|
||||||
"../../util" |
|
||||||
) |
|
||||||
|
|
||||||
type Stats struct { |
|
||||||
sync.Mutex |
|
||||||
ValidShares uint32 |
|
||||||
InvalidShares uint32 |
|
||||||
Malformed uint32 |
|
||||||
ConnLimit int32 |
|
||||||
FailsCount uint32 |
|
||||||
LastBeat int64 |
|
||||||
Banned uint32 |
|
||||||
BannedAt int64 |
|
||||||
} |
|
||||||
|
|
||||||
type PolicyServer struct { |
|
||||||
sync.RWMutex |
|
||||||
config *pool.Policy |
|
||||||
stats StatsMap |
|
||||||
banChannel chan string |
|
||||||
startedAt int64 |
|
||||||
grace int64 |
|
||||||
timeout int64 |
|
||||||
blacklist []string |
|
||||||
whitelist []string |
|
||||||
storage *storage.RedisClient |
|
||||||
} |
|
||||||
|
|
||||||
func Start(cfg *pool.Config, storage *storage.RedisClient) *PolicyServer { |
|
||||||
s := &PolicyServer{config: &cfg.Policy, startedAt: util.MakeTimestamp()} |
|
||||||
grace, _ := time.ParseDuration(cfg.Policy.Limits.Grace) |
|
||||||
s.grace = int64(grace / time.Millisecond) |
|
||||||
s.banChannel = make(chan string, 64) |
|
||||||
s.stats = NewStatsMap() |
|
||||||
s.storage = storage |
|
||||||
s.refreshState() |
|
||||||
|
|
||||||
timeout, _ := time.ParseDuration(s.config.ResetInterval) |
|
||||||
s.timeout = int64(timeout / time.Millisecond) |
|
||||||
|
|
||||||
resetIntv, _ := time.ParseDuration(s.config.ResetInterval) |
|
||||||
resetTimer := time.NewTimer(resetIntv) |
|
||||||
log.Printf("Set policy stats reset every %v", resetIntv) |
|
||||||
|
|
||||||
refreshIntv, _ := time.ParseDuration(s.config.RefreshInterval) |
|
||||||
refreshTimer := time.NewTimer(refreshIntv) |
|
||||||
log.Printf("Set policy state refresh every %v", refreshIntv) |
|
||||||
|
|
||||||
go func() { |
|
||||||
for { |
|
||||||
select { |
|
||||||
case <-resetTimer.C: |
|
||||||
s.resetStats() |
|
||||||
resetTimer.Reset(resetIntv) |
|
||||||
case <-refreshTimer.C: |
|
||||||
s.refreshState() |
|
||||||
refreshTimer.Reset(refreshIntv) |
|
||||||
} |
|
||||||
} |
|
||||||
}() |
|
||||||
|
|
||||||
for i := 0; i < s.config.Workers; i++ { |
|
||||||
s.startPolicyWorker() |
|
||||||
} |
|
||||||
log.Printf("Running with %v policy workers", s.config.Workers) |
|
||||||
return s |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) startPolicyWorker() { |
|
||||||
go func() { |
|
||||||
for { |
|
||||||
select { |
|
||||||
case ip := <-s.banChannel: |
|
||||||
s.doBan(ip) |
|
||||||
} |
|
||||||
} |
|
||||||
}() |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) resetStats() { |
|
||||||
now := util.MakeTimestamp() |
|
||||||
banningTimeout := s.config.Banning.Timeout * 1000 |
|
||||||
total := 0 |
|
||||||
|
|
||||||
for m := range s.stats.IterBuffered() { |
|
||||||
lastBeat := atomic.LoadInt64(&m.Val.LastBeat) |
|
||||||
bannedAt := atomic.LoadInt64(&m.Val.BannedAt) |
|
||||||
|
|
||||||
if now-bannedAt >= banningTimeout { |
|
||||||
atomic.StoreInt64(&m.Val.BannedAt, 0) |
|
||||||
if atomic.CompareAndSwapUint32(&m.Val.Banned, 1, 0) { |
|
||||||
log.Printf("Ban dropped for %v", m.Key) |
|
||||||
} |
|
||||||
} |
|
||||||
if now-lastBeat >= s.timeout { |
|
||||||
s.stats.Remove(m.Key) |
|
||||||
total++ |
|
||||||
} |
|
||||||
} |
|
||||||
log.Printf("Flushed stats for %v IP addresses", total) |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) refreshState() { |
|
||||||
s.Lock() |
|
||||||
defer s.Unlock() |
|
||||||
|
|
||||||
s.blacklist = s.storage.GetBlacklist() |
|
||||||
s.whitelist = s.storage.GetWhitelist() |
|
||||||
log.Println("Policy state refresh complete") |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) NewStats() *Stats { |
|
||||||
x := &Stats{ |
|
||||||
ConnLimit: s.config.Limits.Limit, |
|
||||||
Malformed: s.config.Banning.MalformedLimit, |
|
||||||
} |
|
||||||
x.heartbeat() |
|
||||||
return x |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) Get(ip string) *Stats { |
|
||||||
if x, ok := s.stats.Get(ip); ok { |
|
||||||
x.heartbeat() |
|
||||||
return x |
|
||||||
} |
|
||||||
x := s.NewStats() |
|
||||||
s.stats.Set(ip, x) |
|
||||||
return x |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) ApplyLimitPolicy(ip string) bool { |
|
||||||
if !s.config.Limits.Enabled { |
|
||||||
return true |
|
||||||
} |
|
||||||
now := util.MakeTimestamp() |
|
||||||
if now-s.startedAt > s.grace { |
|
||||||
return s.Get(ip).decrLimit() > 0 |
|
||||||
} |
|
||||||
return true |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) ApplyLoginPolicy(addy, ip string) bool { |
|
||||||
if s.InBlackList(addy) { |
|
||||||
x := s.Get(ip) |
|
||||||
s.forceBan(x, ip) |
|
||||||
return false |
|
||||||
} |
|
||||||
return true |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) ApplyMalformedPolicy(ip string) { |
|
||||||
x := s.Get(ip) |
|
||||||
n := x.incrMalformed() |
|
||||||
if n >= s.config.Banning.MalformedLimit { |
|
||||||
s.forceBan(x, ip) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) ApplySharePolicy(ip string, validShare bool) bool { |
|
||||||
x := s.Get(ip) |
|
||||||
if validShare && s.config.Limits.Enabled { |
|
||||||
s.Get(ip).incrLimit(s.config.Limits.LimitJump) |
|
||||||
} |
|
||||||
x.Lock() |
|
||||||
|
|
||||||
if validShare { |
|
||||||
x.ValidShares++ |
|
||||||
if s.config.Limits.Enabled { |
|
||||||
x.incrLimit(s.config.Limits.LimitJump) |
|
||||||
} |
|
||||||
} else { |
|
||||||
x.InvalidShares++ |
|
||||||
} |
|
||||||
|
|
||||||
totalShares := x.ValidShares + x.InvalidShares |
|
||||||
if totalShares < s.config.Banning.CheckThreshold { |
|
||||||
x.Unlock() |
|
||||||
return true |
|
||||||
} |
|
||||||
validShares := float32(x.ValidShares) |
|
||||||
invalidShares := float32(x.InvalidShares) |
|
||||||
x.resetShares() |
|
||||||
x.Unlock() |
|
||||||
|
|
||||||
if invalidShares == 0 { |
|
||||||
return true |
|
||||||
} |
|
||||||
|
|
||||||
// Can be +Inf or value, previous check prevents NaN
|
|
||||||
ratio := invalidShares / validShares |
|
||||||
|
|
||||||
if ratio >= s.config.Banning.InvalidPercent/100.0 { |
|
||||||
s.forceBan(x, ip) |
|
||||||
return false |
|
||||||
} |
|
||||||
return true |
|
||||||
} |
|
||||||
|
|
||||||
func (x *Stats) resetShares() { |
|
||||||
x.ValidShares = 0 |
|
||||||
x.InvalidShares = 0 |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) forceBan(x *Stats, ip string) { |
|
||||||
if !s.config.Banning.Enabled || s.InWhiteList(ip) { |
|
||||||
return |
|
||||||
} |
|
||||||
|
|
||||||
if atomic.CompareAndSwapUint32(&x.Banned, 0, 1) { |
|
||||||
if len(s.config.Banning.IPSet) > 0 { |
|
||||||
s.banChannel <- ip |
|
||||||
} |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (x *Stats) incrLimit(n int32) { |
|
||||||
atomic.AddInt32(&x.ConnLimit, n) |
|
||||||
} |
|
||||||
|
|
||||||
func (x *Stats) incrMalformed() uint32 { |
|
||||||
return atomic.AddUint32(&x.Malformed, 1) |
|
||||||
} |
|
||||||
|
|
||||||
func (x *Stats) decrLimit() int32 { |
|
||||||
return atomic.AddInt32(&x.ConnLimit, -1) |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) InBlackList(addy string) bool { |
|
||||||
s.RLock() |
|
||||||
defer s.RUnlock() |
|
||||||
return util.StringInSlice(addy, s.blacklist) |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) InWhiteList(ip string) bool { |
|
||||||
s.RLock() |
|
||||||
defer s.RUnlock() |
|
||||||
return util.StringInSlice(ip, s.whitelist) |
|
||||||
} |
|
||||||
|
|
||||||
func (s *PolicyServer) doBan(ip string) { |
|
||||||
set, timeout := s.config.Banning.IPSet, s.config.Banning.Timeout |
|
||||||
cmd := fmt.Sprintf("sudo ipset add %s %s timeout %v -!", set, ip, timeout) |
|
||||||
args := strings.Fields(cmd) |
|
||||||
head := args[0] |
|
||||||
args = args[1:] |
|
||||||
|
|
||||||
log.Printf("Banned %v with timeout %v", ip, timeout) |
|
||||||
|
|
||||||
_, err := exec.Command(head, args...).Output() |
|
||||||
if err != nil { |
|
||||||
log.Printf("CMD Error: %s", err) |
|
||||||
} |
|
||||||
} |
|
||||||
|
|
||||||
func (x *Stats) heartbeat() { |
|
||||||
now := util.MakeTimestamp() |
|
||||||
atomic.StoreInt64(&x.LastBeat, now) |
|
||||||
} |
|
@ -1,136 +0,0 @@ |
|||||||
// Generated from https://github.com/streamrail/concurrent-map
|
|
||||||
package policy |
|
||||||
|
|
||||||
import ( |
|
||||||
"hash/fnv" |
|
||||||
"sync" |
|
||||||
) |
|
||||||
|
|
||||||
var SHARD_COUNT = 32 |
|
||||||
|
|
||||||
// TODO: Add Keys function which returns an array of keys for the map.
|
|
||||||
|
|
||||||
// A "thread" safe map of type string:*Stats.
|
|
||||||
// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
|
|
||||||
type StatsMap []*StatsMapShared |
|
||||||
type StatsMapShared struct { |
|
||||||
items map[string]*Stats |
|
||||||
sync.RWMutex // Read Write mutex, guards access to internal map.
|
|
||||||
} |
|
||||||
|
|
||||||
// Creates a new concurrent map.
|
|
||||||
func NewStatsMap() StatsMap { |
|
||||||
m := make(StatsMap, SHARD_COUNT) |
|
||||||
for i := 0; i < SHARD_COUNT; i++ { |
|
||||||
m[i] = &StatsMapShared{items: make(map[string]*Stats)} |
|
||||||
} |
|
||||||
return m |
|
||||||
} |
|
||||||
|
|
||||||
// Returns shard under given key
|
|
||||||
func (m StatsMap) GetShard(key string) *StatsMapShared { |
|
||||||
hasher := fnv.New32() |
|
||||||
hasher.Write([]byte(key)) |
|
||||||
return m[int(hasher.Sum32())%SHARD_COUNT] |
|
||||||
} |
|
||||||
|
|
||||||
// Sets the given value under the specified key.
|
|
||||||
func (m *StatsMap) Set(key string, value *Stats) { |
|
||||||
// Get map shard.
|
|
||||||
shard := m.GetShard(key) |
|
||||||
shard.Lock() |
|
||||||
defer shard.Unlock() |
|
||||||
shard.items[key] = value |
|
||||||
} |
|
||||||
|
|
||||||
// Retrieves an element from map under given key.
|
|
||||||
func (m StatsMap) Get(key string) (*Stats, bool) { |
|
||||||
// Get shard
|
|
||||||
shard := m.GetShard(key) |
|
||||||
shard.RLock() |
|
||||||
defer shard.RUnlock() |
|
||||||
|
|
||||||
// Get item from shard.
|
|
||||||
val, ok := shard.items[key] |
|
||||||
return val, ok |
|
||||||
} |
|
||||||
|
|
||||||
// Returns the number of elements within the map.
|
|
||||||
func (m StatsMap) Count() int { |
|
||||||
count := 0 |
|
||||||
for i := 0; i < SHARD_COUNT; i++ { |
|
||||||
shard := m[i] |
|
||||||
shard.RLock() |
|
||||||
count += len(shard.items) |
|
||||||
shard.RUnlock() |
|
||||||
} |
|
||||||
return count |
|
||||||
} |
|
||||||
|
|
||||||
// Looks up an item under specified key
|
|
||||||
func (m *StatsMap) Has(key string) bool { |
|
||||||
// Get shard
|
|
||||||
shard := m.GetShard(key) |
|
||||||
shard.RLock() |
|
||||||
defer shard.RUnlock() |
|
||||||
|
|
||||||
// See if element is within shard.
|
|
||||||
_, ok := shard.items[key] |
|
||||||
return ok |
|
||||||
} |
|
||||||
|
|
||||||
// Removes an element from the map.
|
|
||||||
func (m *StatsMap) Remove(key string) { |
|
||||||
// Try to get shard.
|
|
||||||
shard := m.GetShard(key) |
|
||||||
shard.Lock() |
|
||||||
defer shard.Unlock() |
|
||||||
delete(shard.items, key) |
|
||||||
} |
|
||||||
|
|
||||||
// Checks if map is empty.
|
|
||||||
func (m *StatsMap) IsEmpty() bool { |
|
||||||
return m.Count() == 0 |
|
||||||
} |
|
||||||
|
|
||||||
// Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
|
|
||||||
type Tuple struct { |
|
||||||
Key string |
|
||||||
Val *Stats |
|
||||||
} |
|
||||||
|
|
||||||
// Returns an iterator which could be used in a for range loop.
|
|
||||||
func (m StatsMap) Iter() <-chan Tuple { |
|
||||||
ch := make(chan Tuple) |
|
||||||
go func() { |
|
||||||
// Foreach shard.
|
|
||||||
for _, shard := range m { |
|
||||||
// Foreach key, value pair.
|
|
||||||
shard.RLock() |
|
||||||
for key, val := range shard.items { |
|
||||||
ch <- Tuple{key, val} |
|
||||||
} |
|
||||||
shard.RUnlock() |
|
||||||
} |
|
||||||
close(ch) |
|
||||||
}() |
|
||||||
return ch |
|
||||||
} |
|
||||||
|
|
||||||
// Returns a buffered iterator which could be used in a for range loop.
|
|
||||||
func (m StatsMap) IterBuffered() <-chan Tuple { |
|
||||||
ch := make(chan Tuple, m.Count()) |
|
||||||
go func() { |
|
||||||
// Foreach shard.
|
|
||||||
for _, shard := range m { |
|
||||||
// Foreach key, value pair.
|
|
||||||
shard.RLock() |
|
||||||
for key, val := range shard.items { |
|
||||||
ch <- Tuple{key, val} |
|
||||||
} |
|
||||||
shard.RUnlock() |
|
||||||
} |
|
||||||
close(ch) |
|
||||||
}() |
|
||||||
return ch |
|
||||||
} |
|
Loading…
Reference in new issue