package cn.nosum.framework.mvc.v2.servlet; import cn.nosum.framework.annotation.*; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.io.InputStream; import java.lang.annotation.Annotation; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.net.URL; import java.util.*; import java.io.File; public class DispatcherServlet extends HttpServlet { private Properties contextConfig = new Properties(); //享元模式,缓存 private List classNames = new ArrayList(); //IoC容器,key默认是类名首字母小写,value就是对应的实例对象 private Map ioc = new HashMap(); // 保存 URL 与对应执行方法的映射 private Map handlerMapping = new HashMap(); @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { this.doPost(req,resp); } @Override protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException { try { //6、委派,根据URL去找到一个对应的Method并通过response返回 doDispatch(req,resp); } catch (Exception e) { e.printStackTrace(); resp.getWriter().write("500 Exception,Detail : " + Arrays.toString(e.getStackTrace())); } } private void doDispatch(HttpServletRequest req, HttpServletResponse resp) throws Exception { String url = req.getRequestURI(); String contextPath = req.getContextPath(); url = url.replaceAll(contextPath,"").replaceAll("/+","/"); if(!this.handlerMapping.containsKey(url)){ resp.getWriter().write("404 Not Found!!!"); return; } Map params = req.getParameterMap(); Method method = this.handlerMapping.get(url); //获取形参列表 Class [] parameterTypes = method.getParameterTypes(); Object [] paramValues = new Object[parameterTypes.length]; for (int i = 0; i < parameterTypes.length; i++) { Class parameterType = parameterTypes[i]; if(parameterType == HttpServletRequest.class){ paramValues[i] = req; }else if(parameterType == HttpServletResponse.class){ paramValues[i] = resp; }else if(parameterType == String.class){ // 需要通过运行时的状态才可以拿到 Annotation[] [] pa = method.getParameterAnnotations(); for (int j = 0; j < pa.length ; j ++) { for(Annotation annotation : pa[i]){ if(annotation instanceof RequestParam){ String paramName = ((RequestParam) annotation).value(); if(!"".equals(paramName.trim())){ String value = Arrays.toString(params.get(paramName)) .replaceAll("\\[|\\]","") .replaceAll("\\s+",","); paramValues[i] = value; } } } } } } // 通过 method 获取对应的类 String beanName = toLowerFirstCase(method.getDeclaringClass().getSimpleName()); // 从 ioc 中获取对应的类并且执行 method.invoke(ioc.get(beanName),paramValues); } @Override public void init(ServletConfig config) throws ServletException { //1、加载配置文件 doLoadConfig(config.getInitParameter("contextConfigLocation")); //2、扫描相关的类 doScanner(contextConfig.getProperty("scanPackage")); //==============IoC部分============== //3、初始化IoC容器,将扫描到的相关的类实例化,保存到IcC容器中 doInstance(); //AOP,新生成的代理对象 //==============DI部分============== //4、完成依赖注入 doAutowired(); //==============MVC部分============== //5、初始化HandlerMapping doInitHandlerMapping(); System.out.println("Spring framework is init."); } private void doInitHandlerMapping() { if(ioc.isEmpty()){ return;} for (Map.Entry entry : ioc.entrySet()) { Class clazz = entry.getValue().getClass(); if(!clazz.isAnnotationPresent(Controller.class)){ continue; } String baseUrl = ""; if(clazz.isAnnotationPresent(RequestMapping.class)){ // 取得类配置的 url RequestMapping requestMapping = clazz.getAnnotation(RequestMapping.class); baseUrl = requestMapping.value(); } // 只获取 public 的方法 for (Method method : clazz.getMethods()) { if(!method.isAnnotationPresent(RequestMapping.class)){continue;} // 提取每个方法配置的 url RequestMapping requestMapping = method.getAnnotation(RequestMapping.class); String url = ("/" + baseUrl + "/" + requestMapping.value()).replaceAll("/+","/"); handlerMapping.put(url,method); System.out.println("Mapped : " + url + "," + method); } } } private void doAutowired() { if(ioc.isEmpty()){return;} for (Map.Entry entry : ioc.entrySet()) { // 把所有 private/protected/default/public 修饰字段都取出来 for (Field field : entry.getValue().getClass().getDeclaredFields()) { if(!field.isAnnotationPresent(Autowired.class)){ continue; } Autowired autowired = field.getAnnotation(Autowired.class); // 没有自定义的beanName,根据类型注入 String beanName = autowired.value().trim(); if("".equals(beanName)){ beanName = field.getType().getName(); } // 赋值 try { field.setAccessible(true); field.set(entry.getValue(),ioc.get(beanName)); } catch (IllegalAccessException e) { e.printStackTrace(); } } } } private void doInstance() { if(classNames.isEmpty()){return;} try { for (String className : classNames) { Class clazz = Class.forName(className); if(clazz.isAnnotationPresent(Controller.class)) { String beanName = toLowerFirstCase(clazz.getSimpleName()); Object instance = clazz.newInstance(); ioc.put(beanName, instance); }else if(clazz.isAnnotationPresent(Service.class)){ // 获取自定义的名称 String beanName = clazz.getAnnotation(Service.class).value(); if("".equals(beanName.trim())){ beanName = toLowerFirstCase(clazz.getSimpleName()); } Object instance = clazz.newInstance(); ioc.put(beanName, instance); // 如果是接口,只能有一个实现,否则抛出异常 for (Class i : clazz.getInterfaces()) { if(ioc.containsKey(i.getName())){ throw new Exception("The " + i.getName() + " is exists!!"); } ioc.put(i.getName(),instance); } }else{ continue; } } }catch (Exception e){ e.printStackTrace(); } } private void doScanner(String scanPackage) { URL url = this.getClass().getClassLoader().getResource("/" + scanPackage.replaceAll("\\.","/")); File classPath = new File(url.getFile()); //当成是一个 ClassPath 文件夹 for (File file : classPath.listFiles()) { if(file.isDirectory()){ doScanner(scanPackage + "." + file.getName()); }else { if(!file.getName().endsWith(".class")){continue;} // 全类名 = 包名.类名 String className = (scanPackage + "." + file.getName().replace(".class", "")); classNames.add(className); } } } private void doLoadConfig(String contextConfigLocation) { InputStream is = this.getClass().getClassLoader().getResourceAsStream(contextConfigLocation); try { contextConfig.load(is); } catch (IOException e) { e.printStackTrace(); }finally { if(null != is){ try { is.close(); } catch (IOException e) { e.printStackTrace(); } } } } private String toLowerFirstCase(String simpleName) { char [] chars = simpleName.toCharArray(); chars[0] += 32; return String.valueOf(chars); } }