mm: memcontrol: don't count limit-setting reclaim as memory pressure

When an outside process lowers one of the memory limits of a cgroup (or
uses the force_empty knob in cgroup1), direct reclaim is performed in the
context of the write(), in order to directly enforce the new limit and
have it being met by the time the write() returns.

Currently, this reclaim activity is accounted as memory pressure in the
cgroup that the writer(!) belongs to.  This is unexpected.  It
specifically causes problems for senpai
(https://github.com/facebookincubator/senpai), which is an agent that
routinely adjusts the memory limits and performs associated reclaim work
in tens or even hundreds of cgroups running on the host.  The cgroup that
senpai is running in itself will report elevated levels of memory
pressure, even though it itself is under no memory shortage or any sort of
distress.

Move the psi annotation from the central cgroup reclaim function to
callsites in the allocation context, and thereby no longer count any
limit-setting reclaim as memory pressure.  If the newly set limit causes
the workload inside the cgroup into direct reclaim, that of course will
continue to count as memory pressure.

Signed-off-by: Johannes Weiner <hannes@cmpxchg.org>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>
Reviewed-by: Shakeel Butt <shakeelb@google.com>
Reviewed-by: Roman Gushchin <guro@fb.com>
Acked-by: Chris Down <chris@chrisdown.name>
Acked-by: Michal Hocko <mhocko@suse.com>
Link: http://lkml.kernel.org/r/20200728135210.379885-2-hannes@cmpxchg.org
Signed-off-by: Linus Torvalds <torvalds@linux-foundation.org>
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 856db3b..8d9ceea 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -2374,12 +2374,18 @@
 	unsigned long nr_reclaimed = 0;
 
 	do {
+		unsigned long pflags;
+
 		if (page_counter_read(&memcg->memory) <=
 		    READ_ONCE(memcg->memory.high))
 			continue;
+
 		memcg_memory_event(memcg, MEMCG_HIGH);
+
+		psi_memstall_enter(&pflags);
 		nr_reclaimed += try_to_free_mem_cgroup_pages(memcg, nr_pages,
 							     gfp_mask, true);
+		psi_memstall_leave(&pflags);
 	} while ((memcg = parent_mem_cgroup(memcg)) &&
 		 !mem_cgroup_is_root(memcg));
 
@@ -2621,10 +2627,11 @@
 	int nr_retries = MAX_RECLAIM_RETRIES;
 	struct mem_cgroup *mem_over_limit;
 	struct page_counter *counter;
+	enum oom_status oom_status;
 	unsigned long nr_reclaimed;
 	bool may_swap = true;
 	bool drained = false;
-	enum oom_status oom_status;
+	unsigned long pflags;
 
 	if (mem_cgroup_is_root(memcg))
 		return 0;
@@ -2684,8 +2691,10 @@
 
 	memcg_memory_event(mem_over_limit, MEMCG_MAX);
 
+	psi_memstall_enter(&pflags);
 	nr_reclaimed = try_to_free_mem_cgroup_pages(mem_over_limit, nr_pages,
 						    gfp_mask, may_swap);
+	psi_memstall_leave(&pflags);
 
 	if (mem_cgroup_margin(mem_over_limit) >= nr_pages)
 		goto retry;
diff --git a/mm/vmscan.c b/mm/vmscan.c
index 5747867..23156c2 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -3310,7 +3310,6 @@
 					   bool may_swap)
 {
 	unsigned long nr_reclaimed;
-	unsigned long pflags;
 	unsigned int noreclaim_flag;
 	struct scan_control sc = {
 		.nr_to_reclaim = max(nr_pages, SWAP_CLUSTER_MAX),
@@ -3331,17 +3330,12 @@
 	struct zonelist *zonelist = node_zonelist(numa_node_id(), sc.gfp_mask);
 
 	set_task_reclaim_state(current, &sc.reclaim_state);
-
 	trace_mm_vmscan_memcg_reclaim_begin(0, sc.gfp_mask);
-
-	psi_memstall_enter(&pflags);
 	noreclaim_flag = memalloc_noreclaim_save();
 
 	nr_reclaimed = do_try_to_free_pages(zonelist, &sc);
 
 	memalloc_noreclaim_restore(noreclaim_flag);
-	psi_memstall_leave(&pflags);
-
 	trace_mm_vmscan_memcg_reclaim_end(nr_reclaimed);
 	set_task_reclaim_state(current, NULL);