diff --git a/src/java.base/share/classes/java/util/IdentityHashMap.java b/src/java.base/share/classes/java/util/IdentityHashMap.java
index 4795c30b3d5ee..83428c1a62c77 100644
--- a/src/java.base/share/classes/java/util/IdentityHashMap.java
+++ b/src/java.base/share/classes/java/util/IdentityHashMap.java
@@ -31,6 +31,7 @@
 import java.util.function.BiConsumer;
 import java.util.function.BiFunction;
 import java.util.function.Consumer;
+import java.util.function.Function;
 import jdk.internal.access.SharedSecrets;
 
 /**
@@ -308,6 +309,43 @@ private static int nextKeyIndex(int i, int len) {
         return (i + 2 < len ? i + 2 : 0);
     }
 
+    /**
+     * Finds an interesting index for a key such that the key is present
+     * at that index in the table, or an open index where the key could
+     * be stored. Guaranteed that {@code tab[ret] == null} or {@code
+     * tab[ret] == key}.
+     *
+     * @param tab the hash table
+     * @param key the key, null-masked
+     * @return the index found
+     */
+    private static int findInterestingIndex(Object[] tab, Object key) {
+        final int len = tab.length;
+        Object item;
+        int i = hash(key, len);
+        while ((item = tab[i]) != key && item != null) {
+            i = nextKeyIndex(i, len);
+        }
+        return i;
+    }
+
+    /**
+     * Finds an open index where the key could be stored. Guaranteed that
+     * {@code tab[ret] == null}.
+     *
+     * @param tab the hash table
+     * @param key the key, null-masked
+     * @return the index found
+     */
+    private static int findOpenIndex(Object[] tab, Object key) {
+        final int len = tab.length;
+        int i = hash(key, len);
+        while (tab[i] != null) {
+            i = nextKeyIndex(i, len);
+        }
+        return i;
+    }
+
     /**
      * Returns the value to which the specified key is mapped,
      * or {@code null} if this map contains no mapping for the key.
@@ -325,20 +363,8 @@ private static int nextKeyIndex(int i, int len) {
      *
      * @see #put(Object, Object)
      */
