Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,20 @@
import java.lang.reflect.Method;
import java.security.Principal;
import java.util.List;
import java.util.Map;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public class ArtemisMBeanServerGuard implements GuardInvocationHandler {

private static final Logger logger = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private JMXAccessControlList jmxAccessControlList = JMXAccessControlList.createDefaultList();

private final Map<ObjectName, Boolean> bypassRBACCache = new ConcurrentHashMap<>();

public void init() {
ArtemisMBeanServerBuilder.setGuard(this);
}
Expand Down Expand Up @@ -122,18 +129,12 @@ private void handleSetAttributes(MBeanServer proxy, ObjectName objectName, Attri
}

private boolean canBypassRBAC(ObjectName objectName) {
return jmxAccessControlList.isInAllowList(objectName);
return bypassRBACCache.computeIfAbsent(objectName, name -> jmxAccessControlList.isInAllowList(name));
}

@Override
public boolean canInvoke(String object, String operationName) {
ObjectName objectName = null;
try {
objectName = ObjectName.getInstance(object);
} catch (MalformedObjectNameException e) {
logger.debug("can't check invoke rights as object name invalid: {}", object, e);
return false;
}

/*
* HawtIO calls this with a null operationName as a coarse grained way of authenticating against all the
* operations on an mbean. Until this addition this was throwing a null pointer on operationName later in this
Expand All @@ -142,7 +143,19 @@ public boolean canInvoke(String object, String operationName) {
* it. Since it is just an optimisation it is fine to always return true. Note that the alternative
* ArtemisRbacInvocationHandler does allow the ability to restrict a whole mbean.
*/
if (operationName == null || canBypassRBAC(objectName)) {
if (operationName == null) {
return true;
}

ObjectName objectName = null;
try {
objectName = ObjectName.getInstance(object);
} catch (MalformedObjectNameException e) {
logger.debug("can't check invoke rights as object name invalid: {}", object, e);
return false;
}

if (canBypassRBAC(objectName)) {
return true;
}

Expand All @@ -151,15 +164,22 @@ public boolean canInvoke(String object, String operationName) {
if (paramListIndex > 0) {
operationName = operationName.substring(0, paramListIndex);
}
Set<String> currentUserRoles = getCurrentUserRoles();

List<String> requiredRoles = getRequiredRoles(objectName, operationName);
for (String role : requiredRoles) {
if (currentUserHasRole(role)) {
return true;
}
if (currentUserRoles.isEmpty()) {
return false;
}
logger.debug("{} {} false", object, operationName);
return false;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This extra line can be removed.

boolean authorized = authorizeUserForMethod(objectName, operationName, currentUserRoles);

if (authorized) {
logger.debug("{} {} true", object, operationName);
return true;
} else {
logger.debug("{} {} false", object, operationName);
return false;
}

}

void handleInvoke(ObjectName objectName, String operationName) throws IOException {
Expand All @@ -182,6 +202,10 @@ List<String> getRequiredRoles(ObjectName objectName, String methodName) {
return jmxAccessControlList.getRolesForObject(objectName, methodName);
}

boolean authorizeUserForMethod(ObjectName objectName, String operationName, Set<String> currentUserRoles) {
return jmxAccessControlList.authorizeUserForMethod(objectName, operationName, currentUserRoles);
}

public void setJMXAccessControlList(JMXAccessControlList JMXAccessControlList) {
this.jmxAccessControlList = JMXAccessControlList;
}
Expand Down Expand Up @@ -210,4 +234,18 @@ public static boolean currentUserHasRole(String requestedRole) {
}
return false;
}

public static Set<String> getCurrentUserRoles() {
Subject subject = SecurityManagerShim.currentSubject();
if (subject == null) {
return Collections.emptySet();
}

Set<String> roles = new HashSet<>();
for (Principal p : subject.getPrincipals()) {
roles.add(p.getName());
}
return roles;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@

import javax.management.ObjectName;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand All @@ -31,6 +33,36 @@
public class JMXAccessControlList {
private static final String WILDCARD = "*";

private record AccessEntry(Access access, String rawPattern) { }
private record Bucket(
Map<String, AccessEntry> exactMatches,
List<AccessEntry> regexPatterns
) { }

private final Map<String, Map<String, String>> keyPropertyCache =
Collections.synchronizedMap(new LinkedHashMap<String, Map<String, String>>(128, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, Map<String, String>> eldest) {
return size() > 5000;
}
});

private final Map<String, TreeMap<String, Access>> domainCache =
Collections.synchronizedMap(new LinkedHashMap<String, TreeMap<String, Access>>(128, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, TreeMap<String, Access>> eldest) {
return size() > 5000;
}
});

private final Map<String, Map<String, Bucket>> bucketedDomainCache =
Collections.synchronizedMap(new LinkedHashMap<>(128, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, Map<String, Bucket>> eldest) {
return size() > 1000;
}
});

private Access defaultAccess = new Access(WILDCARD);
private ConcurrentMap<String, TreeMap<String, Access>> domainAccess = new ConcurrentHashMap<>();
private ConcurrentMap<String, TreeMap<String, Access>> allowList = new ConcurrentHashMap<>();
Expand All @@ -48,6 +80,7 @@ public class JMXAccessControlList {
return key2.length() - key1.length();
};


public void addToAllowList(String domain, String key) {
TreeMap<String, Access> domainMap = new TreeMap<>(keyComparator);
domainMap = allowList.putIfAbsent(domain, domainMap);
Expand Down Expand Up @@ -81,6 +114,86 @@ public List<String> getRolesForObject(ObjectName objectName, String methodName)
return defaultAccess.getMatchingRolesForMethod(methodName);
}


public boolean authorizeUserForMethod(ObjectName objectName, String methodName, Set<String> userRoles) {

String domainKey = objectName.getDomain();

TreeMap<String, Access> domainMap = domainCache.computeIfAbsent(objectName.getDomain(), key ->
domainAccess.get(key)
);

Map<String, Bucket> bucketedMap = bucketedDomainCache.computeIfAbsent(domainKey, d -> {
TreeMap<String, Access> rawMap = domainAccess.get(d);
if (rawMap == null) {
return null;
}

Map<String, Bucket> grouped = new HashMap<>();
for (Access access : rawMap.values()) {
String rawPattern = access.getKeyPattern().pattern();
int eqIndex = rawPattern.indexOf('=');
String prefix = (eqIndex != -1) ? rawPattern.substring(0, eqIndex) : "";

// Initialize the Bucket (Map + List) instead of just an ArrayList
Bucket bucket = grouped.computeIfAbsent(prefix, k ->
new Bucket(new HashMap<>(), new ArrayList<>())
);

AccessEntry entry = new AccessEntry(access, rawPattern);

// Sort into Exact or Regex
if (rawPattern.contains("*") || rawPattern.contains("?") || rawPattern.contains("[")) {
bucket.regexPatterns().add(entry);
} else {
bucket.exactMatches().put(rawPattern, entry);
}
}
return grouped;
});

if (bucketedMap != null) {

String cacheKey = objectName.getCanonicalName();
Map<String, String> keyPropertyList = keyPropertyCache.get(cacheKey);
if (keyPropertyList == null) {
keyPropertyList = objectName.getKeyPropertyList();
keyPropertyCache.put(cacheKey, keyPropertyList);
}


for (Map.Entry<String, String> entry : keyPropertyList.entrySet()) {
String propKey = entry.getKey();
Bucket bucket = bucketedMap.get(propKey);

if (bucket != null) {
String normalizedValue = normalizeKey(propKey + "=" + entry.getValue());

// Priority 1: O(1) Exact Match Check
if (bucket.exactMatches().containsKey(normalizedValue)) {
return bucket.exactMatches().get(normalizedValue).access().authorizeUserForMethod(methodName, userRoles);
}

// Priority 2: O(N) Regex Match (but only for actual regexes)
for (AccessEntry regexEntry : bucket.regexPatterns()) {
if (regexEntry.access().getKeyPattern().matcher(normalizedValue).matches()) {
return regexEntry.access().authorizeUserForMethod(methodName, userRoles);
}
}
}
}

Access access = domainMap.get("");
if (access != null) {
return access.authorizeUserForMethod(methodName, userRoles);
}
}

return defaultAccess.authorizeUserForMethod(methodName, userRoles);
}



public boolean isInAllowList(ObjectName objectName) {
TreeMap<String, Access> domainMap = allowList.get(objectName.getDomain());

Expand Down Expand Up @@ -223,6 +336,20 @@ public List<String> getMatchingRolesForMethod(String methodName) {
}
return catchAllRoles;
}

public boolean authorizeUserForMethod(String methodName, Set<String> userRoles) {
List<String> roles = methodRoles.get(methodName);
if (roles != null) {
return !Collections.disjoint(roles, userRoles);

}
for (Map.Entry<String, List<String>> entry : methodPrefixRoles.entrySet()) {
if (methodName.startsWith(entry.getKey())) {
return !Collections.disjoint(entry.getValue(), userRoles);
}
}
return !Collections.disjoint(catchAllRoles, userRoles);
}
}

public static JMXAccessControlList createDefaultList() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void testCanInvokeMethodHasRole() throws Throwable {


@Test
public void testCanInvokeMethodDoeNotHasRole() throws Throwable {
public void testCanInvokeMethodDoesNotHaveRole() throws Throwable {
ArtemisMBeanServerGuard guard = new ArtemisMBeanServerGuard();
JMXAccessControlList controlList = new JMXAccessControlList();
guard.setJMXAccessControlList(controlList);
Expand Down