Rework connection handling

This commit is contained in:
Sammy Libre 2016-12-07 02:19:42 +05:00
parent 5d795dc56d
commit 41710ba0e8

View File

@ -4,7 +4,6 @@ import (
"bufio" "bufio"
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -144,9 +143,13 @@ func (s *StratumServer) Listen() {
func (e *Endpoint) Listen(s *StratumServer) { func (e *Endpoint) Listen(s *StratumServer) {
bindAddr := fmt.Sprintf("%s:%d", e.config.Host, e.config.Port) bindAddr := fmt.Sprintf("%s:%d", e.config.Host, e.config.Port)
addr, err := net.ResolveTCPAddr("tcp", bindAddr) addr, err := net.ResolveTCPAddr("tcp", bindAddr)
checkError(err) if err != nil {
log.Fatalf("Error: %v", err)
}
server, err := net.ListenTCP("tcp", addr) server, err := net.ListenTCP("tcp", addr)
checkError(err) if err != nil {
log.Fatalf("Error: %v", err)
}
defer server.Close() defer server.Close()
log.Printf("Stratum listening on %s", bindAddr) log.Printf("Stratum listening on %s", bindAddr)
@ -165,17 +168,13 @@ func (e *Endpoint) Listen(s *StratumServer) {
accept <- n accept <- n
go func() { go func() {
err = s.handleClient(cs, e) s.handleClient(cs, e)
if err != nil {
s.removeSession(cs)
conn.Close()
}
<-accept <-accept
}() }()
} }
} }
func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error { func (s *StratumServer) handleClient(cs *Session, e *Endpoint) {
_, targetHex := util.GetTargetHex(e.config.Difficulty) _, targetHex := util.GetTargetHex(e.config.Difficulty)
cs.targetHex = targetHex cs.targetHex = targetHex
@ -185,15 +184,14 @@ func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error {
for { for {
data, isPrefix, err := connbuff.ReadLine() data, isPrefix, err := connbuff.ReadLine()
if isPrefix { if isPrefix {
log.Printf("Socket flood detected") log.Println("Socket flood detected from", cs.ip)
return errors.New("Socket flood") break
} else if err == io.EOF { } else if err == io.EOF {
log.Println("Client disconnected", cs.ip) log.Println("Client disconnected", cs.ip)
s.removeSession(cs)
break break
} else if err != nil { } else if err != nil {
log.Printf("Error reading: %v", err) log.Println("Error reading:", err)
return err break
} }
// NOTICE: cpuminer-multi sends junk newlines, so we demand at least 1 byte for decode // NOTICE: cpuminer-multi sends junk newlines, so we demand at least 1 byte for decode
@ -202,77 +200,73 @@ func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error {
var req JSONRpcReq var req JSONRpcReq
err = json.Unmarshal(data, &req) err = json.Unmarshal(data, &req)
if err != nil { if err != nil {
log.Printf("Malformed request: %v", err) log.Printf("Malformed request from: %v", cs.ip, err)
return err break
} }
s.setDeadline(cs.conn) s.setDeadline(cs.conn)
cs.handleMessage(s, e, &req) err = cs.handleMessage(s, e, &req)
if err != nil {
break
}
} }
} }
return nil s.removeSession(cs)
cs.conn.Close()
} }
func (cs *Session) handleMessage(s *StratumServer, e *Endpoint, req *JSONRpcReq) { func (cs *Session) handleMessage(s *StratumServer, e *Endpoint, req *JSONRpcReq) error {
if req.Id == nil { if req.Id == nil {
log.Println("Missing RPC id") err := fmt.Errorf("Server disconnect request")
cs.conn.Close() log.Println(err)
return return err
} else if req.Params == nil { } else if req.Params == nil {
log.Println("Missing RPC params") err := fmt.Errorf("Server RPC request params")
cs.conn.Close() log.Println(err)
return return err
} }
var err error
// Handle RPC methods // Handle RPC methods
switch req.Method { switch req.Method {
case "login": case "login":
var params LoginParams var params LoginParams
err = json.Unmarshal(*req.Params, &params) err := json.Unmarshal(*req.Params, &params)
if err != nil { if err != nil {
log.Println("Unable to parse params") log.Println("Unable to parse params")
break return err
} }
reply, errReply := s.handleLoginRPC(cs, e, &params) reply, errReply := s.handleLoginRPC(cs, e, &params)
if errReply != nil { if errReply != nil {
err = cs.sendError(req.Id, errReply) return cs.sendError(req.Id, errReply, true)
break
} }
err = cs.sendResult(req.Id, &reply) return cs.sendResult(req.Id, &reply)
case "getjob": case "getjob":
var params GetJobParams var params GetJobParams
err = json.Unmarshal(*req.Params, &params) err := json.Unmarshal(*req.Params, &params)
if err != nil { if err != nil {
log.Println("Unable to parse params") log.Println("Unable to parse params")
break return err
} }
reply, errReply := s.handleGetJobRPC(cs, e, &params) reply, errReply := s.handleGetJobRPC(cs, e, &params)
if errReply != nil { if errReply != nil {
err = cs.sendError(req.Id, errReply) return cs.sendError(req.Id, errReply, true)
break
} }
err = cs.sendResult(req.Id, &reply) return cs.sendResult(req.Id, &reply)
case "submit": case "submit":
var params SubmitParams var params SubmitParams
err := json.Unmarshal(*req.Params, &params) err := json.Unmarshal(*req.Params, &params)
if err != nil { if err != nil {
log.Println("Unable to parse params") log.Println("Unable to parse params")
break return err
} }
reply, errReply := s.handleSubmitRPC(cs, e, &params) reply, errReply := s.handleSubmitRPC(cs, e, &params)
if errReply != nil { if errReply != nil {
err = cs.sendError(req.Id, errReply) return cs.sendError(req.Id, errReply, true)
break
} }
err = cs.sendResult(req.Id, &reply) return cs.sendResult(req.Id, &reply)
default: default:
errReply := s.handleUnknownRPC(cs, req) errReply := s.handleUnknownRPC(cs, req)
err = cs.sendError(req.Id, errReply) return cs.sendError(req.Id, errReply, true)
}
if err != nil {
cs.conn.Close()
} }
} }
@ -290,15 +284,18 @@ func (cs *Session) pushMessage(method string, params interface{}) error {
return cs.enc.Encode(&message) return cs.enc.Encode(&message)
} }
func (cs *Session) sendError(id *json.RawMessage, reply *ErrorReply) error { func (cs *Session) sendError(id *json.RawMessage, reply *ErrorReply, drop bool) error {
cs.Lock() cs.Lock()
defer cs.Unlock() defer cs.Unlock()
message := JSONRpcResp{Id: id, Version: "2.0", Error: reply} message := JSONRpcResp{Id: id, Version: "2.0", Error: reply}
err := cs.enc.Encode(&message) err := cs.enc.Encode(&message)
if reply.Close { if err != nil {
return errors.New("Force close") return err
} }
return err if drop {
return fmt.Errorf("Server disconnect request")
}
return nil
} }
func (s *StratumServer) setDeadline(conn *net.TCPConn) { func (s *StratumServer) setDeadline(conn *net.TCPConn) {
@ -364,9 +361,3 @@ func (s *StratumServer) rpc() *rpc.RPCClient {
i := atomic.LoadInt32(&s.upstream) i := atomic.LoadInt32(&s.upstream)
return s.upstreams[i] return s.upstreams[i]
} }
func checkError(err error) {
if err != nil {
log.Fatalf("Error: %v", err)
}
}