一道题目带你搞懂回溯算法

Python七号

共 2900字,需浏览 6分钟

 · 2021-02-12

明天就过年了,无论你有没有回到家乡,都祝你新的一年,有所收获。

今天分享一道算法题,希望能让你学会回溯算法的思路。

学会了回溯,你就能解决著名的八皇后问题,数学家高斯穷其一生都没有解出八皇后的解,而借助现代计算机和回溯算法,你分分钟就搞定了,当然,N 皇后也不在话下。

回溯法(back tracking)是一种选优搜素算法,又称为试探法,按选优条件向前搜索,以达到目标。但当探索到某一步时,发现达不到目标,就退回一步重新选择,这种走不通就退回再走的技术为回溯法,当然,回溯也是暴力搜索法中的一种。

昨天看到一道回溯算法题目,非常烧脑,不过我很喜欢这种感觉,程序员应该定期刷一刷算法题,只有刷算法题目的时候,我才觉得那是真正意义上的编程,平时的工作在多数情况下,都是熟练调用编程语言或框架的 API 而已。

这道题目是 leetcode 第 93 题,难度为中等,让我们根据一个包含数字的字符串,复原它所有可能的 IP 地址。具体如下:

给定一个只包含数字的字符串,复原它并返回所有可能的 IP 地址格式。

有效的 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 '.' 分隔。

例如:"0.1.2.201" 和 "192.168.1.1" 是有效的 IP 地址,但是 "0.011.255.245"、"192.168.1.312" 是无效的 IP 地址。

示例 1:

输入:s = "25525511135"
输出:["255.255.11.135","255.255.111.35"]

示例 2:

输入:s = "0000"
输出:["0.0.0.0"]

示例 3:

输入:s = "1111"
输出:["1.1.1.1"]

示例 4:

输入:s = "010010"
输出:["0.10.0.10","0.100.1.0"]

示例 5:

输入:s = "101023"
输出:["1.0.10.23","1.0.102.3","10.1.0.23","10.10.2.3","101.0.2.3"]

来源:力扣(LeetCode) 链接:https://leetcode-cn.com/problems/restore-ip-addresses

暴力穷举

这个题目,我相信你大脑里最先想到的就是找三个点来分隔出 4 个字符串,然后判断分隔出的 4 个字符串是否满足 ip 某一段的要求,假如 4 个字符串都在 0 到 255 之间并且没有前导的零,那就是一个合法的 ip 地址。

但是三个点号的位置不太容易穷举,4 个字符串的长度倒是好穷举的,每个字符串的长度至少是 1,至多是 3,只有 3 种可能,,因此可以穷尽 4 个字符串的所有长度,也就是 3 的 4 次方 81 种可能。

如果 4 个字符串的长度加起来等于给定字符串的长度时,就可以按长度分隔,然后分别进行判断了。能想到这一点,就不难写出如下代码:

class Solution(object):
    def restoreIpAddresses(self, s):
        """
        :type s: str
        :rtype: List[str]
        """

        result = []
        for a in range(1,4):
            for b in range(1,4):
                for c in range(1,4):
                    for d in range(1,4):
                        if a+b+c+d == len(s):
                            s1 = s[0:a]
                            s2 = s[a:a+b]
                            s3 = s[a+b:a+b+c]
                            s4 = s[-d:]
                            if self.isValid(s1) \
                            and self.isValid(s2) \
                            and self.isValid(s3) \
                            and self.isValid(s4):
                                result.append("{}.{}.{}.{}".format(s1,s2,s3,s4))
        return result;
                            
    def isValid(self,s_sub):
        if len(s_sub) > 1 and s_sub.startswith('0'):
            return False
        if int(s_sub) <= 255#全部都由数字组成
            return True
        return False

但这种方法非常易懂,但是却不够通用,无法举一反三,比如说题目改成 ipv6 的地址,这种方法就不太合适了。

回溯思想

接下来我们尝试一下回溯的思路。

比如 25525511135,先来确定 ip 的第一段,第一段最多有 3 种可能:2,25,255,这里可以使用一个小循环。假如先选择 2 做为 ip 的第一段,2 小于等于 255,满足要求。

