flatMap in Java

Mohammad Irfan Oct 12, 2023
  1. The flatMap Function in Java
  2. How To Use flatMap() Method in Java
  3. Remove Duplicate From flatMap in Java
  4. Filter flatMap Elements in Java
  5. flatMap and Primitive type
  6. Conclusion
flatMap in Java

This tutorial introduces flatMap and how to use it in Java.

flatMap is an operation/function in the Java stream used to get a new stream after performing some functional task. Here, we will talk about the flatMap() operation.

This operation is an extension of the map() operation. This function is applied to each stream element and generates a stream of new values.

These new streams’ generated elements are then copied to a recently created stream, which serves as the method’s return value.

The flatMap Function in Java

The signature of the flatMap() function is:

<R> Stream<R> flatMap(Function<? super T,? extends Stream<? extends R>> mapper)

flatMap is an intermediate operation. Intermediate operations are lazy operations and invoked on a Stream instance, and after they finish their processing, they return a Stream instance.

The flatMap() operation is a combination of the map() and flat() operation. This means that the flatMap() operation first applies the map() operation, then flattens the result.

This method takes a mapper, a function that is to be applied to each element of the incoming Stream.

Flattening in Java

Let us first understand what flattening is. Flattening a list converts two or more level lists to a single level list.

An example of a two 2-level list is:

[[ "I" ], [ "Love" ], [ "Delft", "Stack" ]]

The above list after flattening gets converted into:

["I", "Love", "Delft", "Stack"]

The list produced is a single-level list.

Need for Flattening a List

There is a need to flatten a list because dealing with several Stream levels can be difficult, complicated, and error-prone.

We can use the Stream.flatMap() operation to convert the two Stream levels into a single level Stream. We will understand this using an example later in this article.

How To Use flatMap() Method in Java

In this example, we first created a stream of objects from a List using the stream() method. Each object is a programmer in a company.

We will first create a class to represent a developer/programmer in a company.

import java.util.HashSet;
import java.util.Set;
class Programmer {
  private String name;
  private Set<String> languages_known;

  public Programmer(String name) {
    this.name = name;
    this.languages_known = new HashSet<>();
  }

  public void addLanguage(String lang) {
    this.languages_known.add(lang);
  }

  public Set<String> getLanguages() {
    return languages_known;
  }
}

We will now initialize objects and create a list of lists containing all the languages known by the programmers in the whole company. We will then flatten that list to know all the languages in the team.

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class SimpleTesting {
  public static void main(String[] args) {
    Programmer raj = new Programmer("Raj");
    raj.addLanguage("Java");
    raj.addLanguage("Dart");
    raj.addLanguage("go");
    raj.addLanguage("groovy");

    Programmer karan = new Programmer("Karan");
    karan.addLanguage("Java");
    karan.addLanguage("Python");

    Programmer chahal = new Programmer("Chahal");
    chahal.addLanguage("Dart");
    chahal.addLanguage("Javascript");

    List<Programmer> team = new ArrayList<>();
    team.add(raj);
    team.add(karan);
    team.add(chahal);

    System.out.println("Programming languages in the team: ");
    List<String> languages = team.stream()
                                 .map(Programmer::getLanguages)
                                 .flatMap(Collection::stream)
                                 .collect(Collectors.toList());
    System.out.println(languages);
  }
}

Output:

Programming languages in the team: 
[Java, groovy, go, Dart, Java, Python, Javascript, Dart]

In the above example, we first created a stream of all the programmers using the stream API. After that, we created a stream of lists of languages known by each programmer using the map() function.

Remove Duplicate From flatMap in Java

We then flatten this list using the flatMap() operation and convert the resulting stream into a list. Notice that the List produced has some duplicate values; we use the duplicate() operation to eliminate these.

Look at the code below.

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class SimpleTesting {
  public static void main(String[] args) {
    Programmer raj = new Programmer("Raj");
    raj.addLanguage("Java");
    raj.addLanguage("Dart");
    raj.addLanguage("go");
    raj.addLanguage("groovy");
    Programmer karan = new Programmer("Karan");
    karan.addLanguage("Java");
    karan.addLanguage("Python");
    Programmer chahal = new Programmer("Chahal");
    chahal.addLanguage("Dart");
    chahal.addLanguage("Javascript");
    List<Programmer> team = new ArrayList<>();
    team.add(raj);
    team.add(karan);
    team.add(chahal);
    System.out.println("Programming languages in the team: ");
    List<String> languages = team.stream()
                                 .map(Programmer::getLanguages)
                                 .flatMap(Collection::stream)
                                 .distinct()
                                 .collect(Collectors.toList());
    System.out.println(languages);
  }
}

Output:

Programming languages in the team: 
[Java, groovy, go, Dart, Python, Javascript]

Filter flatMap Elements in Java

If we want to get all the languages except Dart, we can use the filter() function with flatMap(). Look at the code below.

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

public class SimpleTesting {
  public static void main(String[] args) {
    Programmer raj = new Programmer("Raj");
    raj.addLanguage("Java");
    raj.addLanguage("Dart");
    raj.addLanguage("go");
    raj.addLanguage("groovy");
    Programmer karan = new Programmer("Karan");
    karan.addLanguage("Java");
    karan.addLanguage("Python");
    Programmer chahal = new Programmer("Chahal");
    chahal.addLanguage("Dart");
    chahal.addLanguage("Javascript");
    List<Programmer> team = new ArrayList<>();
    team.add(raj);
    team.add(karan);
    team.add(chahal);
    System.out.println("Programming languages in the team: ");
    List<String> languages = team.stream()
                                 .map(Programmer::getLanguages)
                                 .flatMap(Collection::stream)
                                 .distinct()
                                 .filter(x -> !x.equals("Dart"))
                                 .collect(Collectors.toList());
    System.out.println(languages);
  }
}

Output:

Programming languages in the team:
[Java, groovy, go, Python, Javascript]

flatMap and Primitive type

Java Stream API also provides separate operations like flatMapto{primitive type} for primitive data types like int, float, long to flat the Stream of primitive type.

import java.util.Arrays;
import java.util.stream.IntStream;
import java.util.stream.Stream;

public class SimpleTesting {
  public static void main(String[] args) {
    int[] Je_array = {3, 5, 2, 35, 85, 32, 23, 43, 12};
    Stream<int[]> JE_streamArray = Stream.of(Je_array);
    IntStream JE_intStream = JE_streamArray.flatMapToInt(x -> Arrays.stream(x));
    JE_intStream.forEach(x -> System.out.println(x));
  }
}

Output:

3
5
2
35
85
32
23
43
12

Conclusion

We have discussed the flatMap() operation and why it is needed. We also discussed how the Java Stream API provides a separate flatMap() operation for primitive datatypes.

Note that the normal flatMap() operation is applicable on primitive datatypes as well.

Related Article - Java Stream