Skip to content

feat(smb): add volume isolation and stage/unstage support to SMB CSI … #943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/csi-common/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (d *CSIDriver) AddControllerServiceCapabilities(cl []csi.ControllerServiceC
csc = append(csc, NewControllerServiceCapability(c))
}

d.Cap = csc
d.Cap = append(d.Cap, csc...)
}

func (d *CSIDriver) AddNodeServiceCapabilities(nl []csi.NodeServiceCapability_RPC_Type) {
Expand All @@ -103,7 +103,7 @@ func (d *CSIDriver) AddNodeServiceCapabilities(nl []csi.NodeServiceCapability_RP
klog.V(2).Infof("Enabling node service capability: %v", n.String())
nsc = append(nsc, NewNodeServiceCapability(n))
}
d.NSCap = nsc
d.NSCap = append(d.NSCap, nsc...)
}

func (d *CSIDriver) AddVolumeCapabilityAccessModes(vc []csi.VolumeCapability_AccessMode_Mode) []*csi.VolumeCapability_AccessMode {
Expand Down
13 changes: 13 additions & 0 deletions pkg/smb/controllerserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package smb

import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"io/fs"
"os"
Expand Down Expand Up @@ -85,6 +87,17 @@ func (d *Driver) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest)
}

secrets := req.GetSecrets()
username := strings.TrimSpace(secrets["username"])
password := strings.TrimSpace(secrets["password"])
if username != "" || password != "" {
hashKey := fmt.Sprintf("%s|%s", username, password)
hash := sha256.Sum256([]byte(hashKey))
hashStr := hex.EncodeToString(hash[:8])
smbVol.id = fmt.Sprintf("%s#cred=%s", getVolumeIDFromSmbVol(smbVol), hashStr)
} else {
smbVol.id = getVolumeIDFromSmbVol(smbVol)
}

