diff --git a/pkg/csi/blockstorage/controllerserver.go b/pkg/csi/blockstorage/controllerserver.go index 8de6237e..ef962edd 100644 --- a/pkg/csi/blockstorage/controllerserver.go +++ b/pkg/csi/blockstorage/controllerserver.go @@ -370,6 +370,10 @@ func (cs *controllerServer) ControllerPublishVolume(ctx context.Context, req *cs _, err = cloud.AttachVolume(ctx, instanceID, volumeID) if err != nil { + // Trigger's an immediate `NodeGetInfo` RPC call when MutableCSINodeAllocatableCount is enabled + if stackiterrors.IsTooManyDevicesError(err) { + return nil, status.Errorf(codes.ResourceExhausted, "[ControllerPublishVolume] Node can't accept any more volumes %v. All PCIe lanes are exhausted!", err) + } klog.Errorf("Failed to AttachVolume: %v", err) return nil, status.Errorf(codes.Internal, "[ControllerPublishVolume] Attach Volume failed with error %v", err) } diff --git a/pkg/csi/blockstorage/nodeserver.go b/pkg/csi/blockstorage/nodeserver.go index 648e5df3..d7c00a63 100644 --- a/pkg/csi/blockstorage/nodeserver.go +++ b/pkg/csi/blockstorage/nodeserver.go @@ -302,19 +302,9 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve instance id of node %v", err) } - flavor, err := ns.Metadata.GetFlavor(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "[NodeGetInfo] unable to retrieve flavor of node %v", err) - } - - maxVolumesPerNode := DetermineMaxVolumesByFlavor(flavor) - // Subtract 1 for root disk and another for configDrive/spare - maxVolumesPerNode -= 2 - klog.V(4).Infof("Determined node to support %d volumes", maxVolumesPerNode) - nodeInfo := &csi.NodeGetInfoResponse{ NodeId: nodeID, - MaxVolumesPerNode: maxVolumesPerNode, + MaxVolumesPerNode: ns.calculateMaxVolumesPerNode(), } zone, err := ns.Metadata.GetAvailabilityZone(ctx) @@ -332,6 +322,22 @@ func (ns *nodeServer) NodeGetInfo(ctx context.Context, _ *csi.NodeGetInfoRequest return nodeInfo, nil } +func (ns *nodeServer) calculateMaxVolumesPerNode() int64 { + freePCIeRootPorts, err := mount.CountFreePCIeSlots() + if err != nil { + klog.Errorf("[NodeGetInfo] unable to retrieve PCIe root ports: %v", err) + freePCIeRootPorts = 0 + } + + mountedCSIVolumes, err := mount.CountLocalCSIVolumes(driverName) + if err != nil { + klog.Errorf("[NodeGetInfo] unable to retrieve volume count: %v", err) + mountedCSIVolumes = 0 + } + + return freePCIeRootPorts + mountedCSIVolumes +} + func (ns *nodeServer) NodeGetCapabilities(_ context.Context, req *csi.NodeGetCapabilitiesRequest) (*csi.NodeGetCapabilitiesResponse, error) { klog.V(5).Infof("NodeGetCapabilities called with req: %#v", req) diff --git a/pkg/csi/blockstorage/utils.go b/pkg/csi/blockstorage/utils.go index aaafc864..8fc02143 100644 --- a/pkg/csi/blockstorage/utils.go +++ b/pkg/csi/blockstorage/utils.go @@ -84,23 +84,6 @@ func ParseEndpoint(ep string) (proto, addr string, err error) { return "", "", fmt.Errorf("invalid endpoint: %v", ep) } -func DetermineMaxVolumesByFlavor(flavor string) int64 { - flavorParts := strings.Split(flavor, ".") - - // The following numbers were specified by the IaaS team. They are based on actual tests. - switch { - case strings.HasPrefix(flavor, "n"): - // Flavors starting with 'n' are nvidia GPU flavors, all GPU VM's can only mount 10 volumes - return 10 - case strings.HasSuffix(flavorParts[0], "2a"): - // AMD 2nd Gen - return 159 - default: - // All other flavors can mount 28 volumes - return 25 - } -} - func logGRPC(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { callID := serverGRPCEndpointCallCounter.Add(1) diff --git a/pkg/csi/blockstorage/utils_test.go b/pkg/csi/blockstorage/utils_test.go deleted file mode 100644 index f9261de4..00000000 --- a/pkg/csi/blockstorage/utils_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package blockstorage - -import ( - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" -) - -var _ = Describe("Util Test", func() { - - Context("DetermineMaxVolumesByFlavor", func() { - DescribeTable("should return the correct maximum volume count for different flavors", func(flavor string, expectedMaxVolumes int) { - maxVolumes := DetermineMaxVolumesByFlavor(flavor) - Expect(maxVolumes).To(Equal(int64(expectedMaxVolumes))) - }, - Entry("Intel 3rd Gen", "c3i.2", 25), - Entry("Intel 2rd Gen", "c2i.2", 25), - Entry("Intel 1st Gen", "c1.2", 25), - Entry("AMD 1st Gen without overprovisioning", "s1a.8d", 25), - Entry("AMD 2nd Gen without overprovisioning", "s2a.8d", 159), - Entry("Nvidia GPU", "n2.14d.g1", 10), - Entry("Nvidia GPU", "n2.56d.g4", 10), - Entry("ARM Gen1Link without CPU-overprovisioning ARM Gen1", "g1r.4d", 25), - ) - }) -}) diff --git a/pkg/csi/util/mount/mount_darwin.go b/pkg/csi/util/mount/mount_darwin.go index 122f4c1c..07dcba11 100644 --- a/pkg/csi/util/mount/mount_darwin.go +++ b/pkg/csi/util/mount/mount_darwin.go @@ -17,3 +17,12 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { UsedInodes: int64(statfs.Files) - int64(statfs.Ffree), } } + +func CountLocalCSIVolumes(_ string) (int64, error) { + // not implemented + return 0, nil +} + +func CountFreePCIeSlots() (int64, error) { + return 0, nil +} diff --git a/pkg/csi/util/mount/mount_helper.go b/pkg/csi/util/mount/mount_helper.go new file mode 100644 index 00000000..9c8264c8 --- /dev/null +++ b/pkg/csi/util/mount/mount_helper.go @@ -0,0 +1,60 @@ +package mount + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "k8s.io/klog/v2" +) + +const ( + // pciClassBridgePCI matches the Linux PCI-to-PCI bridge class prefix. + pciClassBridgePCI = "0x0604" + globalMountDir = "globalmount" +) + +func countFreePCIeSlotsAt(devicesPath string) (int64, error) { + devices, err := os.ReadDir(devicesPath) + if err != nil { + return 0, fmt.Errorf("failed to read PCI bus: %w", err) + } + + var freePCIeSlots int64 + + for _, dev := range devices { + devPath := filepath.Join(devicesPath, dev.Name()) + + classBuf, err := os.ReadFile(filepath.Join(devPath, "class")) + if err != nil { + klog.Errorf("failed to read PCI device class %s: %v", devPath, err) + continue + } + + class := strings.TrimSpace(string(classBuf)) + if !strings.HasPrefix(class, pciClassBridgePCI) { + continue + } + + children, err := filepath.Glob(filepath.Join(devPath, "????:??:??.?")) + if err != nil { + return 0, fmt.Errorf("failed to glob PCI children for %s: %w", devPath, err) + } + + if len(children) == 0 { + freePCIeSlots++ + } + } + + return freePCIeSlots, nil +} + +func countLocalCSIVolumesAt(driverPluginDir string) (int64, error) { + volumeMounts, err := filepath.Glob(filepath.Join(driverPluginDir, "*", globalMountDir)) + if err != nil { + return 0, fmt.Errorf("failed to glob CSI volume mounts in %s: %w", driverPluginDir, err) + } + + return int64(len(volumeMounts)), nil +} diff --git a/pkg/csi/util/mount/mount_helper_test.go b/pkg/csi/util/mount/mount_helper_test.go new file mode 100644 index 00000000..519b440d --- /dev/null +++ b/pkg/csi/util/mount/mount_helper_test.go @@ -0,0 +1,162 @@ +package mount + +import ( + "os" + "path/filepath" + "testing" +) + +func TestCountFreePCIeSlotsAtMissingRoot(t *testing.T) { + t.Parallel() + + _, err := countFreePCIeSlotsAt(filepath.Join(t.TempDir(), "missing")) + if err == nil { + t.Fatal("countFreePCIeSlotsAt() error = nil, want error") + } +} + +func TestCountFreePCIeSlotsAtCountsOnlyFreeBridgeSlots(t *testing.T) { + t.Parallel() + + devicesPath := t.TempDir() + + createPCIDevice(t, devicesPath, "0000:00:00.0", "0x060400") + createPCIDevice(t, devicesPath, "0000:00:01.0", "0x060400", "0000:01:00.0") + createPCIDevice(t, devicesPath, "0000:00:02.0", "0x010000", "0000:02:00.0") + + count, err := countFreePCIeSlotsAt(devicesPath) + if err != nil { + t.Fatalf("countFreePCIeSlotsAt() error = %v", err) + } + + if count != 1 { + t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) + } +} + +func TestCountFreePCIeSlotsAtSkipsDevicesWithoutClass(t *testing.T) { + t.Parallel() + + devicesPath := t.TempDir() + + createPCIDevice(t, devicesPath, "0000:00:00.0", "0x060400") + mustMkdirAll(t, filepath.Join(devicesPath, "0000:00:01.0")) + + count, err := countFreePCIeSlotsAt(devicesPath) + if err != nil { + t.Fatalf("countFreePCIeSlotsAt() error = %v", err) + } + + if count != 1 { + t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) + } +} + +func TestCountFreePCIeSlotsAtIgnoresNonPCIChildren(t *testing.T) { + t.Parallel() + + devicesPath := t.TempDir() + devPath := filepath.Join(devicesPath, "0000:00:00.0") + mustMkdirAll(t, devPath) + mustWriteFile(t, filepath.Join(devPath, "class"), "0x060400") + mustMkdirAll(t, filepath.Join(devPath, "driver")) + mustMkdirAll(t, filepath.Join(devPath, "not-a-pci-child")) + + count, err := countFreePCIeSlotsAt(devicesPath) + if err != nil { + t.Fatalf("countFreePCIeSlotsAt() error = %v", err) + } + + if count != 1 { + t.Fatalf("countFreePCIeSlotsAt() = %d, want 1", count) + } +} + +func TestCountLocalCSIVolumesAtMissingDir(t *testing.T) { + t.Parallel() + + count, err := countLocalCSIVolumesAt(filepath.Join(t.TempDir(), "missing")) + if err != nil { + t.Fatalf("countLocalCSIVolumesAt() error = %v", err) + } + + if count != 0 { + t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) + } +} + +func TestCountLocalCSIVolumesAtCountsOnlyGlobalMountDirs(t *testing.T) { + t.Parallel() + + driverPluginDir := t.TempDir() + + mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-a", globalMountDir)) + mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-b", globalMountDir)) + mustMkdirAll(t, filepath.Join(driverPluginDir, "volume-c", "not-a-globalmount")) + + count, err := countLocalCSIVolumesAt(driverPluginDir) + if err != nil { + t.Fatalf("countLocalCSIVolumesAt() error = %v", err) + } + + if count != 2 { + t.Fatalf("countLocalCSIVolumesAt() = %d, want 2", count) + } +} + +func TestCountLocalCSIVolumesAtEmptyDir(t *testing.T) { + t.Parallel() + + count, err := countLocalCSIVolumesAt(t.TempDir()) + if err != nil { + t.Fatalf("countLocalCSIVolumesAt() error = %v", err) + } + + if count != 0 { + t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) + } +} + +func TestCountLocalCSIVolumesAtReturnsZeroWhenDriverPathIsFile(t *testing.T) { + t.Parallel() + + driverPluginDir := filepath.Join(t.TempDir(), "driver") + mustWriteFile(t, driverPluginDir, "not a directory") + + count, err := countLocalCSIVolumesAt(driverPluginDir) + if err != nil { + t.Fatalf("countLocalCSIVolumesAt() error = %v", err) + } + + if count != 0 { + t.Fatalf("countLocalCSIVolumesAt() = %d, want 0", count) + } +} + +func createPCIDevice(t *testing.T, rootPath, deviceName, class string, children ...string) { + t.Helper() + + devPath := filepath.Join(rootPath, deviceName) + mustMkdirAll(t, devPath) + mustWriteFile(t, filepath.Join(devPath, "class"), class) + + for _, child := range children { + mustMkdirAll(t, filepath.Join(devPath, child)) + } +} + +func mustMkdirAll(t *testing.T, path string) { + t.Helper() + + if err := os.MkdirAll(path, 0o755); err != nil { + t.Fatalf("MkdirAll(%q) error = %v", path, err) + } +} + +func mustWriteFile(t *testing.T, path string, content string) { + t.Helper() + + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatalf("WriteFile(%q) error = %v", path, err) + } +} diff --git a/pkg/csi/util/mount/mount_linux.go b/pkg/csi/util/mount/mount_linux.go index b525b753..4bdfe55b 100644 --- a/pkg/csi/util/mount/mount_linux.go +++ b/pkg/csi/util/mount/mount_linux.go @@ -2,7 +2,16 @@ package mount -import "golang.org/x/sys/unix" +import ( + "path/filepath" + + "golang.org/x/sys/unix" +) + +const ( + pciDevicesPath = "/sys/bus/pci/devices" + kubeletDir = "/var/lib/kubelet" +) func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { return &DeviceStats{ @@ -17,3 +26,14 @@ func newDeviceStats(statfs *unix.Statfs_t) *DeviceStats { UsedInodes: int64(statfs.Files) - int64(statfs.Ffree), } } + +// CountFreePCIeSlots returns the number of PCIe root ports that are not occupied. +func CountFreePCIeSlots() (int64, error) { + return countFreePCIeSlotsAt(pciDevicesPath) +} + +// CountLocalCSIVolumes counts staged CSI volumes for the given driver. +func CountLocalCSIVolumes(driverName string) (int64, error) { + driverPluginDir := filepath.Join(kubeletDir, "plugins", "kubernetes.io", "csi", driverName) + return countLocalCSIVolumesAt(driverPluginDir) +} diff --git a/pkg/stackit/stackiterrors/errors.go b/pkg/stackit/stackiterrors/errors.go index ae19b7d7..1b1f127a 100644 --- a/pkg/stackit/stackiterrors/errors.go +++ b/pkg/stackit/stackiterrors/errors.go @@ -4,22 +4,35 @@ import ( "errors" "fmt" "net/http" + "strings" oapiError "github.com/stackitcloud/stackit-sdk-go/core/oapierror" - wait "github.com/stackitcloud/stackit-sdk-go/services/iaas/v2api/wait" + "github.com/stackitcloud/stackit-sdk-go/services/iaas/v2api/wait" ) +const tooManyDiskDevicesMessageFragment = "maximum allowed number of disk devices" + var ErrNotFound = errors.New("failed to find object") func IsNotFound(err error) bool { - var oAPIError *oapiError.GenericOpenAPIError - if ok := errors.As(err, &oAPIError); !ok { + oAPIError, ok := genericOpenAPIError(err) + if !ok { return false } return oAPIError.StatusCode == http.StatusNotFound } +func IsTooManyDevicesError(err error) bool { + oAPIError, ok := genericOpenAPIError(err) + if !ok { + return false + } + + return oAPIError.StatusCode == http.StatusForbidden && + strings.Contains(string(oAPIError.Body), tooManyDiskDevicesMessageFragment) +} + func IgnoreNotFound(err error) error { if IsNotFound(err) { return nil @@ -40,10 +53,19 @@ func WrapErrorWithResponseID(err error, reqID string) error { } func IsInvalidError(err error) bool { - var oAPIError *oapiError.GenericOpenAPIError - if ok := errors.As(err, &oAPIError); !ok { + oAPIError, ok := genericOpenAPIError(err) + if !ok { return false } return oAPIError.StatusCode == http.StatusBadRequest } + +func genericOpenAPIError(err error) (*oapiError.GenericOpenAPIError, bool) { + var oAPIError *oapiError.GenericOpenAPIError + if ok := errors.As(err, &oAPIError); !ok { + return nil, false + } + + return oAPIError, true +}