ExtensionLoader.java 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. package cn.nosum.common.extension;
  2. import cn.nosum.common.annotation.Adaptive;
  3. import cn.nosum.common.annotation.DisableInject;
  4. import cn.nosum.common.annotation.SPI;
  5. import cn.nosum.common.util.*;
  6. import org.slf4j.Logger;
  7. import org.slf4j.LoggerFactory;
  8. import java.io.BufferedReader;
  9. import java.io.InputStreamReader;
  10. import java.lang.reflect.Method;
  11. import java.lang.reflect.Modifier;
  12. import java.net.URL;
  13. import java.nio.charset.StandardCharsets;
  14. import java.util.*;
  15. import java.util.concurrent.ConcurrentHashMap;
  16. import java.util.concurrent.ConcurrentMap;
  17. import java.util.regex.Pattern;
  18. /**
  19. * Load gateway ExtensionLoader
  20. * @param <T>
  21. */
  22. public class ExtensionLoader<T> {
  23. private static final Logger logger= LoggerFactory.getLogger(ExtensionLoader.class);
  24. // 需要扫描的路径
  25. private static final String GATEWAY_DIRECTORY = "META-INF/gateway/";
  26. private static final Pattern NAME_SEPARATOR = Pattern.compile("\\s*[,]+\\s*");
  27. // 缓存
  28. private static final ConcurrentMap<Class<?>, ExtensionLoader<?>> EXTENSION_LOADERS = new ConcurrentHashMap<>();
  29. private static final ConcurrentMap<Class<?>, Object> EXTENSION_INSTANCES = new ConcurrentHashMap<>();
  30. private final Holder<Map<String, Class<?>>> cachedClasses = new Holder<>();
  31. private final Holder<Object> cachedAdaptiveInstance = new Holder<>();
  32. private final ConcurrentMap<String, Holder<Object>> cachedInstances = new ConcurrentHashMap<>();
  33. private Set<Class<?>> cachedWrapperClasses;
  34. private final ExtensionFactory objectFactory;
  35. private volatile Class<?> cachedAdaptiveClass = null;
  36. private final Class<?> type;
  37. private String cachedDefaultName;
  38. // 记录错误信息
  39. private Map<String, IllegalStateException> exceptions = new ConcurrentHashMap<>();
  40. private volatile Throwable createAdaptiveInstanceError;
  41. private ExtensionLoader(Class<?> type) {
  42. this.type = type;
  43. objectFactory = (type == ExtensionFactory.class ? null : ExtensionLoader.getExtensionLoader(ExtensionFactory.class).getAdaptiveExtension());
  44. }
  45. @SuppressWarnings("unchecked")
  46. public static <T> ExtensionLoader<T> getExtensionLoader(Class<T> type) {
  47. if (type == null) {
  48. throw new IllegalArgumentException("Extension type == null");
  49. }
  50. if (!type.isInterface()) {
  51. throw new IllegalArgumentException("Extension type (" + type + ") is not an interface!");
  52. }
  53. if (!withExtensionAnnotation(type)) {
  54. throw new IllegalArgumentException("Extension type (" + type +
  55. ") is not an extension, because it is NOT annotated with @" + SPI.class.getSimpleName() + "!");
  56. }
  57. ExtensionLoader<T> loader = (ExtensionLoader<T>) EXTENSION_LOADERS.get(type);
  58. if (loader == null) {
  59. EXTENSION_LOADERS.putIfAbsent(type, new ExtensionLoader<T>(type));
  60. loader = (ExtensionLoader<T>) EXTENSION_LOADERS.get(type);
  61. }
  62. return loader;
  63. }
  64. @SuppressWarnings("unchecked")
  65. public T getExtension(String name) {
  66. if (StringUtils.isEmpty(name)) {
  67. throw new IllegalArgumentException("Extension name == null");
  68. }
  69. if ("true".equals(name)) {
  70. return getDefaultExtension();
  71. }
  72. final Holder<Object> holder = getOrCreateHolder(name);
  73. Object instance = holder.get();
  74. if (instance == null) {
  75. synchronized (holder) {
  76. instance = holder.get();
  77. if (instance == null) {
  78. instance = createExtension(name);
  79. holder.set(instance);
  80. }
  81. }
  82. }
  83. return (T) instance;
  84. }
  85. private Holder<Object> getOrCreateHolder(String name) {
  86. Holder<Object> holder = cachedInstances.get(name);
  87. if (holder == null) {
  88. cachedInstances.putIfAbsent(name, new Holder<>());
  89. holder = cachedInstances.get(name);
  90. }
  91. return holder;
  92. }
  93. private T createExtension(String name) {
  94. Class<?> clazz = getExtensionClasses().get(name);
  95. if (clazz == null) {
  96. throw new NullPointerException(name);
  97. }
  98. try {
  99. T instance = (T) EXTENSION_INSTANCES.get(clazz);
  100. if (instance == null) {
  101. EXTENSION_INSTANCES.putIfAbsent(clazz, clazz.newInstance());
  102. instance = (T) EXTENSION_INSTANCES.get(clazz);
  103. }
  104. injectExtension(instance);
  105. Set<Class<?>> wrapperClasses = cachedWrapperClasses;
  106. if (CollectionUtils.isNotEmpty(wrapperClasses)) {
  107. for (Class<?> wrapperClass : wrapperClasses) {
  108. instance = injectExtension((T) wrapperClass.getConstructor(type).newInstance(instance));
  109. }
  110. }
  111. return instance;
  112. } catch (Throwable t) {
  113. throw new IllegalStateException("Extension instance (name: " + name + ", class: " +
  114. type + ") couldn't be instantiated: " + t.getMessage(), t);
  115. }
  116. }
  117. /**
  118. * 实现依赖注入
  119. * @param instance 需要进行依赖注入的实例
  120. */
  121. private T injectExtension(T instance) {
  122. try {
  123. if (objectFactory != null) {
  124. for (Method method : instance.getClass().getMethods()) {
  125. if (isSetter(method)) {
  126. // 存在 @DisableInject 则跳过
  127. if (method.getAnnotation(DisableInject.class) != null) {
  128. continue;
  129. }
  130. Class<?> pt = method.getParameterTypes()[0];
  131. if (ReflectUtils.isPrimitives(pt)) {
  132. continue;
  133. }
  134. try {
  135. String property = getSetterProperty(method);
  136. Object object = objectFactory.getExtension(pt, property);
  137. if (object != null) {
  138. method.invoke(instance, object);
  139. }
  140. } catch (Exception e) {
  141. logger.error("Failed to inject via method " + method.getName()
  142. + " of interface " + type.getName() + ": " + e.getMessage(), e);
  143. }
  144. }
  145. }
  146. }
  147. } catch (Exception e) {
  148. logger.error(e.getMessage(), e);
  149. }
  150. return instance;
  151. }
  152. /**
  153. * @return 根据默认扩展名获取到的实例
  154. */
  155. public T getDefaultExtension() {
  156. getExtensionClasses();
  157. if (StringUtils.isEmpty(cachedDefaultName) || "true".equals(cachedDefaultName)) {
  158. return null;
  159. }
  160. return getExtension(cachedDefaultName);
  161. }
  162. /**
  163. * @return 获取自适应扩展点
  164. */
  165. @SuppressWarnings("unchecked")
  166. public T getAdaptiveExtension() {
  167. Object instance = cachedAdaptiveInstance.get();
  168. if (instance == null) {
  169. if (createAdaptiveInstanceError == null) {
  170. synchronized (cachedAdaptiveInstance) {
  171. instance = cachedAdaptiveInstance.get();
  172. if (instance == null) {
  173. try {
  174. instance = createAdaptiveExtension();
  175. cachedAdaptiveInstance.set(instance);
  176. } catch (Throwable t) {
  177. createAdaptiveInstanceError = t;
  178. throw new IllegalStateException("Failed to create adaptive instance: " + t.toString(), t);
  179. }
  180. }
  181. }
  182. } else {
  183. throw new IllegalStateException("Failed to create adaptive instance: " + createAdaptiveInstanceError.toString(), createAdaptiveInstanceError);
  184. }
  185. }
  186. return (T) instance;
  187. }
  188. @SuppressWarnings("unchecked")
  189. private T createAdaptiveExtension() {
  190. try {
  191. return injectExtension((T) getAdaptiveExtensionClass().newInstance());
  192. } catch (Exception e) {
  193. throw new IllegalStateException("Can't create adaptive extension " + type + ", cause: " + e.getMessage(), e);
  194. }
  195. }
  196. /**
  197. * 获取自适应扩展点
  198. */
  199. private Class<?> getAdaptiveExtensionClass() {
  200. getExtensionClasses();
  201. if (cachedAdaptiveClass != null) {
  202. return cachedAdaptiveClass;
  203. }
  204. return null;
  205. }
  206. private Map<String, Class<?>> getExtensionClasses() {
  207. Map<String, Class<?>> classes = cachedClasses.get();
  208. if (classes == null) {
  209. synchronized (cachedClasses) {
  210. classes = cachedClasses.get();
  211. if (classes == null) {
  212. classes = loadExtensionClasses();
  213. cachedClasses.set(classes);
  214. }
  215. }
  216. }
  217. return classes;
  218. }
  219. /**
  220. * 如果存在默认扩展名,提取并且缓存
  221. */
  222. private void cacheDefaultExtensionName() {
  223. String value;
  224. if (StringUtils.isEmpty(value=PropertiesUtil.getProperty(type.getName()))){
  225. final SPI defaultAnnotation = type.getAnnotation(SPI.class);
  226. if (defaultAnnotation != null) {
  227. value = defaultAnnotation.value();
  228. if ((value = value.trim()).length() > 0) {
  229. String[] names = NAME_SEPARATOR.split(value);
  230. if (names.length > 1) {
  231. throw new IllegalStateException("More than 1 default extension name on extension "
  232. + type.getName() + ": "
  233. + Arrays.toString(names));
  234. }
  235. if (names.length == 1) {
  236. cachedDefaultName = names[0];
  237. }
  238. }
  239. }
  240. }else{
  241. cachedDefaultName=value;
  242. }
  243. }
  244. // synchronized in getExtensionClasses
  245. private Map<String, Class<?>> loadExtensionClasses() {
  246. cacheDefaultExtensionName();
  247. Map<String, Class<?>> extensionClasses = new HashMap<>();
  248. loadDirectory(extensionClasses, GATEWAY_DIRECTORY, type.getName());
  249. return extensionClasses;
  250. }
  251. private void loadDirectory(Map<String, Class<?>> extensionClasses, String dir, String type) {
  252. String fileName = dir + type;
  253. try {
  254. Enumeration<URL> urls;
  255. ClassLoader classLoader = findClassLoader();
  256. if (classLoader != null) {
  257. urls = classLoader.getResources(fileName);
  258. } else {
  259. urls = ClassLoader.getSystemResources(fileName);
  260. }
  261. if (urls != null) {
  262. while (urls.hasMoreElements()) {
  263. java.net.URL resourceURL = urls.nextElement();
  264. loadResource(extensionClasses, classLoader, resourceURL);
  265. }
  266. }
  267. } catch (Throwable t) {
  268. logger.error("Exception occurred when loading extension class (interface: " +
  269. type + ", description file: " + fileName + ").", t);
  270. }
  271. }
  272. private void loadResource(Map<String, Class<?>> extensionClasses, ClassLoader classLoader, java.net.URL resourceURL) {
  273. try {
  274. try (BufferedReader reader = new BufferedReader(new InputStreamReader(resourceURL.openStream(), StandardCharsets.UTF_8))) {
  275. String line;
  276. while ((line = reader.readLine()) != null) {
  277. final int ci = line.indexOf('#');
  278. if (ci >= 0) {
  279. line = line.substring(0, ci);
  280. }
  281. line = line.trim();
  282. if (line.length() > 0) {
  283. try {
  284. String name = null;
  285. int i = line.indexOf('=');
  286. if (i > 0) {
  287. name = line.substring(0, i).trim();
  288. line = line.substring(i + 1).trim();
  289. }
  290. if (line.length() > 0) {
  291. loadClass(extensionClasses, resourceURL, Class.forName(line, true, classLoader), name);
  292. }
  293. } catch (Throwable t) {
  294. IllegalStateException e = new IllegalStateException("Failed to load extension class (interface: " + type + ", class line: " + line + ") in " + resourceURL + ", cause: " + t.getMessage(), t);
  295. exceptions.put(line, e);
  296. }
  297. }
  298. }
  299. }
  300. } catch (Throwable t) {
  301. logger.error("Exception occurred when loading extension class (interface: " +
  302. type + ", class file: " + resourceURL + ") in " + resourceURL, t);
  303. }
  304. }
  305. private void loadClass(Map<String, Class<?>> extensionClasses, java.net.URL resourceURL, Class<?> clazz, String name) throws NoSuchMethodException {
  306. if (!type.isAssignableFrom(clazz)) {
  307. throw new IllegalStateException("Error occurred when loading extension class (interface: " +
  308. type + ", class line: " + clazz.getName() + "), class "
  309. + clazz.getName() + " is not subtype of interface.");
  310. }
  311. if (clazz.isAnnotationPresent(Adaptive.class)) {
  312. cacheAdaptiveClass(clazz,name);
  313. } else if (isWrapperClass(clazz)) {
  314. cacheWrapperClass(clazz);
  315. }
  316. }
  317. private void cacheAdaptiveClass(Class<?> clazz,String name) {
  318. // 只有名称匹配时才会进行保存
  319. if (cachedAdaptiveClass == null && name.equals(cachedDefaultName)) {
  320. cachedAdaptiveClass = clazz;
  321. } else if (!cachedAdaptiveClass.equals(clazz)) {
  322. throw new IllegalStateException("More than 1 adaptive class found: "
  323. + cachedAdaptiveClass.getClass().getName()
  324. + ", " + clazz.getClass().getName());
  325. }
  326. }
  327. private void cacheWrapperClass(Class<?> clazz) {
  328. if (cachedWrapperClasses == null) {
  329. cachedWrapperClasses = new HashSet<>();
  330. }
  331. cachedWrapperClasses.add(clazz);
  332. }
  333. private boolean isWrapperClass(Class<?> clazz) {
  334. try {
  335. clazz.getConstructor(type);
  336. return true;
  337. } catch (NoSuchMethodException e) {
  338. return false;
  339. }
  340. }
  341. private boolean isSetter(Method method) {
  342. return method.getName().startsWith("set")
  343. && method.getParameterTypes().length == 1
  344. && Modifier.isPublic(method.getModifiers());
  345. }
  346. private static ClassLoader findClassLoader() {
  347. return ClassUtils.getClassLoader(ExtensionLoader.class);
  348. }
  349. private static <T> boolean withExtensionAnnotation(Class<T> type) {
  350. return type.isAnnotationPresent(SPI.class);
  351. }
  352. private String getSetterProperty(Method method) {
  353. return method.getName().length() > 3 ? method.getName().substring(3, 4).toLowerCase() + method.getName().substring(4) : "";
  354. }
  355. public Set<String> getSupportedExtensions() {
  356. Map<String, Class<?>> clazzes = getExtensionClasses();
  357. return Collections.unmodifiableSet(new TreeSet<>(clazzes.keySet()));
  358. }
  359. }