Skip to content
Merged
104 changes: 64 additions & 40 deletions compliance/compliance.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"os"
"os/signal"
"sync/atomic"
"syscall"
"time"

Expand Down Expand Up @@ -35,14 +36,21 @@ import (

var log = logging.LoggerForModule()

const (
// nodeResourceID is the resource ID used for node scanning UMH.
// Compliance handles exactly one node, so a single constant suffices.
nodeResourceID = "this-node"
)
Comment thread
vikin91 marked this conversation as resolved.

// Compliance represents the Compliance app
type Compliance struct {
nodeNameProvider node.NodeNameProvider
nodeScanner node.NodeScanner
nodeIndexer node.NodeIndexer
umhNodeInventory node.UnconfirmedMessageHandler
umhNodeIndex node.UnconfirmedMessageHandler
cache *sensor.MsgFromCompliance
nodeNameProvider node.NodeNameProvider
nodeScanner node.NodeScanner
nodeIndexer node.NodeIndexer
umhNodeInventory node.UnconfirmedMessageHandler
umhNodeIndex node.UnconfirmedMessageHandler
nodeInventoryCache atomic.Pointer[sensor.MsgFromCompliance]
nodeIndexCache atomic.Pointer[sensor.MsgFromCompliance]
}

// NewComplianceApp constructs the Compliance app object
Expand All @@ -54,7 +62,6 @@ func NewComplianceApp(nnp node.NodeNameProvider, scanner node.NodeScanner, nodeI
nodeIndexer: nodeIndexer,
umhNodeInventory: umhNodeInv,
umhNodeIndex: umhNodeIndex,
cache: nil,
}
}

Expand Down Expand Up @@ -179,13 +186,18 @@ func (c *Compliance) manageNodeInventoryScanLoop(ctx context.Context) <-chan *se
select {
case <-ctx.Done():
return
case _, ok := <-c.umhNodeInventory.RetryCommand():
if c.cache == nil {
log.Debug("Requested to retry but cache is empty. Resetting scan timer.")
case resourceID, ok := <-c.umhNodeInventory.RetryCommand():
if !ok {
log.Info("UMH retry channel for node inventory closed; stopping scan loop")
return
}
cachedMsg := c.nodeInventoryCache.Load()
if cachedMsg == nil {
log.Debugf("Requested to retry %s but cache is empty. Resetting scan timer.", resourceID)
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionResendingCacheMiss, cmetrics.ScannerVersionV2)
t.Reset(time.Second)
} else if ok {
nodeInventoriesC <- c.cache
} else {
nodeInventoriesC <- cachedMsg
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionResendingCacheHit, cmetrics.ScannerVersionV2)
}
case <-t.C:
Expand Down Expand Up @@ -215,13 +227,18 @@ func (c *Compliance) manageNodeIndexScanLoop(ctx context.Context) <-chan *sensor
select {
case <-ctx.Done():
return
case _, ok := <-c.umhNodeIndex.RetryCommand():
if c.cache == nil {
log.Debug("Requested to retry but cache is empty. Resetting scan timer.")
case resourceID, ok := <-c.umhNodeIndex.RetryCommand():
if !ok {
log.Info("UMH retry channel for node index closed; stopping scan loop")
return
}
cachedMsg := c.nodeIndexCache.Load()
if cachedMsg == nil {
log.Debugf("Requested to retry %s but cache is empty. Resetting scan timer.", resourceID)
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionResendingCacheMiss, cmetrics.ScannerVersionV4)
t.Reset(time.Second)
} else if ok {
nodeIndexesC <- c.cache
} else {
nodeIndexesC <- cachedMsg
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionResendingCacheHit, cmetrics.ScannerVersionV4)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
case <-t.C:
Expand Down Expand Up @@ -249,8 +266,8 @@ func (c *Compliance) runNodeInventoryScan(ctx context.Context) *sensor.MsgFromCo
}
cmetrics.ObserveNodeInventoryScan(msg.GetNodeInventory())
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionScan, cmetrics.ScannerVersionV2)
c.umhNodeInventory.ObserveSending()
c.cache = msg.CloneVT()
c.umhNodeInventory.ObserveSending(nodeResourceID)
c.nodeInventoryCache.Store(msg.CloneVT())
return msg
}

Expand All @@ -266,11 +283,12 @@ func (c *Compliance) runNodeIndex(ctx context.Context) *sensor.MsgFromCompliance
log.Errorf("Error creating node index: %v", err)
return nil
}
c.umhNodeIndex.ObserveSending()
c.umhNodeIndex.ObserveSending(nodeResourceID)
cmetrics.ObserveNodeIndexReport(report, nodeName)
msg := c.createIndexMsg(report, nodeName)
cmetrics.ObserveReportProtobufMessage(msg, cmetrics.ScannerVersionV4)
cmetrics.ObserveNodePackageReportTransmissions(nodeName, cmetrics.InventoryTransmissionScan, cmetrics.ScannerVersionV4)
c.nodeIndexCache.Store(msg.CloneVT())
return msg
}

