mirror of
https://github.com/kvazar-network/keva-stratum.git
synced 2025-01-11 15:48:00 +00:00
Rework connection handling
This commit is contained in:
parent
5d795dc56d
commit
41710ba0e8
@ -4,7 +4,6 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@ -144,9 +143,13 @@ func (s *StratumServer) Listen() {
|
|||||||
func (e *Endpoint) Listen(s *StratumServer) {
|
func (e *Endpoint) Listen(s *StratumServer) {
|
||||||
bindAddr := fmt.Sprintf("%s:%d", e.config.Host, e.config.Port)
|
bindAddr := fmt.Sprintf("%s:%d", e.config.Host, e.config.Port)
|
||||||
addr, err := net.ResolveTCPAddr("tcp", bindAddr)
|
addr, err := net.ResolveTCPAddr("tcp", bindAddr)
|
||||||
checkError(err)
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
}
|
||||||
server, err := net.ListenTCP("tcp", addr)
|
server, err := net.ListenTCP("tcp", addr)
|
||||||
checkError(err)
|
if err != nil {
|
||||||
|
log.Fatalf("Error: %v", err)
|
||||||
|
}
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
log.Printf("Stratum listening on %s", bindAddr)
|
log.Printf("Stratum listening on %s", bindAddr)
|
||||||
@ -165,17 +168,13 @@ func (e *Endpoint) Listen(s *StratumServer) {
|
|||||||
|
|
||||||
accept <- n
|
accept <- n
|
||||||
go func() {
|
go func() {
|
||||||
err = s.handleClient(cs, e)
|
s.handleClient(cs, e)
|
||||||
if err != nil {
|
|
||||||
s.removeSession(cs)
|
|
||||||
conn.Close()
|
|
||||||
}
|
|
||||||
<-accept
|
<-accept
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error {
|
func (s *StratumServer) handleClient(cs *Session, e *Endpoint) {
|
||||||
_, targetHex := util.GetTargetHex(e.config.Difficulty)
|
_, targetHex := util.GetTargetHex(e.config.Difficulty)
|
||||||
cs.targetHex = targetHex
|
cs.targetHex = targetHex
|
||||||
|
|
||||||
@ -185,15 +184,14 @@ func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error {
|
|||||||
for {
|
for {
|
||||||
data, isPrefix, err := connbuff.ReadLine()
|
data, isPrefix, err := connbuff.ReadLine()
|
||||||
if isPrefix {
|
if isPrefix {
|
||||||
log.Printf("Socket flood detected")
|
log.Println("Socket flood detected from", cs.ip)
|
||||||
return errors.New("Socket flood")
|
break
|
||||||
} else if err == io.EOF {
|
} else if err == io.EOF {
|
||||||
log.Println("Client disconnected", cs.ip)
|
log.Println("Client disconnected", cs.ip)
|
||||||
s.removeSession(cs)
|
|
||||||
break
|
break
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
log.Printf("Error reading: %v", err)
|
log.Println("Error reading:", err)
|
||||||
return err
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTICE: cpuminer-multi sends junk newlines, so we demand at least 1 byte for decode
|
// NOTICE: cpuminer-multi sends junk newlines, so we demand at least 1 byte for decode
|
||||||
@ -202,77 +200,73 @@ func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error {
|
|||||||
var req JSONRpcReq
|
var req JSONRpcReq
|
||||||
err = json.Unmarshal(data, &req)
|
err = json.Unmarshal(data, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Malformed request: %v", err)
|
log.Printf("Malformed request from: %v", cs.ip, err)
|
||||||
return err
|
break
|
||||||
}
|
}
|
||||||
s.setDeadline(cs.conn)
|
s.setDeadline(cs.conn)
|
||||||
cs.handleMessage(s, e, &req)
|
err = cs.handleMessage(s, e, &req)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
s.removeSession(cs)
|
||||||
|
cs.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *Session) handleMessage(s *StratumServer, e *Endpoint, req *JSONRpcReq) {
|
func (cs *Session) handleMessage(s *StratumServer, e *Endpoint, req *JSONRpcReq) error {
|
||||||
if req.Id == nil {
|
if req.Id == nil {
|
||||||
log.Println("Missing RPC id")
|
err := fmt.Errorf("Server disconnect request")
|
||||||
cs.conn.Close()
|
log.Println(err)
|
||||||
return
|
return err
|
||||||
} else if req.Params == nil {
|
} else if req.Params == nil {
|
||||||
log.Println("Missing RPC params")
|
err := fmt.Errorf("Server RPC request params")
|
||||||
cs.conn.Close()
|
log.Println(err)
|
||||||
return
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// Handle RPC methods
|
// Handle RPC methods
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
|
|
||||||
case "login":
|
case "login":
|
||||||
var params LoginParams
|
var params LoginParams
|
||||||
err = json.Unmarshal(*req.Params, ¶ms)
|
err := json.Unmarshal(*req.Params, ¶ms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Unable to parse params")
|
log.Println("Unable to parse params")
|
||||||
break
|
return err
|
||||||
}
|
}
|
||||||
reply, errReply := s.handleLoginRPC(cs, e, ¶ms)
|
reply, errReply := s.handleLoginRPC(cs, e, ¶ms)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
err = cs.sendError(req.Id, errReply)
|
return cs.sendError(req.Id, errReply, true)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
err = cs.sendResult(req.Id, &reply)
|
return cs.sendResult(req.Id, &reply)
|
||||||
case "getjob":
|
case "getjob":
|
||||||
var params GetJobParams
|
var params GetJobParams
|
||||||
err = json.Unmarshal(*req.Params, ¶ms)
|
err := json.Unmarshal(*req.Params, ¶ms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Unable to parse params")
|
log.Println("Unable to parse params")
|
||||||
break
|
return err
|
||||||
}
|
}
|
||||||
reply, errReply := s.handleGetJobRPC(cs, e, ¶ms)
|
reply, errReply := s.handleGetJobRPC(cs, e, ¶ms)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
err = cs.sendError(req.Id, errReply)
|
return cs.sendError(req.Id, errReply, true)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
err = cs.sendResult(req.Id, &reply)
|
return cs.sendResult(req.Id, &reply)
|
||||||
case "submit":
|
case "submit":
|
||||||
var params SubmitParams
|
var params SubmitParams
|
||||||
err := json.Unmarshal(*req.Params, ¶ms)
|
err := json.Unmarshal(*req.Params, ¶ms)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Unable to parse params")
|
log.Println("Unable to parse params")
|
||||||
break
|
return err
|
||||||
}
|
}
|
||||||
reply, errReply := s.handleSubmitRPC(cs, e, ¶ms)
|
reply, errReply := s.handleSubmitRPC(cs, e, ¶ms)
|
||||||
if errReply != nil {
|
if errReply != nil {
|
||||||
err = cs.sendError(req.Id, errReply)
|
return cs.sendError(req.Id, errReply, true)
|
||||||
break
|
|
||||||
}
|
}
|
||||||
err = cs.sendResult(req.Id, &reply)
|
return cs.sendResult(req.Id, &reply)
|
||||||
default:
|
default:
|
||||||
errReply := s.handleUnknownRPC(cs, req)
|
errReply := s.handleUnknownRPC(cs, req)
|
||||||
err = cs.sendError(req.Id, errReply)
|
return cs.sendError(req.Id, errReply, true)
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
cs.conn.Close()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -290,15 +284,18 @@ func (cs *Session) pushMessage(method string, params interface{}) error {
|
|||||||
return cs.enc.Encode(&message)
|
return cs.enc.Encode(&message)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *Session) sendError(id *json.RawMessage, reply *ErrorReply) error {
|
func (cs *Session) sendError(id *json.RawMessage, reply *ErrorReply, drop bool) error {
|
||||||
cs.Lock()
|
cs.Lock()
|
||||||
defer cs.Unlock()
|
defer cs.Unlock()
|
||||||
message := JSONRpcResp{Id: id, Version: "2.0", Error: reply}
|
message := JSONRpcResp{Id: id, Version: "2.0", Error: reply}
|
||||||
err := cs.enc.Encode(&message)
|
err := cs.enc.Encode(&message)
|
||||||
if reply.Close {
|
if err != nil {
|
||||||
return errors.New("Force close")
|
return err
|
||||||
}
|
}
|
||||||
return err
|
if drop {
|
||||||
|
return fmt.Errorf("Server disconnect request")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *StratumServer) setDeadline(conn *net.TCPConn) {
|
func (s *StratumServer) setDeadline(conn *net.TCPConn) {
|
||||||
@ -364,9 +361,3 @@ func (s *StratumServer) rpc() *rpc.RPCClient {
|
|||||||
i := atomic.LoadInt32(&s.upstream)
|
i := atomic.LoadInt32(&s.upstream)
|
||||||
return s.upstreams[i]
|
return s.upstreams[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkError(err error) {
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
Reference in New Issue
Block a user