mirror of
https://github.com/kvazar-network/keva-stratum.git
synced 2025-01-11 15:48:00 +00:00
Rework connection handling
This commit is contained in:
parent
5d795dc56d
commit
41710ba0e8
@ -4,7 +4,6 @@ import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@ -144,9 +143,13 @@ func (s *StratumServer) Listen() {
|
||||
func (e *Endpoint) Listen(s *StratumServer) {
|
||||
bindAddr := fmt.Sprintf("%s:%d", e.config.Host, e.config.Port)
|
||||
addr, err := net.ResolveTCPAddr("tcp", bindAddr)
|
||||
checkError(err)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
server, err := net.ListenTCP("tcp", addr)
|
||||
checkError(err)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
log.Printf("Stratum listening on %s", bindAddr)
|
||||
@ -165,17 +168,13 @@ func (e *Endpoint) Listen(s *StratumServer) {
|
||||
|
||||
accept <- n
|
||||
go func() {
|
||||
err = s.handleClient(cs, e)
|
||||
if err != nil {
|
||||
s.removeSession(cs)
|
||||
conn.Close()
|
||||
}
|
||||
s.handleClient(cs, e)
|
||||
<-accept
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error {
|
||||
func (s *StratumServer) handleClient(cs *Session, e *Endpoint) {
|
||||
_, targetHex := util.GetTargetHex(e.config.Difficulty)
|
||||
cs.targetHex = targetHex
|
||||
|
||||
@ -185,15 +184,14 @@ func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error {
|
||||
for {
|
||||
data, isPrefix, err := connbuff.ReadLine()
|
||||
if isPrefix {
|
||||
log.Printf("Socket flood detected")
|
||||
return errors.New("Socket flood")
|
||||
log.Println("Socket flood detected from", cs.ip)
|
||||
break
|
||||
} else if err == io.EOF {
|
||||
log.Println("Client disconnected", cs.ip)
|
||||
s.removeSession(cs)
|
||||
break
|
||||
} else if err != nil {
|
||||
log.Printf("Error reading: %v", err)
|
||||
return err
|
||||
log.Println("Error reading:", err)
|
||||
break
|
||||
}
|
||||
|
||||
// NOTICE: cpuminer-multi sends junk newlines, so we demand at least 1 byte for decode
|
||||
@ -202,77 +200,73 @@ func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error {
|
||||
var req JSONRpcReq
|
||||
err = json.Unmarshal(data, &req)
|
||||
if err != nil {
|
||||
log.Printf("Malformed request: %v", err)
|
||||
return err
|
||||
log.Printf("Malformed request from: %v", cs.ip, err)
|
||||
break
|
||||
}
|
||||
s.setDeadline(cs.conn)
|
||||
cs.handleMessage(s, e, &req)
|
||||
err = cs.handleMessage(s, e, &req)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
s.removeSession(cs)
|
||||
cs.conn.Close()
|
||||
}
|
||||
|
||||
func (cs *Session) handleMessage(s *StratumServer, e *Endpoint, req *JSONRpcReq) {
|
||||
func (cs *Session) handleMessage(s *StratumServer, e *Endpoint, req *JSONRpcReq) error {
|
||||
if req.Id == nil {
|
||||
log.Println("Missing RPC id")
|
||||
cs.conn.Close()
|
||||
return
|
||||
err := fmt.Errorf("Server disconnect request")
|
||||
log.Println(err)
|
||||
return err
|
||||
} else if req.Params == nil {
|
||||
log.Println("Missing RPC params")
|
||||
cs.conn.Close()
|
||||
return
|
||||
err := fmt.Errorf("Server RPC request params")
|
||||
log.Println(err)
|
||||
return err
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// Handle RPC methods
|
||||
switch req.Method {
|
||||
|
||||
case "login":
|
||||
var params LoginParams
|
||||
err = json.Unmarshal(*req.Params, ¶ms)
|
||||
err := json.Unmarshal(*req.Params, ¶ms)
|
||||
if err != nil {
|
||||
log.Println("Unable to parse params")
|
||||
break
|
||||
return err
|
||||
}
|
||||
reply, errReply := s.handleLoginRPC(cs, e, ¶ms)
|
||||
if errReply != nil {
|
||||
err = cs.sendError(req.Id, errReply)
|
||||
break
|
||||
return cs.sendError(req.Id, errReply, true)
|
||||
}
|
||||
err = cs.sendResult(req.Id, &reply)
|
||||
return cs.sendResult(req.Id, &reply)
|
||||
case "getjob":
|
||||
var params GetJobParams
|
||||
err = json.Unmarshal(*req.Params, ¶ms)
|
||||
err := json.Unmarshal(*req.Params, ¶ms)
|
||||
if err != nil {
|
||||
log.Println("Unable to parse params")
|
||||
break
|
||||
return err
|
||||
}
|
||||
reply, errReply := s.handleGetJobRPC(cs, e, ¶ms)
|
||||
if errReply != nil {
|
||||
err = cs.sendError(req.Id, errReply)
|
||||
break
|
||||
return cs.sendError(req.Id, errReply, true)
|
||||
}
|
||||
err = cs.sendResult(req.Id, &reply)
|
||||
return cs.sendResult(req.Id, &reply)
|
||||
case "submit":
|
||||
var params SubmitParams
|
||||
err := json.Unmarshal(*req.Params, ¶ms)
|
||||
if err != nil {
|
||||
log.Println("Unable to parse params")
|
||||
break
|
||||
return err
|
||||
}
|
||||
reply, errReply := s.handleSubmitRPC(cs, e, ¶ms)
|
||||
if errReply != nil {
|
||||
err = cs.sendError(req.Id, errReply)
|
||||
break
|
||||
return cs.sendError(req.Id, errReply, true)
|
||||
}
|
||||
err = cs.sendResult(req.Id, &reply)
|
||||
return cs.sendResult(req.Id, &reply)
|
||||
default:
|
||||
errReply := s.handleUnknownRPC(cs, req)
|
||||
err = cs.sendError(req.Id, errReply)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cs.conn.Close()
|
||||
return cs.sendError(req.Id, errReply, true)
|
||||
}
|
||||
}
|
||||
|
||||
@ -290,15 +284,18 @@ func (cs *Session) pushMessage(method string, params interface{}) error {
|
||||
return cs.enc.Encode(&message)
|
||||
}
|
||||
|
||||
func (cs *Session) sendError(id *json.RawMessage, reply *ErrorReply) error {
|
||||
func (cs *Session) sendError(id *json.RawMessage, reply *ErrorReply, drop bool) error {
|
||||
cs.Lock()
|
||||
defer cs.Unlock()
|
||||
message := JSONRpcResp{Id: id, Version: "2.0", Error: reply}
|
||||
err := cs.enc.Encode(&message)
|
||||
if reply.Close {
|
||||
return errors.New("Force close")
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if drop {
|
||||
return fmt.Errorf("Server disconnect request")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StratumServer) setDeadline(conn *net.TCPConn) {
|
||||
@ -364,9 +361,3 @@ func (s *StratumServer) rpc() *rpc.RPCClient {
|
||||
i := atomic.LoadInt32(&s.upstream)
|
||||
return s.upstreams[i]
|
||||
}
|
||||
|
||||
func checkError(err error) {
|
||||
if err != nil {
|
||||
log.Fatalf("Error: %v", err)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user