package cn.nosum.common.extension; import cn.nosum.common.annotation.Adaptive; import cn.nosum.common.annotation.DisableInject; import cn.nosum.common.annotation.SPI; import cn.nosum.common.util.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.BufferedReader; import java.io.InputStreamReader; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.net.URL; import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.regex.Pattern; /** * Load gateway ExtensionLoader * @param */ public class ExtensionLoader { private static final Logger logger= LoggerFactory.getLogger(ExtensionLoader.class); // 需要扫描的路径 private static final String GATEWAY_DIRECTORY = "META-INF/gateway/"; private static final Pattern NAME_SEPARATOR = Pattern.compile("\\s*[,]+\\s*"); // 缓存 private static final ConcurrentMap, ExtensionLoader> EXTENSION_LOADERS = new ConcurrentHashMap<>(); private static final ConcurrentMap, Object> EXTENSION_INSTANCES = new ConcurrentHashMap<>(); private final Holder>> cachedClasses = new Holder<>(); private final Holder cachedAdaptiveInstance = new Holder<>(); private final ConcurrentMap> cachedInstances = new ConcurrentHashMap<>(); private Set> cachedWrapperClasses; private final ExtensionFactory objectFactory; private volatile Class cachedAdaptiveClass = null; private final Class type; private String cachedDefaultName; // 记录错误信息 private Map exceptions = new ConcurrentHashMap<>(); private volatile Throwable createAdaptiveInstanceError; private ExtensionLoader(Class type) { this.type = type; objectFactory = (type == ExtensionFactory.class ? null : ExtensionLoader.getExtensionLoader(ExtensionFactory.class).getAdaptiveExtension()); } @SuppressWarnings("unchecked") public static ExtensionLoader getExtensionLoader(Class type) { if (type == null) { throw new IllegalArgumentException("Extension type == null"); } if (!type.isInterface()) { throw new IllegalArgumentException("Extension type (" + type + ") is not an interface!"); } if (!withExtensionAnnotation(type)) { throw new IllegalArgumentException("Extension type (" + type + ") is not an extension, because it is NOT annotated with @" + SPI.class.getSimpleName() + "!"); } ExtensionLoader loader = (ExtensionLoader) EXTENSION_LOADERS.get(type); if (loader == null) { EXTENSION_LOADERS.putIfAbsent(type, new ExtensionLoader(type)); loader = (ExtensionLoader) EXTENSION_LOADERS.get(type); } return loader; } @SuppressWarnings("unchecked") public T getExtension(String name) { if (StringUtils.isEmpty(name)) { throw new IllegalArgumentException("Extension name == null"); } if ("true".equals(name)) { return getDefaultExtension(); } final Holder holder = getOrCreateHolder(name); Object instance = holder.get(); if (instance == null) { synchronized (holder) { instance = holder.get(); if (instance == null) { instance = createExtension(name); holder.set(instance); } } } return (T) instance; } private Holder getOrCreateHolder(String name) { Holder holder = cachedInstances.get(name); if (holder == null) { cachedInstances.putIfAbsent(name, new Holder<>()); holder = cachedInstances.get(name); } return holder; } private T createExtension(String name) { Class clazz = getExtensionClasses().get(name); if (clazz == null) { throw new NullPointerException(name); } try { T instance = (T) EXTENSION_INSTANCES.get(clazz); if (instance == null) { EXTENSION_INSTANCES.putIfAbsent(clazz, clazz.newInstance()); instance = (T) EXTENSION_INSTANCES.get(clazz); } injectExtension(instance); Set> wrapperClasses = cachedWrapperClasses; if (CollectionUtils.isNotEmpty(wrapperClasses)) { for (Class wrapperClass : wrapperClasses) { instance = injectExtension((T) wrapperClass.getConstructor(type).newInstance(instance)); } } return instance; } catch (Throwable t) { throw new IllegalStateException("Extension instance (name: " + name + ", class: " + type + ") couldn't be instantiated: " + t.getMessage(), t); } } /** * 实现依赖注入 * @param instance 需要进行依赖注入的实例 */ private T injectExtension(T instance) { try { if (objectFactory != null) { for (Method method : instance.getClass().getMethods()) { if (isSetter(method)) { // 存在 @DisableInject 则跳过 if (method.getAnnotation(DisableInject.class) != null) { continue; } Class pt = method.getParameterTypes()[0]; if (ReflectUtils.isPrimitives(pt)) { continue; } try { String property = getSetterProperty(method); Object object = objectFactory.getExtension(pt, property); if (object != null) { method.invoke(instance, object); } } catch (Exception e) { logger.error("Failed to inject via method " + method.getName() + " of interface " + type.getName() + ": " + e.getMessage(), e); } } } } } catch (Exception e) { logger.error(e.getMessage(), e); } return instance; } /** * @return 根据默认扩展名获取到的实例 */ public T getDefaultExtension() { getExtensionClasses(); if (StringUtils.isEmpty(cachedDefaultName) || "true".equals(cachedDefaultName)) { return null; } return getExtension(cachedDefaultName); } /** * @return 获取自适应扩展点 */ @SuppressWarnings("unchecked") public T getAdaptiveExtension() { Object instance = cachedAdaptiveInstance.get(); if (instance == null) { if (createAdaptiveInstanceError == null) { synchronized (cachedAdaptiveInstance) { instance = cachedAdaptiveInstance.get(); if (instance == null) { try { instance = createAdaptiveExtension(); cachedAdaptiveInstance.set(instance); } catch (Throwable t) { createAdaptiveInstanceError = t; throw new IllegalStateException("Failed to create adaptive instance: " + t.toString(), t); } } } } else { throw new IllegalStateException("Failed to create adaptive instance: " + createAdaptiveInstanceError.toString(), createAdaptiveInstanceError); } } return (T) instance; } @SuppressWarnings("unchecked") private T createAdaptiveExtension() { try { return injectExtension((T) getAdaptiveExtensionClass().newInstance()); } catch (Exception e) { throw new IllegalStateException("Can't create adaptive extension " + type + ", cause: " + e.getMessage(), e); } } /** * 获取自适应扩展点 */ private Class getAdaptiveExtensionClass() { getExtensionClasses(); if (cachedAdaptiveClass != null) { return cachedAdaptiveClass; } return null; } private Map> getExtensionClasses() { Map> classes = cachedClasses.get(); if (classes == null) { synchronized (cachedClasses) { classes = cachedClasses.get(); if (classes == null) { classes = loadExtensionClasses(); cachedClasses.set(classes); } } } return classes; } /** * 如果存在默认扩展名,提取并且缓存 */ private void cacheDefaultExtensionName() { String value; if (StringUtils.isEmpty(value=PropertiesUtil.getProperty(type.getName()))){ final SPI defaultAnnotation = type.getAnnotation(SPI.class); if (defaultAnnotation != null) { value = defaultAnnotation.value(); if ((value = value.trim()).length() > 0) { String[] names = NAME_SEPARATOR.split(value); if (names.length > 1) { throw new IllegalStateException("More than 1 default extension name on extension " + type.getName() + ": " + Arrays.toString(names)); } if (names.length == 1) { cachedDefaultName = names[0]; } } } }else{ cachedDefaultName=value; } } // synchronized in getExtensionClasses private Map> loadExtensionClasses() { cacheDefaultExtensionName(); Map> extensionClasses = new HashMap<>(); loadDirectory(extensionClasses, GATEWAY_DIRECTORY, type.getName()); return extensionClasses; } private void loadDirectory(Map> extensionClasses, String dir, String type) { String fileName = dir + type; try { Enumeration urls; ClassLoader classLoader = findClassLoader(); if (classLoader != null) { urls = classLoader.getResources(fileName); } else { urls = ClassLoader.getSystemResources(fileName); } if (urls != null) { while (urls.hasMoreElements()) { java.net.URL resourceURL = urls.nextElement(); loadResource(extensionClasses, classLoader, resourceURL); } } } catch (Throwable t) { logger.error("Exception occurred when loading extension class (interface: " + type + ", description file: " + fileName + ").", t); } } private void loadResource(Map> extensionClasses, ClassLoader classLoader, java.net.URL resourceURL) { try { try (BufferedReader reader = new BufferedReader(new InputStreamReader(resourceURL.openStream(), StandardCharsets.UTF_8))) { String line; while ((line = reader.readLine()) != null) { final int ci = line.indexOf('#'); if (ci >= 0) { line = line.substring(0, ci); } line = line.trim(); if (line.length() > 0) { try { String name = null; int i = line.indexOf('='); if (i > 0) { name = line.substring(0, i).trim(); line = line.substring(i + 1).trim(); } if (line.length() > 0) { loadClass(extensionClasses, resourceURL, Class.forName(line, true, classLoader), name); } } catch (Throwable t) { IllegalStateException e = new IllegalStateException("Failed to load extension class (interface: " + type + ", class line: " + line + ") in " + resourceURL + ", cause: " + t.getMessage(), t); exceptions.put(line, e); } } } } } catch (Throwable t) { logger.error("Exception occurred when loading extension class (interface: " + type + ", class file: " + resourceURL + ") in " + resourceURL, t); } } private void loadClass(Map> extensionClasses, java.net.URL resourceURL, Class clazz, String name) throws NoSuchMethodException { if (!type.isAssignableFrom(clazz)) { throw new IllegalStateException("Error occurred when loading extension class (interface: " + type + ", class line: " + clazz.getName() + "), class " + clazz.getName() + " is not subtype of interface."); } if (clazz.isAnnotationPresent(Adaptive.class)) { cacheAdaptiveClass(clazz,name); } else if (isWrapperClass(clazz)) { cacheWrapperClass(clazz); } } private void cacheAdaptiveClass(Class clazz,String name) { // 只有名称匹配时才会进行保存 if (cachedAdaptiveClass == null && name.equals(cachedDefaultName)) { cachedAdaptiveClass = clazz; } else if (!cachedAdaptiveClass.equals(clazz)) { throw new IllegalStateException("More than 1 adaptive class found: " + cachedAdaptiveClass.getClass().getName() + ", " + clazz.getClass().getName()); } } private void cacheWrapperClass(Class clazz) { if (cachedWrapperClasses == null) { cachedWrapperClasses = new HashSet<>(); } cachedWrapperClasses.add(clazz); } private boolean isWrapperClass(Class clazz) { try { clazz.getConstructor(type); return true; } catch (NoSuchMethodException e) { return false; } } private boolean isSetter(Method method) { return method.getName().startsWith("set") && method.getParameterTypes().length == 1 && Modifier.isPublic(method.getModifiers()); } private static ClassLoader findClassLoader() { return ClassUtils.getClassLoader(ExtensionLoader.class); } private static boolean withExtensionAnnotation(Class type) { return type.isAnnotationPresent(SPI.class); } private String getSetterProperty(Method method) { return method.getName().length() > 3 ? method.getName().substring(3, 4).toLowerCase() + method.getName().substring(4) : ""; } public Set getSupportedExtensions() { Map> clazzes = getExtensionClasses(); return Collections.unmodifiableSet(new TreeSet<>(clazzes.keySet())); } }