createSubDir := len(secrets) > 0
if len(smbVol.uuid) > 0 {
klog.V(2).Infof("create subdirectory(%s) if not exists", smbVol.subDir)
Expand Down
7 changes: 6 additions & 1 deletion pkg/smb/controllerserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"path/filepath"
"reflect"
"runtime"
"strings"
"testing"

"github.com/container-storage-interface/spec/lib/go/csi"
Expand Down Expand Up @@ -203,7 +204,11 @@ func TestCreateVolume(t *testing.T) {
if !test.expectErr && err != nil {
t.Errorf("test %q failed: %v", test.name, err)
}
if !reflect.DeepEqual(resp, test.resp) {
if !test.expectErr && test.name == "valid defaults" {
if resp.Volume == nil || !strings.HasPrefix(resp.Volume.VolumeId, "test-server/baseDir#test-csi###cred=") {
t.Errorf("test %q failed: got volume ID %q, expected it to start with prefix %q", test.name, resp.Volume.VolumeId, "test-server/baseDir#test-csi###cred=")
}
} else if !reflect.DeepEqual(resp, test.resp) {
t.Errorf("test %q failed: got resp %+v, expected %+v", test.name, resp, test.resp)
}
if !test.expectErr {
Expand Down
107 changes: 72 additions & 35 deletions pkg/smb/nodeserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
azcache "sigs.k8s.io/cloud-provider-azure/pkg/cache"
)

// NodePublishVolume mount the volume from staging to target path
func (d *Driver) NodePublishVolume(ctx context.Context, req *csi.NodePublishVolumeRequest) (*csi.NodePublishVolumeResponse, error) {
volCap := req.GetVolumeCapability()
if volCap == nil {
Expand All @@ -51,27 +50,31 @@
return nil, status.Error(codes.InvalidArgument, "Volume ID missing in request")
}

target := req.GetTargetPath()
if len(target) == 0 {
// Strip cred hash suffix if present
cleanID := strings.SplitN(volumeID, "#cred=", 2)[0]

targetPath := req.GetTargetPath()
if len(targetPath) == 0 {
return nil, status.Error(codes.InvalidArgument, "Target path not provided")
}

context := req.GetVolumeContext()
if context != nil && strings.EqualFold(context[ephemeralField], trueValue) {
// ephemeral volume
util.SetKeyValueInMap(context, secretNamespaceField, context[podNamespaceField])
klog.V(2).Infof("NodePublishVolume: ephemeral volume(%s) mount on %s", volumeID, target)
klog.V(2).Infof("NodePublishVolume: ephemeral volume(%s) mount on %s", volumeID, targetPath)
_, err := d.NodeStageVolume(ctx, &csi.NodeStageVolumeRequest{
StagingTargetPath: target,
StagingTargetPath: targetPath,
VolumeContext: context,
VolumeCapability: volCap,
VolumeId: volumeID,
VolumeId: cleanID,
})
return &csi.NodePublishVolumeResponse{}, err
}

source := req.GetStagingTargetPath()
if len(source) == 0 {
// Get staging path
stagingPath := req.GetStagingTargetPath()
if len(stagingPath) == 0 {
return nil, status.Error(codes.InvalidArgument, "Staging target not provided")
}

Expand All @@ -80,31 +83,31 @@
mountOptions = append(mountOptions, "ro")
}

mnt, err := d.ensureMountPoint(target)
mnt, err := d.ensureMountPoint(targetPath)
if err != nil {
return nil, status.Errorf(codes.Internal, "Could not mount target %q: %v", target, err)
return nil, status.Errorf(codes.Internal, "Could not mount target %q: %v", targetPath, err)
}
if mnt {
klog.V(2).Infof("NodePublishVolume: %s is already mounted", target)
klog.V(2).Infof("NodePublishVolume: %s is already mounted", targetPath)
return &csi.NodePublishVolumeResponse{}, nil
}

if err = preparePublishPath(target, d.mounter); err != nil {
return nil, fmt.Errorf("prepare publish failed for %s with error: %v", target, err)
if err = preparePublishPath(targetPath, d.mounter); err != nil {
return nil, fmt.Errorf("prepare publish failed for %s with error: %v", targetPath, err)
}

klog.V(2).Infof("NodePublishVolume: mounting %s at %s with mountOptions: %v volumeID(%s)", source, target, mountOptions, volumeID)
if err := d.mounter.Mount(source, target, "", mountOptions); err != nil {
if removeErr := os.Remove(target); removeErr != nil {
return nil, status.Errorf(codes.Internal, "Could not remove mount target %q: %v", target, removeErr)
klog.V(2).Infof("NodePublishVolume: bind mounting %s to %s with options: %v", stagingPath, targetPath, mountOptions)
if err := d.mounter.Mount(stagingPath, targetPath, "", mountOptions); err != nil {
if removeErr := os.Remove(targetPath); removeErr != nil {
return nil, status.Errorf(codes.Internal, "Could not remove mount target %q: %v", targetPath, removeErr)
}
return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", source, target, err)
return nil, status.Errorf(codes.Internal, "Could not mount %q at %q: %v", stagingPath, targetPath, err)
}
klog.V(2).Infof("NodePublishVolume: mount %s at %s volumeID(%s) successfully", source, target, volumeID)

klog.V(2).Infof("NodePublishVolume: mount %s at %s volumeID(%s) successfully", stagingPath, targetPath, volumeID)
return &csi.NodePublishVolumeResponse{}, nil
}

// NodeUnpublishVolume unmount the volume from the target path
func (d *Driver) NodeUnpublishVolume(_ context.Context, req *csi.NodeUnpublishVolumeRequest) (*csi.NodeUnpublishVolumeResponse, error) {
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
Expand All @@ -115,12 +118,28 @@
return nil, status.Error(codes.InvalidArgument, "Target path missing in request")
}

klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s on %s", volumeID, targetPath)
err := CleanupMountPoint(d.mounter, targetPath, true /*extensiveMountPointCheck*/)
if err != nil {
klog.V(2).Infof("NodeUnpublishVolume: unmounting volume %s from %s", volumeID, targetPath)

notMnt, err := d.mounter.IsLikelyNotMountPoint(targetPath)
if err != nil && !os.IsNotExist(err) {
return nil, status.Errorf(codes.Internal, "failed to check mount point %q: %v", targetPath, err)
}
if notMnt {
klog.V(2).Infof("NodeUnpublishVolume: target %s is already unmounted", targetPath)
if err := os.Remove(targetPath); err != nil && !os.IsNotExist(err) {
return nil, status.Errorf(codes.Internal, "failed to remove target path %q: %v", targetPath, err)
}
return &csi.NodeUnpublishVolumeResponse{}, nil
}

if err := d.mounter.Unmount(targetPath); err != nil {
return nil, status.Errorf(codes.Internal, "failed to unmount target %q: %v", targetPath, err)
}
klog.V(2).Infof("NodeUnpublishVolume: unmount volume %s on %s successfully", volumeID, targetPath)
if err := os.Remove(targetPath); err != nil && !os.IsNotExist(err) {
return nil, status.Errorf(codes.Internal, "failed to remove target path %q after unmount: %v", targetPath, err)
}

klog.V(2).Infof("NodeUnpublishVolume: successfully unmounted and removed %s for volume %s", targetPath, volumeID)
return &csi.NodeUnpublishVolumeResponse{}, nil
}

Expand All @@ -142,8 +161,8 @@
}

context := req.GetVolumeContext()
mountFlags := req.GetVolumeCapability().GetMount().GetMountFlags()
volumeMountGroup := req.GetVolumeCapability().GetMount().GetVolumeMountGroup()
mountFlags := volumeCapability.GetMount().GetMountFlags()
volumeMountGroup := volumeCapability.GetMount().GetVolumeMountGroup()
secrets := req.GetSecrets()
gidPresent := checkGidPresentInMountFlags(mountFlags)

Expand Down Expand Up @@ -199,7 +218,6 @@
mountFlags = strings.Split(ephemeralVolMountOptions, ",")
}

// in guest login, username and password options are not needed
requireUsernamePwdOption := !hasGuestMountOptions(mountFlags)
if ephemeralVol && requireUsernamePwdOption {
klog.V(2).Infof("NodeStageVolume: getting username and password from secret %s in namespace %s", secretName, secretNamespace)
Expand Down Expand Up @@ -264,7 +282,6 @@
if subDir != "" {
// replace pv/pvc name namespace metadata in subDir
subDir = replaceWithMap(subDir, subDirReplaceMap)

source = strings.TrimRight(source, "/")
source = fmt.Sprintf("%s/%s", source, subDir)
}
Expand All @@ -281,7 +298,7 @@
return &csi.NodeStageVolumeResponse{}, nil
}

// NodeUnstageVolume unmount the volume from the staging path
// NodeUnstageVolume unmounts the volume from the staging path
func (d *Driver) NodeUnstageVolume(_ context.Context, req *csi.NodeUnstageVolumeRequest) (*csi.NodeUnstageVolumeResponse, error) {
volumeID := req.GetVolumeId()
if len(volumeID) == 0 {
Expand All @@ -298,16 +315,36 @@
}
defer d.volumeLocks.Release(lockKey)

klog.V(2).Infof("NodeUnstageVolume: CleanupMountPoint on %s with volume %s", stagingTargetPath, volumeID)
if err := CleanupSMBMountPoint(d.mounter, stagingTargetPath, true /*extensiveMountPointCheck*/, volumeID); err != nil {
return nil, status.Errorf(codes.Internal, "failed to unmount staging target %q: %v", stagingTargetPath, err)
inUse, err := HasMountReferences(stagingTargetPath)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to check mount references: %v", err)
}
if inUse {
klog.V(2).Infof("NodeUnstageVolume: staging path %s is still in use by other mounts", stagingTargetPath)
return &csi.NodeUnstageVolumeResponse{}, nil
}

if err := deleteKerberosCache(d.krb5CacheDirectory, volumeID); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete kerberos cache: %v", err)
notMnt, err := d.mounter.IsLikelyNotMountPoint(stagingTargetPath)
if err != nil && !os.IsNotExist(err) {
return nil, status.Errorf(codes.Internal, "failed to check mount point %q: %v", stagingTargetPath, err)
}
if notMnt {
klog.V(2).Infof("NodeUnstageVolume: staging path %s is already unmounted", stagingTargetPath)
if err := os.Remove(stagingTargetPath); err != nil && !os.IsNotExist(err) {
return nil, status.Errorf(codes.Internal, "failed to remove staging path %q: %v", stagingTargetPath, err)
}
return &csi.NodeUnstageVolumeResponse{}, nil
}

klog.V(2).Infof("NodeUnstageVolume: unmounting %s for volume %s", stagingTargetPath, volumeID)
if err := d.mounter.Unmount(stagingTargetPath); err != nil {
return nil, status.Errorf(codes.Internal, "failed to unmount staging path %q: %v", stagingTargetPath, err)
}
if err := os.Remove(stagingTargetPath); err != nil && !os.IsNotExist(err) {
return nil, status.Errorf(codes.Internal, "failed to remove staging path %q after unmount: %v", stagingTargetPath, err)
}

klog.V(2).Infof("NodeUnstageVolume: unmount volume %s on %s successfully", volumeID, stagingTargetPath)
klog.V(2).Infof("NodeUnstageVolume: successfully unmounted and cleaned up %s for volume %s", stagingTargetPath, volumeID)
return &csi.NodeUnstageVolumeResponse{}, nil
}

Expand Down Expand Up @@ -600,7 +637,7 @@
return false, nil
}

func deleteKerberosCache(krb5CacheDirectory, volumeID string) error {

Check failure on line 640 in pkg/smb/nodeserver.go

View workflow job for this annotation

GitHub Actions / Go Lint

func `deleteKerberosCache` is unused (unused)
exists, err := kerberosCacheDirectoryExists(krb5CacheDirectory)
// If not supported, simply return
if !exists {
Expand Down
5 changes: 5 additions & 0 deletions pkg/smb/smb_common_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ func prepareStagePath(path string, m *mount.SafeFormatAndMount) error {
func Mkdir(m *mount.SafeFormatAndMount, name string, perm os.FileMode) error {
return os.Mkdir(name, perm)
}

func HasMountReferences(stagingTargetPath string) (bool, error) {
// Stubbed for Windows/macOS — cannot inspect bind mounts
return false, nil
}
23 changes: 23 additions & 0 deletions pkg/smb/smb_common_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ limitations under the License.
package smb

import (
"bufio"
"fmt"
"os"
"strings"

mount "k8s.io/mount-utils"
)
Expand Down Expand Up @@ -48,3 +51,23 @@ func prepareStagePath(_ string, _ *mount.SafeFormatAndMount) error {
func Mkdir(_ *mount.SafeFormatAndMount, name string, perm os.FileMode) error {
return os.Mkdir(name, perm)
}

func HasMountReferences(stagingTargetPath string) (bool, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add explanation for this func? I still doubt how did you solve the problem by only searching stagingTargetPath since in my testing environment (with 2 replicas referencing to one PVC), the bind mount would be like following, you need to count the target mount share during NodeUnstageVolume to make sure there are no other mounts referencing the smb share.

# cat /proc/mounts | grep smb
//smb-server.default.svc.cluster.local/share/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9 /var/lib/kubelet/plugins/kubernetes.io/csi/smb.csi.k8s.io/085ffd65e2835034cdf2a23f67a498673427c7497def5d808bfb505c0df0b1a4/globalmount cifs rw,relatime,vers=3.1.1
//smb-server.default.svc.cluster.local/share/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9 /var/lib/kubelet/pods/4aad1717-a5cb-4f93-a28b-2a8a8c36d1bd/volumes/kubernetes.io~csi/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9/mount cifs rw,relatime,vers=3.1.1
//smb-server.default.svc.cluster.local/share/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9 /var/lib/kubelet/pods/2164da99-d68f-4886-9bcb-9bb9a42c844f/volumes/kubernetes.io~csi/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9/mount cifs rw,relatime,vers=3.1.1

Copy link
Author

@MattPOlson MattPOlson Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add explanation for this func? I still doubt how did you solve the problem by only searching stagingTargetPath since in my testing environment (with 2 replicas referencing to one PVC), the bind mount would be like following, you need to count the target mount share during NodeUnstageVolume to make sure there are no other mounts referencing the smb share.

# cat /proc/mounts | grep smb
//smb-server.default.svc.cluster.local/share/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9 /var/lib/kubelet/plugins/kubernetes.io/csi/smb.csi.k8s.io/085ffd65e2835034cdf2a23f67a498673427c7497def5d808bfb505c0df0b1a4/globalmount cifs rw,relatime,vers=3.1.1
//smb-server.default.svc.cluster.local/share/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9 /var/lib/kubelet/pods/4aad1717-a5cb-4f93-a28b-2a8a8c36d1bd/volumes/kubernetes.io~csi/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9/mount cifs rw,relatime,vers=3.1.1
//smb-server.default.svc.cluster.local/share/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9 /var/lib/kubelet/pods/2164da99-d68f-4886-9bcb-9bb9a42c844f/volumes/kubernetes.io~csi/pvc-00a18b4a-977d-4ce8-a911-fc661a7332f9/mount cifs rw,relatime,vers=3.1.1

@andyzhangx
You’re right that our initial implementation of HasMountReferences() was scanning /proc/mounts for paths prefixed by stagingTargetPath, which doesn't fully reflect how bind mounts work in real-world deployments. As you pointed out, kubelet performs bind mounts from the global path into pod-specific paths like /var/lib/kubelet/pods/.../volumes/.../mount, and these paths are not subdirectories of the staging path.

The correct approach here is to:

Parse /proc/mounts

Count the number of entries where the source (the SMB share URI) matches the mounted share

Only proceed with unmounting if stagingTargetPath is the last remaining mount target

I’ll update the implementation of HasMountReferences() to reflect this, by comparing entries that have the same mount source as the staging path. This aligns with how GetDeviceMountRefs() works internally in kubelet, and ensures that the global mount is only unmounted once all bind mounts are gone.

Thanks for flagging this — I’ll update the PR accordingly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MattPOlson thanks for the work, pls also provide the /proc/mounts examples when there are multiple PVCs using the same file share in the PR description, I think we only need this fix right now, thanks.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andyzhangx
Thanks for the continued feedback and review.

After extensive testing and deeper inspection of /proc/mounts and /proc/self/mountinfo, I’ve confirmed that there is no reliable or portable way to determine which global mount path a pod bind mount is referencing when multiple global mounts exist for the same SMB share.

The Problem
When two PVCs reference the same SMB share (e.g., //smb-server/share) and get different volume handles, the driver ends up creating multiple global mounts like:

/var/lib/kubelet/plugins/kubernetes.io/csi/smb.csi.k8s.io/<guid1>/globalmount
/var/lib/kubelet/plugins/kubernetes.io/csi/smb.csi.k8s.io/<guid2>/globalmount

Then, all pod mounts — regardless of which global mount they bind from — show up in /proc/mounts with the same source:

source: //smb-server/share
device: 0:334

So:

We can detect that the share is still in use.

But we cannot detect which global mount path is actually being used.

There’s no way to know which global mount is safe to unmount.

This leads to orphaned global mounts, errors in GetDeviceMountRefs, and premature or blocked unmounts.

Why Normalizing the Global Mount is the Fix
If the driver normalizes volume handles (e.g., by hashing the share + credentials), then:

Kubernetes will use the same volumeHandle for identical mounts

Only one global mount path is created per node per share

All bind mounts share that one staging path

Cleanup is simple: once all bind mounts are gone, the staging path is unmounted

This aligns with the behavior of block devices and is consistent with Kubernetes’ expectations for how NodeStageVolume and NodePublishVolume interact.

Summary
Without volumeHandle normalization, there's no reliable mechanism to correlate bind mounts back to the correct staging path — especially with CIFS, where device IDs and mount sources are always the same.

That’s why I strongly recommend restoring the normalization logic. It’s the only way to ensure safe, deterministic cleanup of global mounts in multi-PVC, multi-pod SMB scenarios.

f, err := os.Open("/proc/mounts")
if err != nil {
return false, fmt.Errorf("failed to open /proc/mounts: %v", err)
}
defer f.Close()

scanner := bufio.NewScanner(f)
for scanner.Scan() {
fields := strings.Fields(scanner.Text())
if len(fields) >= 2 {
mountPoint := fields[1]
if strings.HasPrefix(mountPoint, stagingTargetPath) && mountPoint != stagingTargetPath {
return true, nil
}
}
}
return false, nil
}
5 changes: 5 additions & 0 deletions pkg/smb/smb_common_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,8 @@ func Mkdir(m *mount.SafeFormatAndMount, name string, perm os.FileMode) error {
}
return fmt.Errorf("could not cast to csi proxy class")
}

func HasMountReferences(stagingTargetPath string) (bool, error) {
// Stubbed for Windows/macOS — cannot inspect bind mounts
return false, nil
}
Loading