--- /home/cpdev/src/classpath/java/lang/InheritableThreadLocal.java	2005-07-02 21:03:32.000000000 +0000
+++ java/lang/InheritableThreadLocal.java	2005-06-30 05:34:38.000000000 +0000
@@ -37,7 +37,11 @@
 
 package java.lang;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
 import java.util.WeakHashMap;
 
 /**
@@ -60,11 +64,30 @@
 public class InheritableThreadLocal extends ThreadLocal
 {
   /**
+   * Maps Threads to a List of InheritableThreadLocals (the heritage of that
+   * Thread). Uses a WeakHashMap so if the Thread is garbage collected the
+   * List can be collected, too. Maps to a list in case the user overrides
+   * equals.
+   */
+  private static final Map threadMap
+	  = Collections.synchronizedMap(new WeakHashMap());
+
+  /**
    * Creates a new InheritableThreadLocal that has no values associated
    * with it yet.
    */
   public InheritableThreadLocal()
   {
+    Thread currentThread = Thread.currentThread();
+    // Note that we don't have to synchronize, as only this thread will
+    // ever modify the returned heritage and threadMap is a synchronizedMap.
+    List heritage = (List) threadMap.get(currentThread);
+    if (heritage == null)
+      {
+        heritage = new ArrayList();
+        threadMap.put(currentThread, heritage);
+      }
+    heritage.add(this);
   }
 
   /**
@@ -93,22 +116,25 @@
   {
     // The currentThread is the parent of the new thread.
     Thread parentThread = Thread.currentThread();
-    if (parentThread.locals != null)
+    // Note that we don't have to synchronize, as only this thread will
+    // ever modify the returned heritage and threadMap is a synchronizedMap. 
+    ArrayList heritage = (ArrayList) threadMap.get(parentThread);
+    if (heritage != null)
       {
-        Iterator keys = parentThread.locals.keySet().iterator();
-        while (keys.hasNext())
+        threadMap.put(childThread, heritage.clone());
+        // Perform the inheritance.
+        Iterator it = heritage.iterator();
+        int i = heritage.size();
+        while (--i >= 0)
           {
-            Key key = (Key)keys.next();
-            if (key.get() instanceof InheritableThreadLocal)
+            InheritableThreadLocal local = (InheritableThreadLocal) it.next();
+            Object parentValue = local.valueMap.get(parentThread);
+            if (parentValue != null)
               {
-                InheritableThreadLocal local = (InheritableThreadLocal)key.get();
-                Object parentValue = parentThread.locals.get(key);
                 Object childValue = local.childValue(parentValue == NULL
                                                      ? null : parentValue);
-                if (childThread.locals == null)
-                    childThread.locals = new WeakHashMap();
-                childThread.locals.put(key, (childValue == null
-                                             ? NULL : childValue));
+                local.valueMap.put(childThread, (childValue == null
+                                                 ? NULL : parentValue));
               }
           }
       }