接下来确定 ip 的第二段,也就是说对剩余的字符串 5525511135 重复上述操作,同样的,最多有 3 种可能:5,55,525。假如这里选择 5, 5 是小于等于 255 的,因此满足条件。

接下来确定 ip 的第三段。

接下来确定 ip 的第四段。

每一段的选择,都是同样的操作。这就很像是一个决策树,每做一次选择,都是沿着树的某一分支走到叶子节点的过程,我这里使用脑图来展示一下这个决策树。

上图中除了叶子节点,其他节点都是 3 个子节点,某些我没有画出,希望不影响你理解。

每一层的检索都是一个递归操作,递归的退出条件就是到第 5 层结束,第 5 层结束后如果没有剩余字符串,说明找到了一个正确的 IP 地址,保存到结果集里即可。

不可避免地需要遍历这棵决策树的每个节点,比如 2.5.5.2,本质就是多叉树的遍历操作,这也就是回溯思想的核心。写代码时我们可以把多叉树的遍历骨架写出来:


def backtrace(root :str) -> None:
    """
    有 3 个子节点的多叉树的中序遍历。
    """

    if 满足退出条件:
        if 满足要求:
            加入结果集
        退出

    for i in range(0,3):
        if i < len(root): #索引不能超过字符串的长度
            #选择 root[0:i+1]
            #具体做法就是 tmp_list.append(root[0:i+1])
            backtrace(root[i+1:])
            #撤销选择 root[0:i+1] 
            #具体做法就是 tmp_list.pop()

进入下一轮决策(递归)之前,先做选择,把当前 ip 段加入路径 tmp_list 中,决策(递归)完成后,再撤销选择。

这里有人可能不太理解,为什么需要撤销选择?其实不难理解,看上图决策树的最左边的分支,当遍历到 2.5.5.2 发现不合适的时候,需要回溯到 2.5.5,然后选择 25,也就是说最后的 2 加入 tmp_list 之后,判断不合适,递归返回之后,我们需要把 2 删除,然后腾出空间放 25,这也是为什么叫回溯算法的原因,遇到不符合目标的,就回头重新选择。当然了,遇到合适的,也要重新选择,是因为我们要选出所有合法的 ip 地址。

接下来,为这个骨架填充一点血肉。遍历了每个节点,需要把这些节点的顺序保存下来,这里使用一个 tmp_list 来保存,为了编写退出条件,还需要一个变量指示现在是第几层,为了返回最终结果,再传入一个 result 的数组来保存。

def backtrace(root :str, tmp_list:list, levle: int, result:list ) -> None:
    """
    有 3 个子节点的多叉树的中序遍历。
    tmp_list 保存遍历的路径,比如 2.5.5.2
    level 表示现在是第几层,初始调用时传入 1
    """

    ##剩余字符串为空,或者遍历到第 5 层,终止递归。
    if len(root) == 0 or level == 5:
        ##同时满足时,说明已经找到了合法的ip
        if len(root) == 0 and level == 5:
           result.append(".".join(tmp_list)) 
        return 

    for i in range(0,3):
        if i < len(root): #索引不能超过字符串的长度
            #选择 root[0:i+1]
            part = root[0:i+1]
            if isValid(part):
                #合法的部分,才去递归
                #加入选择
                tmp_list.append(part)
                backtrace(root[i+1:],tmp_list,level+1,result)
                #撤销选择
                tmp_list.pop()
            else:
                pass

组装一下,以下是完整代码,可直接在 leetcode 运行的,提交后看看结果:

class Solution(object):

    def restoreIpAddresses(self, s):
        """
        :type s: str
        :rtype: List[str]
        """

        if len(s) < 4:
            return []
        result = []
        tmp_list = []
        self.backtrace(s,tmp_list,1,result);
        return result

    def backtrace(self, root:str , tmp_list:list, level:int, result:list) -> None:
        """
        有 3 个子节点的多叉树的中序遍历。
        tmp_list 保存遍历的路径,比如 2.5.5.2
        level 表示现在是第几层,初始调用时传入 1
        """


        if len(root) == 0 or level == 5:
            if len(root) == 0 and level == 5:
                result.append(".".join(tmp_list)) 
            return 

        for i in range(0,3):
            if i < len(root): #索引不能超过字符串的长度
                #选择 root[0:i+1]
                part = root[0:i+1]
                if self.isValid(part):
                    #合法的部分,才去递归
                    #加入选择
                    tmp_list.append(part)
                    self.backtrace(root[i+1:],tmp_list,level+1,result)
                    #撤销选择
                    tmp_list.pop()
                else:
                    pass

    
    def isValid(self, sub_s : str) -> bool:
        if len(sub_s) > 1 and sub_s.startswith('0'):
            return False
        if 0 <= int(sub_s) <= 255:
            return True
        return False

