在spring中关于对bean的扩张可以分为两种:
1 基于所有bean:可以使用beanfactorypostproccessor进行修改,但是这种修改是全局的,也就是所有的bean都会被进行修改
2 基于单个bean:只针对单个特定的bean的实例化修改,这种情况可以使用BeanPostProcessor,也是接下来要简单介绍和使用的
从spring源码的来看,主要可以包含下面的几个流程
1. 初始化时,spring容器有特别处理,会直接调用beanFactory.addBeanPostProcessor进行注册(例如AbstractApplicationContext类的prepareBeanFactory方法中就有); 2. 找出所有实现了BeanPostProcessor接口的bean,注册到容器,注册顺序如下: 第一:实现了PriorityOrdered接口的,排序后; 第二:实现了Ordered接口的,排序后; 第三:既没实现PriorityOrdered接口,也没有实现Ordered接口的; 第四:实现了MergedBeanDefinitionPostProcessor接口的(这些也按照PriorityOrdered、Ordered等逻辑拍过续); 第五:实例化一个ApplicationListenerDetector对象; 3. 实例化bean的时候,对于每个bean,先用MergedBeanDefinitionPostProcessor实现类的postProcessMergedBeanDefinition方法处理每个bean的定义类; 4. 再用BeanPostProcessor的postProcessBeforeInitialization方法处理每个bean实例; 5. bean实例初始化; 6. 用BeanPostProcessor的postProcessAfterInitialization方法处理每个bean实例;它的核心方法只有两个
public interface BeanPostProcessor { Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException; Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException; }可以自定义个bean的后置处理器来控制bean
@Component public class MyBeanPostProcessor implements BeanPostProcessor { @Override public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { if("calculateService".equals(beanName)) { CalculateService calculateService = (CalculateService)bean; //修改calculateService实例的成员变量serviceDesc的值 calculateService.setServiceDesc("desc from " + this.getClass().getSimpleName()); } return bean; } @Override public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { if("calculateService".equals(beanName)) { Utils.printTrack("do postProcess after initialization"); } return bean; } }下面再介绍一个实际开发中可能会用到的例子:就是一个接口有多个实现类
当然对于一个接口有多个实现类,我们可以直接使用spring自带的注解,只不过需要加入@Primary,但是这样看起来不友好,接下来通过自定义注解+后置处理器来实现动态切换
首先定义一个接口和两个实现类
public interface HelloService{ public void sayHello(); } @Service public class HelloServiceImpl1 implements HelloService { @Override public void sayHello() { System.out.println("你好我是HelloServiceImpl1"); } } @Service public class HelloServiceImpl2 implements HelloService { @Override public void sayHello() { System.out.println("你好我是HelloServiceImpl2"); } } 定义一个自定义注解 @Target({ElementType.FIELD}) @Retention(RetentionPolicy.RUNTIME) @Documented @Component public @interface RountingInjected { String value() default "helloServiceImpl1"; } 定义bean的后置处理器 @Component public class HelloServiceInjectProcessor implements BeanPostProcessor { @Autowired private ApplicationContext applicationContext; @Override public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException { return bean; } @Override public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException { Class<?> targetCls = bean.getClass(); Field[] targetFld = targetCls.getDeclaredFields(); for (Field field : targetFld) { //找到制定目标的注解类 if (field.isAnnotationPresent(RountingInjected.class)) { if (!field.getType().isInterface()) { throw new BeanCreationException("RoutingInjected field must be declared as an interface:" + field.getName() + " @Class " + targetCls.getName()); } try { this.handleRoutingInjected(field, bean, field.getType()); } catch (IllegalAccessException e) { e.printStackTrace(); } } } return bean; } /** * @param field * @param bean * @param type * @throws IllegalAccessException */ private void handleRoutingInjected(Field field, Object bean, Class type) throws IllegalAccessException { Map<String, Object> candidates = this.applicationContext.getBeansOfType(type); field.setAccessible(true); if (candidates.size() == 1) { field.set(bean, candidates.values().iterator().next()); } else if (candidates.size() == 2) { String injectVal = field.getAnnotation(RountingInjected.class).value(); Object proxy = RoutingBeanProxyFactory.createProxy(injectVal, type, candidates); field.set(bean, proxy); } else { throw new IllegalArgumentException("Find more than 2 beans for type: " + type); } } } 还需要一个代理工厂 public class RoutingBeanProxyFactory { private final static String DEFAULT_BEAN_NAME = "helloServiceImpl1"; public static Object createProxy(String name, Class type, Map<String, Object> candidates) { ProxyFactory proxyFactory = new ProxyFactory(); proxyFactory.setInterfaces(type); proxyFactory.addAdvice(new VersionRoutingMethodInterceptor(name, candidates)); return proxyFactory.getProxy(); } static class VersionRoutingMethodInterceptor implements MethodInterceptor { private Object targetObject; public VersionRoutingMethodInterceptor(String name, Map<String, Object> beans) { this.targetObject = beans.get(name); if (this.targetObject == null) { this.targetObject = beans.get(DEFAULT_BEAN_NAME); } } @Override public Object invoke(MethodInvocation invocation) throws Throwable { return invocation.getMethod().invoke(this.targetObject, invocation.getArguments()); } } } 最后启动一个application 测试一下效果 @SpringBootApplication @MapperScan("com.lx.mapper") public class MlxcApplication { public static void main(final String[] args) { try (ConfigurableApplicationContext applicationContext = SpringApplication.run(MlxcApplication.class, args)) { HelloServiceTest helloService = applicationContext.getBean(HelloServiceTest.class); helloService.testSayHello(); } } }