Expand Down Expand Up @@ -343,43 +361,49 @@ func (c *Compliance) runRecv(ctx context.Context, client sensor.ComplianceServic
}
}
case *sensor.MsgToCompliance_ComplianceAck:
complianceAck := t.ComplianceAck
log.Debugf("Received ComplianceACK: type=%s, action=%s, resource_id=%s, reason=%s",
complianceAck.GetMessageType(),
complianceAck.GetAction(),
complianceAck.GetResourceId(),
complianceAck.GetReason(),
)
c.handleNodeScanningComplianceAck(complianceAck)
// New ComplianceACK from Sensor 4.10+
c.handleComplianceACK(t.ComplianceAck)
default:
utils.Should(errors.Errorf("Unhandled msg type: %T", t))
}
}
}

func (c *Compliance) handleNodeScanningComplianceAck(complianceAck *sensor.MsgToCompliance_ComplianceACK) {
if complianceAck == nil {
// handleComplianceACK handles the new ComplianceACK message from Sensor 4.10+.
// This is the generic ACK/NACK message that replaces the legacy NodeInventoryACK.
func (c *Compliance) handleComplianceACK(ack *sensor.MsgToCompliance_ComplianceACK) {
if ack == nil {
log.Error("Received nil ComplianceACK")
return
}

var handler node.UnconfirmedMessageHandler
switch complianceAck.GetMessageType() {
log.Debugf("Received ComplianceACK: type=%s, action=%s, resource_id=%s, reason=%s",
ack.GetMessageType(), ack.GetAction(), ack.GetResourceId(), ack.GetReason())

switch ack.GetMessageType() {
case sensor.MsgToCompliance_ComplianceACK_NODE_INVENTORY:
handler = c.umhNodeInventory
dispatchACK(c.umhNodeInventory, "node inventory", ack.GetAction(), ack.GetReason())
case sensor.MsgToCompliance_ComplianceACK_NODE_INDEX_REPORT:
handler = c.umhNodeIndex
dispatchACK(c.umhNodeIndex, "node index", ack.GetAction(), ack.GetReason())
case sensor.MsgToCompliance_ComplianceACK_VM_INDEX_REPORT:
// TODO: Implement basic handling of VM_INDEX_REPORT ACK/NACK messages in ROX-33555.
default:
log.Debugf("Ignoring ComplianceACK with unsupported message type: %s", complianceAck.GetMessageType())
return
log.Errorf("Unknown ComplianceACK message type: %s", ack.GetMessageType())
}
}

switch complianceAck.GetAction() {
// dispatchACK routes a ComplianceACK action to the appropriate UMH method.
func dispatchACK(umh node.UnconfirmedMessageHandler, label string, action sensor.MsgToCompliance_ComplianceACK_Action, reason string) {
switch action {
case sensor.MsgToCompliance_ComplianceACK_ACK:
handler.HandleACK()
umh.HandleACK(nodeResourceID)
case sensor.MsgToCompliance_ComplianceACK_NACK:
handler.HandleNACK()
if reason != "" {
log.Infof("%s NACK received: %s", label, reason)
}
umh.HandleNACK(nodeResourceID)
default:
log.Errorf("Unknown ComplianceACK action: %s", complianceAck.GetAction())
log.Errorf("Unknown ComplianceACK action for %s: %s", label, action)
}
}

Expand Down
29 changes: 23 additions & 6 deletions compliance/compliance_ack_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,37 @@ import (
"testing"

"github.com/stackrox/rox/generated/internalapi/sensor"
"github.com/stackrox/rox/pkg/concurrency"
"github.com/stretchr/testify/assert"
)

// fakeUMH is a minimal test double for node.UnconfirmedMessageHandler.
// Set retryC to a non-nil channel when tests need RetryCommand() to be selectable.
type fakeUMH struct {
ackCount int
nackCount int
retryC chan string
}

func (f *fakeUMH) HandleACK() { f.ackCount++ }
func (f *fakeUMH) HandleNACK() { f.nackCount++ }
func (f *fakeUMH) ObserveSending() {}
func (f *fakeUMH) RetryCommand() <-chan struct{} { return nil }
func (f *fakeUMH) HandleACK(string) { f.ackCount++ }
func (f *fakeUMH) HandleNACK(string) { f.nackCount++ }
func (f *fakeUMH) ObserveSending(string) {}
func (f *fakeUMH) OnACK(func(string)) {}

func TestHandleNodeScanningComplianceAck(t *testing.T) {
func (f *fakeUMH) RetryCommand() <-chan string {
if f.retryC != nil {
return f.retryC
}
return nil
}

func (f *fakeUMH) Stopped() concurrency.ReadOnlyErrorSignal {
s := concurrency.NewStopper()
s.Flow().ReportStopped()
return s.Client().Stopped()
}

func TestHandleComplianceACK(t *testing.T) {
inv := &fakeUMH{}
idx := &fakeUMH{}
c := &Compliance{
Expand Down Expand Up @@ -89,7 +106,7 @@ func TestHandleNodeScanningComplianceAck(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
inv.ackCount, inv.nackCount = 0, 0
idx.ackCount, idx.nackCount = 0, 0
c.handleNodeScanningComplianceAck(tt.ack)
c.handleComplianceACK(tt.ack)
Comment thread
vikin91 marked this conversation as resolved.
assert.Equal(t, tt.wantInvACK, inv.ackCount)
assert.Equal(t, tt.wantInvNACK, inv.nackCount)
assert.Equal(t, tt.wantIdxACK, idx.ackCount)
Expand Down
98 changes: 98 additions & 0 deletions compliance/compliance_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package compliance

import (
"testing"

"github.com/stackrox/rox/generated/internalapi/sensor"
"github.com/stretchr/testify/suite"
)

func TestCompliance(t *testing.T) {
suite.Run(t, new(ComplianceTestSuite))
}

type ComplianceTestSuite struct {
suite.Suite
}

func (s *ComplianceTestSuite) TestHandleComplianceACK() {
cases := map[string]struct {
ack *sensor.MsgToCompliance_ComplianceACK
expectedInventoryACKs int
expectedInventoryNACKs int
expectedIndexACKs int
expectedIndexNACKs int
}{
"should handle NODE_INVENTORY ACK": {
ack: &sensor.MsgToCompliance_ComplianceACK{
Action: sensor.MsgToCompliance_ComplianceACK_ACK,
MessageType: sensor.MsgToCompliance_ComplianceACK_NODE_INVENTORY,
ResourceId: "node-1",
},
expectedInventoryACKs: 1,
},
"should handle NODE_INVENTORY NACK": {
ack: &sensor.MsgToCompliance_ComplianceACK{
Action: sensor.MsgToCompliance_ComplianceACK_NACK,
MessageType: sensor.MsgToCompliance_ComplianceACK_NODE_INVENTORY,
ResourceId: "node-1",
Reason: "rate limit exceeded",
},
expectedInventoryNACKs: 1,
},
"should handle NODE_INDEX_REPORT ACK": {
ack: &sensor.MsgToCompliance_ComplianceACK{
Action: sensor.MsgToCompliance_ComplianceACK_ACK,
MessageType: sensor.MsgToCompliance_ComplianceACK_NODE_INDEX_REPORT,
ResourceId: "node-2",
},
expectedIndexACKs: 1,
},
"should handle NODE_INDEX_REPORT NACK": {
ack: &sensor.MsgToCompliance_ComplianceACK{
Action: sensor.MsgToCompliance_ComplianceACK_NACK,
MessageType: sensor.MsgToCompliance_ComplianceACK_NODE_INDEX_REPORT,
ResourceId: "node-2",
Reason: "central unreachable",
},
expectedIndexNACKs: 1,
},
}

for name, tc := range cases {
s.Run(name, func() {
mockInventory := &fakeUMH{retryC: make(chan string)}
mockIndex := &fakeUMH{retryC: make(chan string)}

c := &Compliance{
umhNodeInventory: mockInventory,
umhNodeIndex: mockIndex,
}

c.handleComplianceACK(tc.ack)

s.Equal(tc.expectedInventoryACKs, mockInventory.ackCount, "inventory ACK count")
s.Equal(tc.expectedInventoryNACKs, mockInventory.nackCount, "inventory NACK count")
s.Equal(tc.expectedIndexACKs, mockIndex.ackCount, "index ACK count")
s.Equal(tc.expectedIndexNACKs, mockIndex.nackCount, "index NACK count")
})
}
}

func (s *ComplianceTestSuite) TestHandleComplianceACK_NilACK() {
mockInventory := &fakeUMH{retryC: make(chan string)}
mockIndex := &fakeUMH{retryC: make(chan string)}

c := &Compliance{
umhNodeInventory: mockInventory,
umhNodeIndex: mockIndex,
}

// Should not panic and should not call any handlers
c.handleComplianceACK(nil)

s.Equal(0, mockInventory.ackCount)
s.Equal(0, mockInventory.nackCount)
s.Equal(0, mockIndex.ackCount)
s.Equal(0, mockIndex.nackCount)
}
14 changes: 9 additions & 5 deletions compliance/node/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/stackrox/rox/compliance/utils"
v4 "github.com/stackrox/rox/generated/internalapi/scanner/v4"
"github.com/stackrox/rox/generated/internalapi/sensor"
"github.com/stackrox/rox/pkg/concurrency"
)

// NodeNameProvider provides node name
Expand All @@ -29,10 +30,13 @@ type NodeIndexer interface {
GetIntervals() *utils.NodeScanIntervals
}

// UnconfirmedMessageHandler handles the observation of sending, and ACK/NACK messages
// UnconfirmedMessageHandler handles the observation of sending, and ACK/NACK messages.
// Each resource (identified by resourceID) has independent retry state.
type UnconfirmedMessageHandler interface {
HandleACK()
HandleNACK()
ObserveSending()
RetryCommand() <-chan struct{}
HandleACK(resourceID string)
HandleNACK(resourceID string)
ObserveSending(resourceID string)
RetryCommand() <-chan string // Returns resourceID to retry
OnACK(callback func(resourceID string))
Stopped() concurrency.ReadOnlyErrorSignal
}
Loading
Loading