-    @SuppressWarnings("unchecked")
     public V get(Object key) {
-        Object k = maskNull(key);
-        Object[] tab = table;
-        int len = tab.length;
-        int i = hash(k, len);
-        while (true) {
-            Object item = tab[i];
-            if (item == k)
-                return (V) tab[i + 1];
-            if (item == null)
-                return null;
-            i = nextKeyIndex(i, len);
-        }
+        return getOrDefault(key, null);
     }
 
     /**
@@ -353,16 +379,7 @@ public V get(Object key) {
     public boolean containsKey(Object key) {
         Object k = maskNull(key);
         Object[] tab = table;
-        int len = tab.length;
-        int i = hash(k, len);
-        while (true) {
-            Object item = tab[i];
-            if (item == k)
-                return true;
-            if (item == null)
-                return false;
-            i = nextKeyIndex(i, len);
-        }
+        return tab[findInterestingIndex(tab, k)] == k;
     }
 
     /**
@@ -394,16 +411,8 @@ public boolean containsValue(Object value) {
     private boolean containsMapping(Object key, Object value) {
         Object k = maskNull(key);
         Object[] tab = table;
-        int len = tab.length;
-        int i = hash(k, len);
-        while (true) {
-            Object item = tab[i];
-            if (item == k)
-                return tab[i + 1] == value;
-            if (item == null)
-                return false;
-            i = nextKeyIndex(i, len);
-        }
+        int i = findInterestingIndex(tab, k);
+        return tab[i] == k && tab[i + 1] == value;
     }
 
     /**
@@ -422,35 +431,7 @@ private boolean containsMapping(Object key, Object value) {
      * @see     #containsKey(Object)
      */
     public V put(K key, V value) {
-        final Object k = maskNull(key);
-
-        retryAfterResize: for (;;) {
-            final Object[] tab = table;
-            final int len = tab.length;
-            int i = hash(k, len);
-
-            for (Object item; (item = tab[i]) != null;
-                 i = nextKeyIndex(i, len)) {
-                if (item == k) {
-                    @SuppressWarnings("unchecked")
-                        V oldValue = (V) tab[i + 1];
-                    tab[i + 1] = value;
-                    return oldValue;
-                }
-            }
-
-            final int s = size + 1;
-            // Use optimized form of 3 * s.
-            // Next capacity is len, 2 * current capacity.
-            if (s + (s << 1) > len && resize(len))
-                continue retryAfterResize;
-
-            modCount++;
-            tab[i] = k;
-            tab[i + 1] = value;
-            size = s;
-            return null;
-        }
+        return put(key, value, true);
     }
 
     /**
@@ -481,9 +462,7 @@ private boolean resize(int newCapacity) {
                 Object value = oldTable[j+1];
                 oldTable[j] = null;
                 oldTable[j+1] = null;
-                int i = hash(key, newLength);
-                while (newTable[i] != null)
-                    i = nextKeyIndex(i, newLength);
+                int i = findOpenIndex(newTable, key);
                 newTable[i] = key;
                 newTable[i + 1] = value;
             }
@@ -523,25 +502,14 @@ public void putAll(Map<? extends K, ? extends V> m) {
     public V remove(Object key) {
         Object k = maskNull(key);
         Object[] tab = table;
-        int len = tab.length;
-        int i = hash(k, len);
-
-        while (true) {
-            Object item = tab[i];
-            if (item == k) {
-                modCount++;
-                size--;
-                @SuppressWarnings("unchecked")
-                    V oldValue = (V) tab[i + 1];
-                tab[i + 1] = null;
-                tab[i] = null;
-                closeDeletion(i);
-                return oldValue;
-            }
-            if (item == null)
-                return null;
-            i = nextKeyIndex(i, len);
-        }
+        int i = findInterestingIndex(tab, k);
+        if (tab[i] == k) {
+            @SuppressWarnings("unchecked")
+            V oldValue = (V) tab[i + 1];
+            delete(tab, i);
+            return oldValue;
+        } else
+            return null;
     }
 
     /**
@@ -555,25 +523,14 @@ public V remove(Object key) {
     private boolean removeMapping(Object key, Object value) {
         Object k = maskNull(key);
         Object[] tab = table;
-        int len = tab.length;
-        int i = hash(k, len);
-
-        while (true) {
-            Object item = tab[i];
-            if (item == k) {
-                if (tab[i + 1] != value)
-                    return false;
-                modCount++;
-                size--;
-                tab[i] = null;
-                tab[i + 1] = null;
-                closeDeletion(i);
-                return true;
-            }
-            if (item == null)
+        int i = findInterestingIndex(tab, k);
+        if (tab[i] == k) {
+            if (tab[i + 1] != value)
                 return false;
-            i = nextKeyIndex(i, len);
-        }
+            delete(tab, i);
+            return true;
+        } else
+            return false;
     }
 
     /**
@@ -1379,6 +1336,277 @@ public void replaceAll(BiFunction<? super K, ? super V, ? extends V> function) {
         }
     }
 
+    @SuppressWarnings("unchecked")
+    @Override
+    public V getOrDefault(Object key, V fallback) {
+        Object k = maskNull(key);
+        Object[] tab = table;
+        int i = findInterestingIndex(tab, k);
+        return tab[i] == k ? (V) tab[i + 1] : fallback;
+    }
+
+    @Override
+    public V putIfAbsent(K key, V value) {
+        return put(key, value, false);
+    }
+
+    @Override
+    public V replace(K key, V value) {
+        Object k = maskNull(key);
+        Object[] tab = table;
+        int i = findInterestingIndex(tab, k);
+        if (tab[i] == k) {
+            @SuppressWarnings("unchecked")
+            final V prev = (V) tab[i + 1];
+            tab[i + 1] = value;
+            return prev;
+        } else
+            return null;
+    }
+
+    /**
+     * {@inheritDoc}
+     *
+     * <p>This method will, on a best-effort basis, throw a
+     * {@link ConcurrentModificationException} if it is detected that the
+     * mapping function modifies this map during computation.
+     *
+     * @throws ConcurrentModificationException if it is detected that the
+     * mapping function modified this map
+     */
+    @Override
+    public V computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction) {
+        Objects.requireNonNull(mappingFunction);
+
+        Object k = maskNull(key);
+        Object[] tab = table;
+        int i = findInterestingIndex(tab, k);
+        if (tab[i] == k) {
+            @SuppressWarnings("unchecked")
+            final V oldValue = (V) tab[i + 1];
+            if (oldValue != null)
+                return oldValue;
+
+            // replace null old value, per specification
+            final V newValue = callFunction(key, mappingFunction);
+            if (newValue != null) {
+                tab[i + 1] = newValue;
+            }
+            return newValue;
+        } else
+            return maybeAddNewEntry(tab, i, k, callFunction(key, mappingFunction));
+    }
+
+    /**
+     * {@inheritDoc}
+     *
+     * <p>This method will, on a best-effort basis, throw a
+     * {@link ConcurrentModificationException} if it is detected that the
+     * remapping function modifies this map during computation.
+     *
+     * @throws ConcurrentModificationException if it is detected that the
+     * remapping function modified this map
+     */
+    @Override
+    public V computeIfPresent(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
+        Objects.requireNonNull(remappingFunction);
+
+        Object k = maskNull(key);
+        Object[] tab = table;
+        int i = findInterestingIndex(tab, k);
+        if (tab[i] == k) {
+            @SuppressWarnings("unchecked")
+            final V oldValue = (V) tab[i + 1];
+            if (oldValue == null) {
+                return null;
+            }
+
+            return updateByNewValue(tab, i, callFunction(key, oldValue, remappingFunction));
+        } else
+            return null;
+    }
+
+    /**
+     * {@inheritDoc}
+     *
+     * <p>This method will, on a best-effort basis, throw a
+     * {@link ConcurrentModificationException} if it is detected that the
+     * remapping function modifies this map during computation.
+     *
+     * @throws ConcurrentModificationException if it is detected that the
+     * remapping function modified this map
+     */
+    @Override
+    public V compute(K key, BiFunction<? super K, ? super V, ? extends V> remappingFunction) {
+        Objects.requireNonNull(remappingFunction);
+
+        Object k = maskNull(key);
+        Object[] tab = table;
+        int i = findInterestingIndex(tab, k);
+        if (tab[i] == k) {
+            @SuppressWarnings("unchecked")
+            final V oldValue = (V) tab[i + 1];
+            return updateByNewValue(tab, i, callFunction(key, oldValue, remappingFunction));
+        } else
+            return maybeAddNewEntry(tab, i, k, callFunction(key, null, remappingFunction));
+    }
+
+    /**
+     * {@inheritDoc}
+     *
+     * <p>This method will, on a best-effort basis, throw a
+     * {@link ConcurrentModificationException} if it is detected that the
+     * remapping function modifies this map during computation.
+     *
+     * @throws ConcurrentModificationException if it is detected that the
+     * remapping function modified this map
+     */
+    @Override
+    public V merge(K key, V value, BiFunction<? super V, ? super V, ? extends V> remappingFunction) {
+        Objects.requireNonNull(value);
+        Objects.requireNonNull(remappingFunction);
+
+        Object k = maskNull(key);
+        Object[] tab = table;
+        int i = findInterestingIndex(tab, k);
+        if (tab[i] == k) {
+            @SuppressWarnings("unchecked")
+            final V oldValue = (V) tab[i + 1];
+            return updateByNewValue(tab, i, mergeValue(oldValue, value, remappingFunction));
+        } else
+            return maybeAddNewEntry(tab, i, k, value);
+    }
+
+    private V callFunction(K key, Function<? super K, ? extends V> function) {
+        final int expectedModCount = modCount;
+        V result = function.apply(key);
+        if (expectedModCount != modCount) {
+            throw new ConcurrentModificationException();
+        }
+        return result;
+    }
+
+    private V callFunction(K key, V value, BiFunction<? super K, ? super V, ? extends V> function) {
+        final int expectedModCount = modCount;
+        V result = function.apply(key, value);
+        if (expectedModCount != modCount) {
+            throw new ConcurrentModificationException();
+        }
+        return result;
+    }
+
+    private V mergeValue(V oldValue, V value, BiFunction<? super V, ? super V, ? extends V> function) {
+        if (oldValue == null) {
+            return value;
+        }
+
+        final int expectedModCount = modCount;
+        V result = function.apply(oldValue, value);
+        if (expectedModCount != modCount) {
+            throw new ConcurrentModificationException();
+        }
+        return result;
+    }
+
+    private V updateByNewValue(Object[] tab, int i, V newValue) {
+        if (newValue != null) {
+            tab[i + 1] = newValue;
+        } else {
+            delete(tab, i);
+        }
+        return newValue;
+    }
+
+    /**
+     * Deletes a mapping from this map's table. Increases modCount as it
+     * changes map size.
+     *
+     * @param tab the table, should be equivalent to {@code this.table}
+     * @param i the index of the object to delete
+     */
+    private void delete(Object[] tab, int i) {
+        modCount++;
+        size--;
+        tab[i] = null;
+        tab[i + 1] = null;
+        closeDeletion(i);
+    }
+
+    /**
+     * Shared implementation of put and putIfAbsent.
+     *
+     * @param key key with which the specified value is to be associated
+     * @param value value to be associated with the specified key
+     * @param replace whether a non-null existing value is to be replaced if the key is present
+     * @return the value associated to the key before the call, or {@code null} if
+     * there was no previously associated value
+     */
+    private V put(K key, V value, boolean replace) {
+        final Object k = maskNull(key);
+        Object[] tab = table;
+        int i = findInterestingIndex(tab, k);
+        if (tab[i] == k) {
+            @SuppressWarnings("unchecked")
+            V oldValue = (V) tab[i + 1];
+            if (replace || oldValue == null) {
+                tab[i + 1] = value;
+            }
+            return oldValue;
+        }
+
+        addNewEntry(tab, i, k, value);
+        return null;
+    }
+
+    /**
+     * Adds an entry to this map if and only if {@code newValue} is not null.
+     *
+     * @param tab the hash table of this map, may be reused
+     * @param i the current index of k in the table
+     * @param k the key
+     * @param newValue the value
+     * @return the value
+     */
+    private V maybeAddNewEntry(Object[] tab, int i, Object k, V newValue) {
+        if (newValue == null) {
+            return null;
+        }
+
+        return addNewEntry(tab, i, k, newValue);
+    }
+
+    /**
+     * Adds a new entry to this map, associating {@code k} with the {@code newValue}.
+     * Accepts null values. Increases modCount as it changes map size.
+     *
+     * @param tab the hash table of this map, may be reused
+     * @param i the current index of k in the table
+     * @param k the key
+     * @param newValue the value
+     * @return the value
+     */
+    private V addNewEntry(Object[] tab, int i, Object k, V newValue) {
+        int len = tab.length;
+        do {
+            final int s = size + 1;
+            // Use optimized form of 3 * s.
+            // Next capacity is len, 2 * current capacity.
+            if (!(s + (s << 1) > len && resize(len)))
+                break;
+
+            tab = table;
+            len = tab.length;
+            // findInterestingIndex should return the same value here
+            i = findOpenIndex(tab, k);
+        } while (true);
+
+        modCount++;
+        tab[i] = k;
+        tab[i + 1] = newValue;
+        size++;
+        return newValue;
+    }
+
     /**
      * Similar form as array-based Spliterators, but skips blank elements,
      * and guestimates size as decreasing by half per split.
diff --git a/test/jdk/java/util/Map/FunctionalCMEs.java b/test/jdk/java/util/Map/FunctionalCMEs.java
index 51cc85efe80f7..f42d5c84c4e05 100644
--- a/test/jdk/java/util/Map/FunctionalCMEs.java
+++ b/test/jdk/java/util/Map/FunctionalCMEs.java
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2015, 2020, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2015, 2021, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -27,6 +27,7 @@
 import java.util.ConcurrentModificationException;
 import java.util.HashMap;
 import java.util.Hashtable;
+import java.util.IdentityHashMap;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.Map;
@@ -55,6 +56,7 @@ private static Iterator<Object[]> makeMaps() {
                 new Object[]{new Hashtable<>(), true},
                 new Object[]{new LinkedHashMap<>(), true},
                 new Object[]{new TreeMap<>(), true},
+                new Object[]{new IdentityHashMap<>(), true},
                 // Test default Map methods - no CME
                 new Object[]{new Defaults.ExtendsAbstractMap<>(), false}
         ).iterator();
diff --git a/test/micro/org/openjdk/bench/java/util/IdentityHashMapBench.java b/test/micro/org/openjdk/bench/java/util/IdentityHashMapBench.java
new file mode 100644
index 0000000000000..54d860e703e5b
--- /dev/null
+++ b/test/micro/org/openjdk/bench/java/util/IdentityHashMapBench.java
@@ -0,0 +1,105 @@
+/*
+ * Copyright (c) 2018, Red Hat, Inc. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.bench.java.util;
+
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.TearDown;
+import org.openjdk.jmh.infra.Blackhole;
+
+import java.util.IdentityHashMap;
+import java.util.Map;
+import java.util.concurrent.ThreadLocalRandom;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.MILLISECONDS)
+@State(Scope.Thread)
+public class IdentityHashMapBench {
+    private Supplier<Map<Object, Object>> mapSupplier;
+    private Object[] objects;
+    // orders, 2 * i th is key, 2* i+1 th is value
+    private Object[] orders;
+
+    @Param("1000000")
+    private int orderSize;
+
+    @Param("100000")
+    private int size;
+
+    @Setup @TearDown
+    public void setup() {
+        mapSupplier = IdentityHashMap::new;
+
+        {
+            final int size = this.size;
+            final Object[] objects = new Object[size];
+            for (int i = 0; i < size; i++) {
+                objects[i] = new Object();
+            }
+            this.objects = objects;
+        }
+
+        {
+            ThreadLocalRandom rnd = ThreadLocalRandom.current();
+            final int poolSize = this.size;
+            final Object[] objects = this.objects;
+            final int size = this.orderSize;
+            final Object[] orders = new Object[size];
+            for (int i = 0; i < size; i++) {
+                orders[i] = objects[rnd.nextInt(poolSize)];
+            }
+            this.orders = orders;
+        }
+    }
+
+    @Benchmark
+    public int putBench(Blackhole blackhole) {
+        var map = mapSupplier.get();
+        final Object[] data = this.objects;
+        final int len = data.length;
+        for (int i = 0; i < len; i += 2) {
+            blackhole.consume(map.put(data[i], data[i + 1]));
+        }
+        return map.size();
+    }
+
+    @Benchmark
+    public int putIfAbsentBench(Blackhole blackhole) {
+        var map = mapSupplier.get();
+        final Object[] data = this.objects;
+        final int len = data.length;
+        for (int i = 0; i < len; i += 2) {
+            blackhole.consume(map.putIfAbsent(data[i], data[i + 1]));
+        }
+        return map.size();
+    }
+}