diff --git a/src/main/java/w/core/Swapper.java b/src/main/java/w/core/Swapper.java index 8e48a1d..329d716 100644 --- a/src/main/java/w/core/Swapper.java +++ b/src/main/java/w/core/Swapper.java @@ -50,19 +50,7 @@ public boolean swap(Message message) { Set> classes = Global.allLoadedClasses.getOrDefault(transformer.getClassName(), new HashSet<>()); - boolean classExists = false; - for (Class aClass : classes) { - if (aClass.isInterface() || Modifier.isAbstract(aClass.getModifiers())) { - Set candidates = new HashSet<>(); - for (Object instances : Global.getInstances(aClass)) { - candidates.add(instances.getClass().getName()); - } - Global.error("!Error: Should use a simple pojo, but " + aClass.getName() + " is a Interface or Abstract class or something wired, \nmaybe you should use: " + candidates); - return false; - } - classExists = true; - } - + boolean classExists = !classes.isEmpty(); if (!classExists) { Global.error("Class not exist: " + transformer.getClassName()); return false; diff --git a/src/main/java/w/core/model/BaseClassTransformer.java b/src/main/java/w/core/model/BaseClassTransformer.java index 6742aff..2a7ec94 100644 --- a/src/main/java/w/core/model/BaseClassTransformer.java +++ b/src/main/java/w/core/model/BaseClassTransformer.java @@ -1,6 +1,9 @@ package w.core.model; +import javassist.CtClass; +import javassist.CtMethod; import javassist.LoaderClassPath; +import javassist.NotFoundException; import lombok.Getter; import lombok.Setter; import w.Global; @@ -10,8 +13,7 @@ import java.lang.instrument.ClassFileTransformer; import java.lang.instrument.IllegalClassFormatException; import java.security.ProtectionDomain; -import java.util.Objects; -import java.util.UUID; +import java.util.*; import java.util.concurrent.CompletableFuture; /** diff --git a/src/main/java/w/core/model/ChangeBodyTransformer.java b/src/main/java/w/core/model/ChangeBodyTransformer.java index 8fe166c..ec124ad 100644 --- a/src/main/java/w/core/model/ChangeBodyTransformer.java +++ b/src/main/java/w/core/model/ChangeBodyTransformer.java @@ -3,6 +3,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import javassist.CtClass; import javassist.CtMethod; +import javassist.Modifier; import lombok.Data; import w.Global; import w.web.message.ChangeBodyMessage; @@ -44,12 +45,18 @@ public byte[] transform(String className, byte[] origin) throws Exception { Arrays.equals(paramTypes.toArray(new String[0]), Arrays.stream(declaredMethod.getParameterTypes()).map(CtClass::getName).toArray()) ) { + if ((declaredMethod.getModifiers() & Modifier.ABSTRACT) != 0) { + throw new IllegalArgumentException("Cannot change abstract method."); + } + if ((declaredMethod.getModifiers() & Modifier.NATIVE) != 0) { + throw new IllegalArgumentException("Cannot change native method."); + } declaredMethod.setBody(message.getBody()); effect = true; } } if (!effect) { - throw new IllegalArgumentException("Class or Method not exist."); + throw new IllegalArgumentException("Method not declared here."); } byte[] result = ctClass.toBytecode(); ctClass.detach(); diff --git a/src/main/java/w/core/model/ChangeResultTransformer.java b/src/main/java/w/core/model/ChangeResultTransformer.java index 1baab17..b94d007 100644 --- a/src/main/java/w/core/model/ChangeResultTransformer.java +++ b/src/main/java/w/core/model/ChangeResultTransformer.java @@ -5,6 +5,7 @@ import javassist.CannotCompileException; import javassist.CtClass; import javassist.CtMethod; +import javassist.Modifier; import javassist.expr.ExprEditor; import javassist.expr.MethodCall; import lombok.Data; @@ -54,6 +55,12 @@ public byte[] transform(String className, byte[] origin) throws Exception { Arrays.equals(paramTypes.toArray(new String[0]), Arrays.stream(declaredMethod.getParameterTypes()).map(CtClass::getName).toArray()) ) { + if ((declaredMethod.getModifiers() & Modifier.ABSTRACT) != 0) { + throw new IllegalArgumentException("Cannot change abstract method."); + } + if ((declaredMethod.getModifiers() & Modifier.NATIVE) != 0) { + throw new IllegalArgumentException("Cannot change native method."); + } declaredMethod.instrument(new ExprEditor() { public void edit(MethodCall m) throws CannotCompileException { if (m.getMethodName().equals(innerMethod)) { @@ -67,7 +74,7 @@ public void edit(MethodCall m) throws CannotCompileException { } } if (!effect) { - throw new IllegalArgumentException("Class or Method not exist."); + throw new IllegalArgumentException("Method not declared here."); } byte[] result = ctClass.toBytecode(); ctClass.detach(); diff --git a/src/main/java/w/core/model/OuterWatchTransformer.java b/src/main/java/w/core/model/OuterWatchTransformer.java index ed5c35b..6963053 100644 --- a/src/main/java/w/core/model/OuterWatchTransformer.java +++ b/src/main/java/w/core/model/OuterWatchTransformer.java @@ -46,12 +46,18 @@ public byte[] transform(String className, byte[] origin) throws Exception { boolean effect = false; for (CtMethod declaredMethod : ctClass.getDeclaredMethods()) { if (Objects.equals(declaredMethod.getName(), method)) { + if ((declaredMethod.getModifiers() & Modifier.ABSTRACT) != 0) { + throw new IllegalArgumentException("Cannot change abstract method."); + } + if ((declaredMethod.getModifiers() & Modifier.NATIVE) != 0) { + throw new IllegalArgumentException("Cannot change native method."); + } addOuterWatchCodeToMethod(declaredMethod); effect = true; } } if (!effect) { - throw new IllegalArgumentException("Class or Method not exist."); + throw new IllegalArgumentException("Method not declared here."); } byte[] result = ctClass.toBytecode(); ctClass.detach(); diff --git a/src/main/java/w/core/model/TraceTransformer.java b/src/main/java/w/core/model/TraceTransformer.java index 9f2ecec..c34b9bc 100644 --- a/src/main/java/w/core/model/TraceTransformer.java +++ b/src/main/java/w/core/model/TraceTransformer.java @@ -4,10 +4,7 @@ import java.lang.reflect.Method; import java.util.*; -import javassist.CannotCompileException; -import javassist.CtClass; -import javassist.CtMethod; -import javassist.NotFoundException; +import javassist.*; import javassist.expr.ExprEditor; import javassist.expr.MethodCall; import lombok.Data; @@ -42,12 +39,18 @@ public byte[] transform(String className, byte[] origin) throws Exception { boolean effect = false; for (CtMethod declaredMethod : ctClass.getDeclaredMethods()) { if (Objects.equals(declaredMethod.getName(), method)) { + if ((declaredMethod.getModifiers() & Modifier.ABSTRACT) != 0) { + throw new IllegalArgumentException("Cannot change abstract method."); + } + if ((declaredMethod.getModifiers() & Modifier.NATIVE) != 0) { + throw new IllegalArgumentException("Cannot change native method."); + } addTraceCodeToMethod(declaredMethod); effect = true; } } if (!effect) { - throw new IllegalArgumentException("Class or Method not exist."); + throw new IllegalArgumentException("Method not declared here."); } byte[] result = ctClass.toBytecode(); ctClass.detach(); diff --git a/src/main/java/w/core/model/WatchTransformer.java b/src/main/java/w/core/model/WatchTransformer.java index f13f533..8bd49f9 100644 --- a/src/main/java/w/core/model/WatchTransformer.java +++ b/src/main/java/w/core/model/WatchTransformer.java @@ -8,9 +8,7 @@ import java.io.ByteArrayInputStream; import java.lang.reflect.Method; -import java.util.Map; -import java.util.Objects; -import java.util.UUID; +import java.util.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; @@ -44,12 +42,18 @@ public byte[] transform(String className, byte[] origin) throws Exception { boolean effect = false; for (CtMethod declaredMethod : ctClass.getDeclaredMethods()) { if (Objects.equals(declaredMethod.getName(), method)) { + if ((declaredMethod.getModifiers() & Modifier.ABSTRACT) != 0) { + throw new IllegalArgumentException("Cannot change abstract method."); + } + if ((declaredMethod.getModifiers() & Modifier.NATIVE) != 0) { + throw new IllegalArgumentException("Cannot change native method."); + } addWatchCodeToMethod(declaredMethod); effect = true; } } if (!effect) { - throw new IllegalArgumentException("Class or Method not exist."); + throw new IllegalArgumentException("Method not declared here."); } byte[] result = ctClass.toBytecode(); ctClass.detach(); diff --git a/src/test/java/w/core/AbstractService.java b/src/test/java/w/core/AbstractService.java new file mode 100644 index 0000000..43033c2 --- /dev/null +++ b/src/test/java/w/core/AbstractService.java @@ -0,0 +1,13 @@ +package w.core; + +/** + * @author Frank + * @date 2024/4/30 19:26 + */ +public abstract class AbstractService implements MyInterface { + public String normalParentMethod() { + return "a"; + } + + abstract String abstractParentMethod(); +} diff --git a/src/test/java/w/core/MyInterface.java b/src/test/java/w/core/MyInterface.java new file mode 100644 index 0000000..f7c13aa --- /dev/null +++ b/src/test/java/w/core/MyInterface.java @@ -0,0 +1,13 @@ +package w.core; + +/** + * @author Frank + * @date 2024/4/30 19:29 + */ +public interface MyInterface { + default String interfaceDefaultMethod() { + return "default"; + } + + String interfaceMethod(); +} diff --git a/src/test/java/w/core/R.java b/src/test/java/w/core/R.java new file mode 100644 index 0000000..70d5e96 --- /dev/null +++ b/src/test/java/w/core/R.java @@ -0,0 +1,25 @@ +package w.core; + +import java.util.UUID; + +/** + * @author Frank + * @date 2024/4/30 19:30 + */ +public class R extends AbstractService { + @Override + String abstractParentMethod() { + return "R.abstractParentMethod"; + } + + @Override + public String interfaceMethod() { + return "R.interfaceMethod"; + } + + public int recursive(int n) { + if (n <= 1) return 1; + UUID.randomUUID(); + return recursive(n - 1) + recursive(n - 2); + } +} diff --git a/src/test/java/w/core/R2.java b/src/test/java/w/core/R2.java new file mode 100644 index 0000000..5e6e9ff --- /dev/null +++ b/src/test/java/w/core/R2.java @@ -0,0 +1,17 @@ +package w.core; + +/** + * @author Frank + * @date 2024/4/30 19:30 + */ +public class R2 extends AbstractService { + @Override + String abstractParentMethod() { + return "R2.abstractParentMethod"; + } + + @Override + public String interfaceMethod() { + return "R2.interfaceMethod"; + } +} diff --git a/src/test/java/w/core/SwapperTest.java b/src/test/java/w/core/SwapperTest.java index abcc17a..e95a557 100644 --- a/src/test/java/w/core/SwapperTest.java +++ b/src/test/java/w/core/SwapperTest.java @@ -20,7 +20,10 @@ class SwapperTest { Swapper swapper = Swapper.getInstance();; - TestClass t = new TestClass();; + TestClass t = new TestClass(); + + R r = new R(); + R2 r2 = new R2(); @BeforeAll public static void setUp() throws Exception { @@ -37,10 +40,23 @@ public void reset() { @Test public void watchTest() { WatchMessage watchMessage = new WatchMessage(); - watchMessage.setSignature("w.core.TestClass#hello"); - watchMessage.setMinCost(0); + watchMessage.setSignature("w.core.R#abstractParentMethod"); Assertions.assertTrue(swapper.swap(watchMessage)); - t.hello("world"); + WatchMessage watchMessage2 = new WatchMessage(); + watchMessage2.setSignature("w.core.R#interfaceMethod"); + Assertions.assertTrue(swapper.swap(watchMessage2)); + WatchMessage watchMessage3 = new WatchMessage(); + watchMessage3.setSignature("w.core.AbstractService#normalParentMethod"); + Assertions.assertTrue(swapper.swap(watchMessage3)); + WatchMessage watchMessage4 = new WatchMessage(); + watchMessage4.setSignature("w.core.MyInterface#interfaceDefaultMethod"); + Assertions.assertTrue(swapper.swap(watchMessage4)); + r.abstractParentMethod(); + r.interfaceMethod(); + r.normalParentMethod(); + r.interfaceDefaultMethod(); + r2.normalParentMethod(); + r2.interfaceDefaultMethod(); } @Test @@ -115,11 +131,11 @@ public void replaceClassTest() throws Exception { @Test public void traceRecursiveTest() { TraceMessage message = new TraceMessage(); - message.setSignature("w.core.TestClass#recursive"); + message.setSignature("w.core.R#recursive"); message.setIgnoreZero(false); Assertions.assertTrue(swapper.swap(message)); for (int i = 0; i < 10; i++) { - t.recursive(3); + r.recursive(3); } } } \ No newline at end of file diff --git a/src/test/java/w/core/TestClass.java b/src/test/java/w/core/TestClass.java index e5e4ad0..3008d27 100644 --- a/src/test/java/w/core/TestClass.java +++ b/src/test/java/w/core/TestClass.java @@ -16,10 +16,4 @@ public String wrapperHello(String name) { return hello(name); } - - public int recursive(int n) { - if (n <= 1) return 1; - UUID.randomUUID(); - return recursive(n - 1) + recursive(n - 2); - } }