Sammy Libre
8 years ago
2 changed files with 0 additions and 407 deletions
@ -1,271 +0,0 @@
@@ -1,271 +0,0 @@
|
||||
package policy |
||||
|
||||
import ( |
||||
"fmt" |
||||
"log" |
||||
"os/exec" |
||||
"strings" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"../../pool" |
||||
"../../storage" |
||||
"../../util" |
||||
) |
||||
|
||||
type Stats struct { |
||||
sync.Mutex |
||||
ValidShares uint32 |
||||
InvalidShares uint32 |
||||
Malformed uint32 |
||||
ConnLimit int32 |
||||
FailsCount uint32 |
||||
LastBeat int64 |
||||
Banned uint32 |
||||
BannedAt int64 |
||||
} |
||||
|
||||
type PolicyServer struct { |
||||
sync.RWMutex |
||||
config *pool.Policy |
||||
stats StatsMap |
||||
banChannel chan string |
||||
startedAt int64 |
||||
grace int64 |
||||
timeout int64 |
||||
blacklist []string |
||||
whitelist []string |
||||
storage *storage.RedisClient |
||||
} |
||||
|
||||
func Start(cfg *pool.Config, storage *storage.RedisClient) *PolicyServer { |
||||
s := &PolicyServer{config: &cfg.Policy, startedAt: util.MakeTimestamp()} |
||||
grace, _ := time.ParseDuration(cfg.Policy.Limits.Grace) |
||||
s.grace = int64(grace / time.Millisecond) |
||||
s.banChannel = make(chan string, 64) |
||||
s.stats = NewStatsMap() |
||||
s.storage = storage |
||||
s.refreshState() |
||||
|
||||
timeout, _ := time.ParseDuration(s.config.ResetInterval) |
||||
s.timeout = int64(timeout / time.Millisecond) |
||||
|
||||
resetIntv, _ := time.ParseDuration(s.config.ResetInterval) |
||||
resetTimer := time.NewTimer(resetIntv) |
||||
log.Printf("Set policy stats reset every %v", resetIntv) |
||||
|
||||
refreshIntv, _ := time.ParseDuration(s.config.RefreshInterval) |
||||
refreshTimer := time.NewTimer(refreshIntv) |
||||
log.Printf("Set policy state refresh every %v", refreshIntv) |
||||
|
||||
go func() { |
||||
for { |
||||
select { |
||||
case <-resetTimer.C: |
||||
s.resetStats() |
||||
resetTimer.Reset(resetIntv) |
||||
case <-refreshTimer.C: |
||||
s.refreshState() |
||||
refreshTimer.Reset(refreshIntv) |
||||
} |
||||
} |
||||
}() |
||||
|
||||
for i := 0; i < s.config.Workers; i++ { |
||||
s.startPolicyWorker() |
||||
} |
||||
log.Printf("Running with %v policy workers", s.config.Workers) |
||||
return s |
||||
} |
||||
|
||||
func (s *PolicyServer) startPolicyWorker() { |
||||
go func() { |
||||
for { |
||||
select { |
||||
case ip := <-s.banChannel: |
||||
s.doBan(ip) |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
func (s *PolicyServer) resetStats() { |
||||
now := util.MakeTimestamp() |
||||
banningTimeout := s.config.Banning.Timeout * 1000 |
||||
total := 0 |
||||
|
||||
for m := range s.stats.IterBuffered() { |
||||
lastBeat := atomic.LoadInt64(&m.Val.LastBeat) |
||||
bannedAt := atomic.LoadInt64(&m.Val.BannedAt) |
||||
|
||||
if now-bannedAt >= banningTimeout { |
||||
atomic.StoreInt64(&m.Val.BannedAt, 0) |
||||
if atomic.CompareAndSwapUint32(&m.Val.Banned, 1, 0) { |
||||
log.Printf("Ban dropped for %v", m.Key) |
||||
} |
||||
} |
||||
if now-lastBeat >= s.timeout { |
||||
s.stats.Remove(m.Key) |
||||
total++ |
||||
} |
||||
} |
||||
log.Printf("Flushed stats for %v IP addresses", total) |
||||
} |
||||
|
||||
func (s *PolicyServer) refreshState() { |
||||
s.Lock() |
||||
defer s.Unlock() |
||||
|
||||
s.blacklist = s.storage.GetBlacklist() |
||||
s.whitelist = s.storage.GetWhitelist() |
||||
log.Println("Policy state refresh complete") |
||||
} |
||||
|
||||
func (s *PolicyServer) NewStats() *Stats { |
||||
x := &Stats{ |
||||
ConnLimit: s.config.Limits.Limit, |
||||
Malformed: s.config.Banning.MalformedLimit, |
||||
} |
||||
x.heartbeat() |
||||
return x |
||||
} |
||||
|
||||
func (s *PolicyServer) Get(ip string) *Stats { |
||||
if x, ok := s.stats.Get(ip); ok { |
||||
x.heartbeat() |
||||
return x |
||||
} |
||||
x := s.NewStats() |
||||
s.stats.Set(ip, x) |
||||
return x |
||||
} |
||||
|
||||
func (s *PolicyServer) ApplyLimitPolicy(ip string) bool { |
||||
if !s.config.Limits.Enabled { |
||||
return true |
||||
} |
||||
now := util.MakeTimestamp() |
||||
if now-s.startedAt > s.grace { |
||||
return s.Get(ip).decrLimit() > 0 |
||||
} |
||||
return true |
||||
} |
||||
|
||||
func (s *PolicyServer) ApplyLoginPolicy(addy, ip string) bool { |
||||
if s.InBlackList(addy) { |
||||
x := s.Get(ip) |
||||
s.forceBan(x, ip) |
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
func (s *PolicyServer) ApplyMalformedPolicy(ip string) { |
||||
x := s.Get(ip) |
||||
n := x.incrMalformed() |
||||
if n >= s.config.Banning.MalformedLimit { |
||||
s.forceBan(x, ip) |
||||
} |
||||
} |
||||
|
||||
func (s *PolicyServer) ApplySharePolicy(ip string, validShare bool) bool { |
||||
x := s.Get(ip) |
||||
if validShare && s.config.Limits.Enabled { |
||||
s.Get(ip).incrLimit(s.config.Limits.LimitJump) |
||||
} |
||||
x.Lock() |
||||
|
||||
if validShare { |
||||
x.ValidShares++ |
||||
if s.config.Limits.Enabled { |
||||
x.incrLimit(s.config.Limits.LimitJump) |
||||
} |
||||
} else { |
||||
x.InvalidShares++ |
||||
} |
||||
|
||||
totalShares := x.ValidShares + x.InvalidShares |
||||
if totalShares < s.config.Banning.CheckThreshold { |
||||
x.Unlock() |
||||
return true |
||||
} |
||||
validShares := float32(x.ValidShares) |
||||
invalidShares := float32(x.InvalidShares) |
||||
x.resetShares() |
||||
x.Unlock() |
||||
|
||||
if invalidShares == 0 { |
||||
return true |
||||
} |
||||
|
||||
// Can be +Inf or value, previous check prevents NaN
|
||||
ratio := invalidShares / validShares |
||||
|
||||
if ratio >= s.config.Banning.InvalidPercent/100.0 { |
||||
s.forceBan(x, ip) |
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
func (x *Stats) resetShares() { |
||||
x.ValidShares = 0 |
||||
x.InvalidShares = 0 |
||||
} |
||||
|
||||
func (s *PolicyServer) forceBan(x *Stats, ip string) { |
||||
if !s.config.Banning.Enabled || s.InWhiteList(ip) { |
||||
return |
||||
} |
||||
|
||||
if atomic.CompareAndSwapUint32(&x.Banned, 0, 1) { |
||||
if len(s.config.Banning.IPSet) > 0 { |
||||
s.banChannel <- ip |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (x *Stats) incrLimit(n int32) { |
||||
atomic.AddInt32(&x.ConnLimit, n) |
||||
} |
||||
|
||||
func (x *Stats) incrMalformed() uint32 { |
||||
return atomic.AddUint32(&x.Malformed, 1) |
||||
} |
||||
|
||||
func (x *Stats) decrLimit() int32 { |
||||
return atomic.AddInt32(&x.ConnLimit, -1) |
||||
} |
||||
|
||||
func (s *PolicyServer) InBlackList(addy string) bool { |
||||
s.RLock() |
||||
defer s.RUnlock() |
||||
return util.StringInSlice(addy, s.blacklist) |
||||
} |
||||
|
||||
func (s *PolicyServer) InWhiteList(ip string) bool { |
||||
s.RLock() |
||||
defer s.RUnlock() |
||||
return util.StringInSlice(ip, s.whitelist) |
||||
} |
||||
|
||||
func (s *PolicyServer) doBan(ip string) { |
||||
set, timeout := s.config.Banning.IPSet, s.config.Banning.Timeout |
||||
cmd := fmt.Sprintf("sudo ipset add %s %s timeout %v -!", set, ip, timeout) |
||||
args := strings.Fields(cmd) |
||||
head := args[0] |
||||
args = args[1:] |
||||
|
||||
log.Printf("Banned %v with timeout %v", ip, timeout) |
||||
|
||||
_, err := exec.Command(head, args...).Output() |
||||
if err != nil { |
||||
log.Printf("CMD Error: %s", err) |
||||
} |
||||
} |
||||
|
||||
func (x *Stats) heartbeat() { |
||||
now := util.MakeTimestamp() |
||||
atomic.StoreInt64(&x.LastBeat, now) |
||||
} |
@ -1,136 +0,0 @@
@@ -1,136 +0,0 @@
|
||||
// Generated from https://github.com/streamrail/concurrent-map
|
||||
package policy |
||||
|
||||
import ( |
||||
"hash/fnv" |
||||
"sync" |
||||
) |
||||
|
||||
var SHARD_COUNT = 32 |
||||
|
||||
// TODO: Add Keys function which returns an array of keys for the map.
|
||||
|
||||
// A "thread" safe map of type string:*Stats.
|
||||
// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
|
||||
type StatsMap []*StatsMapShared |
||||
type StatsMapShared struct { |
||||
items map[string]*Stats |
||||
sync.RWMutex // Read Write mutex, guards access to internal map.
|
||||
} |
||||
|
||||
// Creates a new concurrent map.
|
||||
func NewStatsMap() StatsMap { |
||||
m := make(StatsMap, SHARD_COUNT) |
||||
for i := 0; i < SHARD_COUNT; i++ { |
||||
m[i] = &StatsMapShared{items: make(map[string]*Stats)} |
||||
} |
||||
return m |
||||
} |
||||
|
||||
// Returns shard under given key
|
||||
func (m StatsMap) GetShard(key string) *StatsMapShared { |
||||
hasher := fnv.New32() |
||||
hasher.Write([]byte(key)) |
||||
return m[int(hasher.Sum32())%SHARD_COUNT] |
||||
} |
||||
|
||||
// Sets the given value under the specified key.
|
||||
func (m *StatsMap) Set(key string, value *Stats) { |
||||
// Get map shard.
|
||||
shard := m.GetShard(key) |
||||
shard.Lock() |
||||
defer shard.Unlock() |
||||
shard.items[key] = value |
||||
} |
||||
|
||||
// Retrieves an element from map under given key.
|
||||
func (m StatsMap) Get(key string) (*Stats, bool) { |
||||
// Get shard
|
||||
shard := m.GetShard(key) |
||||
shard.RLock() |
||||
defer shard.RUnlock() |
||||
|
||||
// Get item from shard.
|
||||
val, ok := shard.items[key] |
||||
return val, ok |
||||
} |
||||
|
||||
// Returns the number of elements within the map.
|
||||
func (m StatsMap) Count() int { |
||||
count := 0 |
||||
for i := 0; i < SHARD_COUNT; i++ { |
||||
shard := m[i] |
||||
shard.RLock() |
||||
count += len(shard.items) |
||||
shard.RUnlock() |
||||
} |
||||
return count |
||||
} |
||||
|
||||
// Looks up an item under specified key
|
||||
func (m *StatsMap) Has(key string) bool { |
||||
// Get shard
|
||||
shard := m.GetShard(key) |
||||
shard.RLock() |
||||
defer shard.RUnlock() |
||||
|
||||
// See if element is within shard.
|
||||
_, ok := shard.items[key] |
||||
return ok |
||||
} |
||||
|
||||
// Removes an element from the map.
|
||||
func (m *StatsMap) Remove(key string) { |
||||
// Try to get shard.
|
||||
shard := m.GetShard(key) |
||||
shard.Lock() |
||||
defer shard.Unlock() |
||||
delete(shard.items, key) |
||||
} |
||||
|
||||
// Checks if map is empty.
|
||||
func (m *StatsMap) IsEmpty() bool { |
||||
return m.Count() == 0 |
||||
} |
||||
|
||||
// Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
|
||||
type Tuple struct { |
||||
Key string |
||||
Val *Stats |
||||
} |
||||
|
||||
// Returns an iterator which could be used in a for range loop.
|
||||
func (m StatsMap) Iter() <-chan Tuple { |
||||
ch := make(chan Tuple) |
||||
go func() { |
||||
// Foreach shard.
|
||||
for _, shard := range m { |
||||
// Foreach key, value pair.
|
||||
shard.RLock() |
||||
for key, val := range shard.items { |
||||
ch <- Tuple{key, val} |
||||
} |
||||
shard.RUnlock() |
||||
} |
||||
close(ch) |
||||
}() |
||||
return ch |
||||
} |
||||
|
||||
// Returns a buffered iterator which could be used in a for range loop.
|
||||
func (m StatsMap) IterBuffered() <-chan Tuple { |
||||
ch := make(chan Tuple, m.Count()) |
||||
go func() { |
||||
// Foreach shard.
|
||||
for _, shard := range m { |
||||
// Foreach key, value pair.
|
||||
shard.RLock() |
||||
for key, val := range shard.items { |
||||
ch <- Tuple{key, val} |
||||
} |
||||
shard.RUnlock() |
||||
} |
||||
close(ch) |
||||
}() |
||||
return ch |
||||
} |
Loading…
Reference in new issue