1 package org.testng.internal.thread;
2 
3 import org.testng.collections.Lists;
4 import org.testng.internal.Utils;
5 
6 import java.util.List;
7 import java.util.concurrent.Callable;
8 import java.util.concurrent.CountDownLatch;
9 import java.util.concurrent.ExecutorService;
10 import java.util.concurrent.LinkedBlockingQueue;
11 import java.util.concurrent.ThreadFactory;
12 import java.util.concurrent.ThreadPoolExecutor;
13 import java.util.concurrent.TimeUnit;
14 
15 /**
16  * A helper class to interface TestNG concurrency usage.
17  *
18  * @author <a href="mailto:the_mindstorm@evolva.ro>Alex Popescu</a>
19  */
20 public class ThreadUtil {
21   private static final String THREAD_NAME = "TestNG";
22 
23   /**
24    * @return true if the current thread was created by TestNG.
25    */
isTestNGThread()26   public static boolean isTestNGThread() {
27     return Thread.currentThread().getName().contains(THREAD_NAME);
28   }
29 
30   /**
31    * Parallel execution of the <code>tasks</code>. The startup is synchronized so this method
32    * emulates a load test.
33    * @param tasks the list of tasks to be run
34    * @param threadPoolSize the size of the parallel threads to be used to execute the tasks
35    * @param timeout a maximum timeout to wait for tasks finalization
36    * @param triggerAtOnce <tt>true</tt> if the parallel execution of tasks should be trigger at once
37    */
execute(List<? extends Runnable> tasks, int threadPoolSize, long timeout, boolean triggerAtOnce)38   public static final void execute(List<? extends Runnable> tasks, int threadPoolSize,
39       long timeout, boolean triggerAtOnce) {
40     final CountDownLatch startGate= new CountDownLatch(1);
41     final CountDownLatch endGate= new CountDownLatch(tasks.size());
42 
43     Utils.log("ThreadUtil", 2, "Starting executor timeOut:" + timeout + "ms"
44         + " workers:" + tasks.size() + " threadPoolSize:" + threadPoolSize);
45     ExecutorService pooledExecutor = // Executors.newFixedThreadPool(threadPoolSize);
46         new ThreadPoolExecutor(threadPoolSize, threadPoolSize,
47         timeout, TimeUnit.MILLISECONDS,
48         new LinkedBlockingQueue<Runnable>(),
49         new ThreadFactory() {
50           @Override
51           public Thread newThread(Runnable r) {
52             Thread result = new Thread(r);
53             result.setName(THREAD_NAME);
54             return result;
55           }
56         });
57 
58     List<Callable<Object>> callables = Lists.newArrayList();
59     for (final Runnable task : tasks) {
60       callables.add(new Callable<Object>() {
61 
62         @Override
63         public Object call() throws Exception {
64           task.run();
65           return null;
66         }
67 
68       });
69     }
70     try {
71       if (timeout != 0) {
72         pooledExecutor.invokeAll(callables, timeout, TimeUnit.MILLISECONDS);
73       } else {
74         pooledExecutor.invokeAll(callables);
75       }
76     } catch (InterruptedException handled) {
77       handled.printStackTrace();
78       Thread.currentThread().interrupt();
79     } finally {
80       pooledExecutor.shutdown();
81     }
82   }
83 
84   /**
85    * Returns a readable name of the current executing thread.
86    */
currentThreadInfo()87   public static final String currentThreadInfo() {
88     Thread thread= Thread.currentThread();
89     return String.valueOf(thread.getName() + "@" + thread.hashCode());
90   }
91 
createExecutor(int threadCount, String threadFactoryName)92   public static final IExecutor createExecutor(int threadCount, String threadFactoryName) {
93     return new ExecutorAdapter(threadCount, createFactory(threadFactoryName));
94   }
95 
createFactory(String name)96   private static final IThreadFactory createFactory(String name) {
97     return new ThreadFactoryImpl(name);
98   }
99 
log(int level, String msg)100   private static void log(int level, String msg) {
101     Utils.log("ThreadUtil:" + ThreadUtil.currentThreadInfo(), level, msg);
102   }
103 
104   public static class ThreadFactoryImpl implements IThreadFactory, ThreadFactory {
105     private String m_methodName;
106     private List<Thread> m_threads = Lists.newArrayList();
107 
ThreadFactoryImpl(String name)108     public ThreadFactoryImpl(String name) {
109       m_methodName= name;
110     }
111 
112     @Override
newThread(Runnable run)113     public Thread newThread(Runnable run) {
114       Thread result = new TestNGThread(run, m_methodName);
115       m_threads.add(result);
116       return result;
117     }
118 
119     @Override
getThreadFactory()120     public Object getThreadFactory() {
121       return this;
122     }
123 
124     @Override
getThreads()125     public List<Thread> getThreads() {
126       return m_threads;
127     }
128   }
129 }
130