运行结果如下:

心心苦苦搞了半天,看来还没有第一段暴力解法来得快,别灰心,一定有什么可以优化的地方,其实,只要有某一段 ip 的长度大于 1,且是 0 开头的时候,后面就不需要向下递归了,可以提升点效率。

比如:输入:s = "010010",当 "01"做为第一段时就可以 break 跳出循环了。

优化一下 backtrace 函数和 isValid 函数:

    def backtrace(self, root:str , tmp_list:list, level:int, result:list) -> None:
        """
        有 3 个子节点的多叉树的中序遍历。
        tmp_list 保存遍历的路径,比如 2.5.5.2
        level 表示现在是第几层,初始调用时传入 1
        """


        if len(root) == 0 or level == 5:
            if len(root) == 0 and level == 5:
                result.append(".".join(tmp_list)) 
            return 

        for i in range(0,3):
            if i < len(root): #索引不能超过字符串的长度
                #选择 root[0:i+1]
                part = root[0:i+1]

                ##如果某段以0开头,且长度超过 1 ,那么跳出循环,提升效率
                if part.startswith('0'and len(part)>1:
                    break;

                if self.isValid(part):
                    #合法的部分,才去递归
                    #加入选择
                    tmp_list.append(part)
                    self.backtrace(root[i+1:],tmp_list,level+1,result)
                    #撤销选择
                    tmp_list.pop()
                else:
                    pass

    
    def isValid(self, sub_s : str) -> bool:
        # if len(sub_s) > 1 and sub_s.startswith('0'):
        #     return False
        if int(sub_s) <= 255:
            return True
        return False

一个小小的优化,再次提交看结果,确实提升了不少:

由于 leetcode 同时有很多人使用,因此不同的时间提交,服务器的计算压力是不同的,得出的结果会有少量的差异,这个理解就好。

到这里不知道你是否理解了回溯算法的思路。如果有不理解的地方,请在文末留言交流。

最后的话

其实不管多么复杂的算法,归根结底都逃离不开最基本的循环语句、if、else 的组合,再高级一点的,就是与栈、队列、递归的组合应用。

本文提到的回溯算法,本质就是暴力遍历多叉树(本题是 3 叉树)求解,先确定决策树,写出多叉树的遍历框架,然后填充内容。不要忘记在递归完成后撤销选择。如果还有点不理解,这里我提个问题:

请问二叉树前、中、后序遍历的区别是什么,你可能会说不就是访问根节点的顺序不同么,先访问根节点就是前序遍历....

其实这样的回答是错的,无论哪一种遍历,都是要先访问根节点的,不访问根节点,你怎么可能访问得到子节点?

真正的区别在于对根节点的处理是放在进入子节点的递归调用之前,还是在递归调用之后。前序遍历的代码在进⼊某⼀个节点之前的那个时间点执⾏,后序遍历代码在离开某个节点之后的那个时间点执⾏,如下图所示:

def trace(root):
    ##前序
    trace(root.left)
    ##中序
    trace(root.right)
    ##后序

因此后序遍历之后,需要撤销选择的 child,加入新的 child 进行遍历。

PS:如果你也在刷 Leetcode,我这里有一份从 Leetcode 中精选大概 200 左右的题目,去除了某些繁杂但是没有多少算法思想的题目,同时保留了面试中经常被问到的经典题目,对本号发消息回复「算法」即可获取,让你更高效地刷力扣。

如果觉得本文对你有用,请点赞在看转发支持,感谢老铁。

留言讨论


浏览 55
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报