diff --git a/go-pool/stratum/stratum.go b/go-pool/stratum/stratum.go index 89f3f09..2df3143 100644 --- a/go-pool/stratum/stratum.go +++ b/go-pool/stratum/stratum.go @@ -4,7 +4,6 @@ import ( "bufio" "crypto/rand" "encoding/json" - "errors" "fmt" "io" "log" @@ -144,9 +143,13 @@ func (s *StratumServer) Listen() { func (e *Endpoint) Listen(s *StratumServer) { bindAddr := fmt.Sprintf("%s:%d", e.config.Host, e.config.Port) addr, err := net.ResolveTCPAddr("tcp", bindAddr) - checkError(err) + if err != nil { + log.Fatalf("Error: %v", err) + } server, err := net.ListenTCP("tcp", addr) - checkError(err) + if err != nil { + log.Fatalf("Error: %v", err) + } defer server.Close() log.Printf("Stratum listening on %s", bindAddr) @@ -165,17 +168,13 @@ func (e *Endpoint) Listen(s *StratumServer) { accept <- n go func() { - err = s.handleClient(cs, e) - if err != nil { - s.removeSession(cs) - conn.Close() - } + s.handleClient(cs, e) <-accept }() } } -func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error { +func (s *StratumServer) handleClient(cs *Session, e *Endpoint) { _, targetHex := util.GetTargetHex(e.config.Difficulty) cs.targetHex = targetHex @@ -185,15 +184,14 @@ func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error { for { data, isPrefix, err := connbuff.ReadLine() if isPrefix { - log.Printf("Socket flood detected") - return errors.New("Socket flood") + log.Println("Socket flood detected from", cs.ip) + break } else if err == io.EOF { log.Println("Client disconnected", cs.ip) - s.removeSession(cs) break } else if err != nil { - log.Printf("Error reading: %v", err) - return err + log.Println("Error reading:", err) + break } // NOTICE: cpuminer-multi sends junk newlines, so we demand at least 1 byte for decode @@ -202,77 +200,73 @@ func (s *StratumServer) handleClient(cs *Session, e *Endpoint) error { var req JSONRpcReq err = json.Unmarshal(data, &req) if err != nil { - log.Printf("Malformed request: %v", err) - return err + log.Printf("Malformed request from: %v", cs.ip, err) + break } s.setDeadline(cs.conn) - cs.handleMessage(s, e, &req) + err = cs.handleMessage(s, e, &req) + if err != nil { + break + } } } - return nil + s.removeSession(cs) + cs.conn.Close() } -func (cs *Session) handleMessage(s *StratumServer, e *Endpoint, req *JSONRpcReq) { +func (cs *Session) handleMessage(s *StratumServer, e *Endpoint, req *JSONRpcReq) error { if req.Id == nil { - log.Println("Missing RPC id") - cs.conn.Close() - return + err := fmt.Errorf("Server disconnect request") + log.Println(err) + return err } else if req.Params == nil { - log.Println("Missing RPC params") - cs.conn.Close() - return + err := fmt.Errorf("Server RPC request params") + log.Println(err) + return err } - var err error - // Handle RPC methods switch req.Method { + case "login": var params LoginParams - err = json.Unmarshal(*req.Params, ¶ms) + err := json.Unmarshal(*req.Params, ¶ms) if err != nil { log.Println("Unable to parse params") - break + return err } reply, errReply := s.handleLoginRPC(cs, e, ¶ms) if errReply != nil { - err = cs.sendError(req.Id, errReply) - break + return cs.sendError(req.Id, errReply, true) } - err = cs.sendResult(req.Id, &reply) + return cs.sendResult(req.Id, &reply) case "getjob": var params GetJobParams - err = json.Unmarshal(*req.Params, ¶ms) + err := json.Unmarshal(*req.Params, ¶ms) if err != nil { log.Println("Unable to parse params") - break + return err } reply, errReply := s.handleGetJobRPC(cs, e, ¶ms) if errReply != nil { - err = cs.sendError(req.Id, errReply) - break + return cs.sendError(req.Id, errReply, true) } - err = cs.sendResult(req.Id, &reply) + return cs.sendResult(req.Id, &reply) case "submit": var params SubmitParams err := json.Unmarshal(*req.Params, ¶ms) if err != nil { log.Println("Unable to parse params") - break + return err } reply, errReply := s.handleSubmitRPC(cs, e, ¶ms) if errReply != nil { - err = cs.sendError(req.Id, errReply) - break + return cs.sendError(req.Id, errReply, true) } - err = cs.sendResult(req.Id, &reply) + return cs.sendResult(req.Id, &reply) default: errReply := s.handleUnknownRPC(cs, req) - err = cs.sendError(req.Id, errReply) - } - - if err != nil { - cs.conn.Close() + return cs.sendError(req.Id, errReply, true) } } @@ -290,15 +284,18 @@ func (cs *Session) pushMessage(method string, params interface{}) error { return cs.enc.Encode(&message) } -func (cs *Session) sendError(id *json.RawMessage, reply *ErrorReply) error { +func (cs *Session) sendError(id *json.RawMessage, reply *ErrorReply, drop bool) error { cs.Lock() defer cs.Unlock() message := JSONRpcResp{Id: id, Version: "2.0", Error: reply} err := cs.enc.Encode(&message) - if reply.Close { - return errors.New("Force close") + if err != nil { + return err + } + if drop { + return fmt.Errorf("Server disconnect request") } - return err + return nil } func (s *StratumServer) setDeadline(conn *net.TCPConn) { @@ -364,9 +361,3 @@ func (s *StratumServer) rpc() *rpc.RPCClient { i := atomic.LoadInt32(&s.upstream) return s.upstreams[i] } - -func checkError(err error) { - if err != nil { - log.Fatalf("Error: %v", err) - } -}