Propagating ThreadLocal to a new Thread fetched from a ExecutorService

后端 未结 6 743
傲寒
傲寒 2020-12-02 11:31

I\'m running a process in a separate thread with a timeout, using an ExecutorService and a Future (example code here) (the thread \"spawning\" takes place in a AOP Aspect).<

6条回答
  •  -上瘾入骨i
    2020-12-02 11:54

    Based on @erickson answer I wrote this code. It is working for inheritableThreadLocals. It builds list of inheritableThreadLocals using same method as is used in Thread contructor. Of course I use reflection to do this. Also I override the executor class.

    public class MyThreadPoolExecutor extends ThreadPoolExecutor
    {
       @Override
       public void execute(Runnable command)
       {
          super.execute(new Wrapped(command, Thread.currentThread()));
       }
    }
    

    Wrapper:

       private class Wrapped implements Runnable
       {
          private final Runnable task;
    
          private final Thread caller;
    
          public Wrapped(Runnable task, Thread caller)
          {
             this.task = task;
             this.caller = caller;
          }
    
          public void run()
          {
             Iterable> vars = null;
             try
             {
                vars = copy(caller);
             }
             catch (Exception e)
             {
                throw new RuntimeException("error when coping Threads", e);
             }
             try {
                task.run();
             }
             finally {
                for (ThreadLocal var : vars)
                   var.remove();
             }
          }
       }
    

    copy method:

    public static Iterable> copy(Thread caller) throws Exception
       {
          List> threadLocals = new ArrayList<>();
          Field field = Thread.class.getDeclaredField("inheritableThreadLocals");
          field.setAccessible(true);
          Object map = field.get(caller);
          Field table = Class.forName("java.lang.ThreadLocal$ThreadLocalMap").getDeclaredField("table");
          table.setAccessible(true);
    
          Method method = ThreadLocal.class
                  .getDeclaredMethod("createInheritedMap", Class.forName("java.lang.ThreadLocal$ThreadLocalMap"));
          method.setAccessible(true);
          Object o = method.invoke(null, map);
    
          Field field2 = Thread.class.getDeclaredField("inheritableThreadLocals");
          field2.setAccessible(true);
          field2.set(Thread.currentThread(), o);
    
          Object tbl = table.get(o);
          int length = Array.getLength(tbl);
          for (int i = 0; i < length; i++)
          {
             Object entry = Array.get(tbl, i);
             Object value = null;
             if (entry != null)
             {
                Method referentField = Class.forName("java.lang.ThreadLocal$ThreadLocalMap$Entry").getMethod(
                        "get");
                referentField.setAccessible(true);
                value = referentField.invoke(entry);
                threadLocals.add((ThreadLocal) value);
             }
          }
          return threadLocals;
       }
    

提交回复
热议问题